Giter VIP home page Giter VIP logo

vita-group / tegnas Goto Github PK

View Code? Open in Web Editor NEW
26.0 8.0 4.0 256 KB

"Understanding and Accelerating Neural Architecture Search with Training-Free and Theory-Grounded Metrics" by Wuyang Chen, Xinyu Gong, Yunchao Wei, Humphrey Shi, Zhicheng Yan, Yi Yang, and Zhangyang Wang

License: MIT License

Python 100.00%
neural-architecture-search ntk nas-bench-201 darts-space linear-region generalization reinforcement-learning evolutionary-algorithms

tegnas's Introduction

Understanding and Accelerating Neural Architecture Search with Training-Free and Theory-Grounded Metrics [PDF]

MIT licensed

Wuyang Chen*, Xinyu Gong*, Yunchao Wei, Humphrey Shi, Zhicheng Yan, Yi Yang, and Zhangyang Wang

Note

  1. This repo is still under development. Scripts are excutable but some CUDA errors may occur.
  2. Due to IP issue, we can only release the code for NAS via reinforcement learning and evolution, but not FP-NAS.

Overview

We present TEG-NAS, a generalized training-free neural architecture search method that can significantly reduce time cost of popular search methods (no gradient descent at all!) with high-quality performance.

Highlights:

  • Trainig-free NAS: for three popular NAS methods (Reinforcement Learning, Evolution, Differentiable), we adopt our TEG-NAS method into them and achieved extreme fast neural architecture search without a single gradient descent.
  • Bridging the theory-application gap: We identified three training-free indicators to rank the quality of deep networks: the condition number of their NTKs ("Trainability"), and the number of linear regions in their input space ("Expressivity"), and the error of NTK kernel regression ("Generalization").

Prerequisites

  • Ubuntu 16.04
  • Python 3.6.9
  • CUDA 11.0 (lower versions may work but were not tested)
  • NVIDIA GPU + CuDNN v7.6

This repository has been tested on GTX 1080Ti. Configurations may need to be changed on different platforms.

Installation

  • Clone this repo:
git clone https://github.com/chenwydj/TEGNAS.git
cd TEGNAS
  • Install dependencies:
pip install -r requirements.txt

Usage

0. Prepare the dataset

  • Please follow the guideline here to prepare the CIFAR-10/100 and ImageNet dataset, and also the NAS-Bench-201 database.
  • Remember to properly set the TORCH_HOME and data_paths in the prune_launch.py.

1. Search

Reinforcement Learning
python reinforce_launch.py --space nas-bench-201 --dataset cifar10 --gpu 0
python reinforce_launch.py --space nas-bench-201 --dataset cifar100 --gpu 0
python reinforce_launch.py --space nas-bench-201 --dataset ImageNet16-120 --gpu 0
Evolution
python evolution_launch.py --space nas-bench-201 --dataset cifar10 --gpu 0
python evolution_launch.py --space nas-bench-201 --dataset cifar100 --gpu 0
python evolution_launch.py --space nas-bench-201 --dataset ImageNet16-120 --gpu 0
Reinforcement Learning
python reinforce_launch.py --space darts --dataset cifar10 --gpu 0
python reinforce_launch.py --space darts --dataset imagenet-1k --gpu 0
Evolution
python evolution_launch.py --space darts --dataset cifar10 --gpu 0
python evolution_launch.py --space darts --dataset imagenet-1k --gpu 0

2. Evaluation

  • For architectures searched on nas-bench-201, the accuracies are immediately available at the end of search (from the console output).
  • For architectures searched on darts, please use DARTS_evaluation for training the searched architecture from scratch and evaluation. Genotypes of our searched architectures are listed in genotypes.py

Citation

@inproceedings{chen2021tegnas,
  title={Understanding and Accelerating Neural Architecture Search with Training-Free and Theory-Grounded Metrics},
  author={Chen, Wuyang and Gong, Xinyu and Wei, Yunchao and Shi, Humphrey and Yan, Zhicheng and Yang, Yi and Wang, Zhangyang},
  year={2021}
}

Acknowledgement

tegnas's People

Contributors

chenwydj avatar gongxinyuu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tegnas's Issues

Update the repo?

Hi,

I notice you released the paper for this repo at https://arxiv.org/pdf/2108.11939.pdf, congrats on the acceptance to TPAMI!

The README says it's still under development (from last month). I wonder if you could update the repo so that people can follow/reproduce your work more easily and smoothly. Thanks!

I randomly hit a CUDA error when running your code

The error is:

RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

Strangely, if I simply restart the search, I am able to process networks for which previously I hit the CUDA error.

Is there some sort of garbage collection that is missing?

The number of linear regions always equals batch_size*sample_batch ?

I have used your code to test the number of linear regions of several existing models, such as ResNet18, ResNet50, ResNet101, and so on. I find the resulting number of linear regions always equals batch_size*sample_batch.
Is this normal? If it is normal, I wonder whether the measurement of linear region numbers is meaningful.

Tried running your evolutionary code and hit an error

I cloned your repo and ran
python evolution_launch.py --space darts --dataset cifar10 --gpu 0

I got the following error:

/workspace/TEGNAS/lib/procedures/ntk.py:60: UserWarning: torch.symeig is deprecated in favor of torch.linalg.eigh and will be removed in a future PyTorch release.
The default behavior has changed from using the upper triangular portion of the matrix by default to using the lower triangular portion.
L, _ = torch.symeig(A, upper=upper)
should be replaced with
L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')
and
L, V = torch.symeig(A, eigenvectors=True)
should be replaced with
L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L') (Triggered internally at /opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/BatchLinearAlgebra.cpp:2499.)
eigenvalues, _ = torch.symeig(ntk) # ascending
Traceback (most recent call last):
File "./R_EA.py", line 458, in
main(args, nas_bench)
File "./R_EA.py", line 408, in main
model.accuracy, _, _time_cost_training = proxy_inference(xargs, model.arch, nas_bench, logger, -1, dataname, te_reward_generator)
File "./R_EA.py", line 224, in proxy_inference
_ = te_reward_generator.step(genotype2mask_darts(arch))
File "/workspace/TEGNAS/lib/procedures/te_reward_generator.py", line 244, in step
results = self.get_ntk_region_mse(self._xargs, arch_parameters, self._loader, self._region_model)
File "/workspace/TEGNAS/lib/procedures/te_reward_generator.py", line 192, in get_ntk_region_mse
LRs = region_model.forward_batch_sample()
File "/workspace/TEGNAS/lib/procedures/linear_region_counter.py", line 266, in forward_batch_sample
return [LRCount.getLinearReginCount() for LRCount in self.LRCounts]
File "/workspace/TEGNAS/lib/procedures/linear_region_counter.py", line 266, in
return [LRCount.getLinearReginCount() for LRCount in self.LRCounts]
File "/workspace/TEGNAS/lib/procedures/linear_region_counter.py", line 169, in getLinearReginCount
self.calc_LR()
File "/opt/conda/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
return func(*args, **kwargs)
File "/workspace/TEGNAS/lib/procedures/linear_region_counter.py", line 145, in calc_LR
res += res.T
RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.