Go to file
2019-04-01 22:24:56 +08:00
configs update configs 2019-04-01 01:23:31 +08:00
data update logs 2019-04-01 22:24:56 +08:00
exps-cnn update scripts 2019-04-01 00:19:43 +08:00
exps-rnn update scripts 2019-03-30 02:10:20 +08:00
lib update scripts-cluster 2019-03-31 22:49:43 +08:00
scripts-cluster update scripts 2019-04-01 00:19:43 +08:00
scripts-cnn update de-compress and scripts 2019-04-01 21:12:50 +08:00
scripts-rnn update scripts 2019-03-30 02:10:20 +08:00
.gitignore update scripts-cluster 2019-03-31 22:49:43 +08:00
LICENSE init 2019-02-01 01:27:38 +11:00
README.md update logs 2019-04-01 22:24:56 +08:00

Searching for A Robust Neural Architecture in Four GPU Hours

We propose A Gradient-based neural architecture search approach using Differentiable Architecture Sampler (GDAS).

Requirements

  • PyTorch 1.0.1
  • Python 3.6
  • opencv
conda install pytorch torchvision cuda100 -c pytorch

Usages

Train the searched CNN on CIFAR

CUDA_VISIBLE_DEVICES=0 bash ./scripts-cnn/train-cifar.sh GDAS_FG cifar10  cut
CUDA_VISIBLE_DEVICES=0 bash ./scripts-cnn/train-cifar.sh GDAS_F1 cifar10  cut
CUDA_VISIBLE_DEVICES=0 bash ./scripts-cnn/train-cifar.sh GDAS_V1 cifar100 cut

Train the searched CNN on ImageNet

CUDA_VISIBLE_DEVICES=0 bash ./scripts-cnn/train-imagenet.sh GDAS_F1 52 14
CUDA_VISIBLE_DEVICES=0 bash ./scripts-cnn/train-imagenet.sh GDAS_V1 50 14

Evaluate a trained CNN model

CUDA_VISIBLE_DEVICES=0 python ./exps-cnn/evaluate.py --data_path  $TORCH_HOME/cifar.python --checkpoint ${checkpoint-path}
CUDA_VISIBLE_DEVICES=0 python ./exps-cnn/evaluate.py --data_path  $TORCH_HOME/ILSVRC2012 --checkpoint ${checkpoint-path}

Train the searched RNN

CUDA_VISIBLE_DEVICES=0 bash ./scripts-rnn/train-PTB.sh DARTS_V1
CUDA_VISIBLE_DEVICES=0 bash ./scripts-rnn/train-PTB.sh DARTS_V2
CUDA_VISIBLE_DEVICES=0 bash ./scripts-rnn/train-PTB.sh GDAS
CUDA_VISIBLE_DEVICES=0 bash ./scripts-rnn/train-WT2.sh DARTS_V1
CUDA_VISIBLE_DEVICES=0 bash ./scripts-rnn/train-WT2.sh DARTS_V2
CUDA_VISIBLE_DEVICES=0 bash ./scripts-rnn/train-WT2.sh GDAS

Training Logs

Some training logs can be found in ./data/logs/, and some pre-trained models can be found in Google Driver.

Citation

@inproceedings{dong2019search,
  title={Searching for A Robust Neural Architecture in Four GPU Hours},
  author={Dong, Xuanyi and Yang, Yi},
  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2019}
}