Merge branch 'master' of github.com:D-X-Y/AutoDL-Projects

This commit is contained in:
D-X-Y 2020-10-29 12:35:04 -07:00
commit d58b59a3f3
15 changed files with 1145 additions and 607 deletions

View File

@ -6,3 +6,4 @@
- [2019.01.31] [13e908f] GDAS codes were publicly released. - [2019.01.31] [13e908f] GDAS codes were publicly released.
- [2020.07.01] [a45808b] Upgrade NAS-API to the 2.0 version. - [2020.07.01] [a45808b] Upgrade NAS-API to the 2.0 version.
- [2020.09.16] [7052265] Create NATS-BENCH. - [2020.09.16] [7052265] Create NATS-BENCH.
- [2020.10.15] [446262a] Update NATS-BENCH to version 1.0

View File

@ -61,7 +61,7 @@ At this moment, this project provides the following algorithms and scripts to ru
</tr> </tr>
<tr> <!-- (6-th row) --> <tr> <!-- (6-th row) -->
<td align="center" valign="middle"> NATS-Bench </td> <td align="center" valign="middle"> NATS-Bench </td>
<td align="center" valign="middle"> <a href="https://xuanyidong.com/assets/projects/NATS-Bench"> NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size</a> </td> <td align="center" valign="middle"> <a href="https://xuanyidong.com/assets/projects/NATS-Bench"> NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size</a> </td>
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/master/docs/NATS-Bench.md">NATS-Bench.md</a> </td> <td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/master/docs/NATS-Bench.md">NATS-Bench.md</a> </td>
</tr> </tr>
<tr> <!-- (7-th row) --> <tr> <!-- (7-th row) -->
@ -100,7 +100,7 @@ Some methods use knowledge distillation (KD), which require pre-trained models.
If you find that this project helps your research, please consider citing some of the following papers: If you find that this project helps your research, please consider citing some of the following papers:
``` ```
@article{dong2020nats, @article{dong2020nats,
title={NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size}, title={{NATS-Bench}: Benchmarking NAS Algorithms for Architecture Topology and Size},
author={Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan}, author={Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan},
journal={arXiv preprint arXiv:2009.00437}, journal={arXiv preprint arXiv:2009.00437},
year={2020} year={2020}

View File

@ -61,7 +61,7 @@
</tr> </tr>
<tr> <!-- (6-th row) --> <tr> <!-- (6-th row) -->
<td align="center" valign="middle"> NATS-Bench </td> <td align="center" valign="middle"> NATS-Bench </td>
<td align="center" valign="middle"> <a href="https://xuanyidong.com/assets/projects/NATS-Bench"> NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size</a> </td> <td align="center" valign="middle"> <a href="https://xuanyidong.com/assets/projects/NATS-Bench"> NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size</a> </td>
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/master/docs/NATS-Bench.md">NATS-Bench.md</a> </td> <td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/master/docs/NATS-Bench.md">NATS-Bench.md</a> </td>
</tr> </tr>
<tr> <!-- (7-th row) --> <tr> <!-- (7-th row) -->
@ -99,7 +99,7 @@ Some methods use knowledge distillation (KD), which require pre-trained models.
如果您发现该项目对您的科研或工程有帮助,请考虑引用下列的某些文献: 如果您发现该项目对您的科研或工程有帮助,请考虑引用下列的某些文献:
``` ```
@article{dong2020nats, @article{dong2020nats,
title={NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size}, title={{NATS-Bench}: Benchmarking NAS Algorithms for Architecture Topology and Size},
author={Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan}, author={Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan},
journal={arXiv preprint arXiv:2009.00437}, journal={arXiv preprint arXiv:2009.00437},
year={2020} year={2020}

View File

@ -1,5 +1,7 @@
# [NAS-BENCH-201: Extending the Scope of Reproducible Neural Architecture Search](https://openreview.net/forum?id=HJxyZkBKDr) # [NAS-BENCH-201: Extending the Scope of Reproducible Neural Architecture Search](https://openreview.net/forum?id=HJxyZkBKDr)
**Since our NAS-BENCH-201 has been extended to NATS-Bench, this `README` is deprecated and not maintained. Please use [NATS-Bench](https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NATS-Bench.md), which has 5x more architecture information and faster API than NAS-BENCH-201.**
We propose an algorithm-agnostic NAS benchmark (NAS-Bench-201) with a fixed search space, which provides a unified benchmark for almost any up-to-date NAS algorithms. We propose an algorithm-agnostic NAS benchmark (NAS-Bench-201) with a fixed search space, which provides a unified benchmark for almost any up-to-date NAS algorithms.
The design of our search space is inspired by that used in the most popular cell-based searching algorithms, where a cell is represented as a directed acyclic graph. The design of our search space is inspired by that used in the most popular cell-based searching algorithms, where a cell is represented as a directed acyclic graph.
Each edge here is associated with an operation selected from a predefined operation set. Each edge here is associated with an operation selected from a predefined operation set.
@ -70,17 +72,18 @@ api.show(2)
# show the mean loss and accuracy of an architecture # show the mean loss and accuracy of an architecture
info = api.query_meta_info_by_index(1) # This is an instance of `ArchResults` info = api.query_meta_info_by_index(1) # This is an instance of `ArchResults`
res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys
cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency cost_metrics = info.get_compute_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
# get the detailed information # get the detailed information
results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed
print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1])) print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1]))
print ('Latency : {:}'.format(results[0].get_latency())) for seed, result in results.items():
print ('Train Info : {:}'.format(results[0].get_train())) print ('Latency : {:}'.format(result.get_latency()))
print ('Valid Info : {:}'.format(results[0].get_eval('x-valid'))) print ('Train Info : {:}'.format(result.get_train()))
print ('Test Info : {:}'.format(results[0].get_eval('x-test'))) print ('Valid Info : {:}'.format(result.get_eval('x-valid')))
# for the metric after a specific epoch print ('Test Info : {:}'.format(result.get_eval('x-test')))
print ('Train Info [10-th epoch] : {:}'.format(results[0].get_train(10))) # for the metric after a specific epoch
print ('Train Info [10-th epoch] : {:}'.format(result.get_train(10)))
``` ```
4. Query the index of an architecture by string 4. Query the index of an architecture by string
@ -171,7 +174,7 @@ api.get_more_info(112, 'ImageNet16-120', None, hp='200', is_random=True)
If you find that NAS-Bench-201 helps your research, please consider citing it: If you find that NAS-Bench-201 helps your research, please consider citing it:
``` ```
@inproceedings{dong2020nasbench201, @inproceedings{dong2020nasbench201,
title = {NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search}, title = {{NAS-Bench-201}: Extending the Scope of Reproducible Neural Architecture Search},
author = {Dong, Xuanyi and Yang, Yi}, author = {Dong, Xuanyi and Yang, Yi},
booktitle = {International Conference on Learning Representations (ICLR)}, booktitle = {International Conference on Learning Representations (ICLR)},
url = {https://openreview.net/forum?id=HJxyZkBKDr}, url = {https://openreview.net/forum?id=HJxyZkBKDr},

View File

@ -1,5 +1,7 @@
# [NAS-BENCH-201: Extending the Scope of Reproducible Neural Architecture Search](https://openreview.net/forum?id=HJxyZkBKDr) # [NAS-BENCH-201: Extending the Scope of Reproducible Neural Architecture Search](https://openreview.net/forum?id=HJxyZkBKDr)
**Since our NAS-BENCH-201 has been extended to NATS-Bench, this README is deprecated and not maintained. Please use [NATS-Bench](https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NATS-Bench.md), which has 5x more architecture information and faster API than NAS-BENCH-201.**
We propose an algorithm-agnostic NAS benchmark (NAS-Bench-201) with a fixed search space, which provides a unified benchmark for almost any up-to-date NAS algorithms. We propose an algorithm-agnostic NAS benchmark (NAS-Bench-201) with a fixed search space, which provides a unified benchmark for almost any up-to-date NAS algorithms.
The design of our search space is inspired by that used in the most popular cell-based searching algorithms, where a cell is represented as a directed acyclic graph. The design of our search space is inspired by that used in the most popular cell-based searching algorithms, where a cell is represented as a directed acyclic graph.
Each edge here is associated with an operation selected from a predefined operation set. Each edge here is associated with an operation selected from a predefined operation set.
@ -68,17 +70,18 @@ api.show(2)
# show the mean loss and accuracy of an architecture # show the mean loss and accuracy of an architecture
info = api.query_meta_info_by_index(1) # This is an instance of `ArchResults` info = api.query_meta_info_by_index(1) # This is an instance of `ArchResults`
res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys
cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency cost_metrics = info.get_compute_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
# get the detailed information # get the detailed information
results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed
print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1])) print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1]))
print ('Latency : {:}'.format(results[0].get_latency())) for seed, result in results.items():
print ('Train Info : {:}'.format(results[0].get_train())) print ('Latency : {:}'.format(result.get_latency()))
print ('Valid Info : {:}'.format(results[0].get_eval('x-valid'))) print ('Train Info : {:}'.format(result.get_train()))
print ('Test Info : {:}'.format(results[0].get_eval('x-test'))) print ('Valid Info : {:}'.format(result.get_eval('x-valid')))
# for the metric after a specific epoch print ('Test Info : {:}'.format(result.get_eval('x-test')))
print ('Train Info [10-th epoch] : {:}'.format(results[0].get_train(10))) # for the metric after a specific epoch
print ('Train Info [10-th epoch] : {:}'.format(result.get_train(10)))
``` ```
4. Query the index of an architecture by string 4. Query the index of an architecture by string
@ -242,7 +245,7 @@ In commands [1-6], the first args `cifar10` indicates the dataset name, the seco
If you find that NAS-Bench-201 helps your research, please consider citing it: If you find that NAS-Bench-201 helps your research, please consider citing it:
``` ```
@inproceedings{dong2020nasbench201, @inproceedings{dong2020nasbench201,
title = {NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search}, title = {{NAS-Bench-201}: Extending the Scope of Reproducible Neural Architecture Search},
author = {Dong, Xuanyi and Yang, Yi}, author = {Dong, Xuanyi and Yang, Yi},
booktitle = {International Conference on Learning Representations (ICLR)}, booktitle = {International Conference on Learning Representations (ICLR)},
url = {https://openreview.net/forum?id=HJxyZkBKDr}, url = {https://openreview.net/forum?id=HJxyZkBKDr},

View File

@ -1,4 +1,4 @@
# [NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size](https://arxiv.org/pdf/2009.00437.pdf) # [NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size](https://arxiv.org/pdf/2009.00437.pdf)
Neural architecture search (NAS) has attracted a lot of attention and has been illustrated to bring tangible benefits in a large number of applications in the past few years. Network topology and network size have been regarded as two of the most important aspects for the performance of deep learning models and the community has spawned lots of searching algorithms for both of those aspects of the neural architectures. However, the performance gain from these searching algorithms is achieved under different search spaces and training setups. This makes the overall performance of the algorithms incomparable and the improvement from a sub-module of the searching model unclear. Neural architecture search (NAS) has attracted a lot of attention and has been illustrated to bring tangible benefits in a large number of applications in the past few years. Network topology and network size have been regarded as two of the most important aspects for the performance of deep learning models and the community has spawned lots of searching algorithms for both of those aspects of the neural architectures. However, the performance gain from these searching algorithms is achieved under different search spaces and training setups. This makes the overall performance of the algorithms incomparable and the improvement from a sub-module of the searching model unclear.
In this paper, we propose NATS-Bench, a unified benchmark on searching for both topology and size, for (almost) any up-to-date NAS algorithm. In this paper, we propose NATS-Bench, a unified benchmark on searching for both topology and size, for (almost) any up-to-date NAS algorithm.
@ -7,11 +7,12 @@ We analyze the validity of our benchmark in terms of various criteria and perfor
We also show the versatility of NATS-Bench by benchmarking 13 recent state-of-the-art NAS algorithms on it. All logs and diagnostic information trained using the same setup for each candidate are provided. We also show the versatility of NATS-Bench by benchmarking 13 recent state-of-the-art NAS algorithms on it. All logs and diagnostic information trained using the same setup for each candidate are provided.
This facilitates a much larger community of researchers to focus on developing better NAS algorithms in a more comparable and computationally effective environment. This facilitates a much larger community of researchers to focus on developing better NAS algorithms in a more comparable and computationally effective environment.
**You can use `pip install nats_bench` to install the library of NATS-Bench.**
The structure of this Markdown file: The structure of this Markdown file:
- [How to use NATS-Bench?](#How-to-Use-NATS-Bench) - [How to use NATS-Bench?](#How-to-Use-NATS-Bench)
- [How to re-create NATS-Bench from scratch?](#how-to-re-create-nats-bench-from-scratch) - [How to re-create NATS-Bench from scratch?](#how-to-re-create-nats-bench-from-scratch)
- [How to reproduce benchmarked results?](#to-reproduce-13-baseline-nas-algorithms-in-nas-bench-201) - [How to reproduce benchmarked results?](#to-reproduce-13-baseline-nas-algorithms-in-nats-bench)
## How to Use [NATS-Bench](https://arxiv.org/pdf/2009.00437.pdf) ## How to Use [NATS-Bench](https://arxiv.org/pdf/2009.00437.pdf)
@ -79,8 +80,12 @@ params = api.get_net_param(12, 'cifar10', None)
network.load_state_dict(next(iter(params.values()))) network.load_state_dict(next(iter(params.values())))
``` ```
## How to Re-create NATS-Bench from Scratch ## How to Re-create NATS-Bench from Scratch
You need to use the [AutoDL-Projects](https://github.com/D-X-Y/AutoDL-Projects) repo to re-create NATS-Bench from scratch.
### The Size Search Space ### The Size Search Space
The following command will train all architecture candidate in the size search space with 90 epochs and use the random seed of `777`. The following command will train all architecture candidate in the size search space with 90 epochs and use the random seed of `777`.
@ -110,7 +115,9 @@ python exps/NATS-Bench/tss-collect.py
``` ```
## To Reproduce 13 Baseline NAS Algorithms in NAS-Bench-201 ## To Reproduce 13 Baseline NAS Algorithms in NATS-Bench
You need to use the [AutoDL-Projects](https://github.com/D-X-Y/AutoDL-Projects) repo to run 13 baseline NAS methods.
### Reproduce NAS methods on the topology search space ### Reproduce NAS methods on the topology search space
@ -171,18 +178,18 @@ python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HO
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777 python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777
Run the search strategy in FBNet-V2 Run the channel search strategy in FBNet-V2 -- masking + Gumbel-Softmax :
python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777 python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777 python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777 python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --rand_seed 777
Run the search strategy in TuNAS: Run the channel search strategy in TuNAS -- masking + sampling :
python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0 python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 --use_api 0
python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed 777 python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --rand_seed 777
``` ```
### Final Discovered Architectures for Each Algorithm ### Final Discovered Architectures for Each Algorithm
@ -246,7 +253,7 @@ GDAS:
If you find that NATS-Bench helps your research, please consider citing it: If you find that NATS-Bench helps your research, please consider citing it:
``` ```
@article{dong2020nats, @article{dong2020nats,
title={NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size}, title={{NATS-Bench}: Benchmarking NAS Algorithms for Architecture Topology and Size},
author={Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan}, author={Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan},
journal={arXiv preprint arXiv:2009.00437}, journal={arXiv preprint arXiv:2009.00437},
year={2020} year={2020}

View File

@ -1,28 +1,30 @@
################################################## ##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
########################################################################################################################################### ###########################################################################################################################################
#
# In this file, we aims to evaluate three kinds of channel searching strategies: # In this file, we aims to evaluate three kinds of channel searching strategies:
# - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019" # - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
# - masking + Gumbel-Softmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020" # - masking + Gumbel-Softmax (mask_gumbel) from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
# - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020" # - masking + sampling (mask_rl) from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
# For simplicity, we use tas, fbv2, and tunas to refer these three strategies. Their official implementations are at the following links: #
# For simplicity, we use tas, mask_gumbel, and mask_rl to refer these three strategies. Their official implementations are at the following links:
# - TAS: https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NeurIPS-2019-TAS.md # - TAS: https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NeurIPS-2019-TAS.md
# - FBNetV2: https://github.com/facebookresearch/mobile-vision # - FBNetV2: https://github.com/facebookresearch/mobile-vision
# - TuNAS: https://github.com/google-research/google-research/tree/master/tunas # - TuNAS: https://github.com/google-research/google-research/tree/master/tunas
#### ####
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio 0.25 # python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --warmup_ratio 0.25
#### ####
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777 # python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777 # python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777 # python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777
#### ####
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777 # python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777 # python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777 # python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --rand_seed 777
#### ####
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0 # python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 --use_api 0
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 # python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed 777 # python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --rand_seed 777
########################################################################################################################################### ###########################################################################################################################################
import os, sys, time, random, argparse import os, sys, time, random, argparse
import numpy as np import numpy as np
@ -41,7 +43,7 @@ from models import get_cell_based_tiny_net, get_search_spaces
from nats_bench import create from nats_bench import create
# Ad-hoc for TuNAS # Ad-hoc for RL algorithms.
class ExponentialMovingAverage(object): class ExponentialMovingAverage(object):
"""Class that maintains an exponential moving average.""" """Class that maintains an exponential moving average."""
@ -94,13 +96,13 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
a_optimizer.zero_grad() a_optimizer.zero_grad()
_, logits, log_probs = network(arch_inputs) _, logits, log_probs = network(arch_inputs)
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
if algo == 'tunas': if algo == 'mask_rl':
with torch.no_grad(): with torch.no_grad():
RL_BASELINE_EMA.update(arch_prec1.item()) RL_BASELINE_EMA.update(arch_prec1.item())
rl_advantage = arch_prec1 - RL_BASELINE_EMA.value rl_advantage = arch_prec1 - RL_BASELINE_EMA.value
rl_log_prob = sum(log_probs) rl_log_prob = sum(log_probs)
arch_loss = - rl_advantage * rl_log_prob arch_loss = - rl_advantage * rl_log_prob
elif algo == 'tas' or algo == 'fbv2': elif algo == 'tas' or algo == 'mask_gumbel':
arch_loss = criterion(logits, arch_targets) arch_loss = criterion(logits, arch_targets)
else: else:
raise ValueError('invalid algorightm name: {:}'.format(algo)) raise ValueError('invalid algorightm name: {:}'.format(algo))
@ -231,7 +233,7 @@ def main(xargs):
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller)) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller))
if xargs.algo == 'fbv2' or xargs.algo == 'tas': if xargs.algo == 'mask_gumbel' or xargs.algo == 'tas':
network.set_tau(xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1)) network.set_tau(xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1))
logger.log('[RESET tau as : {:}]'.format(network.tau)) logger.log('[RESET tau as : {:}]'.format(network.tau))
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
@ -291,7 +293,7 @@ if __name__ == '__main__':
parser.add_argument('--data_path' , type=str, help='Path to dataset') parser.add_argument('--data_path' , type=str, help='Path to dataset')
parser.add_argument('--dataset' , type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') parser.add_argument('--dataset' , type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
parser.add_argument('--search_space', type=str, default='sss', choices=['sss'], help='The search space name.') parser.add_argument('--search_space', type=str, default='sss', choices=['sss'], help='The search space name.')
parser.add_argument('--algo' , type=str, choices=['tas', 'fbv2', 'tunas'], help='The search space name.') parser.add_argument('--algo' , type=str, choices=['tas', 'mask_gumbel', 'mask_rl'], help='The search space name.')
parser.add_argument('--genotype' , type=str, default='|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|', help='The genotype.') parser.add_argument('--genotype' , type=str, default='|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|', help='The genotype.')
parser.add_argument('--use_api' , type=int, default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).') parser.add_argument('--use_api' , type=int, default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).')
# FOR GDAS # FOR GDAS

View File

@ -43,9 +43,9 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suf
# alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix) # alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix)
# alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix) # alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix)
# alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix) # alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix)
alg2name['channel-wise interpaltion'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix) alg2name['channel-wise interpolation'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix)
alg2name['masking + Gumbel-Softmax'] = 'fbv2-affine0_BN0-AWD0.001{:}'.format(suffix) alg2name['masking + Gumbel-Softmax'] = 'mask_gumbel-affine0_BN0-AWD0.001{:}'.format(suffix)
alg2name['masking + sampling'] = 'tunas-affine0_BN0-AWD0.0{:}'.format(suffix) alg2name['masking + sampling'] = 'mask_rl-affine0_BN0-AWD0.0{:}'.format(suffix)
for alg, name in alg2name.items(): for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth') alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth')
alg2data = OrderedDict() alg2data = OrderedDict()

View File

@ -3,8 +3,8 @@
##################################################### #####################################################
# Here, we utilized three techniques to search for the number of channels: # Here, we utilized three techniques to search for the number of channels:
# - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019" # - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
# - masking + Gumbel-Softmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020" # - masking + Gumbel-Softmax (mask_gumbel) from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
# - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020" # - masking + sampling (mask_rl) from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
from typing import List, Text, Any from typing import List, Text, Any
import random, torch import random, torch
import torch.nn as nn import torch.nn as nn
@ -52,10 +52,10 @@ class GenericNAS301Model(nn.Module):
def set_algo(self, algo: Text): def set_algo(self, algo: Text):
# used for searching # used for searching
assert self._algo is None, 'This functioin can only be called once.' assert self._algo is None, 'This functioin can only be called once.'
assert algo in ['fbv2', 'tunas', 'tas'], 'invalid algo : {:}'.format(algo) assert algo in ['mask_gumbel', 'mask_rl', 'tas'], 'invalid algo : {:}'.format(algo)
self._algo = algo self._algo = algo
self._arch_parameters = nn.Parameter(1e-3*torch.randn(self._max_num_Cs, len(self._candidate_Cs))) self._arch_parameters = nn.Parameter(1e-3*torch.randn(self._max_num_Cs, len(self._candidate_Cs)))
# if algo == 'fbv2' or algo == 'tunas': # if algo == 'mask_gumbel' or algo == 'mask_rl':
self.register_buffer('_masks', torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs))) self.register_buffer('_masks', torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs)))
for i in range(len(self._candidate_Cs)): for i in range(len(self._candidate_Cs)):
self._masks.data[i, :self._candidate_Cs[i]] = 1 self._masks.data[i, :self._candidate_Cs[i]] = 1
@ -130,7 +130,7 @@ class GenericNAS301Model(nn.Module):
else: else:
mask = self._masks[random.randint(0, len(self._masks)-1)] mask = self._masks[random.randint(0, len(self._masks)-1)]
feature = feature * mask.view(1, -1, 1, 1) feature = feature * mask.view(1, -1, 1, 1)
elif self._algo == 'fbv2': elif self._algo == 'mask_gumbel':
weights = nn.functional.gumbel_softmax(self._arch_parameters[idx:idx+1], tau=self.tau, dim=-1) weights = nn.functional.gumbel_softmax(self._arch_parameters[idx:idx+1], tau=self.tau, dim=-1)
mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1) mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1)
feature = feature * mask feature = feature * mask
@ -148,7 +148,7 @@ class GenericNAS301Model(nn.Module):
else: else:
miss = torch.zeros(feature.shape[0], feature.shape[1]-out.shape[1], feature.shape[2], feature.shape[3], device=feature.device) miss = torch.zeros(feature.shape[0], feature.shape[1]-out.shape[1], feature.shape[2], feature.shape[3], device=feature.device)
feature = torch.cat((out, miss), dim=1) feature = torch.cat((out, miss), dim=1)
elif self._algo == 'tunas': elif self._algo == 'mask_rl':
prob = nn.functional.softmax(self._arch_parameters[idx:idx+1], dim=-1) prob = nn.functional.softmax(self._arch_parameters[idx:idx+1], dim=-1)
dist = torch.distributions.Categorical(prob) dist = torch.distributions.Categorical(prob)
action = dist.sample() action = dist.sample()

View File

@ -3,15 +3,18 @@
############################################################################## ##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # # NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
############################################################################## ##############################################################################
# The official Application Programming Interface (API) for NATS-Bench. # """The official Application Programming Interface (API) for NATS-Bench."""
############################################################################## from nats_bench.api_size import NATSsize
from .api_utils import pickle_save, pickle_load from nats_bench.api_topology import NATStopology
from .api_utils import ArchResults, ResultsCount from nats_bench.api_utils import ArchResults
from .api_topology import NATStopology from nats_bench.api_utils import pickle_load
from .api_size import NATSsize from nats_bench.api_utils import pickle_save
from nats_bench.api_utils import ResultsCount
NATS_BENCH_API_VERSIONs = ['v1.0'] # [2020.08.31] NATS_BENCH_API_VERSIONs = ['v1.0'] # [2020.08.31]
NATS_BENCH_SSS_NAMEs = ('sss', 'size')
NATS_BENCH_TSS_NAMEs = ('tss', 'topology')
def version(): def version():
@ -24,13 +27,43 @@ def create(file_path_or_dict, search_space, fast_mode=False, verbose=True):
Args: Args:
file_path_or_dict: None or a file path or a directory path. file_path_or_dict: None or a file path or a directory path.
search_space: This is a string indicates the search space in NATS-Bench. search_space: This is a string indicates the search space in NATS-Bench.
fast_mode: If True, we will not load all the data at initialization, instead, the data for each candidate architecture will be loaded when quering it; fast_mode: If True, we will not load all the data at initialization,
If False, we will load all the data during initialization. instead, the data for each candidate architecture will be loaded when
quering it; If False, we will load all the data during initialization.
verbose: This is a flag to indicate whether log additional information. verbose: This is a flag to indicate whether log additional information.
Raises:
ValueError: If not find the matched serach space description.
Returns:
The created NATS-Bench API.
""" """
if search_space in ['tss', 'topology']: if search_space in NATS_BENCH_TSS_NAMEs:
return NATStopology(file_path_or_dict, fast_mode, verbose) return NATStopology(file_path_or_dict, fast_mode, verbose)
elif search_space in ['sss', 'size']: elif search_space in NATS_BENCH_SSS_NAMEs:
return NATSsize(file_path_or_dict, fast_mode, verbose) return NATSsize(file_path_or_dict, fast_mode, verbose)
else: else:
raise ValueError('invalid search space : {:}'.format(search_space)) raise ValueError('invalid search space : {:}'.format(search_space))
def search_space_info(main_tag, aux_tag):
"""Obtain the search space information."""
nats_sss = dict(candidates=[8, 16, 24, 32, 40, 48, 56, 64],
num_layers=5)
nats_tss = dict(op_names=['none', 'skip_connect',
'nor_conv_1x1', 'nor_conv_3x3',
'avg_pool_3x3'],
num_nodes=4)
if main_tag == 'nats-bench':
if aux_tag in NATS_BENCH_SSS_NAMEs:
return nats_sss
elif aux_tag in NATS_BENCH_TSS_NAMEs:
return nats_tss
else:
raise ValueError('Unknown auxiliary tag: {:}'.format(aux_tag))
elif main_tag == 'nas-bench-201':
if aux_tag is not None:
raise ValueError('For NAS-Bench-201, the auxiliary tag should be None.')
return nats_tss
else:
raise ValueError('Unknown main tag: {:}'.format(main_tag))

View File

@ -1,65 +1,84 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
############################################################################## ##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # # NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##################################################################################### ##############################################################################
# The history of benchmark files (the name is NATS-sss-[version]-[md5].pickle.pbz2) # # The history of benchmark files are as follows, #
# [2020.08.31] NATS-sss-v1_0-50262.pickle.pbz2 # # where the format is (the name is NATS-sss-[version]-[md5].pickle.pbz2) #
##################################################################################### # [2020.08.31] NATS-sss-v1_0-50262.pickle.pbz2 #
import os, copy, random, numpy as np ##############################################################################
from typing import List, Text, Union, Dict, Optional # pylint: disable=line-too-long
from collections import OrderedDict, defaultdict """The API for size search space in NATS-Bench."""
from .api_utils import time_string import collections
from .api_utils import pickle_load import copy
from .api_utils import ArchResults import os
from .api_utils import NASBenchMetaAPI import random
from .api_utils import remap_dataset_set_names from typing import Dict, Optional, Text, Union, Any
from .api_utils import nats_is_dir
from .api_utils import nats_is_file from nats_bench.api_utils import ArchResults
from .api_utils import PICKLE_EXT from nats_bench.api_utils import NASBenchMetaAPI
from nats_bench.api_utils import nats_is_dir
from nats_bench.api_utils import nats_is_file
from nats_bench.api_utils import PICKLE_EXT
from nats_bench.api_utils import pickle_load
from nats_bench.api_utils import time_string
ALL_BASE_NAMES = ['NATS-sss-v1_0-50262'] ALL_BASE_NAMES = ['NATS-sss-v1_0-50262']
def print_information(information, extra_info=None, show=False): def print_information(information, extra_info=None, show=False):
"""print out the information of a given ArchResults."""
dataset_names = information.get_dataset_names() dataset_names = information.get_dataset_names()
strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)] strings = [
information.arch_str,
'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)
]
def metric2str(loss, acc): def metric2str(loss, acc):
return 'loss = {:.3f} & top1 = {:.2f}%'.format(loss, acc) return 'loss = {:.3f} & top1 = {:.2f}%'.format(loss, acc)
for ida, dataset in enumerate(dataset_names): for dataset in dataset_names:
metric = information.get_compute_costs(dataset) metric = information.get_compute_costs(dataset)
flop, param, latency = metric['flops'], metric['params'], metric['latency'] flop, param, latency = metric['flops'], metric['params'], metric['latency']
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None) str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(
dataset, flop, param,
'{:.2f}'.format(latency *
1000) if latency is not None and latency > 0 else None)
train_info = information.get_metrics(dataset, 'train') train_info = information.get_metrics(dataset, 'train')
if dataset == 'cifar10-valid': if dataset == 'cifar10-valid':
valid_info = information.get_metrics(dataset, 'x-valid') valid_info = information.get_metrics(dataset, 'x-valid')
test__info = information.get_metrics(dataset, 'ori-test') test__info = information.get_metrics(dataset, 'ori-test')
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format( str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(
dataset, metric2str(train_info['loss'], train_info['accuracy']), dataset, metric2str(train_info['loss'], train_info['accuracy']),
metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']),
metric2str(test__info['loss'], test__info['accuracy'])) metric2str(test__info['loss'], test__info['accuracy']))
elif dataset == 'cifar10': elif dataset == 'cifar10':
test__info = information.get_metrics(dataset, 'ori-test') test__info = information.get_metrics(dataset, 'ori-test')
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy'])) str2 = '{:14s} train : [{:}], test : [{:}]'.format(
dataset, metric2str(train_info['loss'], train_info['accuracy']),
metric2str(test__info['loss'], test__info['accuracy']))
else: else:
valid_info = information.get_metrics(dataset, 'x-valid') valid_info = information.get_metrics(dataset, 'x-valid')
test__info = information.get_metrics(dataset, 'x-test') test__info = information.get_metrics(dataset, 'x-test')
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy'])) str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(
dataset, metric2str(train_info['loss'], train_info['accuracy']),
metric2str(valid_info['loss'], valid_info['accuracy']),
metric2str(test__info['loss'], test__info['accuracy']))
strings += [str1, str2] strings += [str1, str2]
if show: print('\n'.join(strings)) if show: print('\n'.join(strings))
return strings return strings
"""
This is the class for the API of size search space in NATS-Bench.
"""
class NATSsize(NASBenchMetaAPI): class NATSsize(NASBenchMetaAPI):
"""This is the class for the API of size search space in NATS-Bench."""
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """ def __init__(self,
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, fast_mode: bool=False, verbose: bool=True): file_path_or_dict: Optional[Union[Text, Dict[Text, Any]]] = None,
self.ALL_BASE_NAMES = ALL_BASE_NAMES fast_mode: bool = False,
verbose: bool = True):
"""The initialization function that takes the dataset file path (or a dict loaded from that path) as input."""
self._all_base_names = ALL_BASE_NAMES
self.filename = None self.filename = None
self._search_space_name = 'size' self._search_space_name = 'size'
self._fast_mode = fast_mode self._fast_mode = fast_mode
@ -67,25 +86,36 @@ class NATSsize(NASBenchMetaAPI):
self.reset_time() self.reset_time()
if file_path_or_dict is None: if file_path_or_dict is None:
if self._fast_mode: if self._fast_mode:
self._archive_dir = os.path.join(os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1])) self._archive_dir = os.path.join(
os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
else: else:
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT)) file_path_or_dict = os.path.join(
print ('{:} Try to use the default NATS-Bench (size) path from fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode, file_path_or_dict)) os.environ['TORCH_HOME'], '{:}.{:}'.format(
ALL_BASE_NAMES[-1], PICKLE_EXT))
print('{:} Try to use the default NATS-Bench (size) path from '
'fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode,
file_path_or_dict))
if isinstance(file_path_or_dict, str): if isinstance(file_path_or_dict, str):
file_path_or_dict = str(file_path_or_dict) file_path_or_dict = str(file_path_or_dict)
if verbose: if verbose:
print('{:} Try to create the NATS-Bench (size) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode)) print('{:} Try to create the NATS-Bench (size) api '
if not nats_is_file(file_path_or_dict) and not nats_is_dir(file_path_or_dict): 'from {:} with fast_mode={:}'.format(
raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict)) time_string(), file_path_or_dict, fast_mode))
if not nats_is_file(file_path_or_dict) and not nats_is_dir(
file_path_or_dict):
raise ValueError('{:} is neither a file or a dir.'.format(
file_path_or_dict))
self.filename = os.path.basename(file_path_or_dict) self.filename = os.path.basename(file_path_or_dict)
if fast_mode: if fast_mode:
if nats_is_file(file_path_or_dict): if nats_is_file(file_path_or_dict):
raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict)) raise ValueError('fast_mode={:} must feed the path for directory '
': {:}'.format(fast_mode, file_path_or_dict))
else: else:
self._archive_dir = file_path_or_dict self._archive_dir = file_path_or_dict
else: else:
if nats_is_dir(file_path_or_dict): if nats_is_dir(file_path_or_dict):
raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict)) raise ValueError('fast_mode={:} must feed the path for file '
': {:}'.format(fast_mode, file_path_or_dict))
else: else:
file_path_or_dict = pickle_load(file_path_or_dict) file_path_or_dict = pickle_load(file_path_or_dict)
elif isinstance(file_path_or_dict, dict): elif isinstance(file_path_or_dict, dict):
@ -93,68 +123,95 @@ class NATSsize(NASBenchMetaAPI):
self.verbose = verbose self.verbose = verbose
if isinstance(file_path_or_dict, dict): if isinstance(file_path_or_dict, dict):
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes') keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key) for key in keys:
if key not in file_path_or_dict:
raise ValueError('Can not find key[{:}] in the dict'.format(key))
self.meta_archs = copy.deepcopy(file_path_or_dict['meta_archs']) self.meta_archs = copy.deepcopy(file_path_or_dict['meta_archs'])
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults # NOTE(xuanyidong): This is a dict mapping each architecture to a dict,
self.arch2infos_dict = OrderedDict() # where the key is #epochs and the value is ArchResults
self.arch2infos_dict = collections.OrderedDict()
self._avaliable_hps = set() self._avaliable_hps = set()
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())): for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_infos = file_path_or_dict['arch2infos'][xkey] all_infos = file_path_or_dict['arch2infos'][xkey]
hp2archres = OrderedDict() hp2archres = collections.OrderedDict()
for hp_key, results in all_infos.items(): for hp_key, results in all_infos.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results) hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
self.arch2infos_dict[xkey] = hp2archres self.arch2infos_dict[xkey] = hp2archres
self.evaluated_indexes = set(file_path_or_dict['evaluated_indexes']) self.evaluated_indexes = set(file_path_or_dict['evaluated_indexes'])
elif self.archive_dir is not None: elif self.archive_dir is not None:
benchmark_meta = pickle_load('{:}/meta.{:}'.format(self.archive_dir, PICKLE_EXT)) benchmark_meta = pickle_load('{:}/meta.{:}'.format(
self.archive_dir, PICKLE_EXT))
self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs']) self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs'])
self.arch2infos_dict = OrderedDict() self.arch2infos_dict = collections.OrderedDict()
self._avaliable_hps = set() self._avaliable_hps = set()
self.evaluated_indexes = set() self.evaluated_indexes = set()
else: else:
raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir must be set'.format(type(file_path_or_dict))) raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir '
'must be set'.format(type(file_path_or_dict)))
self.archstr2index = {} self.archstr2index = {}
for idx, arch in enumerate(self.meta_archs): for idx, arch in enumerate(self.meta_archs):
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch]) if arch in self.archstr2index:
raise ValueError('This [{:}]-th arch {:} already in the '
'dict ({:}).'.format(
idx, arch, self.archstr2index[arch]))
self.archstr2index[arch] = idx self.archstr2index[arch] = idx
if self.verbose: if self.verbose:
print('{:} Create NATS-Bench (size) done with {:}/{:} architectures avaliable.'.format( print('{:} Create NATS-Bench (size) done with {:}/{:} architectures '
time_string(), len(self.evaluated_indexes), len(self.meta_archs))) 'avaliable.'.format(time_string(),
len(self.evaluated_indexes),
len(self.meta_archs)))
def query_info_str_by_arch(self, arch, hp: Text='12'): def query_info_str_by_arch(self, arch, hp: Text = '12'):
""" This function is used to query the information of a specific architecture """Query the information of a specific architecture.
'arch' can be an architecture index or an architecture string
When hp=01, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/01E.config' Args:
When hp=12, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/12E.config' arch: it can be an architecture index or an architecture string.
When hp=90, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/90E.config'
The difference between these three configurations are the number of training epochs. hp: the hyperparamete indicator, could be 01, 12, or 90. The difference
between these three configurations are the number of training epochs.
Returns:
ArchResults instance
""" """
if self.verbose: if self.verbose:
print('{:} Call query_info_str_by_arch with arch={:} and hp={:}'.format(time_string(), arch, hp)) print('{:} Call query_info_str_by_arch with arch={:}'
'and hp={:}'.format(time_string(), arch, hp))
return self._query_info_str_by_arch(arch, hp, print_information) return self._query_info_str_by_arch(arch, hp, print_information)
def get_more_info(self, index, dataset: Text, iepoch=None, hp='12', is_random=True): def get_more_info(self,
"""This function will return the metric for the `index`-th architecture index,
`dataset` indicates the dataset: dataset,
iepoch=None,
hp: Text = '12',
is_random: bool = True):
"""Return the metric for the `index`-th architecture.
Args:
index: the architecture index.
dataset:
'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set 'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set
'cifar100' : using the proposed train set of CIFAR-100 as the training set 'cifar100' : using the proposed train set of CIFAR-100 as the training set
'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set 'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set
`iepoch` indicates the index of training epochs from 0 to 11/199. iepoch: the index of training epochs from 0 to 11/199.
When iepoch=None, it will return the metric for the last training epoch When iepoch=None, it will return the metric for the last training epoch
When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0) When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
`hp` indicates different hyper-parameters for training hp: indicates different hyper-parameters for training
When hp=01, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 01 epochs When hp=01, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 01 epochs
When hp=12, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 12 epochs When hp=12, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 12 epochs
When hp=90, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 90 epochs When hp=90, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 90 epochs
`is_random` is_random:
When is_random=True, the performance of a random architecture will be returned When is_random=True, the performance of a random architecture will be returned
When is_random=False, the performanceo of all trials will be averaged. When is_random=False, the performanceo of all trials will be averaged.
Returns:
a dict, where key is the metric name and value is its value.
""" """
if self.verbose: if self.verbose:
print('{:} Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format( print('{:} Call the get_more_info function with index={:}, dataset={:}, '
time_string(), index, dataset, iepoch, hp, is_random)) 'iepoch={:}, hp={:}, and is_random={:}.'.format(
time_string(), index, dataset, iepoch, hp, is_random))
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
self._prepare_info(index) self._prepare_info(index)
if index not in self.arch2infos_dict: if index not in self.arch2infos_dict:
@ -165,38 +222,47 @@ class NATSsize(NASBenchMetaAPI):
seeds = archresult.get_dataset_seeds(dataset) seeds = archresult.get_dataset_seeds(dataset)
is_random = random.choice(seeds) is_random = random.choice(seeds)
# collect the training information # collect the training information
train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random) train_info = archresult.get_metrics(
dataset, 'train', iepoch=iepoch, is_random=is_random)
total = train_info['iepoch'] + 1 total = train_info['iepoch'] + 1
xinfo = {'train-loss' : train_info['loss'], xinfo = {
'train-accuracy': train_info['accuracy'], 'train-loss': train_info['loss'],
'train-per-time': train_info['all_time'] / total, 'train-accuracy': train_info['accuracy'],
'train-all-time': train_info['all_time']} 'train-per-time': train_info['all_time'] / total,
'train-all-time': train_info['all_time']
}
# collect the evaluation information # collect the evaluation information
if dataset == 'cifar10-valid': if dataset == 'cifar10-valid':
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) valid_info = archresult.get_metrics(
dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
try: try:
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) test_info = archresult.get_metrics(
except: dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
except Exception as unused_e: # pylint: disable=broad-except
test_info = None test_info = None
valtest_info = None valtest_info = None
else: else:
try: # collect results on the proposed test set try: # collect results on the proposed test set
if dataset == 'cifar10': if dataset == 'cifar10':
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) test_info = archresult.get_metrics(
dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else: else:
test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) test_info = archresult.get_metrics(
except: dataset, 'x-test', iepoch=iepoch, is_random=is_random)
except Exception as unused_e: # pylint: disable=broad-except
test_info = None test_info = None
try: # collect results on the proposed validation set try: # collect results on the proposed validation set
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) valid_info = archresult.get_metrics(
except: dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
except Exception as unused_e: # pylint: disable=broad-except
valid_info = None valid_info = None
try: try:
if dataset != 'cifar10': if dataset != 'cifar10':
valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) valtest_info = archresult.get_metrics(
dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else: else:
valtest_info = None valtest_info = None
except: except Exception as unused_e: # pylint: disable=broad-except
valtest_info = None valtest_info = None
if valid_info is not None: if valid_info is not None:
xinfo['valid-loss'] = valid_info['loss'] xinfo['valid-loss'] = valid_info['loss']
@ -216,11 +282,5 @@ class NATSsize(NASBenchMetaAPI):
return xinfo return xinfo
def show(self, index: int = -1) -> None: def show(self, index: int = -1) -> None:
""" """Print the information of a specific (or all) architecture(s)."""
This function will print the information of a specific (or all) architecture(s).
:param index: If the index < 0: it will loop for all architectures and print their information one by one.
else: it will print the information of the 'index'-th architecture.
:return: nothing
"""
self._show(index, print_information) self._show(index, print_information)

View File

@ -0,0 +1,59 @@
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 ##########################
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##############################################################################
"""This file is used to quickly test the API."""
import random
from nats_bench.api_size import NATSsize
from nats_bench.api_topology import NATStopology
def test_nats_bench_tss(benchmark_dir):
return test_nats_bench(benchmark_dir, True)
def test_nats_bench_sss(benchmark_dir):
return test_nats_bench(benchmark_dir, False)
def test_nats_bench(benchmark_dir, is_tss, verbose=False):
if is_tss:
api = NATStopology(benchmark_dir, True, verbose)
else:
api = NATSsize(benchmark_dir, True, verbose)
test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)]
key2dataset = {'cifar10': 'CIFAR-10',
'cifar100': 'CIFAR-100',
'ImageNet16-120': 'ImageNet16-120'}
for index in test_indexes:
print('\n\nEvaluate the {:5d}-th architecture.'.format(index))
for key, dataset in key2dataset.items():
# Query the loss / accuracy / time for the `index`-th candidate
# architecture on CIFAR-10
# info is a dict, where you can easily figure out the meaning by key
info = api.get_more_info(index, key)
print(' -->> The performance on {:}: {:}'.format(dataset, info))
# Query the flops, params, latency. info is a dict.
info = api.get_cost_info(index, key)
print(' -->> The cost info on {:}: {:}'.format(dataset, info))
# Simulate the training of the `index`-th candidate:
validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(
index, dataset=key, hp='12')
print(' -->> The validation accuracy={:}, latency={:}, '
'the current time cost={:} s, accumulated time cost={:} s'
.format(validation_accuracy, latency, time_cost,
current_total_time_cost))
# Print the configuration of the `index`-th architecture on CIFAR-10
config = api.get_net_config(index, key)
print(' -->> The configuration on {:} is {:}'.format(dataset, config))
# Show the information of the `index`-th architecture
api.show(index)

View File

@ -2,61 +2,83 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
############################################################################## ##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # # NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##################################################################################### ##############################################################################
# The history of benchmark files (the name is NATS-tss-[version]-[md5].pickle.pbz2) # # The history of benchmark files are as follows, #
# [2020.08.31] NATS-tss-v1_0-3ffb9.pickle.pbz2 # # where the format is (the name is NATS-tss-[version]-[md5].pickle.pbz2) #
##################################################################################### # [2020.08.31] NATS-tss-v1_0-3ffb9.pickle.pbz2 #
import os, copy, random, numpy as np ##############################################################################
from typing import List, Text, Union, Dict, Optional # pylint: disable=line-too-long
from collections import OrderedDict, defaultdict """The API for topology search space in NATS-Bench."""
import warnings import collections
from .api_utils import time_string import copy
from .api_utils import pickle_load import os
from .api_utils import ArchResults import random
from .api_utils import NASBenchMetaAPI from typing import Any, Dict, List, Optional, Text, Union
from .api_utils import remap_dataset_set_names
from .api_utils import nats_is_dir from nats_bench.api_utils import ArchResults
from .api_utils import nats_is_file from nats_bench.api_utils import NASBenchMetaAPI
from .api_utils import PICKLE_EXT from nats_bench.api_utils import nats_is_dir
from nats_bench.api_utils import nats_is_file
from nats_bench.api_utils import PICKLE_EXT
from nats_bench.api_utils import pickle_load
from nats_bench.api_utils import time_string
import numpy as np
ALL_BASE_NAMES = ['NATS-tss-v1_0-3ffb9'] ALL_BASE_NAMES = ['NATS-tss-v1_0-3ffb9']
def print_information(information, extra_info=None, show=False): def print_information(information, extra_info=None, show=False):
"""print out the information of a given ArchResults."""
dataset_names = information.get_dataset_names() dataset_names = information.get_dataset_names()
strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)] strings = [
information.arch_str,
'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)
]
def metric2str(loss, acc): def metric2str(loss, acc):
return 'loss = {:.3f} & top1 = {:.2f}%'.format(loss, acc) return 'loss = {:.3f} & top1 = {:.2f}%'.format(loss, acc)
for ida, dataset in enumerate(dataset_names): for dataset in dataset_names:
metric = information.get_compute_costs(dataset) metric = information.get_compute_costs(dataset)
flop, param, latency = metric['flops'], metric['params'], metric['latency'] flop, param, latency = metric['flops'], metric['params'], metric['latency']
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None) str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(
dataset, flop, param,
'{:.2f}'.format(latency *
1000) if latency is not None and latency > 0 else None)
train_info = information.get_metrics(dataset, 'train') train_info = information.get_metrics(dataset, 'train')
if dataset == 'cifar10-valid': if dataset == 'cifar10-valid':
valid_info = information.get_metrics(dataset, 'x-valid') valid_info = information.get_metrics(dataset, 'x-valid')
str2 = '{:14s} train : [{:}], valid : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy'])) str2 = '{:14s} train : [{:}], valid : [{:}]'.format(
dataset, metric2str(train_info['loss'], train_info['accuracy']),
metric2str(valid_info['loss'], valid_info['accuracy']))
elif dataset == 'cifar10': elif dataset == 'cifar10':
test__info = information.get_metrics(dataset, 'ori-test') test__info = information.get_metrics(dataset, 'ori-test')
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy'])) str2 = '{:14s} train : [{:}], test : [{:}]'.format(
dataset, metric2str(train_info['loss'], train_info['accuracy']),
metric2str(test__info['loss'], test__info['accuracy']))
else: else:
valid_info = information.get_metrics(dataset, 'x-valid') valid_info = information.get_metrics(dataset, 'x-valid')
test__info = information.get_metrics(dataset, 'x-test') test__info = information.get_metrics(dataset, 'x-test')
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy'])) str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(
dataset, metric2str(train_info['loss'], train_info['accuracy']),
metric2str(valid_info['loss'], valid_info['accuracy']),
metric2str(test__info['loss'], test__info['accuracy']))
strings += [str1, str2] strings += [str1, str2]
if show: print('\n'.join(strings)) if show: print('\n'.join(strings))
return strings return strings
"""
This is the class for the API of topology search space in NATS-Bench.
"""
class NATStopology(NASBenchMetaAPI): class NATStopology(NASBenchMetaAPI):
"""This is the class for the API of topology search space in NATS-Bench."""
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """ def __init__(self,
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, fast_mode: bool=False, verbose: bool=True): file_path_or_dict: Optional[Union[Text, Dict[Text, Any]]] = None,
self.ALL_BASE_NAMES = ALL_BASE_NAMES fast_mode: bool = False,
verbose: bool = True):
"""The initialization function that takes the dataset file path (or a dict loaded from that path) as input."""
self._all_base_names = ALL_BASE_NAMES
self.filename = None self.filename = None
self._search_space_name = 'topology' self._search_space_name = 'topology'
self._fast_mode = fast_mode self._fast_mode = fast_mode
@ -64,25 +86,35 @@ class NATStopology(NASBenchMetaAPI):
self.reset_time() self.reset_time()
if file_path_or_dict is None: if file_path_or_dict is None:
if self._fast_mode: if self._fast_mode:
self._archive_dir = os.path.join(os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1])) self._archive_dir = os.path.join(
os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
else: else:
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT)) file_path_or_dict = os.path.join(
print ('{:} Try to use the default NATS-Bench (topology) path from {:}.'.format(time_string(), file_path_or_dict)) os.environ['TORCH_HOME'], '{:}.{:}'.format(
ALL_BASE_NAMES[-1], PICKLE_EXT))
print('{:} Try to use the default NATS-Bench (topology) path '
'from {:}.'.format(time_string(), file_path_or_dict))
if isinstance(file_path_or_dict, str): if isinstance(file_path_or_dict, str):
file_path_or_dict = str(file_path_or_dict) file_path_or_dict = str(file_path_or_dict)
if verbose: if verbose:
print('{:} Try to create the NATS-Bench (topology) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode)) print('{:} Try to create the NATS-Bench (topology) api '
if not nats_is_file(file_path_or_dict) and not nats_is_dir(file_path_or_dict): 'from {:} with fast_mode={:}'.format(
raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict)) time_string(), file_path_or_dict, fast_mode))
if not nats_is_file(file_path_or_dict) and not nats_is_dir(
file_path_or_dict):
raise ValueError('{:} is neither a file or a dir.'.format(
file_path_or_dict))
self.filename = os.path.basename(file_path_or_dict) self.filename = os.path.basename(file_path_or_dict)
if fast_mode: if fast_mode:
if nats_is_file(file_path_or_dict): if nats_is_file(file_path_or_dict):
raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict)) raise ValueError('fast_mode={:} must feed the path for directory '
': {:}'.format(fast_mode, file_path_or_dict))
else: else:
self._archive_dir = file_path_or_dict self._archive_dir = file_path_or_dict
else: else:
if nats_is_dir(file_path_or_dict): if nats_is_dir(file_path_or_dict):
raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict)) raise ValueError('fast_mode={:} must feed the path for file '
': {:}'.format(fast_mode, file_path_or_dict))
else: else:
file_path_or_dict = pickle_load(file_path_or_dict) file_path_or_dict = pickle_load(file_path_or_dict)
elif isinstance(file_path_or_dict, dict): elif isinstance(file_path_or_dict, dict):
@ -90,65 +122,73 @@ class NATStopology(NASBenchMetaAPI):
self.verbose = verbose self.verbose = verbose
if isinstance(file_path_or_dict, dict): if isinstance(file_path_or_dict, dict):
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes') keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key) for key in keys:
if key not in file_path_or_dict:
raise ValueError('Can not find key[{:}] in the dict'.format(key))
self.meta_archs = copy.deepcopy(file_path_or_dict['meta_archs']) self.meta_archs = copy.deepcopy(file_path_or_dict['meta_archs'])
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults # NOTE(xuanyidong): This is a dict mapping each architecture to a dict,
self.arch2infos_dict = OrderedDict() # where the key is #epochs and the value is ArchResults
self.arch2infos_dict = collections.OrderedDict()
self._avaliable_hps = set() self._avaliable_hps = set()
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())): for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_infos = file_path_or_dict['arch2infos'][xkey] all_infos = file_path_or_dict['arch2infos'][xkey]
hp2archres = OrderedDict() hp2archres = collections.OrderedDict()
for hp_key, results in all_infos.items(): for hp_key, results in all_infos.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results) hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
self.arch2infos_dict[xkey] = hp2archres self.arch2infos_dict[xkey] = hp2archres
self.evaluated_indexes = set(file_path_or_dict['evaluated_indexes']) self.evaluated_indexes = set(file_path_or_dict['evaluated_indexes'])
elif self.archive_dir is not None: elif self.archive_dir is not None:
benchmark_meta = pickle_load('{:}/meta.{:}'.format(self.archive_dir, PICKLE_EXT)) benchmark_meta = pickle_load('{:}/meta.{:}'.format(
self.archive_dir, PICKLE_EXT))
self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs']) self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs'])
self.arch2infos_dict = OrderedDict() self.arch2infos_dict = collections.OrderedDict()
self._avaliable_hps = set() self._avaliable_hps = set()
self.evaluated_indexes = set() self.evaluated_indexes = set()
else: else:
raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir must be set'.format(type(file_path_or_dict))) raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir '
'must be set'.format(type(file_path_or_dict)))
self.archstr2index = {} self.archstr2index = {}
for idx, arch in enumerate(self.meta_archs): for idx, arch in enumerate(self.meta_archs):
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch]) if arch in self.archstr2index:
raise ValueError('This [{:}]-th arch {:} already in the '
'dict ({:}).'.format(
idx, arch, self.archstr2index[arch]))
self.archstr2index[arch] = idx self.archstr2index[arch] = idx
if self.verbose: if self.verbose:
print('{:} Create NATS-Bench (topology) done with {:}/{:} architectures avaliable.'.format( print('{:} Create NATS-Bench (topology) done with {:}/{:} architectures '
time_string(), len(self.evaluated_indexes), len(self.meta_archs))) 'avaliable.'.format(time_string(),
len(self.evaluated_indexes),
len(self.meta_archs)))
def query_info_str_by_arch(self, arch, hp: Text='12'): def query_info_str_by_arch(self, arch, hp: Text = '12'):
""" This function is used to query the information of a specific architecture """Query the information of a specific architecture.
'arch' can be an architecture index or an architecture string
When hp=12, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/12E.config' Args:
When hp=200, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/200E.config' arch: it can be an architecture index or an architecture string.
The difference between these three configurations are the number of training epochs.
hp: the hyperparamete indicator, could be 12 or 200. The difference
between these three configurations are the number of training epochs.
Returns:
ArchResults instance
""" """
if self.verbose: if self.verbose:
print('{:} Call query_info_str_by_arch with arch={:} and hp={:}'.format(time_string(), arch, hp)) print('{:} Call query_info_str_by_arch with arch={:}'
'and hp={:}'.format(time_string(), arch, hp))
return self._query_info_str_by_arch(arch, hp, print_information) return self._query_info_str_by_arch(arch, hp, print_information)
# obtain the metric for the `index`-th architecture def get_more_info(self,
# `dataset` indicates the dataset: index,
# 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set dataset,
# 'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set iepoch=None,
# 'cifar100' : using the proposed train set of CIFAR-100 as the training set hp: Text = '12',
# 'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set is_random: bool = True):
# `iepoch` indicates the index of training epochs from 0 to 11/199. """Return the metric for the `index`-th architecture."""
# When iepoch=None, it will return the metric for the last training epoch
# When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
# `use_12epochs_result` indicates different hyper-parameters for training
# When use_12epochs_result=True, it trains the network with 12 epochs and the LR decayed from 0.1 to 0 within 12 epochs
# When use_12epochs_result=False, it trains the network with 200 epochs and the LR decayed from 0.1 to 0 within 200 epochs
# `is_random`
# When is_random=True, the performance of a random architecture will be returned
# When is_random=False, the performanceo of all trials will be averaged.
def get_more_info(self, index, dataset, iepoch=None, hp='12', is_random=True):
if self.verbose: if self.verbose:
print('{:} Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format( print('{:} Call the get_more_info function with index={:}, dataset={:}, '
time_string(), index, dataset, iepoch, hp, is_random)) 'iepoch={:}, hp={:}, and is_random={:}.'.format(
time_string(), index, dataset, iepoch, hp, is_random))
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
self._prepare_info(index) self._prepare_info(index)
if index not in self.arch2infos_dict: if index not in self.arch2infos_dict:
@ -161,36 +201,43 @@ class NATStopology(NASBenchMetaAPI):
# collect the training information # collect the training information
train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random) train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random)
total = train_info['iepoch'] + 1 total = train_info['iepoch'] + 1
xinfo = {'train-loss' : train_info['loss'], xinfo = {
'train-accuracy': train_info['accuracy'], 'train-loss':
'train-per-time': train_info['all_time'] / total if train_info['all_time'] is not None else None, train_info['loss'],
'train-all-time': train_info['all_time']} 'train-accuracy':
train_info['accuracy'],
'train-per-time':
train_info['all_time'] /
total if train_info['all_time'] is not None else None,
'train-all-time':
train_info['all_time']
}
# collect the evaluation information # collect the evaluation information
if dataset == 'cifar10-valid': if dataset == 'cifar10-valid':
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
try: try:
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
except: except Exception as unused_e: # pylint: disable=broad-except
test_info = None test_info = None
valtest_info = None valtest_info = None
else: else:
try: # collect results on the proposed test set try: # collect results on the proposed test set
if dataset == 'cifar10': if dataset == 'cifar10':
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else: else:
test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
except: except Exception as unused_e: # pylint: disable=broad-except
test_info = None test_info = None
try: # collect results on the proposed validation set try: # collect results on the proposed validation set
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
except: except Exception as unused_e: # pylint: disable=broad-except
valid_info = None valid_info = None
try: try:
if dataset != 'cifar10': if dataset != 'cifar10':
valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else: else:
valtest_info = None valtest_info = None
except: except Exception as unused_e: # pylint: disable=broad-except
valtest_info = None valtest_info = None
if valid_info is not None: if valid_info is not None:
xinfo['valid-loss'] = valid_info['loss'] xinfo['valid-loss'] = valid_info['loss']
@ -214,46 +261,52 @@ class NATStopology(NASBenchMetaAPI):
self._show(index, print_information) self._show(index, print_information)
@staticmethod @staticmethod
def str2lists(arch_str: Text) -> List[tuple]: def str2lists(arch_str: Text) -> List[Any]:
""" """Shows how to read the string-based architecture encoding.
This function shows how to read the string-based architecture encoding.
It is the same as the `str2structure` func in `AutoDL-Projects/lib/models/cell_searchs/genotypes.py`
:param Args:
arch_str: the input is a string indicates the architecture topology, such as arch_str: the input is a string indicates the architecture topology, such as
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2| |nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
:return: a list of tuple, contains multiple (op, input_node_index) pairs. Returns:
a list of tuple, contains multiple (op, input_node_index) pairs.
:usage [USAGE]
It is the same as the `str2structure` func in AutoDL-Projects:
`github.com/D-X-Y/AutoDL-Projects/lib/models/cell_searchs/genotypes.py`
```
arch = api.str2lists( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' ) arch = api.str2lists( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list
for i, node in enumerate(arch): for i, node in enumerate(arch):
print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node)) print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node))
```
""" """
node_strs = arch_str.split('+') node_strs = arch_str.split('+')
genotypes = [] genotypes = []
for i, node_str in enumerate(node_strs): for unused_i, node_str in enumerate(node_strs):
inputs = list(filter(lambda x: x != '', node_str.split('|'))) inputs = list(filter(lambda x: x != '', node_str.split('|'))) # pylint: disable=g-explicit-bool-comparison
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) for xinput in inputs:
inputs = ( xi.split('~') for xi in inputs ) assert len(
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs) xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
genotypes.append( input_infos ) inputs = (xi.split('~') for xi in inputs)
input_infos = tuple((op, int(idx)) for (op, idx) in inputs)
genotypes.append(input_infos)
return genotypes return genotypes
@staticmethod @staticmethod
def str2matrix(arch_str: Text, def str2matrix(arch_str: Text,
search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray: search_space: List[Text] = ('none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3')) -> np.ndarray:
""" """Convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101.
This func shows how to convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101.
:param Args:
arch_str: the input is a string indicates the architecture topology, such as arch_str: the input is a string indicates the architecture topology, such as
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2| |nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
search_space: a list of operation string, the default list is the topology search space for NATS-BENCH. search_space: a list of operation string, the default list is the topology search space for NATS-BENCH.
the default value should be be consistent with this line https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/models/cell_operations.py#L24 the default value should be be consistent with this line https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/models/cell_operations.py#L24
:return
Returns:
the numpy matrix (2-D np.ndarray) representing the DAG of this architecture topology the numpy matrix (2-D np.ndarray) representing the DAG of this architecture topology
:usage
[USAGE]
matrix = api.str2matrix( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' ) matrix = api.str2matrix( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
This matrix is 4-by-4 matrix representing a cell with 4 nodes (only the lower left triangle is useful). This matrix is 4-by-4 matrix representing a cell with 4 nodes (only the lower left triangle is useful).
[ [0, 0, 0, 0], # the first line represents the input (0-th) node [ [0, 0, 0, 0], # the first line represents the input (0-th) node
@ -262,19 +315,19 @@ class NATStopology(NASBenchMetaAPI):
[0, 0, 1, 0] ] # the fourth line represents the 3-rd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) + 1-th-op( 2-th-node ) [0, 0, 1, 0] ] # the fourth line represents the 3-rd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) + 1-th-op( 2-th-node )
In the topology search space in NATS-BENCH, 0-th-op is 'none', 1-th-op is 'skip_connect', In the topology search space in NATS-BENCH, 0-th-op is 'none', 1-th-op is 'skip_connect',
2-th-op is 'nor_conv_1x1', 3-th-op is 'nor_conv_3x3', 4-th-op is 'avg_pool_3x3'. 2-th-op is 'nor_conv_1x1', 3-th-op is 'nor_conv_3x3', 4-th-op is 'avg_pool_3x3'.
:(NOTE) [NOTE]
If a node has two input-edges from the same node, this function does not work. One edge will be overlapped. If a node has two input-edges from the same node, this function does not work. One edge will be overlapped.
""" """
node_strs = arch_str.split('+') node_strs = arch_str.split('+')
num_nodes = len(node_strs) + 1 num_nodes = len(node_strs) + 1
matrix = np.zeros((num_nodes, num_nodes)) matrix = np.zeros((num_nodes, num_nodes))
for i, node_str in enumerate(node_strs): for i, node_str in enumerate(node_strs):
inputs = list(filter(lambda x: x != '', node_str.split('|'))) inputs = list(filter(lambda x: x != '', node_str.split('|'))) # pylint: disable=g-explicit-bool-comparison
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) for xinput in inputs:
assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
for xi in inputs: for xi in inputs:
op, idx = xi.split('~') op, idx = xi.split('~')
if op not in search_space: raise ValueError('this op ({:}) is not in {:}'.format(op, search_space)) if op not in search_space: raise ValueError('this op ({:}) is not in {:}'.format(op, search_space))
op_idx, node_idx = search_space.index(op), int(idx) op_idx, node_idx = search_space.index(op), int(idx)
matrix[i+1, node_idx] = op_idx matrix[i+1, node_idx] = op_idx
return matrix return matrix

File diff suppressed because it is too large Load Diff

View File

@ -23,11 +23,11 @@ CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --warmup_ratio ${ratio} --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --warmup_ratio ${ratio} --rand_seed ${seed}
# #
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --warmup_ratio ${ratio} --rand_seed ${seed}
# #
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}