Merge branch 'master' of github.com:D-X-Y/AutoDL-Projects
This commit is contained in:
commit
d58b59a3f3
@ -6,3 +6,4 @@
|
||||
- [2019.01.31] [13e908f] GDAS codes were publicly released.
|
||||
- [2020.07.01] [a45808b] Upgrade NAS-API to the 2.0 version.
|
||||
- [2020.09.16] [7052265] Create NATS-BENCH.
|
||||
- [2020.10.15] [446262a] Update NATS-BENCH to version 1.0
|
||||
|
@ -61,7 +61,7 @@ At this moment, this project provides the following algorithms and scripts to ru
|
||||
</tr>
|
||||
<tr> <!-- (6-th row) -->
|
||||
<td align="center" valign="middle"> NATS-Bench </td>
|
||||
<td align="center" valign="middle"> <a href="https://xuanyidong.com/assets/projects/NATS-Bench"> NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size</a> </td>
|
||||
<td align="center" valign="middle"> <a href="https://xuanyidong.com/assets/projects/NATS-Bench"> NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size</a> </td>
|
||||
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/master/docs/NATS-Bench.md">NATS-Bench.md</a> </td>
|
||||
</tr>
|
||||
<tr> <!-- (7-th row) -->
|
||||
@ -100,7 +100,7 @@ Some methods use knowledge distillation (KD), which require pre-trained models.
|
||||
If you find that this project helps your research, please consider citing some of the following papers:
|
||||
```
|
||||
@article{dong2020nats,
|
||||
title={NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size},
|
||||
title={{NATS-Bench}: Benchmarking NAS Algorithms for Architecture Topology and Size},
|
||||
author={Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan},
|
||||
journal={arXiv preprint arXiv:2009.00437},
|
||||
year={2020}
|
||||
|
@ -61,7 +61,7 @@
|
||||
</tr>
|
||||
<tr> <!-- (6-th row) -->
|
||||
<td align="center" valign="middle"> NATS-Bench </td>
|
||||
<td align="center" valign="middle"> <a href="https://xuanyidong.com/assets/projects/NATS-Bench"> NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size</a> </td>
|
||||
<td align="center" valign="middle"> <a href="https://xuanyidong.com/assets/projects/NATS-Bench"> NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size</a> </td>
|
||||
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/master/docs/NATS-Bench.md">NATS-Bench.md</a> </td>
|
||||
</tr>
|
||||
<tr> <!-- (7-th row) -->
|
||||
@ -99,7 +99,7 @@ Some methods use knowledge distillation (KD), which require pre-trained models.
|
||||
如果您发现该项目对您的科研或工程有帮助,请考虑引用下列的某些文献:
|
||||
```
|
||||
@article{dong2020nats,
|
||||
title={NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size},
|
||||
title={{NATS-Bench}: Benchmarking NAS Algorithms for Architecture Topology and Size},
|
||||
author={Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan},
|
||||
journal={arXiv preprint arXiv:2009.00437},
|
||||
year={2020}
|
||||
|
@ -1,5 +1,7 @@
|
||||
# [NAS-BENCH-201: Extending the Scope of Reproducible Neural Architecture Search](https://openreview.net/forum?id=HJxyZkBKDr)
|
||||
|
||||
**Since our NAS-BENCH-201 has been extended to NATS-Bench, this `README` is deprecated and not maintained. Please use [NATS-Bench](https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NATS-Bench.md), which has 5x more architecture information and faster API than NAS-BENCH-201.**
|
||||
|
||||
We propose an algorithm-agnostic NAS benchmark (NAS-Bench-201) with a fixed search space, which provides a unified benchmark for almost any up-to-date NAS algorithms.
|
||||
The design of our search space is inspired by that used in the most popular cell-based searching algorithms, where a cell is represented as a directed acyclic graph.
|
||||
Each edge here is associated with an operation selected from a predefined operation set.
|
||||
@ -70,17 +72,18 @@ api.show(2)
|
||||
# show the mean loss and accuracy of an architecture
|
||||
info = api.query_meta_info_by_index(1) # This is an instance of `ArchResults`
|
||||
res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys
|
||||
cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
|
||||
cost_metrics = info.get_compute_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
|
||||
|
||||
# get the detailed information
|
||||
results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed
|
||||
print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1]))
|
||||
print ('Latency : {:}'.format(results[0].get_latency()))
|
||||
print ('Train Info : {:}'.format(results[0].get_train()))
|
||||
print ('Valid Info : {:}'.format(results[0].get_eval('x-valid')))
|
||||
print ('Test Info : {:}'.format(results[0].get_eval('x-test')))
|
||||
# for the metric after a specific epoch
|
||||
print ('Train Info [10-th epoch] : {:}'.format(results[0].get_train(10)))
|
||||
for seed, result in results.items():
|
||||
print ('Latency : {:}'.format(result.get_latency()))
|
||||
print ('Train Info : {:}'.format(result.get_train()))
|
||||
print ('Valid Info : {:}'.format(result.get_eval('x-valid')))
|
||||
print ('Test Info : {:}'.format(result.get_eval('x-test')))
|
||||
# for the metric after a specific epoch
|
||||
print ('Train Info [10-th epoch] : {:}'.format(result.get_train(10)))
|
||||
```
|
||||
|
||||
4. Query the index of an architecture by string
|
||||
@ -171,7 +174,7 @@ api.get_more_info(112, 'ImageNet16-120', None, hp='200', is_random=True)
|
||||
If you find that NAS-Bench-201 helps your research, please consider citing it:
|
||||
```
|
||||
@inproceedings{dong2020nasbench201,
|
||||
title = {NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search},
|
||||
title = {{NAS-Bench-201}: Extending the Scope of Reproducible Neural Architecture Search},
|
||||
author = {Dong, Xuanyi and Yang, Yi},
|
||||
booktitle = {International Conference on Learning Representations (ICLR)},
|
||||
url = {https://openreview.net/forum?id=HJxyZkBKDr},
|
||||
|
@ -1,5 +1,7 @@
|
||||
# [NAS-BENCH-201: Extending the Scope of Reproducible Neural Architecture Search](https://openreview.net/forum?id=HJxyZkBKDr)
|
||||
|
||||
**Since our NAS-BENCH-201 has been extended to NATS-Bench, this README is deprecated and not maintained. Please use [NATS-Bench](https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NATS-Bench.md), which has 5x more architecture information and faster API than NAS-BENCH-201.**
|
||||
|
||||
We propose an algorithm-agnostic NAS benchmark (NAS-Bench-201) with a fixed search space, which provides a unified benchmark for almost any up-to-date NAS algorithms.
|
||||
The design of our search space is inspired by that used in the most popular cell-based searching algorithms, where a cell is represented as a directed acyclic graph.
|
||||
Each edge here is associated with an operation selected from a predefined operation set.
|
||||
@ -68,17 +70,18 @@ api.show(2)
|
||||
# show the mean loss and accuracy of an architecture
|
||||
info = api.query_meta_info_by_index(1) # This is an instance of `ArchResults`
|
||||
res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys
|
||||
cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
|
||||
cost_metrics = info.get_compute_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
|
||||
|
||||
# get the detailed information
|
||||
results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed
|
||||
print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1]))
|
||||
print ('Latency : {:}'.format(results[0].get_latency()))
|
||||
print ('Train Info : {:}'.format(results[0].get_train()))
|
||||
print ('Valid Info : {:}'.format(results[0].get_eval('x-valid')))
|
||||
print ('Test Info : {:}'.format(results[0].get_eval('x-test')))
|
||||
# for the metric after a specific epoch
|
||||
print ('Train Info [10-th epoch] : {:}'.format(results[0].get_train(10)))
|
||||
for seed, result in results.items():
|
||||
print ('Latency : {:}'.format(result.get_latency()))
|
||||
print ('Train Info : {:}'.format(result.get_train()))
|
||||
print ('Valid Info : {:}'.format(result.get_eval('x-valid')))
|
||||
print ('Test Info : {:}'.format(result.get_eval('x-test')))
|
||||
# for the metric after a specific epoch
|
||||
print ('Train Info [10-th epoch] : {:}'.format(result.get_train(10)))
|
||||
```
|
||||
|
||||
4. Query the index of an architecture by string
|
||||
@ -242,7 +245,7 @@ In commands [1-6], the first args `cifar10` indicates the dataset name, the seco
|
||||
If you find that NAS-Bench-201 helps your research, please consider citing it:
|
||||
```
|
||||
@inproceedings{dong2020nasbench201,
|
||||
title = {NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search},
|
||||
title = {{NAS-Bench-201}: Extending the Scope of Reproducible Neural Architecture Search},
|
||||
author = {Dong, Xuanyi and Yang, Yi},
|
||||
booktitle = {International Conference on Learning Representations (ICLR)},
|
||||
url = {https://openreview.net/forum?id=HJxyZkBKDr},
|
||||
|
@ -1,4 +1,4 @@
|
||||
# [NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size](https://arxiv.org/pdf/2009.00437.pdf)
|
||||
# [NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size](https://arxiv.org/pdf/2009.00437.pdf)
|
||||
|
||||
Neural architecture search (NAS) has attracted a lot of attention and has been illustrated to bring tangible benefits in a large number of applications in the past few years. Network topology and network size have been regarded as two of the most important aspects for the performance of deep learning models and the community has spawned lots of searching algorithms for both of those aspects of the neural architectures. However, the performance gain from these searching algorithms is achieved under different search spaces and training setups. This makes the overall performance of the algorithms incomparable and the improvement from a sub-module of the searching model unclear.
|
||||
In this paper, we propose NATS-Bench, a unified benchmark on searching for both topology and size, for (almost) any up-to-date NAS algorithm.
|
||||
@ -7,11 +7,12 @@ We analyze the validity of our benchmark in terms of various criteria and perfor
|
||||
We also show the versatility of NATS-Bench by benchmarking 13 recent state-of-the-art NAS algorithms on it. All logs and diagnostic information trained using the same setup for each candidate are provided.
|
||||
This facilitates a much larger community of researchers to focus on developing better NAS algorithms in a more comparable and computationally effective environment.
|
||||
|
||||
**You can use `pip install nats_bench` to install the library of NATS-Bench.**
|
||||
|
||||
The structure of this Markdown file:
|
||||
- [How to use NATS-Bench?](#How-to-Use-NATS-Bench)
|
||||
- [How to re-create NATS-Bench from scratch?](#how-to-re-create-nats-bench-from-scratch)
|
||||
- [How to reproduce benchmarked results?](#to-reproduce-13-baseline-nas-algorithms-in-nas-bench-201)
|
||||
- [How to reproduce benchmarked results?](#to-reproduce-13-baseline-nas-algorithms-in-nats-bench)
|
||||
|
||||
|
||||
## How to Use [NATS-Bench](https://arxiv.org/pdf/2009.00437.pdf)
|
||||
@ -79,8 +80,12 @@ params = api.get_net_param(12, 'cifar10', None)
|
||||
network.load_state_dict(next(iter(params.values())))
|
||||
```
|
||||
|
||||
|
||||
|
||||
## How to Re-create NATS-Bench from Scratch
|
||||
|
||||
You need to use the [AutoDL-Projects](https://github.com/D-X-Y/AutoDL-Projects) repo to re-create NATS-Bench from scratch.
|
||||
|
||||
### The Size Search Space
|
||||
|
||||
The following command will train all architecture candidate in the size search space with 90 epochs and use the random seed of `777`.
|
||||
@ -110,7 +115,9 @@ python exps/NATS-Bench/tss-collect.py
|
||||
```
|
||||
|
||||
|
||||
## To Reproduce 13 Baseline NAS Algorithms in NAS-Bench-201
|
||||
## To Reproduce 13 Baseline NAS Algorithms in NATS-Bench
|
||||
|
||||
You need to use the [AutoDL-Projects](https://github.com/D-X-Y/AutoDL-Projects) repo to run 13 baseline NAS methods.
|
||||
|
||||
### Reproduce NAS methods on the topology search space
|
||||
|
||||
@ -171,18 +178,18 @@ python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HO
|
||||
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777
|
||||
|
||||
|
||||
Run the search strategy in FBNet-V2
|
||||
Run the channel search strategy in FBNet-V2 -- masking + Gumbel-Softmax :
|
||||
|
||||
python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
|
||||
python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
|
||||
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777
|
||||
python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
|
||||
python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
|
||||
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --rand_seed 777
|
||||
|
||||
|
||||
Run the search strategy in TuNAS:
|
||||
Run the channel search strategy in TuNAS -- masking + sampling :
|
||||
|
||||
python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0
|
||||
python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777
|
||||
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed 777
|
||||
python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 --use_api 0
|
||||
python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777
|
||||
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --rand_seed 777
|
||||
```
|
||||
|
||||
### Final Discovered Architectures for Each Algorithm
|
||||
@ -246,7 +253,7 @@ GDAS:
|
||||
If you find that NATS-Bench helps your research, please consider citing it:
|
||||
```
|
||||
@article{dong2020nats,
|
||||
title={NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size},
|
||||
title={{NATS-Bench}: Benchmarking NAS Algorithms for Architecture Topology and Size},
|
||||
author={Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan},
|
||||
journal={arXiv preprint arXiv:2009.00437},
|
||||
year={2020}
|
||||
|
@ -1,28 +1,30 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
###########################################################################################################################################
|
||||
#
|
||||
# In this file, we aims to evaluate three kinds of channel searching strategies:
|
||||
# - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
|
||||
# - masking + Gumbel-Softmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
|
||||
# - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
|
||||
# For simplicity, we use tas, fbv2, and tunas to refer these three strategies. Their official implementations are at the following links:
|
||||
# - masking + Gumbel-Softmax (mask_gumbel) from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
|
||||
# - masking + sampling (mask_rl) from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
|
||||
#
|
||||
# For simplicity, we use tas, mask_gumbel, and mask_rl to refer these three strategies. Their official implementations are at the following links:
|
||||
# - TAS: https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NeurIPS-2019-TAS.md
|
||||
# - FBNetV2: https://github.com/facebookresearch/mobile-vision
|
||||
# - TuNAS: https://github.com/google-research/google-research/tree/master/tunas
|
||||
####
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio 0.25
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --warmup_ratio 0.25
|
||||
####
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777
|
||||
####
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --rand_seed 777
|
||||
####
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 --use_api 0
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --rand_seed 777
|
||||
###########################################################################################################################################
|
||||
import os, sys, time, random, argparse
|
||||
import numpy as np
|
||||
@ -41,7 +43,7 @@ from models import get_cell_based_tiny_net, get_search_spaces
|
||||
from nats_bench import create
|
||||
|
||||
|
||||
# Ad-hoc for TuNAS
|
||||
# Ad-hoc for RL algorithms.
|
||||
class ExponentialMovingAverage(object):
|
||||
"""Class that maintains an exponential moving average."""
|
||||
|
||||
@ -94,13 +96,13 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
|
||||
a_optimizer.zero_grad()
|
||||
_, logits, log_probs = network(arch_inputs)
|
||||
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
||||
if algo == 'tunas':
|
||||
if algo == 'mask_rl':
|
||||
with torch.no_grad():
|
||||
RL_BASELINE_EMA.update(arch_prec1.item())
|
||||
rl_advantage = arch_prec1 - RL_BASELINE_EMA.value
|
||||
rl_log_prob = sum(log_probs)
|
||||
arch_loss = - rl_advantage * rl_log_prob
|
||||
elif algo == 'tas' or algo == 'fbv2':
|
||||
elif algo == 'tas' or algo == 'mask_gumbel':
|
||||
arch_loss = criterion(logits, arch_targets)
|
||||
else:
|
||||
raise ValueError('invalid algorightm name: {:}'.format(algo))
|
||||
@ -231,7 +233,7 @@ def main(xargs):
|
||||
|
||||
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller))
|
||||
|
||||
if xargs.algo == 'fbv2' or xargs.algo == 'tas':
|
||||
if xargs.algo == 'mask_gumbel' or xargs.algo == 'tas':
|
||||
network.set_tau(xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1))
|
||||
logger.log('[RESET tau as : {:}]'.format(network.tau))
|
||||
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
|
||||
@ -291,7 +293,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--data_path' , type=str, help='Path to dataset')
|
||||
parser.add_argument('--dataset' , type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||
parser.add_argument('--search_space', type=str, default='sss', choices=['sss'], help='The search space name.')
|
||||
parser.add_argument('--algo' , type=str, choices=['tas', 'fbv2', 'tunas'], help='The search space name.')
|
||||
parser.add_argument('--algo' , type=str, choices=['tas', 'mask_gumbel', 'mask_rl'], help='The search space name.')
|
||||
parser.add_argument('--genotype' , type=str, default='|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|', help='The genotype.')
|
||||
parser.add_argument('--use_api' , type=int, default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).')
|
||||
# FOR GDAS
|
||||
|
@ -43,9 +43,9 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suf
|
||||
# alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix)
|
||||
# alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix)
|
||||
# alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix)
|
||||
alg2name['channel-wise interpaltion'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix)
|
||||
alg2name['masking + Gumbel-Softmax'] = 'fbv2-affine0_BN0-AWD0.001{:}'.format(suffix)
|
||||
alg2name['masking + sampling'] = 'tunas-affine0_BN0-AWD0.0{:}'.format(suffix)
|
||||
alg2name['channel-wise interpolation'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix)
|
||||
alg2name['masking + Gumbel-Softmax'] = 'mask_gumbel-affine0_BN0-AWD0.001{:}'.format(suffix)
|
||||
alg2name['masking + sampling'] = 'mask_rl-affine0_BN0-AWD0.0{:}'.format(suffix)
|
||||
for alg, name in alg2name.items():
|
||||
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth')
|
||||
alg2data = OrderedDict()
|
||||
|
@ -3,8 +3,8 @@
|
||||
#####################################################
|
||||
# Here, we utilized three techniques to search for the number of channels:
|
||||
# - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
|
||||
# - masking + Gumbel-Softmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
|
||||
# - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
|
||||
# - masking + Gumbel-Softmax (mask_gumbel) from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
|
||||
# - masking + sampling (mask_rl) from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
|
||||
from typing import List, Text, Any
|
||||
import random, torch
|
||||
import torch.nn as nn
|
||||
@ -52,10 +52,10 @@ class GenericNAS301Model(nn.Module):
|
||||
def set_algo(self, algo: Text):
|
||||
# used for searching
|
||||
assert self._algo is None, 'This functioin can only be called once.'
|
||||
assert algo in ['fbv2', 'tunas', 'tas'], 'invalid algo : {:}'.format(algo)
|
||||
assert algo in ['mask_gumbel', 'mask_rl', 'tas'], 'invalid algo : {:}'.format(algo)
|
||||
self._algo = algo
|
||||
self._arch_parameters = nn.Parameter(1e-3*torch.randn(self._max_num_Cs, len(self._candidate_Cs)))
|
||||
# if algo == 'fbv2' or algo == 'tunas':
|
||||
# if algo == 'mask_gumbel' or algo == 'mask_rl':
|
||||
self.register_buffer('_masks', torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs)))
|
||||
for i in range(len(self._candidate_Cs)):
|
||||
self._masks.data[i, :self._candidate_Cs[i]] = 1
|
||||
@ -130,7 +130,7 @@ class GenericNAS301Model(nn.Module):
|
||||
else:
|
||||
mask = self._masks[random.randint(0, len(self._masks)-1)]
|
||||
feature = feature * mask.view(1, -1, 1, 1)
|
||||
elif self._algo == 'fbv2':
|
||||
elif self._algo == 'mask_gumbel':
|
||||
weights = nn.functional.gumbel_softmax(self._arch_parameters[idx:idx+1], tau=self.tau, dim=-1)
|
||||
mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1)
|
||||
feature = feature * mask
|
||||
@ -148,7 +148,7 @@ class GenericNAS301Model(nn.Module):
|
||||
else:
|
||||
miss = torch.zeros(feature.shape[0], feature.shape[1]-out.shape[1], feature.shape[2], feature.shape[3], device=feature.device)
|
||||
feature = torch.cat((out, miss), dim=1)
|
||||
elif self._algo == 'tunas':
|
||||
elif self._algo == 'mask_rl':
|
||||
prob = nn.functional.softmax(self._arch_parameters[idx:idx+1], dim=-1)
|
||||
dist = torch.distributions.Categorical(prob)
|
||||
action = dist.sample()
|
||||
|
@ -3,15 +3,18 @@
|
||||
##############################################################################
|
||||
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
|
||||
##############################################################################
|
||||
# The official Application Programming Interface (API) for NATS-Bench. #
|
||||
##############################################################################
|
||||
from .api_utils import pickle_save, pickle_load
|
||||
from .api_utils import ArchResults, ResultsCount
|
||||
from .api_topology import NATStopology
|
||||
from .api_size import NATSsize
|
||||
"""The official Application Programming Interface (API) for NATS-Bench."""
|
||||
from nats_bench.api_size import NATSsize
|
||||
from nats_bench.api_topology import NATStopology
|
||||
from nats_bench.api_utils import ArchResults
|
||||
from nats_bench.api_utils import pickle_load
|
||||
from nats_bench.api_utils import pickle_save
|
||||
from nats_bench.api_utils import ResultsCount
|
||||
|
||||
|
||||
NATS_BENCH_API_VERSIONs = ['v1.0'] # [2020.08.31]
|
||||
NATS_BENCH_SSS_NAMEs = ('sss', 'size')
|
||||
NATS_BENCH_TSS_NAMEs = ('tss', 'topology')
|
||||
|
||||
|
||||
def version():
|
||||
@ -24,13 +27,43 @@ def create(file_path_or_dict, search_space, fast_mode=False, verbose=True):
|
||||
Args:
|
||||
file_path_or_dict: None or a file path or a directory path.
|
||||
search_space: This is a string indicates the search space in NATS-Bench.
|
||||
fast_mode: If True, we will not load all the data at initialization, instead, the data for each candidate architecture will be loaded when quering it;
|
||||
If False, we will load all the data during initialization.
|
||||
fast_mode: If True, we will not load all the data at initialization,
|
||||
instead, the data for each candidate architecture will be loaded when
|
||||
quering it; If False, we will load all the data during initialization.
|
||||
verbose: This is a flag to indicate whether log additional information.
|
||||
|
||||
Raises:
|
||||
ValueError: If not find the matched serach space description.
|
||||
|
||||
Returns:
|
||||
The created NATS-Bench API.
|
||||
"""
|
||||
if search_space in ['tss', 'topology']:
|
||||
if search_space in NATS_BENCH_TSS_NAMEs:
|
||||
return NATStopology(file_path_or_dict, fast_mode, verbose)
|
||||
elif search_space in ['sss', 'size']:
|
||||
elif search_space in NATS_BENCH_SSS_NAMEs:
|
||||
return NATSsize(file_path_or_dict, fast_mode, verbose)
|
||||
else:
|
||||
raise ValueError('invalid search space : {:}'.format(search_space))
|
||||
|
||||
|
||||
def search_space_info(main_tag, aux_tag):
|
||||
"""Obtain the search space information."""
|
||||
nats_sss = dict(candidates=[8, 16, 24, 32, 40, 48, 56, 64],
|
||||
num_layers=5)
|
||||
nats_tss = dict(op_names=['none', 'skip_connect',
|
||||
'nor_conv_1x1', 'nor_conv_3x3',
|
||||
'avg_pool_3x3'],
|
||||
num_nodes=4)
|
||||
if main_tag == 'nats-bench':
|
||||
if aux_tag in NATS_BENCH_SSS_NAMEs:
|
||||
return nats_sss
|
||||
elif aux_tag in NATS_BENCH_TSS_NAMEs:
|
||||
return nats_tss
|
||||
else:
|
||||
raise ValueError('Unknown auxiliary tag: {:}'.format(aux_tag))
|
||||
elif main_tag == 'nas-bench-201':
|
||||
if aux_tag is not None:
|
||||
raise ValueError('For NAS-Bench-201, the auxiliary tag should be None.')
|
||||
return nats_tss
|
||||
else:
|
||||
raise ValueError('Unknown main tag: {:}'.format(main_tag))
|
||||
|
@ -1,65 +1,84 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
|
||||
##############################################################################
|
||||
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
|
||||
#####################################################################################
|
||||
# The history of benchmark files (the name is NATS-sss-[version]-[md5].pickle.pbz2) #
|
||||
# [2020.08.31] NATS-sss-v1_0-50262.pickle.pbz2 #
|
||||
#####################################################################################
|
||||
import os, copy, random, numpy as np
|
||||
from typing import List, Text, Union, Dict, Optional
|
||||
from collections import OrderedDict, defaultdict
|
||||
from .api_utils import time_string
|
||||
from .api_utils import pickle_load
|
||||
from .api_utils import ArchResults
|
||||
from .api_utils import NASBenchMetaAPI
|
||||
from .api_utils import remap_dataset_set_names
|
||||
from .api_utils import nats_is_dir
|
||||
from .api_utils import nats_is_file
|
||||
from .api_utils import PICKLE_EXT
|
||||
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
|
||||
##############################################################################
|
||||
# The history of benchmark files are as follows, #
|
||||
# where the format is (the name is NATS-sss-[version]-[md5].pickle.pbz2) #
|
||||
# [2020.08.31] NATS-sss-v1_0-50262.pickle.pbz2 #
|
||||
##############################################################################
|
||||
# pylint: disable=line-too-long
|
||||
"""The API for size search space in NATS-Bench."""
|
||||
import collections
|
||||
import copy
|
||||
import os
|
||||
import random
|
||||
from typing import Dict, Optional, Text, Union, Any
|
||||
|
||||
from nats_bench.api_utils import ArchResults
|
||||
from nats_bench.api_utils import NASBenchMetaAPI
|
||||
from nats_bench.api_utils import nats_is_dir
|
||||
from nats_bench.api_utils import nats_is_file
|
||||
from nats_bench.api_utils import PICKLE_EXT
|
||||
from nats_bench.api_utils import pickle_load
|
||||
from nats_bench.api_utils import time_string
|
||||
|
||||
|
||||
ALL_BASE_NAMES = ['NATS-sss-v1_0-50262']
|
||||
|
||||
|
||||
def print_information(information, extra_info=None, show=False):
|
||||
"""print out the information of a given ArchResults."""
|
||||
dataset_names = information.get_dataset_names()
|
||||
strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)]
|
||||
strings = [
|
||||
information.arch_str,
|
||||
'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)
|
||||
]
|
||||
|
||||
def metric2str(loss, acc):
|
||||
return 'loss = {:.3f} & top1 = {:.2f}%'.format(loss, acc)
|
||||
|
||||
for ida, dataset in enumerate(dataset_names):
|
||||
for dataset in dataset_names:
|
||||
metric = information.get_compute_costs(dataset)
|
||||
flop, param, latency = metric['flops'], metric['params'], metric['latency']
|
||||
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None)
|
||||
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(
|
||||
dataset, flop, param,
|
||||
'{:.2f}'.format(latency *
|
||||
1000) if latency is not None and latency > 0 else None)
|
||||
train_info = information.get_metrics(dataset, 'train')
|
||||
if dataset == 'cifar10-valid':
|
||||
valid_info = information.get_metrics(dataset, 'x-valid')
|
||||
test__info = information.get_metrics(dataset, 'ori-test')
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(
|
||||
dataset, metric2str(train_info['loss'], train_info['accuracy']),
|
||||
metric2str(valid_info['loss'], valid_info['accuracy']),
|
||||
metric2str(test__info['loss'], test__info['accuracy']))
|
||||
dataset, metric2str(train_info['loss'], train_info['accuracy']),
|
||||
metric2str(valid_info['loss'], valid_info['accuracy']),
|
||||
metric2str(test__info['loss'], test__info['accuracy']))
|
||||
elif dataset == 'cifar10':
|
||||
test__info = information.get_metrics(dataset, 'ori-test')
|
||||
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
|
||||
str2 = '{:14s} train : [{:}], test : [{:}]'.format(
|
||||
dataset, metric2str(train_info['loss'], train_info['accuracy']),
|
||||
metric2str(test__info['loss'], test__info['accuracy']))
|
||||
else:
|
||||
valid_info = information.get_metrics(dataset, 'x-valid')
|
||||
test__info = information.get_metrics(dataset, 'x-test')
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(
|
||||
dataset, metric2str(train_info['loss'], train_info['accuracy']),
|
||||
metric2str(valid_info['loss'], valid_info['accuracy']),
|
||||
metric2str(test__info['loss'], test__info['accuracy']))
|
||||
strings += [str1, str2]
|
||||
if show: print('\n'.join(strings))
|
||||
return strings
|
||||
|
||||
|
||||
"""
|
||||
This is the class for the API of size search space in NATS-Bench.
|
||||
"""
|
||||
class NATSsize(NASBenchMetaAPI):
|
||||
"""This is the class for the API of size search space in NATS-Bench."""
|
||||
|
||||
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
|
||||
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, fast_mode: bool=False, verbose: bool=True):
|
||||
self.ALL_BASE_NAMES = ALL_BASE_NAMES
|
||||
def __init__(self,
|
||||
file_path_or_dict: Optional[Union[Text, Dict[Text, Any]]] = None,
|
||||
fast_mode: bool = False,
|
||||
verbose: bool = True):
|
||||
"""The initialization function that takes the dataset file path (or a dict loaded from that path) as input."""
|
||||
self._all_base_names = ALL_BASE_NAMES
|
||||
self.filename = None
|
||||
self._search_space_name = 'size'
|
||||
self._fast_mode = fast_mode
|
||||
@ -67,25 +86,36 @@ class NATSsize(NASBenchMetaAPI):
|
||||
self.reset_time()
|
||||
if file_path_or_dict is None:
|
||||
if self._fast_mode:
|
||||
self._archive_dir = os.path.join(os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
|
||||
self._archive_dir = os.path.join(
|
||||
os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
|
||||
else:
|
||||
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT))
|
||||
print ('{:} Try to use the default NATS-Bench (size) path from fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode, file_path_or_dict))
|
||||
file_path_or_dict = os.path.join(
|
||||
os.environ['TORCH_HOME'], '{:}.{:}'.format(
|
||||
ALL_BASE_NAMES[-1], PICKLE_EXT))
|
||||
print('{:} Try to use the default NATS-Bench (size) path from '
|
||||
'fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode,
|
||||
file_path_or_dict))
|
||||
if isinstance(file_path_or_dict, str):
|
||||
file_path_or_dict = str(file_path_or_dict)
|
||||
if verbose:
|
||||
print('{:} Try to create the NATS-Bench (size) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode))
|
||||
if not nats_is_file(file_path_or_dict) and not nats_is_dir(file_path_or_dict):
|
||||
raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict))
|
||||
print('{:} Try to create the NATS-Bench (size) api '
|
||||
'from {:} with fast_mode={:}'.format(
|
||||
time_string(), file_path_or_dict, fast_mode))
|
||||
if not nats_is_file(file_path_or_dict) and not nats_is_dir(
|
||||
file_path_or_dict):
|
||||
raise ValueError('{:} is neither a file or a dir.'.format(
|
||||
file_path_or_dict))
|
||||
self.filename = os.path.basename(file_path_or_dict)
|
||||
if fast_mode:
|
||||
if nats_is_file(file_path_or_dict):
|
||||
raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict))
|
||||
raise ValueError('fast_mode={:} must feed the path for directory '
|
||||
': {:}'.format(fast_mode, file_path_or_dict))
|
||||
else:
|
||||
self._archive_dir = file_path_or_dict
|
||||
else:
|
||||
if nats_is_dir(file_path_or_dict):
|
||||
raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict))
|
||||
raise ValueError('fast_mode={:} must feed the path for file '
|
||||
': {:}'.format(fast_mode, file_path_or_dict))
|
||||
else:
|
||||
file_path_or_dict = pickle_load(file_path_or_dict)
|
||||
elif isinstance(file_path_or_dict, dict):
|
||||
@ -93,68 +123,95 @@ class NATSsize(NASBenchMetaAPI):
|
||||
self.verbose = verbose
|
||||
if isinstance(file_path_or_dict, dict):
|
||||
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
||||
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
||||
for key in keys:
|
||||
if key not in file_path_or_dict:
|
||||
raise ValueError('Can not find key[{:}] in the dict'.format(key))
|
||||
self.meta_archs = copy.deepcopy(file_path_or_dict['meta_archs'])
|
||||
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
|
||||
self.arch2infos_dict = OrderedDict()
|
||||
# NOTE(xuanyidong): This is a dict mapping each architecture to a dict,
|
||||
# where the key is #epochs and the value is ArchResults
|
||||
self.arch2infos_dict = collections.OrderedDict()
|
||||
self._avaliable_hps = set()
|
||||
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
|
||||
all_infos = file_path_or_dict['arch2infos'][xkey]
|
||||
hp2archres = OrderedDict()
|
||||
hp2archres = collections.OrderedDict()
|
||||
for hp_key, results in all_infos.items():
|
||||
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
|
||||
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
|
||||
self.arch2infos_dict[xkey] = hp2archres
|
||||
self.evaluated_indexes = set(file_path_or_dict['evaluated_indexes'])
|
||||
elif self.archive_dir is not None:
|
||||
benchmark_meta = pickle_load('{:}/meta.{:}'.format(self.archive_dir, PICKLE_EXT))
|
||||
benchmark_meta = pickle_load('{:}/meta.{:}'.format(
|
||||
self.archive_dir, PICKLE_EXT))
|
||||
self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs'])
|
||||
self.arch2infos_dict = OrderedDict()
|
||||
self.arch2infos_dict = collections.OrderedDict()
|
||||
self._avaliable_hps = set()
|
||||
self.evaluated_indexes = set()
|
||||
else:
|
||||
raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir must be set'.format(type(file_path_or_dict)))
|
||||
raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir '
|
||||
'must be set'.format(type(file_path_or_dict)))
|
||||
self.archstr2index = {}
|
||||
for idx, arch in enumerate(self.meta_archs):
|
||||
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
|
||||
if arch in self.archstr2index:
|
||||
raise ValueError('This [{:}]-th arch {:} already in the '
|
||||
'dict ({:}).'.format(
|
||||
idx, arch, self.archstr2index[arch]))
|
||||
self.archstr2index[arch] = idx
|
||||
if self.verbose:
|
||||
print('{:} Create NATS-Bench (size) done with {:}/{:} architectures avaliable.'.format(
|
||||
time_string(), len(self.evaluated_indexes), len(self.meta_archs)))
|
||||
print('{:} Create NATS-Bench (size) done with {:}/{:} architectures '
|
||||
'avaliable.'.format(time_string(),
|
||||
len(self.evaluated_indexes),
|
||||
len(self.meta_archs)))
|
||||
|
||||
def query_info_str_by_arch(self, arch, hp: Text='12'):
|
||||
""" This function is used to query the information of a specific architecture
|
||||
'arch' can be an architecture index or an architecture string
|
||||
When hp=01, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/01E.config'
|
||||
When hp=12, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/12E.config'
|
||||
When hp=90, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/90E.config'
|
||||
The difference between these three configurations are the number of training epochs.
|
||||
def query_info_str_by_arch(self, arch, hp: Text = '12'):
|
||||
"""Query the information of a specific architecture.
|
||||
|
||||
Args:
|
||||
arch: it can be an architecture index or an architecture string.
|
||||
|
||||
hp: the hyperparamete indicator, could be 01, 12, or 90. The difference
|
||||
between these three configurations are the number of training epochs.
|
||||
|
||||
Returns:
|
||||
ArchResults instance
|
||||
"""
|
||||
if self.verbose:
|
||||
print('{:} Call query_info_str_by_arch with arch={:} and hp={:}'.format(time_string(), arch, hp))
|
||||
print('{:} Call query_info_str_by_arch with arch={:}'
|
||||
'and hp={:}'.format(time_string(), arch, hp))
|
||||
return self._query_info_str_by_arch(arch, hp, print_information)
|
||||
|
||||
def get_more_info(self, index, dataset: Text, iepoch=None, hp='12', is_random=True):
|
||||
"""This function will return the metric for the `index`-th architecture
|
||||
`dataset` indicates the dataset:
|
||||
def get_more_info(self,
|
||||
index,
|
||||
dataset,
|
||||
iepoch=None,
|
||||
hp: Text = '12',
|
||||
is_random: bool = True):
|
||||
"""Return the metric for the `index`-th architecture.
|
||||
|
||||
Args:
|
||||
index: the architecture index.
|
||||
dataset:
|
||||
'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
|
||||
'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set
|
||||
'cifar100' : using the proposed train set of CIFAR-100 as the training set
|
||||
'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set
|
||||
`iepoch` indicates the index of training epochs from 0 to 11/199.
|
||||
iepoch: the index of training epochs from 0 to 11/199.
|
||||
When iepoch=None, it will return the metric for the last training epoch
|
||||
When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
|
||||
`hp` indicates different hyper-parameters for training
|
||||
hp: indicates different hyper-parameters for training
|
||||
When hp=01, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 01 epochs
|
||||
When hp=12, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 12 epochs
|
||||
When hp=90, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 90 epochs
|
||||
`is_random`
|
||||
is_random:
|
||||
When is_random=True, the performance of a random architecture will be returned
|
||||
When is_random=False, the performanceo of all trials will be averaged.
|
||||
|
||||
Returns:
|
||||
a dict, where key is the metric name and value is its value.
|
||||
"""
|
||||
if self.verbose:
|
||||
print('{:} Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(
|
||||
time_string(), index, dataset, iepoch, hp, is_random))
|
||||
print('{:} Call the get_more_info function with index={:}, dataset={:}, '
|
||||
'iepoch={:}, hp={:}, and is_random={:}.'.format(
|
||||
time_string(), index, dataset, iepoch, hp, is_random))
|
||||
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
|
||||
self._prepare_info(index)
|
||||
if index not in self.arch2infos_dict:
|
||||
@ -165,38 +222,47 @@ class NATSsize(NASBenchMetaAPI):
|
||||
seeds = archresult.get_dataset_seeds(dataset)
|
||||
is_random = random.choice(seeds)
|
||||
# collect the training information
|
||||
train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random)
|
||||
train_info = archresult.get_metrics(
|
||||
dataset, 'train', iepoch=iepoch, is_random=is_random)
|
||||
total = train_info['iepoch'] + 1
|
||||
xinfo = {'train-loss' : train_info['loss'],
|
||||
'train-accuracy': train_info['accuracy'],
|
||||
'train-per-time': train_info['all_time'] / total,
|
||||
'train-all-time': train_info['all_time']}
|
||||
xinfo = {
|
||||
'train-loss': train_info['loss'],
|
||||
'train-accuracy': train_info['accuracy'],
|
||||
'train-per-time': train_info['all_time'] / total,
|
||||
'train-all-time': train_info['all_time']
|
||||
}
|
||||
# collect the evaluation information
|
||||
if dataset == 'cifar10-valid':
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||
valid_info = archresult.get_metrics(
|
||||
dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||
try:
|
||||
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
test_info = archresult.get_metrics(
|
||||
dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
except Exception as unused_e: # pylint: disable=broad-except
|
||||
test_info = None
|
||||
valtest_info = None
|
||||
else:
|
||||
try: # collect results on the proposed test set
|
||||
try: # collect results on the proposed test set
|
||||
if dataset == 'cifar10':
|
||||
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
test_info = archresult.get_metrics(
|
||||
dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
else:
|
||||
test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
test_info = archresult.get_metrics(
|
||||
dataset, 'x-test', iepoch=iepoch, is_random=is_random)
|
||||
except Exception as unused_e: # pylint: disable=broad-except
|
||||
test_info = None
|
||||
try: # collect results on the proposed validation set
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
try: # collect results on the proposed validation set
|
||||
valid_info = archresult.get_metrics(
|
||||
dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||
except Exception as unused_e: # pylint: disable=broad-except
|
||||
valid_info = None
|
||||
try:
|
||||
if dataset != 'cifar10':
|
||||
valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
valtest_info = archresult.get_metrics(
|
||||
dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
else:
|
||||
valtest_info = None
|
||||
except:
|
||||
except Exception as unused_e: # pylint: disable=broad-except
|
||||
valtest_info = None
|
||||
if valid_info is not None:
|
||||
xinfo['valid-loss'] = valid_info['loss']
|
||||
@ -216,11 +282,5 @@ class NATSsize(NASBenchMetaAPI):
|
||||
return xinfo
|
||||
|
||||
def show(self, index: int = -1) -> None:
|
||||
"""
|
||||
This function will print the information of a specific (or all) architecture(s).
|
||||
|
||||
:param index: If the index < 0: it will loop for all architectures and print their information one by one.
|
||||
else: it will print the information of the 'index'-th architecture.
|
||||
:return: nothing
|
||||
"""
|
||||
"""Print the information of a specific (or all) architecture(s)."""
|
||||
self._show(index, print_information)
|
||||
|
59
lib/nats_bench/api_test.py
Normal file
59
lib/nats_bench/api_test.py
Normal file
@ -0,0 +1,59 @@
|
||||
##############################################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 ##########################
|
||||
##############################################################################
|
||||
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
|
||||
##############################################################################
|
||||
"""This file is used to quickly test the API."""
|
||||
import random
|
||||
|
||||
from nats_bench.api_size import NATSsize
|
||||
from nats_bench.api_topology import NATStopology
|
||||
|
||||
|
||||
def test_nats_bench_tss(benchmark_dir):
|
||||
return test_nats_bench(benchmark_dir, True)
|
||||
|
||||
|
||||
def test_nats_bench_sss(benchmark_dir):
|
||||
return test_nats_bench(benchmark_dir, False)
|
||||
|
||||
|
||||
def test_nats_bench(benchmark_dir, is_tss, verbose=False):
|
||||
if is_tss:
|
||||
api = NATStopology(benchmark_dir, True, verbose)
|
||||
else:
|
||||
api = NATSsize(benchmark_dir, True, verbose)
|
||||
|
||||
test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)]
|
||||
key2dataset = {'cifar10': 'CIFAR-10',
|
||||
'cifar100': 'CIFAR-100',
|
||||
'ImageNet16-120': 'ImageNet16-120'}
|
||||
|
||||
for index in test_indexes:
|
||||
print('\n\nEvaluate the {:5d}-th architecture.'.format(index))
|
||||
|
||||
for key, dataset in key2dataset.items():
|
||||
# Query the loss / accuracy / time for the `index`-th candidate
|
||||
# architecture on CIFAR-10
|
||||
# info is a dict, where you can easily figure out the meaning by key
|
||||
info = api.get_more_info(index, key)
|
||||
print(' -->> The performance on {:}: {:}'.format(dataset, info))
|
||||
|
||||
# Query the flops, params, latency. info is a dict.
|
||||
info = api.get_cost_info(index, key)
|
||||
print(' -->> The cost info on {:}: {:}'.format(dataset, info))
|
||||
|
||||
# Simulate the training of the `index`-th candidate:
|
||||
validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(
|
||||
index, dataset=key, hp='12')
|
||||
print(' -->> The validation accuracy={:}, latency={:}, '
|
||||
'the current time cost={:} s, accumulated time cost={:} s'
|
||||
.format(validation_accuracy, latency, time_cost,
|
||||
current_total_time_cost))
|
||||
|
||||
# Print the configuration of the `index`-th architecture on CIFAR-10
|
||||
config = api.get_net_config(index, key)
|
||||
print(' -->> The configuration on {:} is {:}'.format(dataset, config))
|
||||
|
||||
# Show the information of the `index`-th architecture
|
||||
api.show(index)
|
@ -2,61 +2,83 @@
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
|
||||
##############################################################################
|
||||
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
|
||||
#####################################################################################
|
||||
# The history of benchmark files (the name is NATS-tss-[version]-[md5].pickle.pbz2) #
|
||||
# [2020.08.31] NATS-tss-v1_0-3ffb9.pickle.pbz2 #
|
||||
#####################################################################################
|
||||
import os, copy, random, numpy as np
|
||||
from typing import List, Text, Union, Dict, Optional
|
||||
from collections import OrderedDict, defaultdict
|
||||
import warnings
|
||||
from .api_utils import time_string
|
||||
from .api_utils import pickle_load
|
||||
from .api_utils import ArchResults
|
||||
from .api_utils import NASBenchMetaAPI
|
||||
from .api_utils import remap_dataset_set_names
|
||||
from .api_utils import nats_is_dir
|
||||
from .api_utils import nats_is_file
|
||||
from .api_utils import PICKLE_EXT
|
||||
##############################################################################
|
||||
# The history of benchmark files are as follows, #
|
||||
# where the format is (the name is NATS-tss-[version]-[md5].pickle.pbz2) #
|
||||
# [2020.08.31] NATS-tss-v1_0-3ffb9.pickle.pbz2 #
|
||||
##############################################################################
|
||||
# pylint: disable=line-too-long
|
||||
"""The API for topology search space in NATS-Bench."""
|
||||
import collections
|
||||
import copy
|
||||
import os
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional, Text, Union
|
||||
|
||||
from nats_bench.api_utils import ArchResults
|
||||
from nats_bench.api_utils import NASBenchMetaAPI
|
||||
from nats_bench.api_utils import nats_is_dir
|
||||
from nats_bench.api_utils import nats_is_file
|
||||
from nats_bench.api_utils import PICKLE_EXT
|
||||
from nats_bench.api_utils import pickle_load
|
||||
from nats_bench.api_utils import time_string
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
ALL_BASE_NAMES = ['NATS-tss-v1_0-3ffb9']
|
||||
|
||||
|
||||
def print_information(information, extra_info=None, show=False):
|
||||
"""print out the information of a given ArchResults."""
|
||||
dataset_names = information.get_dataset_names()
|
||||
strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)]
|
||||
strings = [
|
||||
information.arch_str,
|
||||
'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)
|
||||
]
|
||||
|
||||
def metric2str(loss, acc):
|
||||
return 'loss = {:.3f} & top1 = {:.2f}%'.format(loss, acc)
|
||||
|
||||
for ida, dataset in enumerate(dataset_names):
|
||||
for dataset in dataset_names:
|
||||
metric = information.get_compute_costs(dataset)
|
||||
flop, param, latency = metric['flops'], metric['params'], metric['latency']
|
||||
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None)
|
||||
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(
|
||||
dataset, flop, param,
|
||||
'{:.2f}'.format(latency *
|
||||
1000) if latency is not None and latency > 0 else None)
|
||||
train_info = information.get_metrics(dataset, 'train')
|
||||
if dataset == 'cifar10-valid':
|
||||
valid_info = information.get_metrics(dataset, 'x-valid')
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']))
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}]'.format(
|
||||
dataset, metric2str(train_info['loss'], train_info['accuracy']),
|
||||
metric2str(valid_info['loss'], valid_info['accuracy']))
|
||||
elif dataset == 'cifar10':
|
||||
test__info = information.get_metrics(dataset, 'ori-test')
|
||||
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
|
||||
str2 = '{:14s} train : [{:}], test : [{:}]'.format(
|
||||
dataset, metric2str(train_info['loss'], train_info['accuracy']),
|
||||
metric2str(test__info['loss'], test__info['accuracy']))
|
||||
else:
|
||||
valid_info = information.get_metrics(dataset, 'x-valid')
|
||||
test__info = information.get_metrics(dataset, 'x-test')
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(
|
||||
dataset, metric2str(train_info['loss'], train_info['accuracy']),
|
||||
metric2str(valid_info['loss'], valid_info['accuracy']),
|
||||
metric2str(test__info['loss'], test__info['accuracy']))
|
||||
strings += [str1, str2]
|
||||
if show: print('\n'.join(strings))
|
||||
return strings
|
||||
|
||||
|
||||
"""
|
||||
This is the class for the API of topology search space in NATS-Bench.
|
||||
"""
|
||||
class NATStopology(NASBenchMetaAPI):
|
||||
"""This is the class for the API of topology search space in NATS-Bench."""
|
||||
|
||||
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
|
||||
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, fast_mode: bool=False, verbose: bool=True):
|
||||
self.ALL_BASE_NAMES = ALL_BASE_NAMES
|
||||
def __init__(self,
|
||||
file_path_or_dict: Optional[Union[Text, Dict[Text, Any]]] = None,
|
||||
fast_mode: bool = False,
|
||||
verbose: bool = True):
|
||||
"""The initialization function that takes the dataset file path (or a dict loaded from that path) as input."""
|
||||
self._all_base_names = ALL_BASE_NAMES
|
||||
self.filename = None
|
||||
self._search_space_name = 'topology'
|
||||
self._fast_mode = fast_mode
|
||||
@ -64,25 +86,35 @@ class NATStopology(NASBenchMetaAPI):
|
||||
self.reset_time()
|
||||
if file_path_or_dict is None:
|
||||
if self._fast_mode:
|
||||
self._archive_dir = os.path.join(os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
|
||||
self._archive_dir = os.path.join(
|
||||
os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
|
||||
else:
|
||||
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT))
|
||||
print ('{:} Try to use the default NATS-Bench (topology) path from {:}.'.format(time_string(), file_path_or_dict))
|
||||
file_path_or_dict = os.path.join(
|
||||
os.environ['TORCH_HOME'], '{:}.{:}'.format(
|
||||
ALL_BASE_NAMES[-1], PICKLE_EXT))
|
||||
print('{:} Try to use the default NATS-Bench (topology) path '
|
||||
'from {:}.'.format(time_string(), file_path_or_dict))
|
||||
if isinstance(file_path_or_dict, str):
|
||||
file_path_or_dict = str(file_path_or_dict)
|
||||
if verbose:
|
||||
print('{:} Try to create the NATS-Bench (topology) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode))
|
||||
if not nats_is_file(file_path_or_dict) and not nats_is_dir(file_path_or_dict):
|
||||
raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict))
|
||||
print('{:} Try to create the NATS-Bench (topology) api '
|
||||
'from {:} with fast_mode={:}'.format(
|
||||
time_string(), file_path_or_dict, fast_mode))
|
||||
if not nats_is_file(file_path_or_dict) and not nats_is_dir(
|
||||
file_path_or_dict):
|
||||
raise ValueError('{:} is neither a file or a dir.'.format(
|
||||
file_path_or_dict))
|
||||
self.filename = os.path.basename(file_path_or_dict)
|
||||
if fast_mode:
|
||||
if nats_is_file(file_path_or_dict):
|
||||
raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict))
|
||||
raise ValueError('fast_mode={:} must feed the path for directory '
|
||||
': {:}'.format(fast_mode, file_path_or_dict))
|
||||
else:
|
||||
self._archive_dir = file_path_or_dict
|
||||
else:
|
||||
if nats_is_dir(file_path_or_dict):
|
||||
raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict))
|
||||
raise ValueError('fast_mode={:} must feed the path for file '
|
||||
': {:}'.format(fast_mode, file_path_or_dict))
|
||||
else:
|
||||
file_path_or_dict = pickle_load(file_path_or_dict)
|
||||
elif isinstance(file_path_or_dict, dict):
|
||||
@ -90,65 +122,73 @@ class NATStopology(NASBenchMetaAPI):
|
||||
self.verbose = verbose
|
||||
if isinstance(file_path_or_dict, dict):
|
||||
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
||||
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
||||
for key in keys:
|
||||
if key not in file_path_or_dict:
|
||||
raise ValueError('Can not find key[{:}] in the dict'.format(key))
|
||||
self.meta_archs = copy.deepcopy(file_path_or_dict['meta_archs'])
|
||||
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
|
||||
self.arch2infos_dict = OrderedDict()
|
||||
# NOTE(xuanyidong): This is a dict mapping each architecture to a dict,
|
||||
# where the key is #epochs and the value is ArchResults
|
||||
self.arch2infos_dict = collections.OrderedDict()
|
||||
self._avaliable_hps = set()
|
||||
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
|
||||
all_infos = file_path_or_dict['arch2infos'][xkey]
|
||||
hp2archres = OrderedDict()
|
||||
hp2archres = collections.OrderedDict()
|
||||
for hp_key, results in all_infos.items():
|
||||
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
|
||||
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
|
||||
self.arch2infos_dict[xkey] = hp2archres
|
||||
self.evaluated_indexes = set(file_path_or_dict['evaluated_indexes'])
|
||||
elif self.archive_dir is not None:
|
||||
benchmark_meta = pickle_load('{:}/meta.{:}'.format(self.archive_dir, PICKLE_EXT))
|
||||
benchmark_meta = pickle_load('{:}/meta.{:}'.format(
|
||||
self.archive_dir, PICKLE_EXT))
|
||||
self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs'])
|
||||
self.arch2infos_dict = OrderedDict()
|
||||
self.arch2infos_dict = collections.OrderedDict()
|
||||
self._avaliable_hps = set()
|
||||
self.evaluated_indexes = set()
|
||||
else:
|
||||
raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir must be set'.format(type(file_path_or_dict)))
|
||||
raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir '
|
||||
'must be set'.format(type(file_path_or_dict)))
|
||||
self.archstr2index = {}
|
||||
for idx, arch in enumerate(self.meta_archs):
|
||||
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
|
||||
if arch in self.archstr2index:
|
||||
raise ValueError('This [{:}]-th arch {:} already in the '
|
||||
'dict ({:}).'.format(
|
||||
idx, arch, self.archstr2index[arch]))
|
||||
self.archstr2index[arch] = idx
|
||||
if self.verbose:
|
||||
print('{:} Create NATS-Bench (topology) done with {:}/{:} architectures avaliable.'.format(
|
||||
time_string(), len(self.evaluated_indexes), len(self.meta_archs)))
|
||||
print('{:} Create NATS-Bench (topology) done with {:}/{:} architectures '
|
||||
'avaliable.'.format(time_string(),
|
||||
len(self.evaluated_indexes),
|
||||
len(self.meta_archs)))
|
||||
|
||||
def query_info_str_by_arch(self, arch, hp: Text='12'):
|
||||
""" This function is used to query the information of a specific architecture
|
||||
'arch' can be an architecture index or an architecture string
|
||||
When hp=12, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/12E.config'
|
||||
When hp=200, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/200E.config'
|
||||
The difference between these three configurations are the number of training epochs.
|
||||
def query_info_str_by_arch(self, arch, hp: Text = '12'):
|
||||
"""Query the information of a specific architecture.
|
||||
|
||||
Args:
|
||||
arch: it can be an architecture index or an architecture string.
|
||||
|
||||
hp: the hyperparamete indicator, could be 12 or 200. The difference
|
||||
between these three configurations are the number of training epochs.
|
||||
|
||||
Returns:
|
||||
ArchResults instance
|
||||
"""
|
||||
if self.verbose:
|
||||
print('{:} Call query_info_str_by_arch with arch={:} and hp={:}'.format(time_string(), arch, hp))
|
||||
print('{:} Call query_info_str_by_arch with arch={:}'
|
||||
'and hp={:}'.format(time_string(), arch, hp))
|
||||
return self._query_info_str_by_arch(arch, hp, print_information)
|
||||
|
||||
# obtain the metric for the `index`-th architecture
|
||||
# `dataset` indicates the dataset:
|
||||
# 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
|
||||
# 'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set
|
||||
# 'cifar100' : using the proposed train set of CIFAR-100 as the training set
|
||||
# 'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set
|
||||
# `iepoch` indicates the index of training epochs from 0 to 11/199.
|
||||
# When iepoch=None, it will return the metric for the last training epoch
|
||||
# When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
|
||||
# `use_12epochs_result` indicates different hyper-parameters for training
|
||||
# When use_12epochs_result=True, it trains the network with 12 epochs and the LR decayed from 0.1 to 0 within 12 epochs
|
||||
# When use_12epochs_result=False, it trains the network with 200 epochs and the LR decayed from 0.1 to 0 within 200 epochs
|
||||
# `is_random`
|
||||
# When is_random=True, the performance of a random architecture will be returned
|
||||
# When is_random=False, the performanceo of all trials will be averaged.
|
||||
def get_more_info(self, index, dataset, iepoch=None, hp='12', is_random=True):
|
||||
def get_more_info(self,
|
||||
index,
|
||||
dataset,
|
||||
iepoch=None,
|
||||
hp: Text = '12',
|
||||
is_random: bool = True):
|
||||
"""Return the metric for the `index`-th architecture."""
|
||||
if self.verbose:
|
||||
print('{:} Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(
|
||||
time_string(), index, dataset, iepoch, hp, is_random))
|
||||
print('{:} Call the get_more_info function with index={:}, dataset={:}, '
|
||||
'iepoch={:}, hp={:}, and is_random={:}.'.format(
|
||||
time_string(), index, dataset, iepoch, hp, is_random))
|
||||
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
|
||||
self._prepare_info(index)
|
||||
if index not in self.arch2infos_dict:
|
||||
@ -161,36 +201,43 @@ class NATStopology(NASBenchMetaAPI):
|
||||
# collect the training information
|
||||
train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random)
|
||||
total = train_info['iepoch'] + 1
|
||||
xinfo = {'train-loss' : train_info['loss'],
|
||||
'train-accuracy': train_info['accuracy'],
|
||||
'train-per-time': train_info['all_time'] / total if train_info['all_time'] is not None else None,
|
||||
'train-all-time': train_info['all_time']}
|
||||
xinfo = {
|
||||
'train-loss':
|
||||
train_info['loss'],
|
||||
'train-accuracy':
|
||||
train_info['accuracy'],
|
||||
'train-per-time':
|
||||
train_info['all_time'] /
|
||||
total if train_info['all_time'] is not None else None,
|
||||
'train-all-time':
|
||||
train_info['all_time']
|
||||
}
|
||||
# collect the evaluation information
|
||||
if dataset == 'cifar10-valid':
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||
try:
|
||||
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
except Exception as unused_e: # pylint: disable=broad-except
|
||||
test_info = None
|
||||
valtest_info = None
|
||||
else:
|
||||
try: # collect results on the proposed test set
|
||||
try: # collect results on the proposed test set
|
||||
if dataset == 'cifar10':
|
||||
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
else:
|
||||
test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
except Exception as unused_e: # pylint: disable=broad-except
|
||||
test_info = None
|
||||
try: # collect results on the proposed validation set
|
||||
try: # collect results on the proposed validation set
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
except Exception as unused_e: # pylint: disable=broad-except
|
||||
valid_info = None
|
||||
try:
|
||||
if dataset != 'cifar10':
|
||||
valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
else:
|
||||
valtest_info = None
|
||||
except:
|
||||
except Exception as unused_e: # pylint: disable=broad-except
|
||||
valtest_info = None
|
||||
if valid_info is not None:
|
||||
xinfo['valid-loss'] = valid_info['loss']
|
||||
@ -214,46 +261,52 @@ class NATStopology(NASBenchMetaAPI):
|
||||
self._show(index, print_information)
|
||||
|
||||
@staticmethod
|
||||
def str2lists(arch_str: Text) -> List[tuple]:
|
||||
"""
|
||||
This function shows how to read the string-based architecture encoding.
|
||||
It is the same as the `str2structure` func in `AutoDL-Projects/lib/models/cell_searchs/genotypes.py`
|
||||
def str2lists(arch_str: Text) -> List[Any]:
|
||||
"""Shows how to read the string-based architecture encoding.
|
||||
|
||||
:param
|
||||
Args:
|
||||
arch_str: the input is a string indicates the architecture topology, such as
|
||||
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
|
||||
:return: a list of tuple, contains multiple (op, input_node_index) pairs.
|
||||
Returns:
|
||||
a list of tuple, contains multiple (op, input_node_index) pairs.
|
||||
|
||||
:usage
|
||||
[USAGE]
|
||||
It is the same as the `str2structure` func in AutoDL-Projects:
|
||||
`github.com/D-X-Y/AutoDL-Projects/lib/models/cell_searchs/genotypes.py`
|
||||
```
|
||||
arch = api.str2lists( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
|
||||
print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list
|
||||
for i, node in enumerate(arch):
|
||||
print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node))
|
||||
```
|
||||
"""
|
||||
node_strs = arch_str.split('+')
|
||||
genotypes = []
|
||||
for i, node_str in enumerate(node_strs):
|
||||
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||
inputs = ( xi.split('~') for xi in inputs )
|
||||
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
|
||||
genotypes.append( input_infos )
|
||||
for unused_i, node_str in enumerate(node_strs):
|
||||
inputs = list(filter(lambda x: x != '', node_str.split('|'))) # pylint: disable=g-explicit-bool-comparison
|
||||
for xinput in inputs:
|
||||
assert len(
|
||||
xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||
inputs = (xi.split('~') for xi in inputs)
|
||||
input_infos = tuple((op, int(idx)) for (op, idx) in inputs)
|
||||
genotypes.append(input_infos)
|
||||
return genotypes
|
||||
|
||||
@staticmethod
|
||||
def str2matrix(arch_str: Text,
|
||||
search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray:
|
||||
"""
|
||||
This func shows how to convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101.
|
||||
search_space: List[Text] = ('none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3')) -> np.ndarray:
|
||||
"""Convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101.
|
||||
|
||||
:param
|
||||
Args:
|
||||
arch_str: the input is a string indicates the architecture topology, such as
|
||||
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
|
||||
search_space: a list of operation string, the default list is the topology search space for NATS-BENCH.
|
||||
the default value should be be consistent with this line https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/models/cell_operations.py#L24
|
||||
:return
|
||||
|
||||
Returns:
|
||||
the numpy matrix (2-D np.ndarray) representing the DAG of this architecture topology
|
||||
:usage
|
||||
|
||||
[USAGE]
|
||||
matrix = api.str2matrix( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
|
||||
This matrix is 4-by-4 matrix representing a cell with 4 nodes (only the lower left triangle is useful).
|
||||
[ [0, 0, 0, 0], # the first line represents the input (0-th) node
|
||||
@ -262,19 +315,19 @@ class NATStopology(NASBenchMetaAPI):
|
||||
[0, 0, 1, 0] ] # the fourth line represents the 3-rd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) + 1-th-op( 2-th-node )
|
||||
In the topology search space in NATS-BENCH, 0-th-op is 'none', 1-th-op is 'skip_connect',
|
||||
2-th-op is 'nor_conv_1x1', 3-th-op is 'nor_conv_3x3', 4-th-op is 'avg_pool_3x3'.
|
||||
:(NOTE)
|
||||
[NOTE]
|
||||
If a node has two input-edges from the same node, this function does not work. One edge will be overlapped.
|
||||
"""
|
||||
node_strs = arch_str.split('+')
|
||||
num_nodes = len(node_strs) + 1
|
||||
matrix = np.zeros((num_nodes, num_nodes))
|
||||
for i, node_str in enumerate(node_strs):
|
||||
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||
inputs = list(filter(lambda x: x != '', node_str.split('|'))) # pylint: disable=g-explicit-bool-comparison
|
||||
for xinput in inputs:
|
||||
assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||
for xi in inputs:
|
||||
op, idx = xi.split('~')
|
||||
if op not in search_space: raise ValueError('this op ({:}) is not in {:}'.format(op, search_space))
|
||||
op_idx, node_idx = search_space.index(op), int(idx)
|
||||
matrix[i+1, node_idx] = op_idx
|
||||
return matrix
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -23,11 +23,11 @@ CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --warmup_ratio ${ratio} --rand_seed ${seed}
|
||||
|
||||
#
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed}
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed}
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed}
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --warmup_ratio ${ratio} --rand_seed ${seed}
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --warmup_ratio ${ratio} --rand_seed ${seed}
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --warmup_ratio ${ratio} --rand_seed ${seed}
|
||||
|
||||
#
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
|
||||
|
Loading…
Reference in New Issue
Block a user