diff --git a/CHANGE-LOG.md b/CHANGE-LOG.md
index 2a6247e..950d0c5 100644
--- a/CHANGE-LOG.md
+++ b/CHANGE-LOG.md
@@ -6,3 +6,4 @@
- [2019.01.31] [13e908f] GDAS codes were publicly released.
- [2020.07.01] [a45808b] Upgrade NAS-API to the 2.0 version.
- [2020.09.16] [7052265] Create NATS-BENCH.
+- [2020.10.15] [446262a] Update NATS-BENCH to version 1.0
diff --git a/README.md b/README.md
index 265b63a..0441aa7 100644
--- a/README.md
+++ b/README.md
@@ -61,7 +61,7 @@ At this moment, this project provides the following algorithms and scripts to ru
NATS-Bench |
- NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size |
+ NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size |
NATS-Bench.md |
@@ -100,7 +100,7 @@ Some methods use knowledge distillation (KD), which require pre-trained models.
If you find that this project helps your research, please consider citing some of the following papers:
```
@article{dong2020nats,
- title={NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size},
+ title={{NATS-Bench}: Benchmarking NAS Algorithms for Architecture Topology and Size},
author={Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan},
journal={arXiv preprint arXiv:2009.00437},
year={2020}
diff --git a/README_CN.md b/README_CN.md
index cfeee6c..b919d2c 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -61,7 +61,7 @@
NATS-Bench |
- NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size |
+ NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size |
NATS-Bench.md |
@@ -99,7 +99,7 @@ Some methods use knowledge distillation (KD), which require pre-trained models.
如果您发现该项目对您的科研或工程有帮助,请考虑引用下列的某些文献:
```
@article{dong2020nats,
- title={NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size},
+ title={{NATS-Bench}: Benchmarking NAS Algorithms for Architecture Topology and Size},
author={Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan},
journal={arXiv preprint arXiv:2009.00437},
year={2020}
diff --git a/docs/NAS-Bench-201-PURE.md b/docs/NAS-Bench-201-PURE.md
index e9980cb..8a1ac54 100644
--- a/docs/NAS-Bench-201-PURE.md
+++ b/docs/NAS-Bench-201-PURE.md
@@ -1,5 +1,7 @@
# [NAS-BENCH-201: Extending the Scope of Reproducible Neural Architecture Search](https://openreview.net/forum?id=HJxyZkBKDr)
+**Since our NAS-BENCH-201 has been extended to NATS-Bench, this `README` is deprecated and not maintained. Please use [NATS-Bench](https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NATS-Bench.md), which has 5x more architecture information and faster API than NAS-BENCH-201.**
+
We propose an algorithm-agnostic NAS benchmark (NAS-Bench-201) with a fixed search space, which provides a unified benchmark for almost any up-to-date NAS algorithms.
The design of our search space is inspired by that used in the most popular cell-based searching algorithms, where a cell is represented as a directed acyclic graph.
Each edge here is associated with an operation selected from a predefined operation set.
@@ -70,17 +72,18 @@ api.show(2)
# show the mean loss and accuracy of an architecture
info = api.query_meta_info_by_index(1) # This is an instance of `ArchResults`
res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys
-cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
+cost_metrics = info.get_compute_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
# get the detailed information
results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed
print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1]))
-print ('Latency : {:}'.format(results[0].get_latency()))
-print ('Train Info : {:}'.format(results[0].get_train()))
-print ('Valid Info : {:}'.format(results[0].get_eval('x-valid')))
-print ('Test Info : {:}'.format(results[0].get_eval('x-test')))
-# for the metric after a specific epoch
-print ('Train Info [10-th epoch] : {:}'.format(results[0].get_train(10)))
+for seed, result in results.items():
+ print ('Latency : {:}'.format(result.get_latency()))
+ print ('Train Info : {:}'.format(result.get_train()))
+ print ('Valid Info : {:}'.format(result.get_eval('x-valid')))
+ print ('Test Info : {:}'.format(result.get_eval('x-test')))
+ # for the metric after a specific epoch
+ print ('Train Info [10-th epoch] : {:}'.format(result.get_train(10)))
```
4. Query the index of an architecture by string
@@ -171,7 +174,7 @@ api.get_more_info(112, 'ImageNet16-120', None, hp='200', is_random=True)
If you find that NAS-Bench-201 helps your research, please consider citing it:
```
@inproceedings{dong2020nasbench201,
- title = {NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search},
+ title = {{NAS-Bench-201}: Extending the Scope of Reproducible Neural Architecture Search},
author = {Dong, Xuanyi and Yang, Yi},
booktitle = {International Conference on Learning Representations (ICLR)},
url = {https://openreview.net/forum?id=HJxyZkBKDr},
diff --git a/docs/NAS-Bench-201.md b/docs/NAS-Bench-201.md
index d4325fc..dc233a9 100644
--- a/docs/NAS-Bench-201.md
+++ b/docs/NAS-Bench-201.md
@@ -1,5 +1,7 @@
# [NAS-BENCH-201: Extending the Scope of Reproducible Neural Architecture Search](https://openreview.net/forum?id=HJxyZkBKDr)
+**Since our NAS-BENCH-201 has been extended to NATS-Bench, this README is deprecated and not maintained. Please use [NATS-Bench](https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NATS-Bench.md), which has 5x more architecture information and faster API than NAS-BENCH-201.**
+
We propose an algorithm-agnostic NAS benchmark (NAS-Bench-201) with a fixed search space, which provides a unified benchmark for almost any up-to-date NAS algorithms.
The design of our search space is inspired by that used in the most popular cell-based searching algorithms, where a cell is represented as a directed acyclic graph.
Each edge here is associated with an operation selected from a predefined operation set.
@@ -68,17 +70,18 @@ api.show(2)
# show the mean loss and accuracy of an architecture
info = api.query_meta_info_by_index(1) # This is an instance of `ArchResults`
res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys
-cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
+cost_metrics = info.get_compute_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
# get the detailed information
results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed
print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1]))
-print ('Latency : {:}'.format(results[0].get_latency()))
-print ('Train Info : {:}'.format(results[0].get_train()))
-print ('Valid Info : {:}'.format(results[0].get_eval('x-valid')))
-print ('Test Info : {:}'.format(results[0].get_eval('x-test')))
-# for the metric after a specific epoch
-print ('Train Info [10-th epoch] : {:}'.format(results[0].get_train(10)))
+for seed, result in results.items():
+ print ('Latency : {:}'.format(result.get_latency()))
+ print ('Train Info : {:}'.format(result.get_train()))
+ print ('Valid Info : {:}'.format(result.get_eval('x-valid')))
+ print ('Test Info : {:}'.format(result.get_eval('x-test')))
+ # for the metric after a specific epoch
+ print ('Train Info [10-th epoch] : {:}'.format(result.get_train(10)))
```
4. Query the index of an architecture by string
@@ -242,7 +245,7 @@ In commands [1-6], the first args `cifar10` indicates the dataset name, the seco
If you find that NAS-Bench-201 helps your research, please consider citing it:
```
@inproceedings{dong2020nasbench201,
- title = {NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search},
+ title = {{NAS-Bench-201}: Extending the Scope of Reproducible Neural Architecture Search},
author = {Dong, Xuanyi and Yang, Yi},
booktitle = {International Conference on Learning Representations (ICLR)},
url = {https://openreview.net/forum?id=HJxyZkBKDr},
diff --git a/docs/NATS-Bench.md b/docs/NATS-Bench.md
index b42d816..19ea15b 100644
--- a/docs/NATS-Bench.md
+++ b/docs/NATS-Bench.md
@@ -1,4 +1,4 @@
-# [NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size](https://arxiv.org/pdf/2009.00437.pdf)
+# [NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size](https://arxiv.org/pdf/2009.00437.pdf)
Neural architecture search (NAS) has attracted a lot of attention and has been illustrated to bring tangible benefits in a large number of applications in the past few years. Network topology and network size have been regarded as two of the most important aspects for the performance of deep learning models and the community has spawned lots of searching algorithms for both of those aspects of the neural architectures. However, the performance gain from these searching algorithms is achieved under different search spaces and training setups. This makes the overall performance of the algorithms incomparable and the improvement from a sub-module of the searching model unclear.
In this paper, we propose NATS-Bench, a unified benchmark on searching for both topology and size, for (almost) any up-to-date NAS algorithm.
@@ -7,11 +7,12 @@ We analyze the validity of our benchmark in terms of various criteria and perfor
We also show the versatility of NATS-Bench by benchmarking 13 recent state-of-the-art NAS algorithms on it. All logs and diagnostic information trained using the same setup for each candidate are provided.
This facilitates a much larger community of researchers to focus on developing better NAS algorithms in a more comparable and computationally effective environment.
+**You can use `pip install nats_bench` to install the library of NATS-Bench.**
The structure of this Markdown file:
- [How to use NATS-Bench?](#How-to-Use-NATS-Bench)
- [How to re-create NATS-Bench from scratch?](#how-to-re-create-nats-bench-from-scratch)
-- [How to reproduce benchmarked results?](#to-reproduce-13-baseline-nas-algorithms-in-nas-bench-201)
+- [How to reproduce benchmarked results?](#to-reproduce-13-baseline-nas-algorithms-in-nats-bench)
## How to Use [NATS-Bench](https://arxiv.org/pdf/2009.00437.pdf)
@@ -79,8 +80,12 @@ params = api.get_net_param(12, 'cifar10', None)
network.load_state_dict(next(iter(params.values())))
```
+
+
## How to Re-create NATS-Bench from Scratch
+You need to use the [AutoDL-Projects](https://github.com/D-X-Y/AutoDL-Projects) repo to re-create NATS-Bench from scratch.
+
### The Size Search Space
The following command will train all architecture candidate in the size search space with 90 epochs and use the random seed of `777`.
@@ -110,7 +115,9 @@ python exps/NATS-Bench/tss-collect.py
```
-## To Reproduce 13 Baseline NAS Algorithms in NAS-Bench-201
+## To Reproduce 13 Baseline NAS Algorithms in NATS-Bench
+
+You need to use the [AutoDL-Projects](https://github.com/D-X-Y/AutoDL-Projects) repo to run 13 baseline NAS methods.
### Reproduce NAS methods on the topology search space
@@ -171,18 +178,18 @@ python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HO
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777
-Run the search strategy in FBNet-V2
+Run the channel search strategy in FBNet-V2 -- masking + Gumbel-Softmax :
-python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
-python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
-python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777
+python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
+python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
+python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --rand_seed 777
-Run the search strategy in TuNAS:
+Run the channel search strategy in TuNAS -- masking + sampling :
-python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0
-python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777
-python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed 777
+python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 --use_api 0
+python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777
+python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --rand_seed 777
```
### Final Discovered Architectures for Each Algorithm
@@ -246,7 +253,7 @@ GDAS:
If you find that NATS-Bench helps your research, please consider citing it:
```
@article{dong2020nats,
- title={NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size},
+ title={{NATS-Bench}: Benchmarking NAS Algorithms for Architecture Topology and Size},
author={Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan},
journal={arXiv preprint arXiv:2009.00437},
year={2020}
diff --git a/exps/NATS-algos/search-size.py b/exps/NATS-algos/search-size.py
index 78727ee..e215523 100644
--- a/exps/NATS-algos/search-size.py
+++ b/exps/NATS-algos/search-size.py
@@ -1,28 +1,30 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
###########################################################################################################################################
+#
# In this file, we aims to evaluate three kinds of channel searching strategies:
# - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
-# - masking + Gumbel-Softmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
-# - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
-# For simplicity, we use tas, fbv2, and tunas to refer these three strategies. Their official implementations are at the following links:
+# - masking + Gumbel-Softmax (mask_gumbel) from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
+# - masking + sampling (mask_rl) from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
+#
+# For simplicity, we use tas, mask_gumbel, and mask_rl to refer these three strategies. Their official implementations are at the following links:
# - TAS: https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NeurIPS-2019-TAS.md
# - FBNetV2: https://github.com/facebookresearch/mobile-vision
# - TuNAS: https://github.com/google-research/google-research/tree/master/tunas
####
-# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio 0.25
+# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --warmup_ratio 0.25
####
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777
####
-# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
-# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
-# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777
+# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
+# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
+# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --rand_seed 777
####
-# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0
-# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777
-# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed 777
+# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 --use_api 0
+# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777
+# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --rand_seed 777
###########################################################################################################################################
import os, sys, time, random, argparse
import numpy as np
@@ -41,7 +43,7 @@ from models import get_cell_based_tiny_net, get_search_spaces
from nats_bench import create
-# Ad-hoc for TuNAS
+# Ad-hoc for RL algorithms.
class ExponentialMovingAverage(object):
"""Class that maintains an exponential moving average."""
@@ -94,13 +96,13 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
a_optimizer.zero_grad()
_, logits, log_probs = network(arch_inputs)
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
- if algo == 'tunas':
+ if algo == 'mask_rl':
with torch.no_grad():
RL_BASELINE_EMA.update(arch_prec1.item())
rl_advantage = arch_prec1 - RL_BASELINE_EMA.value
rl_log_prob = sum(log_probs)
arch_loss = - rl_advantage * rl_log_prob
- elif algo == 'tas' or algo == 'fbv2':
+ elif algo == 'tas' or algo == 'mask_gumbel':
arch_loss = criterion(logits, arch_targets)
else:
raise ValueError('invalid algorightm name: {:}'.format(algo))
@@ -231,7 +233,7 @@ def main(xargs):
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller))
- if xargs.algo == 'fbv2' or xargs.algo == 'tas':
+ if xargs.algo == 'mask_gumbel' or xargs.algo == 'tas':
network.set_tau(xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1))
logger.log('[RESET tau as : {:}]'.format(network.tau))
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
@@ -291,7 +293,7 @@ if __name__ == '__main__':
parser.add_argument('--data_path' , type=str, help='Path to dataset')
parser.add_argument('--dataset' , type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
parser.add_argument('--search_space', type=str, default='sss', choices=['sss'], help='The search space name.')
- parser.add_argument('--algo' , type=str, choices=['tas', 'fbv2', 'tunas'], help='The search space name.')
+ parser.add_argument('--algo' , type=str, choices=['tas', 'mask_gumbel', 'mask_rl'], help='The search space name.')
parser.add_argument('--genotype' , type=str, default='|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|', help='The genotype.')
parser.add_argument('--use_api' , type=int, default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).')
# FOR GDAS
diff --git a/exps/experimental/vis-nats-bench-ws.py b/exps/experimental/vis-nats-bench-ws.py
index b1d5014..de4a22a 100644
--- a/exps/experimental/vis-nats-bench-ws.py
+++ b/exps/experimental/vis-nats-bench-ws.py
@@ -43,9 +43,9 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suf
# alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix)
# alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix)
# alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix)
- alg2name['channel-wise interpaltion'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix)
- alg2name['masking + Gumbel-Softmax'] = 'fbv2-affine0_BN0-AWD0.001{:}'.format(suffix)
- alg2name['masking + sampling'] = 'tunas-affine0_BN0-AWD0.0{:}'.format(suffix)
+ alg2name['channel-wise interpolation'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix)
+ alg2name['masking + Gumbel-Softmax'] = 'mask_gumbel-affine0_BN0-AWD0.001{:}'.format(suffix)
+ alg2name['masking + sampling'] = 'mask_rl-affine0_BN0-AWD0.0{:}'.format(suffix)
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth')
alg2data = OrderedDict()
diff --git a/lib/models/shape_searchs/generic_size_tiny_cell_model.py b/lib/models/shape_searchs/generic_size_tiny_cell_model.py
index 9a3f6d0..ee887cc 100644
--- a/lib/models/shape_searchs/generic_size_tiny_cell_model.py
+++ b/lib/models/shape_searchs/generic_size_tiny_cell_model.py
@@ -3,8 +3,8 @@
#####################################################
# Here, we utilized three techniques to search for the number of channels:
# - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
-# - masking + Gumbel-Softmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
-# - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
+# - masking + Gumbel-Softmax (mask_gumbel) from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
+# - masking + sampling (mask_rl) from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
from typing import List, Text, Any
import random, torch
import torch.nn as nn
@@ -52,10 +52,10 @@ class GenericNAS301Model(nn.Module):
def set_algo(self, algo: Text):
# used for searching
assert self._algo is None, 'This functioin can only be called once.'
- assert algo in ['fbv2', 'tunas', 'tas'], 'invalid algo : {:}'.format(algo)
+ assert algo in ['mask_gumbel', 'mask_rl', 'tas'], 'invalid algo : {:}'.format(algo)
self._algo = algo
self._arch_parameters = nn.Parameter(1e-3*torch.randn(self._max_num_Cs, len(self._candidate_Cs)))
- # if algo == 'fbv2' or algo == 'tunas':
+ # if algo == 'mask_gumbel' or algo == 'mask_rl':
self.register_buffer('_masks', torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs)))
for i in range(len(self._candidate_Cs)):
self._masks.data[i, :self._candidate_Cs[i]] = 1
@@ -130,7 +130,7 @@ class GenericNAS301Model(nn.Module):
else:
mask = self._masks[random.randint(0, len(self._masks)-1)]
feature = feature * mask.view(1, -1, 1, 1)
- elif self._algo == 'fbv2':
+ elif self._algo == 'mask_gumbel':
weights = nn.functional.gumbel_softmax(self._arch_parameters[idx:idx+1], tau=self.tau, dim=-1)
mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1)
feature = feature * mask
@@ -148,7 +148,7 @@ class GenericNAS301Model(nn.Module):
else:
miss = torch.zeros(feature.shape[0], feature.shape[1]-out.shape[1], feature.shape[2], feature.shape[3], device=feature.device)
feature = torch.cat((out, miss), dim=1)
- elif self._algo == 'tunas':
+ elif self._algo == 'mask_rl':
prob = nn.functional.softmax(self._arch_parameters[idx:idx+1], dim=-1)
dist = torch.distributions.Categorical(prob)
action = dist.sample()
diff --git a/lib/nats_bench/__init__.py b/lib/nats_bench/__init__.py
index 4c318b5..050aae7 100644
--- a/lib/nats_bench/__init__.py
+++ b/lib/nats_bench/__init__.py
@@ -3,15 +3,18 @@
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##############################################################################
-# The official Application Programming Interface (API) for NATS-Bench. #
-##############################################################################
-from .api_utils import pickle_save, pickle_load
-from .api_utils import ArchResults, ResultsCount
-from .api_topology import NATStopology
-from .api_size import NATSsize
+"""The official Application Programming Interface (API) for NATS-Bench."""
+from nats_bench.api_size import NATSsize
+from nats_bench.api_topology import NATStopology
+from nats_bench.api_utils import ArchResults
+from nats_bench.api_utils import pickle_load
+from nats_bench.api_utils import pickle_save
+from nats_bench.api_utils import ResultsCount
NATS_BENCH_API_VERSIONs = ['v1.0'] # [2020.08.31]
+NATS_BENCH_SSS_NAMEs = ('sss', 'size')
+NATS_BENCH_TSS_NAMEs = ('tss', 'topology')
def version():
@@ -24,13 +27,43 @@ def create(file_path_or_dict, search_space, fast_mode=False, verbose=True):
Args:
file_path_or_dict: None or a file path or a directory path.
search_space: This is a string indicates the search space in NATS-Bench.
- fast_mode: If True, we will not load all the data at initialization, instead, the data for each candidate architecture will be loaded when quering it;
- If False, we will load all the data during initialization.
+ fast_mode: If True, we will not load all the data at initialization,
+ instead, the data for each candidate architecture will be loaded when
+ quering it; If False, we will load all the data during initialization.
verbose: This is a flag to indicate whether log additional information.
+
+ Raises:
+ ValueError: If not find the matched serach space description.
+
+ Returns:
+ The created NATS-Bench API.
"""
- if search_space in ['tss', 'topology']:
+ if search_space in NATS_BENCH_TSS_NAMEs:
return NATStopology(file_path_or_dict, fast_mode, verbose)
- elif search_space in ['sss', 'size']:
+ elif search_space in NATS_BENCH_SSS_NAMEs:
return NATSsize(file_path_or_dict, fast_mode, verbose)
else:
raise ValueError('invalid search space : {:}'.format(search_space))
+
+
+def search_space_info(main_tag, aux_tag):
+ """Obtain the search space information."""
+ nats_sss = dict(candidates=[8, 16, 24, 32, 40, 48, 56, 64],
+ num_layers=5)
+ nats_tss = dict(op_names=['none', 'skip_connect',
+ 'nor_conv_1x1', 'nor_conv_3x3',
+ 'avg_pool_3x3'],
+ num_nodes=4)
+ if main_tag == 'nats-bench':
+ if aux_tag in NATS_BENCH_SSS_NAMEs:
+ return nats_sss
+ elif aux_tag in NATS_BENCH_TSS_NAMEs:
+ return nats_tss
+ else:
+ raise ValueError('Unknown auxiliary tag: {:}'.format(aux_tag))
+ elif main_tag == 'nas-bench-201':
+ if aux_tag is not None:
+ raise ValueError('For NAS-Bench-201, the auxiliary tag should be None.')
+ return nats_tss
+ else:
+ raise ValueError('Unknown main tag: {:}'.format(main_tag))
diff --git a/lib/nats_bench/api_size.py b/lib/nats_bench/api_size.py
index d10425c..e7400fa 100644
--- a/lib/nats_bench/api_size.py
+++ b/lib/nats_bench/api_size.py
@@ -1,65 +1,84 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
-# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
-#####################################################################################
-# The history of benchmark files (the name is NATS-sss-[version]-[md5].pickle.pbz2) #
-# [2020.08.31] NATS-sss-v1_0-50262.pickle.pbz2 #
-#####################################################################################
-import os, copy, random, numpy as np
-from typing import List, Text, Union, Dict, Optional
-from collections import OrderedDict, defaultdict
-from .api_utils import time_string
-from .api_utils import pickle_load
-from .api_utils import ArchResults
-from .api_utils import NASBenchMetaAPI
-from .api_utils import remap_dataset_set_names
-from .api_utils import nats_is_dir
-from .api_utils import nats_is_file
-from .api_utils import PICKLE_EXT
+# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
+##############################################################################
+# The history of benchmark files are as follows, #
+# where the format is (the name is NATS-sss-[version]-[md5].pickle.pbz2) #
+# [2020.08.31] NATS-sss-v1_0-50262.pickle.pbz2 #
+##############################################################################
+# pylint: disable=line-too-long
+"""The API for size search space in NATS-Bench."""
+import collections
+import copy
+import os
+import random
+from typing import Dict, Optional, Text, Union, Any
+
+from nats_bench.api_utils import ArchResults
+from nats_bench.api_utils import NASBenchMetaAPI
+from nats_bench.api_utils import nats_is_dir
+from nats_bench.api_utils import nats_is_file
+from nats_bench.api_utils import PICKLE_EXT
+from nats_bench.api_utils import pickle_load
+from nats_bench.api_utils import time_string
ALL_BASE_NAMES = ['NATS-sss-v1_0-50262']
def print_information(information, extra_info=None, show=False):
+ """print out the information of a given ArchResults."""
dataset_names = information.get_dataset_names()
- strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)]
+ strings = [
+ information.arch_str,
+ 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)
+ ]
+
def metric2str(loss, acc):
return 'loss = {:.3f} & top1 = {:.2f}%'.format(loss, acc)
- for ida, dataset in enumerate(dataset_names):
+ for dataset in dataset_names:
metric = information.get_compute_costs(dataset)
flop, param, latency = metric['flops'], metric['params'], metric['latency']
- str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None)
+ str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(
+ dataset, flop, param,
+ '{:.2f}'.format(latency *
+ 1000) if latency is not None and latency > 0 else None)
train_info = information.get_metrics(dataset, 'train')
if dataset == 'cifar10-valid':
valid_info = information.get_metrics(dataset, 'x-valid')
test__info = information.get_metrics(dataset, 'ori-test')
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(
- dataset, metric2str(train_info['loss'], train_info['accuracy']),
- metric2str(valid_info['loss'], valid_info['accuracy']),
- metric2str(test__info['loss'], test__info['accuracy']))
+ dataset, metric2str(train_info['loss'], train_info['accuracy']),
+ metric2str(valid_info['loss'], valid_info['accuracy']),
+ metric2str(test__info['loss'], test__info['accuracy']))
elif dataset == 'cifar10':
test__info = information.get_metrics(dataset, 'ori-test')
- str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
+ str2 = '{:14s} train : [{:}], test : [{:}]'.format(
+ dataset, metric2str(train_info['loss'], train_info['accuracy']),
+ metric2str(test__info['loss'], test__info['accuracy']))
else:
valid_info = information.get_metrics(dataset, 'x-valid')
test__info = information.get_metrics(dataset, 'x-test')
- str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
+ str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(
+ dataset, metric2str(train_info['loss'], train_info['accuracy']),
+ metric2str(valid_info['loss'], valid_info['accuracy']),
+ metric2str(test__info['loss'], test__info['accuracy']))
strings += [str1, str2]
if show: print('\n'.join(strings))
return strings
-"""
-This is the class for the API of size search space in NATS-Bench.
-"""
class NATSsize(NASBenchMetaAPI):
+ """This is the class for the API of size search space in NATS-Bench."""
- """ The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
- def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, fast_mode: bool=False, verbose: bool=True):
- self.ALL_BASE_NAMES = ALL_BASE_NAMES
+ def __init__(self,
+ file_path_or_dict: Optional[Union[Text, Dict[Text, Any]]] = None,
+ fast_mode: bool = False,
+ verbose: bool = True):
+ """The initialization function that takes the dataset file path (or a dict loaded from that path) as input."""
+ self._all_base_names = ALL_BASE_NAMES
self.filename = None
self._search_space_name = 'size'
self._fast_mode = fast_mode
@@ -67,25 +86,36 @@ class NATSsize(NASBenchMetaAPI):
self.reset_time()
if file_path_or_dict is None:
if self._fast_mode:
- self._archive_dir = os.path.join(os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
+ self._archive_dir = os.path.join(
+ os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
else:
- file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT))
- print ('{:} Try to use the default NATS-Bench (size) path from fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode, file_path_or_dict))
+ file_path_or_dict = os.path.join(
+ os.environ['TORCH_HOME'], '{:}.{:}'.format(
+ ALL_BASE_NAMES[-1], PICKLE_EXT))
+ print('{:} Try to use the default NATS-Bench (size) path from '
+ 'fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode,
+ file_path_or_dict))
if isinstance(file_path_or_dict, str):
file_path_or_dict = str(file_path_or_dict)
if verbose:
- print('{:} Try to create the NATS-Bench (size) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode))
- if not nats_is_file(file_path_or_dict) and not nats_is_dir(file_path_or_dict):
- raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict))
+ print('{:} Try to create the NATS-Bench (size) api '
+ 'from {:} with fast_mode={:}'.format(
+ time_string(), file_path_or_dict, fast_mode))
+ if not nats_is_file(file_path_or_dict) and not nats_is_dir(
+ file_path_or_dict):
+ raise ValueError('{:} is neither a file or a dir.'.format(
+ file_path_or_dict))
self.filename = os.path.basename(file_path_or_dict)
if fast_mode:
if nats_is_file(file_path_or_dict):
- raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict))
+ raise ValueError('fast_mode={:} must feed the path for directory '
+ ': {:}'.format(fast_mode, file_path_or_dict))
else:
self._archive_dir = file_path_or_dict
else:
if nats_is_dir(file_path_or_dict):
- raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict))
+ raise ValueError('fast_mode={:} must feed the path for file '
+ ': {:}'.format(fast_mode, file_path_or_dict))
else:
file_path_or_dict = pickle_load(file_path_or_dict)
elif isinstance(file_path_or_dict, dict):
@@ -93,68 +123,95 @@ class NATSsize(NASBenchMetaAPI):
self.verbose = verbose
if isinstance(file_path_or_dict, dict):
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
- for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
+ for key in keys:
+ if key not in file_path_or_dict:
+ raise ValueError('Can not find key[{:}] in the dict'.format(key))
self.meta_archs = copy.deepcopy(file_path_or_dict['meta_archs'])
- # This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
- self.arch2infos_dict = OrderedDict()
+ # NOTE(xuanyidong): This is a dict mapping each architecture to a dict,
+ # where the key is #epochs and the value is ArchResults
+ self.arch2infos_dict = collections.OrderedDict()
self._avaliable_hps = set()
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_infos = file_path_or_dict['arch2infos'][xkey]
- hp2archres = OrderedDict()
+ hp2archres = collections.OrderedDict()
for hp_key, results in all_infos.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
self.arch2infos_dict[xkey] = hp2archres
self.evaluated_indexes = set(file_path_or_dict['evaluated_indexes'])
elif self.archive_dir is not None:
- benchmark_meta = pickle_load('{:}/meta.{:}'.format(self.archive_dir, PICKLE_EXT))
+ benchmark_meta = pickle_load('{:}/meta.{:}'.format(
+ self.archive_dir, PICKLE_EXT))
self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs'])
- self.arch2infos_dict = OrderedDict()
+ self.arch2infos_dict = collections.OrderedDict()
self._avaliable_hps = set()
self.evaluated_indexes = set()
else:
- raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir must be set'.format(type(file_path_or_dict)))
+ raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir '
+ 'must be set'.format(type(file_path_or_dict)))
self.archstr2index = {}
for idx, arch in enumerate(self.meta_archs):
- assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
+ if arch in self.archstr2index:
+ raise ValueError('This [{:}]-th arch {:} already in the '
+ 'dict ({:}).'.format(
+ idx, arch, self.archstr2index[arch]))
self.archstr2index[arch] = idx
if self.verbose:
- print('{:} Create NATS-Bench (size) done with {:}/{:} architectures avaliable.'.format(
- time_string(), len(self.evaluated_indexes), len(self.meta_archs)))
+ print('{:} Create NATS-Bench (size) done with {:}/{:} architectures '
+ 'avaliable.'.format(time_string(),
+ len(self.evaluated_indexes),
+ len(self.meta_archs)))
- def query_info_str_by_arch(self, arch, hp: Text='12'):
- """ This function is used to query the information of a specific architecture
- 'arch' can be an architecture index or an architecture string
- When hp=01, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/01E.config'
- When hp=12, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/12E.config'
- When hp=90, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/90E.config'
- The difference between these three configurations are the number of training epochs.
+ def query_info_str_by_arch(self, arch, hp: Text = '12'):
+ """Query the information of a specific architecture.
+
+ Args:
+ arch: it can be an architecture index or an architecture string.
+
+ hp: the hyperparamete indicator, could be 01, 12, or 90. The difference
+ between these three configurations are the number of training epochs.
+
+ Returns:
+ ArchResults instance
"""
if self.verbose:
- print('{:} Call query_info_str_by_arch with arch={:} and hp={:}'.format(time_string(), arch, hp))
+ print('{:} Call query_info_str_by_arch with arch={:}'
+ 'and hp={:}'.format(time_string(), arch, hp))
return self._query_info_str_by_arch(arch, hp, print_information)
- def get_more_info(self, index, dataset: Text, iepoch=None, hp='12', is_random=True):
- """This function will return the metric for the `index`-th architecture
- `dataset` indicates the dataset:
+ def get_more_info(self,
+ index,
+ dataset,
+ iepoch=None,
+ hp: Text = '12',
+ is_random: bool = True):
+ """Return the metric for the `index`-th architecture.
+
+ Args:
+ index: the architecture index.
+ dataset:
'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set
'cifar100' : using the proposed train set of CIFAR-100 as the training set
'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set
- `iepoch` indicates the index of training epochs from 0 to 11/199.
+ iepoch: the index of training epochs from 0 to 11/199.
When iepoch=None, it will return the metric for the last training epoch
When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
- `hp` indicates different hyper-parameters for training
+ hp: indicates different hyper-parameters for training
When hp=01, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 01 epochs
When hp=12, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 12 epochs
When hp=90, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 90 epochs
- `is_random`
+ is_random:
When is_random=True, the performance of a random architecture will be returned
When is_random=False, the performanceo of all trials will be averaged.
+
+ Returns:
+ a dict, where key is the metric name and value is its value.
"""
if self.verbose:
- print('{:} Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(
- time_string(), index, dataset, iepoch, hp, is_random))
+ print('{:} Call the get_more_info function with index={:}, dataset={:}, '
+ 'iepoch={:}, hp={:}, and is_random={:}.'.format(
+ time_string(), index, dataset, iepoch, hp, is_random))
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
self._prepare_info(index)
if index not in self.arch2infos_dict:
@@ -165,38 +222,47 @@ class NATSsize(NASBenchMetaAPI):
seeds = archresult.get_dataset_seeds(dataset)
is_random = random.choice(seeds)
# collect the training information
- train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random)
+ train_info = archresult.get_metrics(
+ dataset, 'train', iepoch=iepoch, is_random=is_random)
total = train_info['iepoch'] + 1
- xinfo = {'train-loss' : train_info['loss'],
- 'train-accuracy': train_info['accuracy'],
- 'train-per-time': train_info['all_time'] / total,
- 'train-all-time': train_info['all_time']}
+ xinfo = {
+ 'train-loss': train_info['loss'],
+ 'train-accuracy': train_info['accuracy'],
+ 'train-per-time': train_info['all_time'] / total,
+ 'train-all-time': train_info['all_time']
+ }
# collect the evaluation information
if dataset == 'cifar10-valid':
- valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
+ valid_info = archresult.get_metrics(
+ dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
try:
- test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
- except:
+ test_info = archresult.get_metrics(
+ dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
+ except Exception as unused_e: # pylint: disable=broad-except
test_info = None
valtest_info = None
else:
- try: # collect results on the proposed test set
+ try: # collect results on the proposed test set
if dataset == 'cifar10':
- test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
+ test_info = archresult.get_metrics(
+ dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else:
- test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
- except:
+ test_info = archresult.get_metrics(
+ dataset, 'x-test', iepoch=iepoch, is_random=is_random)
+ except Exception as unused_e: # pylint: disable=broad-except
test_info = None
- try: # collect results on the proposed validation set
- valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
- except:
+ try: # collect results on the proposed validation set
+ valid_info = archresult.get_metrics(
+ dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
+ except Exception as unused_e: # pylint: disable=broad-except
valid_info = None
try:
if dataset != 'cifar10':
- valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
+ valtest_info = archresult.get_metrics(
+ dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else:
valtest_info = None
- except:
+ except Exception as unused_e: # pylint: disable=broad-except
valtest_info = None
if valid_info is not None:
xinfo['valid-loss'] = valid_info['loss']
@@ -216,11 +282,5 @@ class NATSsize(NASBenchMetaAPI):
return xinfo
def show(self, index: int = -1) -> None:
- """
- This function will print the information of a specific (or all) architecture(s).
-
- :param index: If the index < 0: it will loop for all architectures and print their information one by one.
- else: it will print the information of the 'index'-th architecture.
- :return: nothing
- """
+ """Print the information of a specific (or all) architecture(s)."""
self._show(index, print_information)
diff --git a/lib/nats_bench/api_test.py b/lib/nats_bench/api_test.py
new file mode 100644
index 0000000..a30118f
--- /dev/null
+++ b/lib/nats_bench/api_test.py
@@ -0,0 +1,59 @@
+##############################################################################
+# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 ##########################
+##############################################################################
+# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
+##############################################################################
+"""This file is used to quickly test the API."""
+import random
+
+from nats_bench.api_size import NATSsize
+from nats_bench.api_topology import NATStopology
+
+
+def test_nats_bench_tss(benchmark_dir):
+ return test_nats_bench(benchmark_dir, True)
+
+
+def test_nats_bench_sss(benchmark_dir):
+ return test_nats_bench(benchmark_dir, False)
+
+
+def test_nats_bench(benchmark_dir, is_tss, verbose=False):
+ if is_tss:
+ api = NATStopology(benchmark_dir, True, verbose)
+ else:
+ api = NATSsize(benchmark_dir, True, verbose)
+
+ test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)]
+ key2dataset = {'cifar10': 'CIFAR-10',
+ 'cifar100': 'CIFAR-100',
+ 'ImageNet16-120': 'ImageNet16-120'}
+
+ for index in test_indexes:
+ print('\n\nEvaluate the {:5d}-th architecture.'.format(index))
+
+ for key, dataset in key2dataset.items():
+ # Query the loss / accuracy / time for the `index`-th candidate
+ # architecture on CIFAR-10
+ # info is a dict, where you can easily figure out the meaning by key
+ info = api.get_more_info(index, key)
+ print(' -->> The performance on {:}: {:}'.format(dataset, info))
+
+ # Query the flops, params, latency. info is a dict.
+ info = api.get_cost_info(index, key)
+ print(' -->> The cost info on {:}: {:}'.format(dataset, info))
+
+ # Simulate the training of the `index`-th candidate:
+ validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(
+ index, dataset=key, hp='12')
+ print(' -->> The validation accuracy={:}, latency={:}, '
+ 'the current time cost={:} s, accumulated time cost={:} s'
+ .format(validation_accuracy, latency, time_cost,
+ current_total_time_cost))
+
+ # Print the configuration of the `index`-th architecture on CIFAR-10
+ config = api.get_net_config(index, key)
+ print(' -->> The configuration on {:} is {:}'.format(dataset, config))
+
+ # Show the information of the `index`-th architecture
+ api.show(index)
diff --git a/lib/nats_bench/api_topology.py b/lib/nats_bench/api_topology.py
index 9b0dccb..399daf8 100644
--- a/lib/nats_bench/api_topology.py
+++ b/lib/nats_bench/api_topology.py
@@ -2,61 +2,83 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
-#####################################################################################
-# The history of benchmark files (the name is NATS-tss-[version]-[md5].pickle.pbz2) #
-# [2020.08.31] NATS-tss-v1_0-3ffb9.pickle.pbz2 #
-#####################################################################################
-import os, copy, random, numpy as np
-from typing import List, Text, Union, Dict, Optional
-from collections import OrderedDict, defaultdict
-import warnings
-from .api_utils import time_string
-from .api_utils import pickle_load
-from .api_utils import ArchResults
-from .api_utils import NASBenchMetaAPI
-from .api_utils import remap_dataset_set_names
-from .api_utils import nats_is_dir
-from .api_utils import nats_is_file
-from .api_utils import PICKLE_EXT
+##############################################################################
+# The history of benchmark files are as follows, #
+# where the format is (the name is NATS-tss-[version]-[md5].pickle.pbz2) #
+# [2020.08.31] NATS-tss-v1_0-3ffb9.pickle.pbz2 #
+##############################################################################
+# pylint: disable=line-too-long
+"""The API for topology search space in NATS-Bench."""
+import collections
+import copy
+import os
+import random
+from typing import Any, Dict, List, Optional, Text, Union
+
+from nats_bench.api_utils import ArchResults
+from nats_bench.api_utils import NASBenchMetaAPI
+from nats_bench.api_utils import nats_is_dir
+from nats_bench.api_utils import nats_is_file
+from nats_bench.api_utils import PICKLE_EXT
+from nats_bench.api_utils import pickle_load
+from nats_bench.api_utils import time_string
+
+import numpy as np
ALL_BASE_NAMES = ['NATS-tss-v1_0-3ffb9']
def print_information(information, extra_info=None, show=False):
+ """print out the information of a given ArchResults."""
dataset_names = information.get_dataset_names()
- strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)]
+ strings = [
+ information.arch_str,
+ 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)
+ ]
+
def metric2str(loss, acc):
return 'loss = {:.3f} & top1 = {:.2f}%'.format(loss, acc)
- for ida, dataset in enumerate(dataset_names):
+ for dataset in dataset_names:
metric = information.get_compute_costs(dataset)
flop, param, latency = metric['flops'], metric['params'], metric['latency']
- str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None)
+ str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(
+ dataset, flop, param,
+ '{:.2f}'.format(latency *
+ 1000) if latency is not None and latency > 0 else None)
train_info = information.get_metrics(dataset, 'train')
if dataset == 'cifar10-valid':
valid_info = information.get_metrics(dataset, 'x-valid')
- str2 = '{:14s} train : [{:}], valid : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']))
+ str2 = '{:14s} train : [{:}], valid : [{:}]'.format(
+ dataset, metric2str(train_info['loss'], train_info['accuracy']),
+ metric2str(valid_info['loss'], valid_info['accuracy']))
elif dataset == 'cifar10':
test__info = information.get_metrics(dataset, 'ori-test')
- str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
+ str2 = '{:14s} train : [{:}], test : [{:}]'.format(
+ dataset, metric2str(train_info['loss'], train_info['accuracy']),
+ metric2str(test__info['loss'], test__info['accuracy']))
else:
valid_info = information.get_metrics(dataset, 'x-valid')
test__info = information.get_metrics(dataset, 'x-test')
- str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
+ str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(
+ dataset, metric2str(train_info['loss'], train_info['accuracy']),
+ metric2str(valid_info['loss'], valid_info['accuracy']),
+ metric2str(test__info['loss'], test__info['accuracy']))
strings += [str1, str2]
if show: print('\n'.join(strings))
return strings
-"""
-This is the class for the API of topology search space in NATS-Bench.
-"""
class NATStopology(NASBenchMetaAPI):
+ """This is the class for the API of topology search space in NATS-Bench."""
- """ The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
- def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, fast_mode: bool=False, verbose: bool=True):
- self.ALL_BASE_NAMES = ALL_BASE_NAMES
+ def __init__(self,
+ file_path_or_dict: Optional[Union[Text, Dict[Text, Any]]] = None,
+ fast_mode: bool = False,
+ verbose: bool = True):
+ """The initialization function that takes the dataset file path (or a dict loaded from that path) as input."""
+ self._all_base_names = ALL_BASE_NAMES
self.filename = None
self._search_space_name = 'topology'
self._fast_mode = fast_mode
@@ -64,25 +86,35 @@ class NATStopology(NASBenchMetaAPI):
self.reset_time()
if file_path_or_dict is None:
if self._fast_mode:
- self._archive_dir = os.path.join(os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
+ self._archive_dir = os.path.join(
+ os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
else:
- file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT))
- print ('{:} Try to use the default NATS-Bench (topology) path from {:}.'.format(time_string(), file_path_or_dict))
+ file_path_or_dict = os.path.join(
+ os.environ['TORCH_HOME'], '{:}.{:}'.format(
+ ALL_BASE_NAMES[-1], PICKLE_EXT))
+ print('{:} Try to use the default NATS-Bench (topology) path '
+ 'from {:}.'.format(time_string(), file_path_or_dict))
if isinstance(file_path_or_dict, str):
file_path_or_dict = str(file_path_or_dict)
if verbose:
- print('{:} Try to create the NATS-Bench (topology) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode))
- if not nats_is_file(file_path_or_dict) and not nats_is_dir(file_path_or_dict):
- raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict))
+ print('{:} Try to create the NATS-Bench (topology) api '
+ 'from {:} with fast_mode={:}'.format(
+ time_string(), file_path_or_dict, fast_mode))
+ if not nats_is_file(file_path_or_dict) and not nats_is_dir(
+ file_path_or_dict):
+ raise ValueError('{:} is neither a file or a dir.'.format(
+ file_path_or_dict))
self.filename = os.path.basename(file_path_or_dict)
if fast_mode:
if nats_is_file(file_path_or_dict):
- raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict))
+ raise ValueError('fast_mode={:} must feed the path for directory '
+ ': {:}'.format(fast_mode, file_path_or_dict))
else:
self._archive_dir = file_path_or_dict
else:
if nats_is_dir(file_path_or_dict):
- raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict))
+ raise ValueError('fast_mode={:} must feed the path for file '
+ ': {:}'.format(fast_mode, file_path_or_dict))
else:
file_path_or_dict = pickle_load(file_path_or_dict)
elif isinstance(file_path_or_dict, dict):
@@ -90,65 +122,73 @@ class NATStopology(NASBenchMetaAPI):
self.verbose = verbose
if isinstance(file_path_or_dict, dict):
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
- for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
+ for key in keys:
+ if key not in file_path_or_dict:
+ raise ValueError('Can not find key[{:}] in the dict'.format(key))
self.meta_archs = copy.deepcopy(file_path_or_dict['meta_archs'])
- # This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
- self.arch2infos_dict = OrderedDict()
+ # NOTE(xuanyidong): This is a dict mapping each architecture to a dict,
+ # where the key is #epochs and the value is ArchResults
+ self.arch2infos_dict = collections.OrderedDict()
self._avaliable_hps = set()
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_infos = file_path_or_dict['arch2infos'][xkey]
- hp2archres = OrderedDict()
+ hp2archres = collections.OrderedDict()
for hp_key, results in all_infos.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
self.arch2infos_dict[xkey] = hp2archres
self.evaluated_indexes = set(file_path_or_dict['evaluated_indexes'])
elif self.archive_dir is not None:
- benchmark_meta = pickle_load('{:}/meta.{:}'.format(self.archive_dir, PICKLE_EXT))
+ benchmark_meta = pickle_load('{:}/meta.{:}'.format(
+ self.archive_dir, PICKLE_EXT))
self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs'])
- self.arch2infos_dict = OrderedDict()
+ self.arch2infos_dict = collections.OrderedDict()
self._avaliable_hps = set()
self.evaluated_indexes = set()
else:
- raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir must be set'.format(type(file_path_or_dict)))
+ raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir '
+ 'must be set'.format(type(file_path_or_dict)))
self.archstr2index = {}
for idx, arch in enumerate(self.meta_archs):
- assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
+ if arch in self.archstr2index:
+ raise ValueError('This [{:}]-th arch {:} already in the '
+ 'dict ({:}).'.format(
+ idx, arch, self.archstr2index[arch]))
self.archstr2index[arch] = idx
if self.verbose:
- print('{:} Create NATS-Bench (topology) done with {:}/{:} architectures avaliable.'.format(
- time_string(), len(self.evaluated_indexes), len(self.meta_archs)))
+ print('{:} Create NATS-Bench (topology) done with {:}/{:} architectures '
+ 'avaliable.'.format(time_string(),
+ len(self.evaluated_indexes),
+ len(self.meta_archs)))
- def query_info_str_by_arch(self, arch, hp: Text='12'):
- """ This function is used to query the information of a specific architecture
- 'arch' can be an architecture index or an architecture string
- When hp=12, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/12E.config'
- When hp=200, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/200E.config'
- The difference between these three configurations are the number of training epochs.
+ def query_info_str_by_arch(self, arch, hp: Text = '12'):
+ """Query the information of a specific architecture.
+
+ Args:
+ arch: it can be an architecture index or an architecture string.
+
+ hp: the hyperparamete indicator, could be 12 or 200. The difference
+ between these three configurations are the number of training epochs.
+
+ Returns:
+ ArchResults instance
"""
if self.verbose:
- print('{:} Call query_info_str_by_arch with arch={:} and hp={:}'.format(time_string(), arch, hp))
+ print('{:} Call query_info_str_by_arch with arch={:}'
+ 'and hp={:}'.format(time_string(), arch, hp))
return self._query_info_str_by_arch(arch, hp, print_information)
- # obtain the metric for the `index`-th architecture
- # `dataset` indicates the dataset:
- # 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
- # 'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set
- # 'cifar100' : using the proposed train set of CIFAR-100 as the training set
- # 'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set
- # `iepoch` indicates the index of training epochs from 0 to 11/199.
- # When iepoch=None, it will return the metric for the last training epoch
- # When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
- # `use_12epochs_result` indicates different hyper-parameters for training
- # When use_12epochs_result=True, it trains the network with 12 epochs and the LR decayed from 0.1 to 0 within 12 epochs
- # When use_12epochs_result=False, it trains the network with 200 epochs and the LR decayed from 0.1 to 0 within 200 epochs
- # `is_random`
- # When is_random=True, the performance of a random architecture will be returned
- # When is_random=False, the performanceo of all trials will be averaged.
- def get_more_info(self, index, dataset, iepoch=None, hp='12', is_random=True):
+ def get_more_info(self,
+ index,
+ dataset,
+ iepoch=None,
+ hp: Text = '12',
+ is_random: bool = True):
+ """Return the metric for the `index`-th architecture."""
if self.verbose:
- print('{:} Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(
- time_string(), index, dataset, iepoch, hp, is_random))
+ print('{:} Call the get_more_info function with index={:}, dataset={:}, '
+ 'iepoch={:}, hp={:}, and is_random={:}.'.format(
+ time_string(), index, dataset, iepoch, hp, is_random))
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
self._prepare_info(index)
if index not in self.arch2infos_dict:
@@ -161,36 +201,43 @@ class NATStopology(NASBenchMetaAPI):
# collect the training information
train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random)
total = train_info['iepoch'] + 1
- xinfo = {'train-loss' : train_info['loss'],
- 'train-accuracy': train_info['accuracy'],
- 'train-per-time': train_info['all_time'] / total if train_info['all_time'] is not None else None,
- 'train-all-time': train_info['all_time']}
+ xinfo = {
+ 'train-loss':
+ train_info['loss'],
+ 'train-accuracy':
+ train_info['accuracy'],
+ 'train-per-time':
+ train_info['all_time'] /
+ total if train_info['all_time'] is not None else None,
+ 'train-all-time':
+ train_info['all_time']
+ }
# collect the evaluation information
if dataset == 'cifar10-valid':
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
try:
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
- except:
+ except Exception as unused_e: # pylint: disable=broad-except
test_info = None
valtest_info = None
else:
- try: # collect results on the proposed test set
+ try: # collect results on the proposed test set
if dataset == 'cifar10':
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else:
test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
- except:
+ except Exception as unused_e: # pylint: disable=broad-except
test_info = None
- try: # collect results on the proposed validation set
+ try: # collect results on the proposed validation set
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
- except:
+ except Exception as unused_e: # pylint: disable=broad-except
valid_info = None
try:
if dataset != 'cifar10':
valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else:
valtest_info = None
- except:
+ except Exception as unused_e: # pylint: disable=broad-except
valtest_info = None
if valid_info is not None:
xinfo['valid-loss'] = valid_info['loss']
@@ -214,46 +261,52 @@ class NATStopology(NASBenchMetaAPI):
self._show(index, print_information)
@staticmethod
- def str2lists(arch_str: Text) -> List[tuple]:
- """
- This function shows how to read the string-based architecture encoding.
- It is the same as the `str2structure` func in `AutoDL-Projects/lib/models/cell_searchs/genotypes.py`
+ def str2lists(arch_str: Text) -> List[Any]:
+ """Shows how to read the string-based architecture encoding.
- :param
+ Args:
arch_str: the input is a string indicates the architecture topology, such as
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
- :return: a list of tuple, contains multiple (op, input_node_index) pairs.
+ Returns:
+ a list of tuple, contains multiple (op, input_node_index) pairs.
- :usage
+ [USAGE]
+ It is the same as the `str2structure` func in AutoDL-Projects:
+ `github.com/D-X-Y/AutoDL-Projects/lib/models/cell_searchs/genotypes.py`
+ ```
arch = api.str2lists( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list
for i, node in enumerate(arch):
print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node))
+ ```
"""
node_strs = arch_str.split('+')
genotypes = []
- for i, node_str in enumerate(node_strs):
- inputs = list(filter(lambda x: x != '', node_str.split('|')))
- for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
- inputs = ( xi.split('~') for xi in inputs )
- input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
- genotypes.append( input_infos )
+ for unused_i, node_str in enumerate(node_strs):
+ inputs = list(filter(lambda x: x != '', node_str.split('|'))) # pylint: disable=g-explicit-bool-comparison
+ for xinput in inputs:
+ assert len(
+ xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
+ inputs = (xi.split('~') for xi in inputs)
+ input_infos = tuple((op, int(idx)) for (op, idx) in inputs)
+ genotypes.append(input_infos)
return genotypes
@staticmethod
def str2matrix(arch_str: Text,
- search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray:
- """
- This func shows how to convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101.
+ search_space: List[Text] = ('none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3')) -> np.ndarray:
+ """Convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101.
- :param
+ Args:
arch_str: the input is a string indicates the architecture topology, such as
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
search_space: a list of operation string, the default list is the topology search space for NATS-BENCH.
the default value should be be consistent with this line https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/models/cell_operations.py#L24
- :return
+
+ Returns:
the numpy matrix (2-D np.ndarray) representing the DAG of this architecture topology
- :usage
+
+ [USAGE]
matrix = api.str2matrix( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
This matrix is 4-by-4 matrix representing a cell with 4 nodes (only the lower left triangle is useful).
[ [0, 0, 0, 0], # the first line represents the input (0-th) node
@@ -262,19 +315,19 @@ class NATStopology(NASBenchMetaAPI):
[0, 0, 1, 0] ] # the fourth line represents the 3-rd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) + 1-th-op( 2-th-node )
In the topology search space in NATS-BENCH, 0-th-op is 'none', 1-th-op is 'skip_connect',
2-th-op is 'nor_conv_1x1', 3-th-op is 'nor_conv_3x3', 4-th-op is 'avg_pool_3x3'.
- :(NOTE)
+ [NOTE]
If a node has two input-edges from the same node, this function does not work. One edge will be overlapped.
"""
node_strs = arch_str.split('+')
num_nodes = len(node_strs) + 1
matrix = np.zeros((num_nodes, num_nodes))
for i, node_str in enumerate(node_strs):
- inputs = list(filter(lambda x: x != '', node_str.split('|')))
- for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
+ inputs = list(filter(lambda x: x != '', node_str.split('|'))) # pylint: disable=g-explicit-bool-comparison
+ for xinput in inputs:
+ assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
for xi in inputs:
op, idx = xi.split('~')
if op not in search_space: raise ValueError('this op ({:}) is not in {:}'.format(op, search_space))
op_idx, node_idx = search_space.index(op), int(idx)
matrix[i+1, node_idx] = op_idx
return matrix
-
diff --git a/lib/nats_bench/api_utils.py b/lib/nats_bench/api_utils.py
index aa49969..433a1aa 100644
--- a/lib/nats_bench/api_utils.py
+++ b/lib/nats_bench/api_utils.py
@@ -1,56 +1,47 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
-############################################################################################
-# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size
-############################################################################################
-# In this Python file, we define NASBenchMetaAPI, the abstract class for benchmark APIs.
-# We also define the class ArchResults, which contains all information of a single architecture trained by one kind of hyper-parameters on three datasets.
-# We also define the class ResultsCount, which contains all information of a single trial for a single architecture.
-############################################################################################
-# History:
-# [2020.07.31] The first version, where most content reused nas_201_api/api_utils.py
-#
-import os, abc, time, copy, random, numpy as np
-import bz2, pickle
+##############################################################################
+# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
+##############################################################################
+"""In this file, we define NASBenchMetaAPI, ArchResults, and ResultsCount.
+
+ NASBenchMetaAPI is the abstract class for benchmark APIs.
+ We also define the class ArchResults, which contains all
+ information of a single architecture trained by one kind of hyper-parameters
+ on three datasets. We also define the class ResultsCount, which contains all
+ information of a single trial for a single architecture.
+"""
+import abc
+import bz2
+import collections
+import copy
+import os
+import pickle
+import random
+import time
+from typing import Any, Dict, Optional, Text, Union
import warnings
-from typing import List, Text, Union, Dict, Optional
-from collections import OrderedDict, defaultdict
+
+import numpy as np
_FILE_SYSTEM = 'default'
PICKLE_EXT = 'pickle.pbz2'
-def pickle_save(obj, file_path, ext='.pbz2', protocol=4):
- """Use pickle to save data (obj) into file_path.
- According to https://docs.python.org/3/library/pickle.html#data-stream-format, Protocol version 4 was added in Python 3.4. It adds support for very large objects, pickling more kinds of objects, and some data format optimizations. It is the default protocol starting with Python 3.8.
- """
- # with open(file_path, 'wb') as cfile:
- with bz2.BZ2File(str(file_path) + ext, 'wb') as cfile:
- pickle.dump(obj, cfile, protocol=protocol)
-
-
-def pickle_load(file_path, ext='.pbz2'):
- # return pickle.load(open(file_path, "rb"))
- if os.path.isfile(str(file_path)):
- xfile_path = str(file_path)
- else:
- xfile_path = str(file_path) + ext
- with bz2.BZ2File(xfile_path, 'rb') as cfile:
- return pickle.load(cfile)
-
-
def time_string():
- ISOTIMEFORMAT='%Y-%m-%d %X'
- string = '[{:}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
+ iso_time_format = '%Y-%m-%d %X'
+ string = '[{:}]'.format(
+ time.strftime(iso_time_format, time.gmtime(time.time())))
return string
-def reset_file_system(lib: Text='default'):
+def reset_file_system(lib: Text = 'default'):
+ global _FILE_SYSTEM
_FILE_SYSTEM = lib
-def get_file_system(lib: Text='default'):
+def get_file_system():
return _FILE_SYSTEM
@@ -58,8 +49,8 @@ def nats_is_dir(file_path):
if _FILE_SYSTEM == 'default':
return os.path.isdir(file_path)
elif _FILE_SYSTEM == 'google':
- import tensorflow as tf
- return tf.gfile.isdir(file_path)
+ import tensorflow as tf # pylint: disable=g-import-not-at-top
+ return tf.io.gfile.isdir(file_path)
else:
raise ValueError('Unknown file system lib: {:}'.format(_FILE_SYSTEM))
@@ -68,36 +59,94 @@ def nats_is_file(file_path):
if _FILE_SYSTEM == 'default':
return os.path.isfile(file_path)
elif _FILE_SYSTEM == 'google':
- import tensorflow as tf
- return tf.gfile.exists(file_path) and not tf.gfile.isdir(file_path)
+ import tensorflow as tf # pylint: disable=g-import-not-at-top
+ return tf.io.gfile.exists(file_path) and not tf.io.gfile.isdir(file_path)
+ else:
+ raise ValueError('Unknown file system lib: {:}'.format(_FILE_SYSTEM))
+
+
+def pickle_save(obj, file_path, ext='.pbz2', protocol=4):
+ """Use pickle to save data (obj) into file_path.
+
+ Args:
+ obj: The object to be saved into a path.
+ file_path: The target saving path.
+ ext: The extension of file name.
+ protocol: The pickle protocol. According to this documentation
+ (https://docs.python.org/3/library/pickle.html#data-stream-format),
+ the protocol version 4 was added in Python 3.4. It adds support for very
+ large objects, pickling more kinds of objects, and some data format
+ optimizations. It is the default protocol starting with Python 3.8.
+ """
+ # with open(file_path, 'wb') as cfile:
+ if _FILE_SYSTEM == 'default':
+ with bz2.BZ2File(str(file_path) + ext, 'wb') as cfile:
+ pickle.dump(obj, cfile, protocol=protocol) # pytype: disable=wrong-arg-types
+ else:
+ raise ValueError('Unknown file system lib: {:}'.format(_FILE_SYSTEM))
+
+
+def pickle_load(file_path, ext='.pbz2'):
+ """Use pickle to load the file on different systems."""
+ # return pickle.load(open(file_path, "rb"))
+ if nats_is_file(str(file_path)):
+ xfile_path = str(file_path)
+ else:
+ xfile_path = str(file_path) + ext
+ if _FILE_SYSTEM == 'default':
+ with bz2.BZ2File(xfile_path, 'rb') as cfile:
+ return pickle.load(cfile) # pytype: disable=wrong-arg-types
+ elif _FILE_SYSTEM == 'google':
+ import tensorflow as tf # pylint: disable=g-import-not-at-top
+ file_content = tf.io.gfile.GFile(file_path, mode='rb').read()
+ byte_content = bz2.decompress(file_content)
+ return pickle.loads(byte_content)
else:
raise ValueError('Unknown file system lib: {:}'.format(_FILE_SYSTEM))
def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
- """re-map the metric_on_set to internal keys"""
+ """Re-map the metric_on_set to internal keys."""
if verbose:
- print('Call internal function _remap_dataset_set_names with dataset={:} and metric_on_set={:}'.format(dataset, metric_on_set))
+ print('Call internal function _remap_dataset_set_names with dataset={:} '
+ 'and metric_on_set={:}'.format(dataset, metric_on_set))
if dataset == 'cifar10' and metric_on_set == 'valid':
dataset, metric_on_set = 'cifar10-valid', 'x-valid'
elif dataset == 'cifar10' and metric_on_set == 'test':
dataset, metric_on_set = 'cifar10', 'ori-test'
elif dataset == 'cifar10' and metric_on_set == 'train':
dataset, metric_on_set = 'cifar10', 'train'
- elif (dataset == 'cifar100' or dataset == 'ImageNet16-120') and metric_on_set == 'valid':
+ elif (dataset == 'cifar100' or
+ dataset == 'ImageNet16-120') and metric_on_set == 'valid':
metric_on_set = 'x-valid'
- elif (dataset == 'cifar100' or dataset == 'ImageNet16-120') and metric_on_set == 'test':
+ elif (dataset == 'cifar100' or
+ dataset == 'ImageNet16-120') and metric_on_set == 'test':
metric_on_set = 'x-test'
if verbose:
- print(' return dataset={:} and metric_on_set={:}'.format(dataset, metric_on_set))
+ print(' return dataset={:} and metric_on_set={:}'.format(
+ dataset, metric_on_set))
return dataset, metric_on_set
class NASBenchMetaAPI(metaclass=abc.ABCMeta):
+ """The abstract class for NATS Bench API."""
@abc.abstractmethod
- def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, verbose: bool=True):
+ def __init__(self,
+ file_path_or_dict: Optional[Union[Text, Dict[Text, Any]]] = None,
+ fast_mode: bool = False,
+ verbose: bool = True):
"""The initialization function that takes the dataset file path (or a dict loaded from that path) as input."""
+ # NOTE(xuanyidong): the following attributes must be initilaized in subclass
+ self.meta_archs = None
+ self.verbose = None
+ self.evaluated_indexes = None
+ self.arch2infos_dict = None
+ self.filename = None
+ self._fast_mode = None
+ self._archive_dir = None
+ self._avaliable_hps = None
+ self.archstr2index = None
def __getitem__(self, index: int):
return copy.deepcopy(self.meta_archs[index])
@@ -106,16 +155,20 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
"""Return the topology structure of the `index`-th architecture."""
if self.verbose:
print('Call the arch function with index={:}'.format(index))
- assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
+ if index < 0 or index >= len(self.meta_archs):
+ raise ValueError('invalid index : {:} vs. {:}.'.format(
+ index, len(self.meta_archs)))
return copy.deepcopy(self.meta_archs[index])
def __len__(self):
return len(self.meta_archs)
def __repr__(self):
- return ('{name}({num}/{total} architectures, fast_mode={fast_mode}, file={filename})'.format(
- name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs),
- fast_mode=self.fast_mode, filename=self.filename))
+ return ('{name}({num}/{total} architectures, fast_mode={fast_mode}, '
+ 'file={filename})'.format(
+ name=self.__class__.__name__,
+ num=len(self.evaluated_indexes), total=len(self.meta_archs),
+ fast_mode=self.fast_mode, filename=self.filename))
@property
def avaliable_hps(self):
@@ -124,7 +177,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
@property
def used_time(self):
return self._used_time
-
+
@property
def search_space_name(self):
return self._search_space_name
@@ -146,15 +199,35 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
def reset_time(self):
self._used_time = 0
- def simulate_train_eval(self, arch, dataset, iepoch=None, hp='12', account_time=True):
+ @abc.abstractmethod
+ def get_more_info(self,
+ index,
+ dataset,
+ iepoch=None,
+ hp: Text = '12',
+ is_random: bool = True):
+ """Return the metric for the `index`-th architecture."""
+
+ def simulate_train_eval(self,
+ arch,
+ dataset,
+ iepoch=None,
+ hp='12',
+ account_time=True):
+ """This function is used to simulate training and evaluating an arch."""
index = self.query_index_by_arch(arch)
all_names = ('cifar10', 'cifar100', 'ImageNet16-120')
- assert dataset in all_names, 'Invalid dataset name : {:} vs {:}'.format(dataset, all_names)
+ if dataset not in all_names:
+ raise ValueError('Invalid dataset name : {:} vs {:}'.format(
+ dataset, all_names))
if dataset == 'cifar10':
- info = self.get_more_info(index, 'cifar10-valid', iepoch=iepoch, hp=hp, is_random=True)
+ info = self.get_more_info(
+ index, 'cifar10-valid', iepoch=iepoch, hp=hp, is_random=True)
else:
- info = self.get_more_info(index, dataset, iepoch=iepoch, hp=hp, is_random=True)
- valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
+ info = self.get_more_info(
+ index, dataset, iepoch=iepoch, hp=hp, is_random=True)
+ valid_acc, time_cost = info[
+ 'valid-accuracy'], info['train-all-time'] + info['valid-per-time']
latency = self.get_latency(index, dataset)
if account_time:
self._used_time += time_cost
@@ -165,18 +238,23 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
return random.randint(0, len(self.meta_archs)-1)
def reload(self, archive_root: Text = None, index: int = None):
- """Overwrite all information of the 'index'-th architecture in the search space,
- where the data will be loaded from 'archive_root'.
- If archive_root is None, it will try to load from the default path os.environ['TORCH_HOME'] / 'BASE_NAME'-full.
- If index is None, overwrite all ckps.
+ """Overwrite all information of the 'index'-th architecture in search space.
+
+ Args:
+ archive_root: If archive_root is None, it will try to load from the
+ default path os.environ['TORCH_HOME'] / 'BASE_NAME'-full.
+ index: If index is None, overwrite all ckps.
"""
if self.verbose:
print('{:} Call clear_params with archive_root={:} and index={:}'.format(
- time_string(), archive_root, index))
+ time_string(), archive_root, index))
if archive_root is None:
- archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(self.ALL_BASE_NAMES[-1]))
+ archive_root = os.path.join(os.environ['TORCH_HOME'],
+ '{:}-full'.format(self._all_base_names[-1]))
if not nats_is_dir(archive_root):
- warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root))
+ warnings.warn('The input archive_root is None and the default '
+ 'archive_root path ({:}) does not exist, try to use '
+ 'self.archive_dir.'.format(archive_root))
archive_root = self.archive_dir
if archive_root is None or not nats_is_dir(archive_root):
raise ValueError('Invalid archive_root : {:}'.format(archive_root))
@@ -185,47 +263,71 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
else:
indexes = [index]
for idx in indexes:
- assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
- xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT))
+ if not (0 <= idx < len(self.meta_archs)): # pylint: disable=superfluous-parens
+ raise ValueError('invalid index of {:}'.format(idx))
+ xfile_path = os.path.join(archive_root,
+ '{:06d}.{:}'.format(idx, PICKLE_EXT))
if not nats_is_file(xfile_path):
- xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT))
- assert nats_is_file(xfile_path), 'invalid data path : {:}'.format(xfile_path)
+ xfile_path = os.path.join(archive_root,
+ '{:d}.{:}'.format(idx, PICKLE_EXT))
+ assert nats_is_file(xfile_path), 'invalid data path : {:}'.format(
+ xfile_path)
xdata = pickle_load(xfile_path)
- assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path)
+ assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(
+ xfile_path)
self.evaluated_indexes.add(idx)
- hp2archres = OrderedDict()
+ hp2archres = collections.OrderedDict()
for hp_key, results in xdata.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key)
self.arch2infos_dict[idx] = hp2archres
def query_index_by_arch(self, arch):
- """ This function is used to query the index of an architecture in the search space.
- In the topology search space, the input arch can be an architecture string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|';
- or an instance that has the 'tostr' function that can generate the architecture string;
- or it is directly an architecture index, in this case, we will check whether it is valid or not.
- This function will return the index.
- If return -1, it means this architecture is not in the search space.
- Otherwise, it will return an int in [0, the-number-of-candidates-in-the-search-space).
+ """Query the index of an architecture in the search space.
+
+ Args:
+ arch: For topology search space, the input arch can be an architecture
+ string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|'; # pylint: disable=line-too-long
+ or an instance that has the 'tostr' function that can
+ generate the architecture string;
+ or it is directly an architecture index, in this case,
+ we will check whether it is valid or not.
+ This function will return the index.
+ If return -1, it means this architecture is not in the search space.
+ Otherwise, it will return an intenger in
+ [0, the-number-of-candidates-in-the-search-space).
+
+ Raises:
+ ValueError: If did not find the architecture in this benchmark.
+
+ Returns:
+ The index of the architcture in this benchmark.
"""
if self.verbose:
- print('{:} Call query_index_by_arch with arch={:}'.format(time_string(), arch))
+ print('{:} Call query_index_by_arch with arch={:}'.format(
+ time_string(), arch))
if isinstance(arch, int):
if 0 <= arch < len(self):
return arch
else:
- raise ValueError('Invalid architecture index {:} vs [{:}, {:}].'.format(arch, 0, len(self)))
+ raise ValueError('Invalid architecture index {:} vs [{:}, {:}].'.format(
+ arch, 0, len(self)))
elif isinstance(arch, str):
- if arch in self.archstr2index: arch_index = self.archstr2index[ arch ]
- else : arch_index = -1
+ if arch in self.archstr2index:
+ arch_index = self.archstr2index[arch]
+ else:
+ arch_index = -1
elif hasattr(arch, 'tostr'):
- if arch.tostr() in self.archstr2index: arch_index = self.archstr2index[ arch.tostr() ]
- else : arch_index = -1
- else: arch_index = -1
+ if arch.tostr() in self.archstr2index:
+ arch_index = self.archstr2index[arch.tostr()]
+ else:
+ arch_index = -1
+ else:
+ arch_index = -1
return arch_index
def query_by_arch(self, arch, hp):
- """This is to make the current version be compatible with the old version."""
+ """Make the current version be compatible with the old NAS-Bench-201 version."""
return self.query_info_str_by_arch(arch, hp)
def _prepare_info(self, index):
@@ -235,171 +337,257 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
self.reload(self.archive_dir, index)
elif not self.fast_mode:
if self.verbose:
- print('{:} Call _prepare_info with index={:} skip because it is not the fast mode.'.format(time_string(), index))
+ print('{:} Call _prepare_info with index={:} skip because it is not'
+ 'the fast mode.'.format(time_string(), index))
else:
- raise ValueError('Invalid status: fast_mode={:} and archive_dir={:}'.format(self.fast_mode, self.archive_dir))
+ raise ValueError('Invalid status: fast_mode={:} and '
+ 'archive_dir={:}'.format(
+ self.fast_mode, self.archive_dir))
else:
- assert index in self.evaluated_indexes, 'The index of {:} is not in self.evaluated_indexes, there must be something wrong.'.format(index)
+ if index not in self.evaluated_indexes:
+ raise ValueError('The index of {:} is not in self.evaluated_indexes, '
+ 'there must be something wrong.'.format(index))
if self.verbose:
- print('{:} Call _prepare_info with index={:} skip because it is in arch2infos_dict'.format(time_string(), index))
+ print('{:} Call _prepare_info with index={:} skip because it is in '
+ 'arch2infos_dict'.format(time_string(), index))
- def clear_params(self, index: int, hp: Optional[Text]=None):
+ def clear_params(self, index: int, hp: Optional[Text] = None):
"""Remove the architecture's weights to save memory.
- :arg
+
+ Args:
index: the index of the target architecture
hp: a flag to controll how to clear the parameters.
- -- None: clear all the weights in '01'/'12'/'90', which indicates the number of training epochs.
- -- '01' or '12' or '90': clear all the weights in arch2infos_dict[index][hp].
+ -- None: clear all the weights in '01'/'12'/'90', which indicates
+ the number of training epochs.
+ -- '01' or '12' or '90': clear all the weights in
+ arch2infos_dict[index][hp].
"""
if self.verbose:
- print('{:} Call clear_params with index={:} and hp={:}'.format(time_string(), index, hp))
+ print('{:} Call clear_params with index={:} and hp={:}'.format(
+ time_string(), index, hp))
if index not in self.arch2infos_dict:
- warnings.warn('The {:}-th architecture is not in the benchmark data yet, no need to clear params.'.format(index))
+ warnings.warn('The {:}-th architecture is not in the benchmark data yet, '
+ 'no need to clear params.'.format(index))
elif hp is None:
for key, result in self.arch2infos_dict[index].items():
result.clear_params()
else:
if str(hp) not in self.arch2infos_dict[index]:
- raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(index, list(self.arch2infos_dict[index].keys()), hp))
+ raise ValueError('The {:}-th architecture only has hyper-parameters '
+ 'of {:} instead of {:}.'.format(
+ index, list(self.arch2infos_dict[index].keys()),
+ hp))
self.arch2infos_dict[index][str(hp)].clear_params()
@abc.abstractmethod
- def query_info_str_by_arch(self, arch, hp: Text='12'):
+ def query_info_str_by_arch(self, arch, hp: Text = '12'):
"""This function is used to query the information of a specific architecture."""
- def _query_info_str_by_arch(self, arch, hp: Text='12', print_information=None):
+ def _query_info_str_by_arch(self,
+ arch,
+ hp: Text = '12',
+ print_information=None):
+ """Internal function to query the information of `arch` when using `hp`."""
arch_index = self.query_index_by_arch(arch)
self._prepare_info(arch_index)
if arch_index in self.arch2infos_dict:
if hp not in self.arch2infos_dict[arch_index]:
- raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(index, list(self.arch2infos_dict[arch_index].keys()), hp))
+ raise ValueError('The {:}-th architecture only has hyper-parameters of '
+ '{:} instead of {:}.'.format(
+ arch_index,
+ list(self.arch2infos_dict[arch_index].keys()), hp))
info = self.arch2infos_dict[arch_index][hp]
strings = print_information(info, 'arch-index={:}'.format(arch_index))
return '\n'.join(strings)
else:
- warnings.warn('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
+ warnings.warn('Find this arch-index : {:}, but this arch is not '
+ 'evaluated.'.format(arch_index))
return None
def query_meta_info_by_index(self, arch_index, hp: Text = '12'):
- """Return the ArchResults for the 'arch_index'-th architecture. This function is similar to query_by_index."""
+ """Return ArchResults for the 'arch_index'-th architecture."""
if self.verbose:
- print('Call query_meta_info_by_index with arch_index={:}, hp={:}'.format(arch_index, hp))
+ print('Call query_meta_info_by_index with arch_index={:}, hp={:}'.format(
+ arch_index, hp))
self._prepare_info(arch_index)
if arch_index in self.arch2infos_dict:
if hp not in self.arch2infos_dict[arch_index]:
- raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(arch_index, list(self.arch2infos_dict[arch_index].keys()), hp))
+ raise ValueError('The {:}-th architecture only has hyper-parameters of '
+ '{:} instead of {:}.'.format(
+ arch_index,
+ list(self.arch2infos_dict[arch_index].keys()),
+ hp))
info = self.arch2infos_dict[arch_index][hp]
else:
- raise ValueError('arch_index [{:}] does not in arch2infos'.format(arch_index))
+ raise ValueError('arch_index [{:}] does not in arch2infos'.format(
+ arch_index))
return copy.deepcopy(info)
- def query_by_index(self, arch_index: int, dataname: Union[None, Text] = None, hp: Text = '12'):
- """ This 'query_by_index' function is used to query information with the training of 01 epochs, 12 epochs, 90 epochs, or 200 epochs.
- ------
- If hp=01, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/01E.config)
- If hp=12, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/12E.config)
- If hp=90, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/90E.config)
- If hp=200, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/200E.config)
- ------
- If dataname is None, return the ArchResults
- else, return a dict with all trials on that dataset (the key is the seed)
- Options are 'cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'.
- -- cifar10-valid : training the model on the CIFAR-10 training set.
- -- cifar10 : training the model on the CIFAR-10 training + validation set.
- -- cifar100 : training the model on the CIFAR-100 training set.
- -- ImageNet16-120 : training the model on the ImageNet16-120 training set.
+ def query_by_index(self,
+ arch_index: int,
+ dataname: Union[None, Text] = None,
+ hp: Text = '12'):
+ """Query the information with the training of 01/12/90/200 epochs.
+
+ Args:
+ arch_index: The architecture index in this benchmark.
+ dataname: If dataname is None, return the ArchResults; otherwise, we will
+ return a dict with all trials on that dataset
+ (the key is the seed).
+ Options are 'cifar10-valid', 'cifar10', 'cifar100',
+ and 'ImageNet16-120'.
+ -- cifar10-valid : train the model on CIFAR-10 training set.
+ -- cifar10 : train the model on CIFAR-10 training + validation set.
+ -- cifar100 : train the model on CIFAR-100 training set.
+ -- ImageNet16-120 : train the model on ImageNet16-120 training set.
+ hp: The hyperparameters.
+ If hp=01, we train the model by 01 epochs.
+ If hp=12, we train the model by 01 epochs.
+ If hp=90, we train the model by 01 epochs.
+ If hp=200, we train the model by 01 epochs.
+ See github.com/D-X-Y/AutoDL-Projects/configs/nas-benchmark/hyper-opts
+ for more details.
+
+ Raises:
+ ValueError: If not find the matched serach space description.
+
+ Returns:
+ An instance fo ArchResults.
"""
if self.verbose:
- print('{:} Call query_by_index with arch_index={:}, dataname={:}, hp={:}'.format(time_string(), arch_index, dataname, hp))
+ print('{:} Call query_by_index with arch_index={:}, dataname={:}, '
+ 'hp={:}'.format(time_string(), arch_index, dataname, hp))
info = self.query_meta_info_by_index(arch_index, hp)
- if dataname is None: return info
+ if dataname is None:
+ return info
else:
if dataname not in info.get_dataset_names():
- raise ValueError('invalid dataset-name : {:} vs. {:}'.format(dataname, info.get_dataset_names()))
+ raise ValueError('invalid dataset-name : {:} vs. {:}'.format(
+ dataname, info.get_dataset_names()))
return info.query(dataname)
- def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, hp: Text = '12'):
+ def find_best(self,
+ dataset,
+ metric_on_set,
+ flop_max=None,
+ param_max=None,
+ hp: Text = '12'):
"""Find the architecture with the highest accuracy based on some constraints."""
if self.verbose:
- print('{:} Call find_best with dataset={:}, metric_on_set={:}, hp={:} | with #FLOPs < {:} and #Params < {:}'.format(
- time_string(), dataset, metric_on_set, hp, FLOP_max, Param_max))
- dataset, metric_on_set = remap_dataset_set_names(dataset, metric_on_set, self.verbose)
+ print('{:} Call find_best with dataset={:}, metric_on_set={:}, hp={:} '
+ '| with #FLOPs < {:} and #Params < {:}'.format(
+ time_string(), dataset, metric_on_set, hp, flop_max, param_max))
+ dataset, metric_on_set = remap_dataset_set_names(
+ dataset, metric_on_set, self.verbose)
best_index, highest_accuracy = -1, None
evaluated_indexes = sorted(list(self.evaluated_indexes))
- for i, arch_index in enumerate(evaluated_indexes):
+ for arch_index in evaluated_indexes:
arch_info = self.arch2infos_dict[arch_index][hp]
info = arch_info.get_compute_costs(dataset) # the information of costs
flop, param, latency = info['flops'], info['params'], info['latency']
- if FLOP_max is not None and flop > FLOP_max : continue
- if Param_max is not None and param > Param_max: continue
- xinfo = arch_info.get_metrics(dataset, metric_on_set) # the information of loss and accuracy
+ if flop_max is not None and flop > flop_max:
+ continue
+ if param_max is not None and param > param_max:
+ continue
+ xinfo = arch_info.get_metrics(
+ dataset, metric_on_set) # the information of loss and accuracy
loss, accuracy = xinfo['loss'], xinfo['accuracy']
if best_index == -1:
best_index, highest_accuracy = arch_index, accuracy
elif highest_accuracy < accuracy:
best_index, highest_accuracy = arch_index, accuracy
+ del latency, loss
if self.verbose:
- print(' the best architecture : [{:}] {:} with accuracy={:.3f}%'.format(best_index, self.arch(best_index), highest_accuracy))
+ print(' the best architecture : [{:}] {:} with accuracy={:.3f}%'.format(
+ best_index, self.arch(best_index), highest_accuracy))
return best_index, highest_accuracy
def get_net_param(self, index, dataset, seed: Optional[int], hp: Text = '12'):
- """
- This function is used to obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed`
- Args [seed]:
- -- None : return a dict containing the trained weights of all trials, where each key is a seed and its corresponding value is the weights.
- -- a interger : return the weights of a specific trial, whose seed is this interger.
- Args [hp]:
+ """Obtain the trained weights of the `index`-th arch on `dataset`.
+
+ Args:
+ index: The architecture index.
+ dataset: The training dataset name.
+ seed:
+ -- None : return a dict containing the trained weights of all trials,
+ where each key is a seed and its corresponding value
+ is the weights.
+ -- Interger : return the weights of a specific trial, whose seed
+ is this interger.
+ hp:
-- 01 : train the model by 01 epochs
-- 12 : train the model by 12 epochs
-- 90 : train the model by 90 epochs
-- 200 : train the model by 200 epochs
+ Returns:
+ PyTorch weights.
"""
if self.verbose:
- print('{:} Call the get_net_param function with index={:}, dataset={:}, seed={:}, hp={:}'.format(time_string(), index, dataset, seed, hp))
+ print('{:} Call the get_net_param function with index={:}, dataset={:}, '
+ 'seed={:}, hp={:}'.format(time_string(), index, dataset, seed, hp))
info = self.query_meta_info_by_index(index, hp)
return info.get_net_param(dataset, seed)
def get_net_config(self, index: int, dataset: Text):
- """
- This function is used to obtain the configuration for the `index`-th architecture on `dataset`.
- Args [dataset] (4 possible options):
- -- cifar10-valid : training the model on the CIFAR-10 training set.
- -- cifar10 : training the model on the CIFAR-10 training + validation set.
- -- cifar100 : training the model on the CIFAR-100 training set.
- -- ImageNet16-120 : training the model on the ImageNet16-120 training set.
- This function will return a dict.
- ========= Some examlpes for using this function:
+ """Obtain the configuration for the `index`-th architecture on `dataset`.
+
+ Args:
+ index: The architecture index.
+ dataset: 4 possible options as follows,
+ -- cifar10-valid : train the model on the CIFAR-10 training set.
+ -- cifar10 : train the model on the CIFAR-10 training + validation set.
+ -- cifar100 : train the model on the CIFAR-100 training set.
+ -- ImageNet16-120 : train the model on the ImageNet16-120 training set.
+ Returns:
+ A dict.
+
+ Note: some examlpes for using this function:
config = api.get_net_config(128, 'cifar10')
"""
if self.verbose:
- print('{:} Call the get_net_config function with index={:}, dataset={:}.'.format(time_string(), index, dataset))
+ print('{:} Call the get_net_config function with index={:}, '
+ 'dataset={:}.'.format(time_string(), index, dataset))
self._prepare_info(index)
if index in self.arch2infos_dict:
info = self.arch2infos_dict[index]
else:
- raise ValueError('The arch_index={:} is not in arch2infos_dict.'.format(index))
+ raise ValueError(
+ 'The arch_index={:} is not in arch2infos_dict.'.format(index))
info = next(iter(info.values()))
results = info.query(dataset, None)
results = next(iter(results.values()))
return results.get_config(None)
-
- def get_cost_info(self, index: int, dataset: Text, hp: Text = '12') -> Dict[Text, float]:
+
+ def get_cost_info(self,
+ index: int,
+ dataset: Text,
+ hp: Text = '12') -> Dict[Text, float]:
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
if self.verbose:
- print('{:} Call the get_cost_info function with index={:}, dataset={:}, and hp={:}.'.format(time_string(), index, dataset, hp))
+ print('{:} Call the get_cost_info function with index={:}, '
+ 'dataset={:}, and hp={:}.'.format(
+ time_string(), index, dataset, hp))
self._prepare_info(index)
info = self.query_meta_info_by_index(index, hp)
return info.get_compute_costs(dataset)
def get_latency(self, index: int, dataset: Text, hp: Text = '12') -> float:
- """
- To obtain the latency of the network (by default it will return the latency with the batch size of 256).
- :param index: the index of the target architecture
- :param dataset: the dataset name (cifar10-valid, cifar10, cifar100, ImageNet16-120)
- :return: return a float value in seconds
+ """Obtain the latency of the network.
+
+ Note: by default it will return the latency with the batch size of 256.
+ Args:
+ index: the index of the target architecture
+ dataset: the dataset name (cifar10-valid, cifar10, cifar100,
+ and ImageNet16-120)
+ hp: the hyperparamete indicator.
+
+ Returns:
+ return a float value in seconds
"""
if self.verbose:
- print('{:} Call the get_latency function with index={:}, dataset={:}, and hp={:}.'.format(time_string(), index, dataset, hp))
+ print('{:} Call the get_latency function with index={:}, '
+ 'dataset={:}, and hp={:}.'.format(
+ time_string(), index, dataset, hp))
cost_dict = self.get_cost_info(index, dataset, hp)
return cost_dict['latency']
@@ -408,49 +596,60 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
"""This function will print the information of a specific (or all) architecture(s)."""
def _show(self, index=-1, print_information=None) -> None:
- """
- This function will print the information of a specific (or all) architecture(s).
+ """Print the information of a specific (or all) architecture(s).
- :param index: If the index < 0: it will loop for all architectures and print their information one by one.
- else: it will print the information of the 'index'-th architecture.
- :return: nothing
+ Args:
+ index: If the index < 0: it will loop for all architectures and print
+ their information one by one. Else: it will print the information
+ of the 'index'-th architecture.
+
+ print_information: A function to print result.
+
+ Returns: None
"""
- if index < 0: # show all architectures
+ if index < 0: # show all architectures
print(self)
evaluated_indexes = sorted(list(self.evaluated_indexes))
for i, idx in enumerate(evaluated_indexes):
- print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(evaluated_indexes), idx) + '-'*10)
+ print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th '
+ 'architecture! '.format(i, len(evaluated_indexes), idx) + '-'*10)
print('arch : {:}'.format(self.meta_archs[idx]))
- for key, result in self.arch2infos_dict[index].items():
+ for unused_key, result in self.arch2infos_dict[index].items():
strings = print_information(result)
- print('>' * 40 + ' {:03d} epochs '.format(result.get_total_epoch()) + '>' * 40)
+ print('>' * 40 + ' {:03d} epochs '.format(
+ result.get_total_epoch()) + '>' * 40)
print('\n'.join(strings))
print('<' * 40 + '------------' + '<' * 40)
else:
if 0 <= index < len(self.meta_archs):
- if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated or not saved.'.format(index))
+ if index not in self.evaluated_indexes:
+ print('The {:}-th architecture has not been evaluated '
+ 'or not saved.'.format(index))
else:
- arch_info = self.arch2infos_dict[index]
- for key, result in self.arch2infos_dict[index].items():
+ # arch_info = self.arch2infos_dict[index]
+ for unused_key, result in self.arch2infos_dict[index].items():
strings = print_information(result)
- print('>' * 40 + ' {:03d} epochs '.format(result.get_total_epoch()) + '>' * 40)
+ print('>' * 40 + ' {:03d} epochs '.format(
+ result.get_total_epoch()) + '>' * 40)
print('\n'.join(strings))
print('<' * 40 + '------------' + '<' * 40)
else:
- print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
+ print('This index ({:}) is out of range (0~{:}).'.format(
+ index, len(self.meta_archs)))
def statistics(self, dataset: Text, hp: Union[Text, int]) -> Dict[int, int]:
"""This function will count the number of total trials."""
if self.verbose:
- print('Call the statistics function with dataset={:} and hp={:}.'.format(dataset, hp))
+ print('Call the statistics function with dataset={:} and hp={:}.'.format(
+ dataset, hp))
valid_datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
if dataset not in valid_datasets:
raise ValueError('{:} not in {:}'.format(dataset, valid_datasets))
- nums, hp = defaultdict(lambda: 0), str(hp)
+ nums, hp = collections.defaultdict(lambda: 0), str(hp)
# for index in range(len(self)):
for index in self.evaluated_indexes:
- archInfo = self.arch2infos_dict[index][hp]
- dataset_seed = archInfo.dataset_seed
+ arch_info = self.arch2infos_dict[index][hp]
+ dataset_seed = arch_info.dataset_seed
if dataset not in dataset_seed:
nums[0] += 1
else:
@@ -459,122 +658,151 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
class ArchResults(object):
+ """A class to maintain the results of an architecture under different settings."""
def __init__(self, arch_index, arch_str):
- self.arch_index = int(arch_index)
- self.arch_str = copy.deepcopy(arch_str)
- self.all_results = dict()
+ self.arch_index = int(arch_index)
+ self.arch_str = copy.deepcopy(arch_str)
+ self.all_results = dict()
self.dataset_seed = dict()
self.clear_net_done = False
def get_compute_costs(self, dataset):
+ """Return the computation cost on the input dataset."""
x_seeds = self.dataset_seed[dataset]
- results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
+ results = [self.all_results[(dataset, seed)] for seed in x_seeds]
- flops = [result.flop for result in results]
- params = [result.params for result in results]
+ flops = [result.flop for result in results]
+ params = [result.params for result in results]
latencies = [result.get_latency() for result in results]
latencies = [x for x in latencies if x > 0]
- mean_latency = np.mean(latencies) if len(latencies) > 0 else None
- time_infos = defaultdict(list)
+ mean_latency = np.mean(latencies) if len(latencies) else None
+ time_infos = collections.defaultdict(list)
for result in results:
time_info = result.get_times()
- for key, value in time_info.items(): time_infos[key].append( value )
-
- info = {'flops' : np.mean(flops),
- 'params' : np.mean(params),
- 'latency': mean_latency}
+ for key, value in time_info.items():
+ time_infos[key].append(value)
+
+ info = {
+ 'flops': np.mean(flops),
+ 'params': np.mean(params),
+ 'latency': mean_latency
+ }
for key, value in time_infos.items():
- if len(value) > 0 and value[0] is not None:
+ if len(value) and value[0] is not None:
info[key] = np.mean(value)
- else: info[key] = None
+ else:
+ info[key] = None
return info
def get_metrics(self, dataset, setname, iepoch=None, is_random=False):
- """
- This `get_metrics` function is used to obtain obtain the loss, accuracy, etc information on a specific dataset.
- If not specify, each set refer to the proposed split in NAS-Bench-201 paper.
+ """Obtain the loss, accuracy, etc information on a specific dataset.
+
+ If not specify, each set refer to the proposed split in NAS-Bench-201.
If some args return None or raise error, then it is not avaliable.
========================================
- Args [dataset] (4 possible options):
- -- cifar10-valid : training the model on the CIFAR-10 training set.
- -- cifar10 : training the model on the CIFAR-10 training + validation set.
- -- cifar100 : training the model on the CIFAR-100 training set.
- -- ImageNet16-120 : training the model on the ImageNet16-120 training set.
- Args [setname] (each dataset has different setnames):
- -- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test'
+
+ Args:
+ dataset: 4 possible options as follows
+ -- cifar10-valid : train the model on the CIFAR-10 training set.
+ -- cifar10 : train the model on the CIFAR-10 training + validation set.
+ -- cifar100 : train the model on the CIFAR-100 training set.
+ -- ImageNet16-120 : train the model on the ImageNet16-120 training set.
+ setname: each dataset has different setnames
+ -- When dataset = cifar10-valid, you can use 'train',
+ 'x-valid', and 'ori-test'
------ 'train' : the metric on the training set.
------ 'x-valid' : the metric on the validation set.
------ 'ori-test' : the metric on the test set.
-- When dataset = cifar10, you can use 'train', 'ori-test'.
------ 'train' : the metric on the training + validation set.
------ 'ori-test' : the metric on the test set.
- -- When dataset = cifar100 or ImageNet16-120, you can use 'train', 'ori-test', 'x-valid', 'x-test'
+ -- When dataset = cifar100 or ImageNet16-120, you can use 'train',
+ 'ori-test', 'x-valid', and 'x-test'
------ 'train' : the metric on the training set.
------ 'x-valid' : the metric on the validation set.
------ 'x-test' : the metric on the test set.
------ 'ori-test' : the metric on the validation + test set.
- Args [iepoch] (None or an integer in [0, the-number-of-total-training-epochs)
+ iepoch: (None or an integer in [0, the-number-of-total-training-epochs)
------ None : return the metric after the last training epoch.
------ an integer i : return the metric after the i-th training epoch.
- Args [is_random]:
+ is_random:
------ True : return the metric of a randomly selected trial.
------ False : return the averaged metric of all avaliable trials.
- ------ an integer indicating the 'seed' value : return the metric of a specific trial (whose random seed is 'is_random').
+ ------ an integer indicating the 'seed' value : return the metric of a
+ specific trial (whose random seed is 'is_random').
+
+ Returns:
+ All the metrics given the input setting.
"""
x_seeds = self.dataset_seed[dataset]
- results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
- infos = defaultdict(list)
+ results = [self.all_results[(dataset, seed)] for seed in x_seeds]
+ infos = collections.defaultdict(list)
for result in results:
if setname == 'train':
info = result.get_train(iepoch)
else:
info = result.get_eval(setname, iepoch)
- for key, value in info.items(): infos[key].append( value )
+ for key, value in info.items():
+ infos[key].append(value)
return_info = dict()
- if isinstance(is_random, bool) and is_random: # randomly select one
+ if isinstance(is_random, bool) and is_random: # randomly select one
index = random.randint(0, len(results)-1)
- for key, value in infos.items(): return_info[key] = value[index]
- elif isinstance(is_random, bool) and not is_random: # average
for key, value in infos.items():
- if len(value) > 0 and value[0] is not None:
+ return_info[key] = value[index]
+ elif isinstance(is_random, bool) and not is_random: # average
+ for key, value in infos.items():
+ if len(value) and value[0] is not None:
return_info[key] = np.mean(value)
- else: return_info[key] = None
- elif isinstance(is_random, int): # specify the seed
- if is_random not in x_seeds: raise ValueError('can not find random seed ({:}) from {:}'.format(is_random, x_seeds))
+ else:
+ return_info[key] = None
+ elif isinstance(is_random, int): # specify the seed
+ if is_random not in x_seeds:
+ raise ValueError('can not find random seed ({:}) from {:}'.format(
+ is_random, x_seeds))
index = x_seeds.index(is_random)
- for key, value in infos.items(): return_info[key] = value[index]
+ for key, value in infos.items():
+ return_info[key] = value[index]
else:
raise ValueError('invalid value for is_random: {:}'.format(is_random))
return return_info
- def show(self, is_print=False):
- return print_information(self, None, is_print)
+ # def show(self, is_print=False):
+ # return print_information(self, None, is_print)
def get_dataset_names(self):
return list(self.dataset_seed.keys())
def get_dataset_seeds(self, dataset):
- return copy.deepcopy( self.dataset_seed[dataset] )
+ return copy.deepcopy(self.dataset_seed[dataset])
- def get_net_param(self, dataset: Text, seed: Union[None, int] =None):
- """
- This function will return the trained network's weights on the 'dataset'.
- :arg
- dataset: one of 'cifar10-valid', 'cifar10', 'cifar100', and 'ImageNet16-120'.
- seed: an integer indicates the seed value or None that indicates returing all trials.
+ def get_net_param(self, dataset: Text, seed: Union[None, int] = None):
+ """Return the trained network's weights on the 'dataset'.
+
+ Args:
+ dataset: 'cifar10-valid', 'cifar10', 'cifar100', or 'ImageNet16-120'.
+ seed: an integer indicates the seed value
+ or None that indicates returing all trials.
+
+ Returns:
+ The trained weights (parameters).
"""
if seed is None:
x_seeds = self.dataset_seed[dataset]
- return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds}
+ return {
+ seed: self.all_results[(dataset, seed)].get_net_param()
+ for seed in x_seeds
+ }
else:
xkey = (dataset, seed)
if xkey in self.all_results:
return self.all_results[xkey].get_net_param()
else:
- raise ValueError('key={:} not in {:}'.format(xkey, list(self.all_results.keys())))
+ raise ValueError('key={:} not in {:}'.format(
+ xkey, list(self.all_results.keys())))
- def reset_latency(self, dataset: Text, seed: Union[None, Text], latency: float) -> None:
+ def reset_latency(self, dataset: Text, seed: Union[None, Text],
+ latency: float) -> None:
"""This function is used to reset the latency in all corresponding ResultsCount(s)."""
if seed is None:
for seed in self.dataset_seed[dataset]:
@@ -582,29 +810,37 @@ class ArchResults(object):
else:
self.all_results[(dataset, seed)].update_latency([latency])
- def reset_pseudo_train_times(self, dataset: Text, seed: Union[None, Text], estimated_per_epoch_time: float) -> None:
+ def reset_pseudo_train_times(self, dataset: Text, seed: Union[None, Text],
+ estimated_per_epoch_time: float) -> None:
"""This function is used to reset the train-times in all corresponding ResultsCount(s)."""
if seed is None:
for seed in self.dataset_seed[dataset]:
- self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
+ self.all_results[(
+ dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
else:
- self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
+ self.all_results[(
+ dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
- def reset_pseudo_eval_times(self, dataset: Text, seed: Union[None, Text], eval_name: Text, estimated_per_epoch_time: float) -> None:
+ def reset_pseudo_eval_times(self, dataset: Text, seed: Union[None, Text],
+ eval_name: Text,
+ estimated_per_epoch_time: float) -> None:
"""This function is used to reset the eval-times in all corresponding ResultsCount(s)."""
if seed is None:
for seed in self.dataset_seed[dataset]:
- self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time)
+ self.all_results[(dataset, seed)].reset_pseudo_eval_times(
+ eval_name, estimated_per_epoch_time)
else:
- self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time)
+ self.all_results[(dataset, seed)].reset_pseudo_eval_times(
+ eval_name, estimated_per_epoch_time)
def get_latency(self, dataset: Text) -> float:
- """Get the latency of a model on the target dataset. [Timestamp: 2020.03.09]"""
+ """Get the latency of a model on the target dataset."""
latencies = []
for seed in self.dataset_seed[dataset]:
latency = self.all_results[(dataset, seed)].get_latency()
if not isinstance(latency, float) or latency <= 0:
- raise ValueError('invalid latency of {:} with seed={:} : {:}'.format(dataset, seed, latency))
+ raise ValueError('invalid latency of {:} with seed={:} : {:}'.format(
+ dataset, seed, latency))
latencies.append(latency)
return sum(latencies) / len(latencies)
@@ -613,17 +849,26 @@ class ArchResults(object):
if dataset is None:
epochss = []
for xdata, x_seeds in self.dataset_seed.items():
- epochss += [self.all_results[(xdata, seed)].get_total_epoch() for seed in x_seeds]
+ epochss += [
+ self.all_results[(xdata, seed)].get_total_epoch()
+ for seed in x_seeds
+ ]
elif isinstance(dataset, str):
x_seeds = self.dataset_seed[dataset]
- epochss = [self.all_results[(dataset, seed)].get_total_epoch() for seed in x_seeds]
+ epochss = [
+ self.all_results[(dataset, seed)].get_total_epoch()
+ for seed in x_seeds
+ ]
else:
raise ValueError('invalid dataset={:}'.format(dataset))
- if len(set(epochss)) > 1: raise ValueError('Each trial mush have the same number of training epochs : {:}'.format(epochss))
+ if len(set(epochss)) > 1:
+ raise ValueError(
+ 'Each trial mush have the same number of training epochs : {:}'
+ .format(epochss))
return epochss[-1]
def query(self, dataset, seed=None):
- """Return the ResultsCount object (containing all information of a single trial) for 'dataset' and 'seed'"""
+ """Return the ResultsCount object (containing all information of a single trial) for 'dataset' and 'seed'."""
if seed is None:
x_seeds = self.dataset_seed[dataset]
return {seed: self.all_results[(dataset, seed)] for seed in x_seeds}
@@ -634,155 +879,189 @@ class ArchResults(object):
return '{:06d}'.format(self.arch_index)
def update(self, dataset_name, seed, result):
+ """Update the result for the given dataset and seed."""
if dataset_name not in self.dataset_seed:
self.dataset_seed[dataset_name] = []
- assert seed not in self.dataset_seed[dataset_name], '{:}-th arch alreadly has this seed ({:}) on {:}'.format(self.arch_index, seed, dataset_name)
- self.dataset_seed[ dataset_name ].append( seed )
- self.dataset_seed[ dataset_name ] = sorted( self.dataset_seed[ dataset_name ] )
+ if seed in self.dataset_seed[dataset_name]:
+ raise ValueError('{:}-th arch alreadly has this seed ({:}) on {:}'.format(
+ self.arch_index, seed, dataset_name))
+ self.dataset_seed[dataset_name].append(seed)
+ self.dataset_seed[dataset_name] = sorted(self.dataset_seed[dataset_name])
assert (dataset_name, seed) not in self.all_results
- self.all_results[ (dataset_name, seed) ] = result
+ self.all_results[(dataset_name, seed)] = result
self.clear_net_done = False
def state_dict(self):
+ """Return a dict that can be used to re-create this instance."""
state_dict = dict()
for key, value in self.__dict__.items():
- if key == 'all_results': # contain the class of ResultsCount
+ if key == 'all_results': # contain the class of ResultsCount
xvalue = dict()
- assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
- for _k, _v in value.items():
- assert isinstance(_v, ResultsCount), 'invalid type of value for {:}/{:} : {:}'.format(key, _k, type(_v))
- xvalue[_k] = _v.state_dict()
+ if not isinstance(value, dict):
+ raise ValueError('invalid type of value for {:} : {:}'.format(
+ key, type(value)))
+ for cur_k, cur_v in value.items():
+ if not isinstance(cur_v, ResultsCount):
+ raise ValueError('invalid type of value for {:}/{:} : {:}'.format(
+ key, cur_k, type(cur_v)))
+ xvalue[cur_k] = cur_v.state_dict()
else:
xvalue = value
state_dict[key] = xvalue
return state_dict
def load_state_dict(self, state_dict):
+ """Update self based on the input dict."""
new_state_dict = dict()
for key, value in state_dict.items():
- if key == 'all_results': # to convert to the class of ResultsCount
+ if key == 'all_results': # To convert to the class of ResultsCount
xvalue = dict()
- assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
- for _k, _v in value.items():
- xvalue[_k] = ResultsCount.create_from_state_dict(_v)
+ if not isinstance(value, dict):
+ raise ValueError('invalid type of value for {:} : {:}'.format(
+ key, type(value)))
+ for cur_k, cur_v in value.items():
+ xvalue[cur_k] = ResultsCount.create_from_state_dict(cur_v)
else: xvalue = value
new_state_dict[key] = xvalue
self.__dict__.update(new_state_dict)
@staticmethod
def create_from_state_dict(state_dict_or_file):
+ """Create the ArchResults instance from a dict or a file."""
x = ArchResults(-1, -1)
- if isinstance(state_dict_or_file, str): # a file path
+ if isinstance(state_dict_or_file, str): # a file path
state_dict = pickle_load(state_dict_or_file)
elif isinstance(state_dict_or_file, dict):
state_dict = state_dict_or_file
else:
- raise ValueError('invalid type of state_dict_or_file : {:}'.format(type(state_dict_or_file)))
+ raise ValueError('invalid type of state_dict_or_file : {:}'.format(
+ type(state_dict_or_file)))
x.load_state_dict(state_dict)
return x
- # This function is used to clear the weights saved in each 'result'
- # This can help reduce the memory footprint.
def clear_params(self):
- for key, result in self.all_results.items():
+ """Clear the weights saved in each 'result'."""
+ # NOTE(xuanyidong): This can help reduce the memory footprint.
+ for unused_key, result in self.all_results.items():
del result.net_state_dict
result.net_state_dict = None
self.clear_net_done = True
def debug_test(self):
- """This function is used for me to debug and test, which will call most methods."""
+ """Help debug and test, which will call most methods."""
all_dataset = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
for dataset in all_dataset:
print('---->>>> {:}'.format(dataset))
- print('The latency on {:} is {:} s'.format(dataset, self.get_latency(dataset)))
+ print('The latency on {:} is {:} s'.format(
+ dataset, self.get_latency(dataset)))
for seed in self.dataset_seed[dataset]:
result = self.all_results[(dataset, seed)]
print(' ==>> result = {:}'.format(result))
print(' ==>> cost = {:}'.format(result.get_times()))
def __repr__(self):
- return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done))
+ return ('{name}(arch-index={index}, arch={arch}, '
+ '{num} runs, clear={clear})'.format(
+ name=self.__class__.__name__,
+ index=self.arch_index,
+ arch=self.arch_str,
+ num=len(self.all_results),
+ clear=self.clear_net_done))
-"""
-This class (ResultsCount) is used to save the information of one trial for a single architecture.
-I did not write much comment for this class, because it is the lowest-level class in NAS-Bench-201 API, which will be rarely called.
-If you have any question regarding this class, please open an issue or email me.
-"""
class ResultsCount(object):
+ """ResultsCount is to save the information of one trial for a single architecture."""
- def __init__(self, name, state_dict, train_accs, train_losses, params, flop, arch_config, seed, epochs, latency):
- self.name = name
+ def __init__(self, name, state_dict, train_accs, train_losses, params, flop,
+ arch_config, seed, epochs, latency):
+ self.name = name
self.net_state_dict = state_dict
self.train_acc1es = copy.deepcopy(train_accs)
self.train_acc5es = None
self.train_losses = copy.deepcopy(train_losses)
- self.train_times = None
- self.arch_config = copy.deepcopy(arch_config)
- self.params = params
- self.flop = flop
- self.seed = seed
- self.epochs = epochs
- self.latency = latency
+ self.train_times = None
+ self.arch_config = copy.deepcopy(arch_config)
+ self.params = params
+ self.flop = flop
+ self.seed = seed
+ self.epochs = epochs
+ self.latency = latency
# evaluation results
self.reset_eval()
- def update_train_info(self, train_acc1es, train_acc5es, train_losses, train_times) -> None:
+ def update_train_info(self, train_acc1es, train_acc5es, train_losses,
+ train_times) -> None:
self.train_acc1es = train_acc1es
self.train_acc5es = train_acc5es
self.train_losses = train_losses
- self.train_times = train_times
+ self.train_times = train_times
def reset_pseudo_train_times(self, estimated_per_epoch_time: float) -> None:
"""Assign the training times."""
- train_times = OrderedDict()
+ train_times = collections.OrderedDict()
for i in range(self.epochs):
train_times[i] = estimated_per_epoch_time
self.train_times = train_times
- def reset_pseudo_eval_times(self, eval_name: Text, estimated_per_epoch_time: float) -> None:
+ def reset_pseudo_eval_times(
+ self, eval_name: Text, estimated_per_epoch_time: float) -> None:
"""Assign the evaluation times."""
- if eval_name not in self.eval_names: raise ValueError('invalid eval name : {:}'.format(eval_name))
+ if eval_name not in self.eval_names:
+ raise ValueError('invalid eval name : {:}'.format(eval_name))
for i in range(self.epochs):
- self.eval_times['{:}@{:}'.format(eval_name,i)] = estimated_per_epoch_time
+ self.eval_times['{:}@{:}'.format(eval_name, i)] = estimated_per_epoch_time
def reset_eval(self):
- self.eval_names = []
+ self.eval_names = []
self.eval_acc1es = {}
- self.eval_times = {}
+ self.eval_times = {}
self.eval_losses = {}
def update_latency(self, latency):
- self.latency = copy.deepcopy( latency )
+ self.latency = copy.deepcopy(latency)
def get_latency(self) -> float:
- """Return the latency value in seconds. -1 represents not avaliable ; otherwise it should be a float value"""
- if self.latency is None: return -1.0
- else: return sum(self.latency) / len(self.latency)
+ """Return the latency value in seconds."""
+ # NOTE(xuanyidong): -1 represents not avaliable,
+ # NOTE(xuanyidong): otherwise it should be a float value.
+ if self.latency is None:
+ return -1.0
+ else:
+ return sum(self.latency) / len(self.latency)
- def update_eval(self, accs, losses, times): # new version
+ def update_eval(self, accs, losses, times):
+ """To update the evaluataion results."""
data_names = set([x.split('@')[0] for x in accs.keys()])
for data_name in data_names:
- assert data_name not in self.eval_names, '{:} has already been added into eval-names'.format(data_name)
- self.eval_names.append( data_name )
+ if data_name in self.eval_names:
+ raise ValueError('{:} has already been added into '
+ 'eval-names'.format(data_name))
+ self.eval_names.append(data_name)
for iepoch in range(self.epochs):
xkey = '{:}@{:}'.format(data_name, iepoch)
- self.eval_acc1es[ xkey ] = accs[ xkey ]
- self.eval_losses[ xkey ] = losses[ xkey ]
- self.eval_times [ xkey ] = times[ xkey ]
+ self.eval_acc1es[xkey] = accs[xkey]
+ self.eval_losses[xkey] = losses[xkey]
+ self.eval_times[xkey] = times[xkey]
- def update_OLD_eval(self, name, accs, losses): # old version
+ def update_OLD_eval(self, name, accs, losses): # pylint: disable=invalid-name
+ """To update the evaluataion results (old NAS-Bench-201 version)."""
assert name not in self.eval_names, '{:} has already added'.format(name)
- self.eval_names.append( name )
+ self.eval_names.append(name)
for iepoch in range(self.epochs):
if iepoch in accs:
- self.eval_acc1es['{:}@{:}'.format(name,iepoch)] = accs[iepoch]
- self.eval_losses['{:}@{:}'.format(name,iepoch)] = losses[iepoch]
+ self.eval_acc1es['{:}@{:}'.format(name, iepoch)] = accs[iepoch]
+ self.eval_losses['{:}@{:}'.format(name, iepoch)] = losses[iepoch]
def __repr__(self):
num_eval = len(self.eval_names)
set_name = '[' + ', '.join(self.eval_names) + ']'
- return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: {set_name})'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval, set_name=set_name))
+ return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, '
+ 'Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: '
+ '{set_name})'.format(name=self.__class__.__name__, xname=self.name,
+ arch=self.arch_config['arch_str'],
+ flop=self.flop, param=self.params,
+ seed=self.seed, num_eval=num_eval,
+ set_name=set_name))
def get_total_epoch(self):
return copy.deepcopy(self.epochs)
@@ -790,16 +1069,22 @@ class ResultsCount(object):
def get_times(self):
"""Obtain the information regarding both training and evaluation time."""
if self.train_times is not None and isinstance(self.train_times, dict):
- train_times = list( self.train_times.values() )
- time_info = {'T-train@epoch': np.mean(train_times), 'T-train@total': np.sum(train_times)}
+ train_times = list(self.train_times.values())
+ time_info = {
+ 'T-train@epoch': np.mean(train_times),
+ 'T-train@total': np.sum(train_times)
+ }
else:
- time_info = {'T-train@epoch': None, 'T-train@total': None }
+ time_info = {'T-train@epoch': None, 'T-train@total': None}
for name in self.eval_names:
try:
- xtimes = [self.eval_times['{:}@{:}'.format(name,i)] for i in range(self.epochs)]
+ xtimes = [
+ self.eval_times['{:}@{:}'.format(name, i)]
+ for i in range(self.epochs)
+ ]
time_info['T-{:}@epoch'.format(name)] = np.mean(xtimes)
time_info['T-{:}@total'.format(name)] = np.sum(xtimes)
- except:
+ except Exception as unused_e: # pylint: disable=broad-except
time_info['T-{:}@epoch'.format(name)] = None
time_info['T-{:}@total'.format(name)] = None
return time_info
@@ -807,70 +1092,102 @@ class ResultsCount(object):
def get_eval_set(self):
return self.eval_names
- # get the training information
+ def judge_valid(self, iepoch):
+ if iepoch < 0 or iepoch >= self.epochs:
+ raise ValueError('invalid iepoch={:} < {:}'.format(iepoch, self.epochs))
+
def get_train(self, iepoch=None):
+ """Get the training information."""
if iepoch is None: iepoch = self.epochs-1
- assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
+ self.judge_valid(iepoch)
if self.train_times is not None:
xtime = self.train_times[iepoch]
atime = sum([self.train_times[i] for i in range(iepoch+1)])
- else: xtime, atime = None, None
- return {'iepoch' : iepoch,
- 'loss' : self.train_losses[iepoch],
- 'accuracy': self.train_acc1es[iepoch],
- 'cur_time': xtime,
- 'all_time': atime}
+ else:
+ xtime, atime = None, None
+ return {
+ 'iepoch': iepoch,
+ 'loss': self.train_losses[iepoch],
+ 'accuracy': self.train_acc1es[iepoch],
+ 'cur_time': xtime,
+ 'all_time': atime
+ }
def get_eval(self, name, iepoch=None):
"""Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument)."""
- if iepoch is None: iepoch = self.epochs-1
- assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
+ if iepoch is None:
+ iepoch = self.epochs-1
+ self.judge_valid(iepoch)
+
def _internal_query(xname):
- if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
+ if isinstance(self.eval_times, dict) and len(self.eval_times):
xtime = self.eval_times['{:}@{:}'.format(xname, iepoch)]
- atime = sum([self.eval_times['{:}@{:}'.format(xname, i)] for i in range(iepoch+1)])
+ atime = sum([
+ self.eval_times['{:}@{:}'.format(xname, i)]
+ for i in range(iepoch + 1)
+ ])
else:
xtime, atime = None, None
- return {'iepoch' : iepoch,
- 'loss' : self.eval_losses['{:}@{:}'.format(xname, iepoch)],
- 'accuracy': self.eval_acc1es['{:}@{:}'.format(xname, iepoch)],
- 'cur_time': xtime,
- 'all_time': atime}
+ return {
+ 'iepoch': iepoch,
+ 'loss': self.eval_losses['{:}@{:}'.format(xname, iepoch)],
+ 'accuracy': self.eval_acc1es['{:}@{:}'.format(xname, iepoch)],
+ 'cur_time': xtime,
+ 'all_time': atime
+ }
+
if name == 'valid':
return _internal_query('x-valid')
else:
return _internal_query(name)
def get_net_param(self, clone=False):
- if clone: return copy.deepcopy(self.net_state_dict)
- else: return self.net_state_dict
+ if clone:
+ return copy.deepcopy(self.net_state_dict)
+ else:
+ return self.net_state_dict
def get_config(self, str2structure):
"""This function is used to obtain the config dict for this architecture."""
if str2structure is None:
- # In this case, this is architecture in the size search space of NATS-BENCH.
- if 'name' in self.arch_config and self.arch_config['name'] == 'infer.shape.tiny':
- return {'name': 'infer.shape.tiny', 'channels': self.arch_config['channels'],
- 'genotype': self.arch_config['genotype'], 'num_classes': self.arch_config['class_num']}
- # In this case, this is architecture in the topology search space of NATS-BENCH.
- else:
- return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
- 'N' : self.arch_config['num_cells'],
- 'arch_str': self.arch_config['arch_str'], 'num_classes': self.arch_config['class_num']}
- else:
- # In this case, this is architecture in the size search space of NATS-BENCH.
- if 'name' in self.arch_config and self.arch_config['name'] == 'infer.shape.tiny':
- return {'name': 'infer.shape.tiny', 'channels': self.arch_config['channels'],
- 'genotype': str2structure(self.arch_config['genotype']), 'num_classes': self.arch_config['class_num']}
- # In this case, this is architecture in the topology search space of NATS-BENCH.
- else:
- return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
- 'N' : self.arch_config['num_cells'],
- 'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']}
+ # In this case, this is an arch in size search space of NATS-BENCH.
+ if 'name' in self.arch_config and self.arch_config[
+ 'name'] == 'infer.shape.tiny':
+ return {
+ 'name': 'infer.shape.tiny',
+ 'channels': self.arch_config['channels'],
+ 'genotype': self.arch_config['genotype'],
+ 'num_classes': self.arch_config['class_num']
+ }
+ else: # This is an arch in NATS-BENCH's topology search space.
+ return {
+ 'name': 'infer.tiny',
+ 'C': self.arch_config['channel'],
+ 'N': self.arch_config['num_cells'],
+ 'arch_str': self.arch_config['arch_str'],
+ 'num_classes': self.arch_config['class_num']
+ }
+ else: # This is an arch in the size search space of NATS-BENCH.
+ if 'name' in self.arch_config and self.arch_config[
+ 'name'] == 'infer.shape.tiny':
+ return {
+ 'name': 'infer.shape.tiny',
+ 'channels': self.arch_config['channels'],
+ 'genotype': str2structure(self.arch_config['genotype']),
+ 'num_classes': self.arch_config['class_num']
+ }
+ else: # This is an arch in the topology search space of NATS-BENCH.
+ return {
+ 'name': 'infer.tiny',
+ 'C': self.arch_config['channel'],
+ 'N': self.arch_config['num_cells'],
+ 'genotype': str2structure(self.arch_config['arch_str']),
+ 'num_classes': self.arch_config['class_num']
+ }
def state_dict(self):
- _state_dict = {key: value for key, value in self.__dict__.items()}
- return _state_dict
+ collected_state_dict = {key: value for key, value in self.__dict__.items()}
+ return collected_state_dict
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
diff --git a/scripts-search/NATS/search-size.sh b/scripts-search/NATS/search-size.sh
index df5b97d..86329f5 100644
--- a/scripts-search/NATS/search-size.sh
+++ b/scripts-search/NATS/search-size.sh
@@ -23,11 +23,11 @@ CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --warmup_ratio ${ratio} --rand_seed ${seed}
#
-CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed}
-CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed}
-CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed}
+CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --warmup_ratio ${ratio} --rand_seed ${seed}
+CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --warmup_ratio ${ratio} --rand_seed ${seed}
+CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --warmup_ratio ${ratio} --rand_seed ${seed}
#
-CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
-CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
-CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
+CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
+CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
+CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}