Comments (18)
Do you observe the same pattern with XGBoost or any other ML model ? If so, this is data related and not model related.
from tabnet.
my learning_rate strategy is optimizer_params = dict(lr=1e-1), scheduler_params = dict(T_0=100, T_mult=1, eta_min=1e-2),scheduler_fn=CosineAnnealingWarmRestarts with Adam optimizer function, patience=10
from tabnet.
Your learning rate is probably too high, also start with a simple learning rate decay like OneCycleLR
from tabnet.
Your learning rate is probably too high, also start with a simple learning rate decay like OneCycleLR
having tried smaller initial learning rate like 2e-2, 5e-2, the loss for training data didn't even decrease
from tabnet.
Maybe worth having a look at the explanation matrices to check if some of the features are causing the overfit, for example some sort of index that has not been dropped. Without more details about the data it's probably going to be hard to diagnose exactly what might be happening.
Other than that the batch size strikes me as pretty large. Maybe that's the reason you have to use such a high learning rate.
from tabnet.
Maybe worth having a look at the explanation matrices to check if some of the features are causing the overfit, for example some sort of index that has not been dropped. Without more details about the data it's probably going to be hard to diagnose exactly what might be happening.
Other than that the batch size strikes me as pretty large. Maybe that's the reason you have to use such a high learning rate.
That's a possibility that some of the features cause the overfit indeed, but I've already configured lambda_sparse and gamma for regularization. Regarding the batch_size, I followed the original article's recommendation, setting it between 1% to 10% of the training set. Should I reduce it?
from tabnet.
It is true that they use large batch sizes, up to 16K in the paper. The virtual batch size is always much smaller though, at 512 max.
from tabnet.
Do you observe the same pattern with XGBoost or any other ML model ? If so, this is data related and not model related.
lgbm performs much better with the same loss and evaluation metric
from tabnet.
It is also worth mentioning that the training is extremely slow, around 9-10min per epoch. Any advice on this?
from tabnet.
Do you have a GPU ?
from tabnet.
what happens with batch size = 2048, virtual batch size = 256 ?
from tabnet.
Do you have a GPU ?
Yes, training using Nvidia 3090. I haven't try batch_size smaller than 16384. Is training speed and learning rate strategy related to batch size? If i use a smaller batch size, should i lower the learning rate correspondingly? Thank you!
from tabnet.
Training speed is directly proportional to your batch size as long as 1) your gpu is not already reaching 100% usage 2) your cpu is NOT the bottleneck.
After that, larger batch size will make the training slower.
Batch size and learning rate are related in theory yes. lr=1e-2, batch_size=1024, virtual_batch_size=256 and nothing else specified in the parameters never let me down. If this does not work at all I can't help you more unless you give access to your dataset
from tabnet.
Perhaps what you said above contradicts what you mentioned in this issue? [https://github.com//issues/391#issuecomment-1113099435]. After i reduce batch size to 4096 and virtual bath size to 512, training speed is slower.
from tabnet.
The larger your batch size, the faster your training is, where is the contradiction here ?
from tabnet.
Sorry, misunderstood what you have said.
from tabnet.
what happens with batch size = 2048, virtual batch size = 256 ?
Just tried a few epochs with batch size 4096 and virtual bath size 512. The performance is worse, the overfit is heavier.
from tabnet.
Training speed is directly proportional to your batch size as long as 1) your gpu is not already reaching 100% usage 2) your cpu is NOT the bottleneck. After that, larger batch size will make the training slower.
Batch size and learning rate are related in theory yes. lr=1e-2, batch_size=1024, virtual_batch_size=256 and nothing else specified in the parameters never let me down. If this does not work at all I can't help you more unless you give access to your dataset
Did you tried this setting for large datasets and large number of features?
from tabnet.
Related Issues (20)
- Minimal working example for TabNetRegressor/Classifier HOT 4
- Transfer learning, capability to change structure of model HOT 1
- Generate Embeddings for Tabular Data HOT 1
- TabNet overfits (help wanted, not a bug) HOT 9
- TabNetRegressor vs other networks HOT 1
- spike in memory when training ends HOT 8
- OOM problem when I search hyperparameters with Tabnet HOT 3
- Support for complex-valued datasets HOT 4
- Different classification variables in the test set and train set HOT 1
- Struggling to get model to fit - Help Wanted HOT 7
- Optimizing TabNet for Disease Classification with Continuous Audio Features HOT 1
- Interpreting Sparsity on Global Importance HOT 5
- ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() HOT 1
- Validation loss HOT 1
- Lightweight Fine-tunning or few-shot learning for limited labeled data HOT 1
- Maybe `drop_last` should be set as False in default? HOT 1
- Incompatiblity of current round() method with pytorch tensors when performing early stopping HOT 1
- Retraining a saved model on different dataset HOT 3
- change device seems not work HOT 8
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tabnet.