@Article{berrada2019deep,
author = {Berrada, Leonard and Zisserman, Andrew and Kumar, M Pawan},
title = {Deep Frank-Wolfe For Neural Network Optimization},
journal = {International Conference on Learning Representations},
year = {2019},
}
Requirements
This code should work for pytorch >= 1.0 in python3. Detailed requirements are available in requirements.txt.
Installation
Clone this repository: git clone --recursive https://github.com/oval-group/dfw (note that the option recursive is necessary to have clone the submodules, these are needed to reproduce the experiments but not for the DFW implementation itself).
Go to directory and install the requirements: cd dfw && pip install -r requirements.txt
Install the DFW package python setup.py install
Example of Usage
Simple usage example:
fromdfwimportDFWfromdfw.lossesimportMultiClassHingeLoss# boilerplate code:# `model` is a nn.Module# `x` is an input sample, `y` is a label# create loss functionsvm=MultiClassHingeLoss()
# create DFW optimizer with learning rate of 0.1optimizer=DFW(model.parameters(), eta=0.1)
# DFW can be used with standard pytorch syntaxoptimizer.zero_grad()
loss=svm(model(x), y)
loss.backward()
# NB: DFW needs to have access to the current loss value,# (this syntax is compatible with standard pytorch optimizers too)optimizer.step(lambda: float(loss))
Technical requirement: the DFW uses a custom step-size at each step. For this update to make sense, the loss function must be piecewise linear convex.
For instance, one can use a multi-class SVM loss or an l1 regression.
Smoothing: sometimes the multi-class SVM loss does not fare well with a large number of classes.
This issue can be alleviated by using dual smoothing, which is easy to plug in the code:
To reproduce the CIFAR experiments: VISION_DATA=[path/to/your/cifar/data] python reproduce/cifar.py
To reproduce the SNLI experiments: follow the preparation instructions and run python reproduce/snli.py
DFW largely outperforms all baselines that do not use a manual schedule for the learning rate.
The tables below show the performance on the CIFAR data sets when using data augmentation (AMSGrad, a variant of Adam, is the strongest baseline in our experiments), and on the SNLI data set.