Giter VIP home page Giter VIP logo

mnist_jax_w_hparams_gui's Introduction

MLP classifier on the MNIST dataset implemented in JAX with a GUI for entering hyperparameters.

Summary

The python script mnist_main.py (based on JAX) trains a model of your choice for classification on the MNIST dataset, with hyperparameters (batch size, learning rate, number of epochs, seed, optimizer) that you are prompted to specify through a GUI. Right now, the module models.py only contains a small MLP, to add other models you simply need to make the necessary additions to this module. The dataset is loaded using torch. The training data is logged in TensorBoard using PyTorch. The relevant logdir is results/the_name_of_the_model_that_you_train. The different runs are logged using the time stamp from the run and the value of the hyperparameters. On the TensorBoard interface, you can go to the HPARAMS tab, and the different runs using different values of hyperparameters will be shown. For each of these, the plots showing the evolution of training loss, training accuracy, test loss, and test accuracy during the training process are shown. This is not a default feature offered by the Pytorch TensorBoard SummaryWriter class but I have subclassed it and added a variation of the add_hparams function to provide this feature.

Detailed instructions

Create the necessary conda environment using:

conda cenv create -f utils/create_env.yml

Install the necessary dependencies using:

pip install -r utils/requirements.txt

Install JAX:

pip install --upgrade "jax[cpu]"             # use the relevant instructions for the gpu version if you need that

Adding models to models.py: The models in models.py are to be implemented in JAX using the Equinox library. The model that is already implemented in the module models.py is called 'mlp_small'. To add a new model, simple create a class for that model. It is recommended that you add an attribute called descript with a small description of the model. Then add this model_name: model_class pair key:value pair to the function model_names_dict(key, model_name) in this module. This will enable mnist_main.py to import the chosen model.

For now let's say we want to train the model named 'mlp_small'.

To train: you simply run python mnist_main.py from within the correct conda env. This will open a GUI for entering hyperparameters.

Entering hyperparameters: The GUI is based on tkinter and is implemented in input_gui.py. The box will have six fields, it's self-explanatory and the values to enter are the model name (as a string), batch size, learning rate, number of epochs to train, seed, and the optimizer name (as a string, like 'adam'). You should use one of the optimizers that are implemented in the Optax library. Then you press Save and then press Done. The hyperparameter values are written into parameters/gui_parameters_temp.py (temporarily, until they're overwritten) from where they are read off by the main script. This should start the training process and you should see a progress bar (implemented through tkdm) denoting progress through the steps for each epoch. At the end of the each epoch the training loss, training accuracy, test loss, and test accuracy will be displayed.

Results: The results should be viewed through TensorBoard. The logdir into which the SummaryWriter logs the data is results/model_name. So for mlp_small it's results/mlp_small. To get the link to the interface enter tensorboard --logdir results/mlp_small. Follow the link to open the interface. Go to the HPARAMS tab and you'll see a list of runs with different hyperparameters. Click on the desired one to view the training plots of training loss, training accuracy, test loss, and test accuracy.

This feature is not offered by default through the SummaryWriter class of PyTorch tensorboard. But I have subclassed it in SummaryWriter_mod.py and then added a function called add_hparams_plot that allows me to offer this visualization.

License and copyright

© Sourav Sarkar

Licensed under the MIT License

mnist_jax_w_hparams_gui's People

Contributors

atsourav avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.