Giter VIP home page Giter VIP logo

adamthroughasecondorderlens's Introduction

This repository contains code for the paper Studying K-FAC Heuristics by Viewing Adam through a Second-Order Lens, published at ICML 2024.

Installation

Our complete development environment under Python 3.10 is specified in local_requirements.txt, with a list of top-level requirements given in Pipfile. In theory, pipenv install in a fresh virtual environment will set everything up; in practice, JAX in particular may need manual intervention depending on your local CUDA and cuDNN versions.

At the time of writing, we depend on a bugfix to the KFAC-JAX library, which is specified in kfac_jax.patch. This can be applied from the project root with

$ patch -p0 -i kfac_jax.patch

Datasets are not bundled with the repository, so before first use they will need to be downloaded by calling the constructors with download=True.

Running

Each dataset and algorithm is specified by a YAML configuration file in configs/, where AdamQLR_Damped_NoLRClipping.yaml is the AdamQLR (Tuned) algorithm described in our paper, and AdamQLR_NoHPO_NoLRClipping.yaml is the AdamQLR (Untuned) setting. To perform a single training run, simply pass the corresponding files to train.py with the -c flag, e.g.:

$ python train.py -c ./configs/fashion_mnist.yaml ./configs/AdamQLR_Damped_NoLRClipping.yaml

A complete hyperparameter optimisation routine, including 50 repetitions of the best hyperparameters found, can be performed by calling hyperparameter_optimisation.py with the corresponding configuration files:

$ python hyperparameter_optimisation.py -c ./configs/fashion_mnist.yaml ./configs/AdamQLR_Damped_NoLRClipping.yaml ./configs/ASHA.yaml

This same file also contains helper functions for running sensitivity studies. Hyperparameter optimisation runs based on overall runtime rather than number of epochs may be performed by substituting ./configs/ASHA_time_training.yaml or ./configs/ASHA_time_validation.yaml in place of ./configs/ASHA.yaml.

To replicate all our experimental results, the various run_*.sh scripts may be useful.

Analysis

Logs are produced by Tensorboard in a runs/ directory by default; the paths can be changed with the config/command-line flag --log-root.

All our experimental plots are produced using paper_plots.py, though you may need to update the paths to match your local configuration.

adamthroughasecondorderlens's People

Stargazers

Anna Smirnova avatar Less Wright avatar

Watchers

Kostas Georgiou avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.