Giter VIP home page Giter VIP logo

Comments (5)

Optimox avatar Optimox commented on July 26, 2024 2

Hi @mysterefrank ,

TabNet is indeed an order of magnitude slower than XGBoost (edit: this is especially true for regression and binary classificaion, if you have a large number of classes in multi class classification you may actually see TabNet faster than XGBoost).
I think it's difficult to compare TabNet to an MLP, as steps are sequential so if you have n_steps=3, you pass through 3 MLPs with attention, so I guess a TabNet architecture is by design way bigger than a classical MLP.

It's pretty rare I guess to have a terabyte of clean data to train on for tabular data, I guess with terabytes of data TabNet is going to be slow. One way would be to tweak the dataloaders so that one epoch does not go into the full data but only a limited number of random samples. Or you might try to train on a subsample you chose. Also, do you have one terabyte of RAM? Because this implementation does not have an external file reading system like hdf5 so if you are switching to SWAP this could slow down everything.

About the GPU, TabNet is using large batches so the large GPU usage is a good thing (personally when training a neural network in GPU I'm happy with 100% GPU usage). If you have memory issues you can reduce the batch size for the training.

About speed ups, I see some possible ways of speeding a training (possibly impacting the final score):

  • augment your batch size (will be more heavy on the GPU)
  • match virtual_batch_size with batch_size, virtual_batch_size is for ghost batch norm, setting virtual_batch_size=batch_size will distable ghost batch norm but should speed up training
  • lower n_steps
  • lower n_independent or n_share
  • decrease your embedding sizes
  • EDIT: this would require to change the source code (we should give this as parameter) but increasing num_workers in the DataLoader might speed up the training. I'll create an issue about this.

On the code side I think we could try to speed things up on ghost batch norm and the way we deal with embeddings, this should be improvement for later.

Hope this answer helps you, please let me know if you need more information.

Cheers!

from tabnet.

Optimox avatar Optimox commented on July 26, 2024

Hey @mysterefrank,

I hope the changes that we made will be useful to you (will make a release soon).
I'm closing this as it seems that we answered your question, feel free to reopen it if you like to discuss more about this.

Cheers

from tabnet.

mysterefrank avatar mysterefrank commented on July 26, 2024

Thanks for the detailed reply!

from tabnet.

mdancho84 avatar mdancho84 commented on July 26, 2024

Hey gang, I'm getting a much greater speedup by increasing the batch_size and the virtual_batch_size. Doubling my batch sizes roughly costs me 50% worse with the 1st epoch, but only 10% worse with the second epoch. Obviously this is one sample, but seems to be a big improvement in speed with larger data sets.

from tabnet.

Optimox avatar Optimox commented on July 26, 2024

hello @mdancho84,

The orginal paper recommends to go as high as 10% of your training data for choosing the batch size.
But the training time for one epoch is not necessarily the good way to monitor the speed, since you might need more epochs to reach the same results.

It's mostly a trials and errors to choose your number of epochs and batch size, then you might stick to it and tune the other parameters

from tabnet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.