Giter VIP home page Giter VIP logo

nao's Introduction

Neural Architecture Optimization

This is the Code for the Paper Neural Architecture Optimization.

Authors: Renqian Luo*, Fei Tian*, Tao Qin, En-Hong Chen, Tie-Yan Liu. *=equal contribution

License

The codes and models in this repo are released under the GNU GPLv3 license.

Citation

If you find this work helpful in your research, please use the following BibTex entry to cite our paper.

@inproceedings{NAO,
  title={Neural Architecture Optimization},
  author={Renqian Luo and Fei Tian and Tao Qin and En-Hong Chen and Tie-Yan Liu},
  booktitle={Advances in neural information processing systems},
  year={2018}
}

This is not an official Microsoft product.

Requirment and Dependency

Tensorflow >= 1.4.0

Pytorch == 0.3.1

CIFAR-10

With Weight Sharing

To Search Architectures

To search the CNN architectures for CIFAR-10 with weight sharing, please refer to:

Script Data GPU Search Time
./NAO-WS/cnn/train_search.sh Google Drive Baidu Pan 1 V100 7.5 hours
cd NAO-WS/cnn
bash train_search.sh

Once the search is done, the final pool of architectures will be in models/child/arch_pool. You can choose top-5 architectures to run them using train_final.sh and pass in the arch by setting the fixed_arc argument.

To obtain the best architecture, we perform grid search on the hyper-parameters for the top-5 architectures discovered.

To Train Discovered Architectures

To train a fixed CNN architecture, for example, our best architecture discovered, please refer to:

Script GPU Time Model Checkpoint Parameter Size Error Rate
./NAO-WS/cnn/train_final.sh 1 P40 42 hours Google Drive Baidu Pan 2.5M 3.50

and run:

cd NAO-WS/cnn
bash train_final.sh

If you want to run it with cutout, add --child_cutout_size=16 in the script.

To Directly Evaluate an Architecture

To directly evaluate an architecture, for example, our best architecture discovered, please download the checkpoint above, move all the files to NAO-WS/cnn/models folder and run:

cd NAO-WS/cnn
bash test_final.sh    #This should give you an accuracy of 96.50% (error rate of 3.50%) without cutout

Without Weight Sharing

To Search Architectures

Please refer to details in ./NAO/README.md

To Train Discovered Architectures

Please download data at Google Drive Baidu Pan

You can train the best architecture discovered (show in Fig. 1 in the Appendix of the paper) using:

Dataset Script GPU Time Checkpoint Error Rate (Test)
CIFAR-10 ./NAO/cnn/train_cifar10_final.sh 2 P40 5 days Google Drive Baidu Pan 2.10%
CIFAR-100 ./NAO/cnn/train_cifar100_final.sh 2 P40 5 days Google Drive Baidu Pan 14.80%

by running:

cd NAO/cnn
bash train_cifar10_final.sh
bash train_cifar100_final.sh

To Directly Evaluate an Architecturethe

To directly evaluate an architecture, for example, our best architecture discovered, please download the checkpoint above, move all the files to NAO/cnn/models/cifar10 or NAO/cnn/models/cifar100/ , and run:

cd NAO/cnn
bash test_cifar10.sh     #This should give you an accuracy of 97.94% (error rate of 2.06%)
bash test_cifar100.sh    #This should give you an accuracy of 85.20% (error rate of 14.81%)

PTB

To Search Architectures

To search the RNN architectures for PTB with weight sharing, please refer to:

Script GPU Search Time
./NAO-WS/rnn/train_search.sh 1 V100 8 hours
cd NAO-WS/rnn
bash train_search.sh

Once the search is done, the final pool of architectures will be in models/child/arch_pool. You can choose top-10 architectures to run them using train_final.sh and pass in the arch by setting the arch argument.

To Train Discovered Architectures

To train a fixed RNN architecture, for example, our best architecture discovered, please refer to:

Script Model Checkpoint GPU Time PPL (Test)
./NAO-WS/rnn/train_final.sh Google Drive Baidu Pan 1 V100 4 days 56.80
cd NAO-WS/rnn
bash train_final.sh   #This should give you a test ppl of 56.66 at the end of training

To Directly Evaluate an Architecture

To directly evaluate an architecture, for example, our best architecture discovered, please download the checkpoint above, move all the files to ./NAO/rnn/models folder and run:

cd NAO-WS/cnn
bash test_final.sh    #This should give you a test ppl of 56.66

Without Weight Sharing

To Search Architectures

Please refer to details in NAO/README.md

To Train Discovered Architectures

You can train the best architecture discovered (showin in Fig. 2 in the Appendix of the paper) using:

Dataset Script GPU Time Checkpoint PPL (Test)
PTB ./NAO/rnn/train_ptb_final.sh 1 V100 4 days Google Drive Baidu Pan 56.02
WikiText-2 ./NAO/rnn/train_wt2_final.sh 1 V100 4 days Google Drive Baidu Pan 67.10

To Directly Evaluate an Architecture

To directly evaluate an architecture, for example, our best architecture discovered, please download the checkpoint above, move all the files to NAO/rnn/models/ptb or NAO/rnn/models/wt2 , and run:

cd NAO/rnn
bash test_ptb.sh    #This should give you a test ppl of 56.02
bash test_wt2.sh    #This should give you a test ppl of 67.10

Acknowledgements

We thank Hieu Pham for the discussion on some details of ENAS implementation, and Hanxiao Liu for the code base of language modeling task in DARTS . We furthermore thank the anonymous reviewers for their constructive comments.

nao's People

Contributors

prabhant avatar renqianluo avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.