Giter VIP home page Giter VIP logo

fwdgrad's Introduction

fwdgrad

This repo contains an unofficial implementation of the paper Gradients without backpropagation.

Repository description

In this paper, a way to compute gradients by only using the forward pass is described and 2 use cases are provided:

  1. The global optimization of two famous funtions: Beale and Rosenbrok functions.
  2. The well known classification task on the MNIST dataset.

In our implementation, we used the brand new functorch. Thanks to this library, the model can be treated as a function of its parameters. With this change of perspective, the Jvp of the model can be computed (for a brief explanation of what the Jvp of a function is, check below).

We tested the above implementation on the same examples provided by the paper.

Jvp

The Jvp is the product of the Jacobian J of the model, derived with respect of the parameters, and a vector v called perturbation vector with the same size of the parameters vector. Entries of v are independent from each other and sampled from a Normal distribution with mean 0 and variance 1.

Running the optimization

For running the optimization, simply clone the repository and, after moving to the repo root, use the command python <name_of_the_example.py>.

For both the MNIST examples you can visualize the training progress with tensorboard: tensorboard --logdir outputs

All the examples are implemented both with the new forward gradient method and with backpropagation. To let the two implementations be as similar as possible, we used the rewritten Pytorch operations also for the backpropagation examples (even if, in this case, Pytorch would support those operators).

Running more examples

All the configurations are managed using Hydra. Configuration files are stored in the /configs folder.

For changing the global optimization example, modify the global_optim_config.yaml file.

Changing something in the MNIST example can be easily achieved by adding configuration files in the respective subfolder and changing the config.yaml file. These configurations can be used to change:

  1. The dataset's specific, like batch size or the shuffling parameter (/dataset subfolder).
  2. The neural network architecture (/model subfolder). Here two models are provided:
    1. A Multilayer Perceptron inside the default configuration
    2. A Convolutional NN inside conv.yaml
  3. The optimization process parameters (/optimization subfolder).

Follow Hydra's documentation to know more about how to manage configurations.

Perfomance comparison

Even if in the paper the forward implementation proved to be faster, in our case we did not notice a speed-up. Convergence is still achieved in all the examples with roughly the same amount of steps, but backpropagation's steps are faster than the fwdgrad ones.

The next graph shows the loss on the train dataset obtained with the following

python mnist_<optimization_type>.py dataset.batch_size=1024 optimization.learning_rate=5e-4

where <optimization_type> is one of fwdgrad or backprop

The accuracies on the test set are:

Test accuracy
$${\color{Cyan}\textrm{Backprop}}$$ 0.8839
$${\color{darkorange}\textrm{Fwdgrad}}$$ 0.9091

Changelog

  • 15/04/2023: update functorch using directly torch-2.0.0

License

This project is licensed under the MIT License

Copyright (c) 2022 Federico Belotti, Davide Angioni, Orobix Srl (www.orobix.com).

fwdgrad's People

Contributors

belerico avatar davidetr8 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.