Comments (5)
Hi @mysterefrank ,
TabNet is indeed an order of magnitude slower than XGBoost (edit: this is especially true for regression and binary classificaion, if you have a large number of classes in multi class classification you may actually see TabNet faster than XGBoost).
I think it's difficult to compare TabNet to an MLP, as steps are sequential so if you have n_steps=3, you pass through 3 MLPs with attention, so I guess a TabNet architecture is by design way bigger than a classical MLP.
It's pretty rare I guess to have a terabyte of clean data to train on for tabular data, I guess with terabytes of data TabNet is going to be slow. One way would be to tweak the dataloaders so that one epoch does not go into the full data but only a limited number of random samples. Or you might try to train on a subsample you chose. Also, do you have one terabyte of RAM? Because this implementation does not have an external file reading system like hdf5 so if you are switching to SWAP this could slow down everything.
About the GPU, TabNet is using large batches so the large GPU usage is a good thing (personally when training a neural network in GPU I'm happy with 100% GPU usage). If you have memory issues you can reduce the batch size for the training.
About speed ups, I see some possible ways of speeding a training (possibly impacting the final score):
- augment your batch size (will be more heavy on the GPU)
- match virtual_batch_size with batch_size, virtual_batch_size is for ghost batch norm, setting virtual_batch_size=batch_size will distable ghost batch norm but should speed up training
- lower n_steps
- lower n_independent or n_share
- decrease your embedding sizes
- EDIT: this would require to change the source code (we should give this as parameter) but increasing num_workers in the DataLoader might speed up the training. I'll create an issue about this.
On the code side I think we could try to speed things up on ghost batch norm and the way we deal with embeddings, this should be improvement for later.
Hope this answer helps you, please let me know if you need more information.
Cheers!
from tabnet.
Hey @mysterefrank,
I hope the changes that we made will be useful to you (will make a release soon).
I'm closing this as it seems that we answered your question, feel free to reopen it if you like to discuss more about this.
Cheers
from tabnet.
Thanks for the detailed reply!
from tabnet.
Hey gang, I'm getting a much greater speedup by increasing the batch_size
and the virtual_batch_size
. Doubling my batch sizes roughly costs me 50% worse with the 1st epoch, but only 10% worse with the second epoch. Obviously this is one sample, but seems to be a big improvement in speed with larger data sets.
from tabnet.
hello @mdancho84,
The orginal paper recommends to go as high as 10% of your training data for choosing the batch size.
But the training time for one epoch is not necessarily the good way to monitor the speed, since you might need more epochs to reach the same results.
It's mostly a trials and errors to choose your number of epochs and batch size, then you might stick to it and tune the other parameters
from tabnet.
Related Issues (20)
- The mask tensor M in script tab_network.py needs to be transformed to realize the objective stated in the paper: "γ is a relaxation parameter – when γ = 1, a feature is enforced to be used only at one decision step".
- Current version on conda-forge is 4.0 while 4.1 is already released HOT 8
- Minimal working example for TabNetRegressor/Classifier HOT 4
- Transfer learning, capability to change structure of model HOT 1
- Generate Embeddings for Tabular Data HOT 1
- TabNet overfits (help wanted, not a bug) HOT 9
- TabNetRegressor vs other networks HOT 1
- spike in memory when training ends HOT 8
- Severe overfitting HOT 18
- OOM problem when I search hyperparameters with Tabnet HOT 3
- Support for complex-valued datasets HOT 4
- Different classification variables in the test set and train set HOT 1
- Struggling to get model to fit - Help Wanted HOT 7
- Optimizing TabNet for Disease Classification with Continuous Audio Features HOT 1
- Interpreting Sparsity on Global Importance HOT 5
- ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() HOT 1
- Validation loss HOT 1
- Lightweight Fine-tunning or few-shot learning for limited labeled data HOT 1
- Maybe `drop_last` should be set as False in default? HOT 1
- Incompatiblity of current round() method with pytorch tensors when performing early stopping HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tabnet.