Sync NATS-Bench's v1.0 and update algorithm names

This commit is contained in:
D-X-Y 2020-10-15 21:56:10 +11:00
parent 10e5f05935
commit 7d55192d83
7 changed files with 28 additions and 26 deletions

View File

@ -6,3 +6,4 @@
- [2019.01.31] [13e908f] GDAS codes were publicly released.
- [2020.07.01] [a45808b] Upgrade NAS-API to the 2.0 version.
- [2020.09.16] [7052265] Create NATS-BENCH.
- [2020.10.15] [ ] Update NATS-BENCH to version 1.0

View File

@ -7,6 +7,7 @@ We analyze the validity of our benchmark in terms of various criteria and perfor
We also show the versatility of NATS-Bench by benchmarking 13 recent state-of-the-art NAS algorithms on it. All logs and diagnostic information trained using the same setup for each candidate are provided.
This facilitates a much larger community of researchers to focus on developing better NAS algorithms in a more comparable and computationally effective environment.
**You can use `pip install nats_bench` to install the library of NATS-Bench.**
The structure of this Markdown file:
- [How to use NATS-Bench?](#How-to-Use-NATS-Bench)
@ -175,18 +176,18 @@ python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HO
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777
Run the channel search strategy in FBNet-V2
Run the channel search strategy in FBNet-V2 -- masking + Gumbel-Softmax :
python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --rand_seed 777
Run the channel search strategy in TuNAS:
Run the channel search strategy in TuNAS -- masking + sampling :
python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0
python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 --use_api 0
python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --rand_seed 777
```
### Final Discovered Architectures for Each Algorithm
@ -250,7 +251,7 @@ GDAS:
If you find that NATS-Bench helps your research, please consider citing it:
```
@article{dong2020nats,
title={NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size},
title={{NATS-Bench}: Benchmarking NAS algorithms for Architecture Topology and Size},
author={Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan},
journal={arXiv preprint arXiv:2009.00437},
year={2020}

View File

@ -43,7 +43,7 @@ from models import get_cell_based_tiny_net, get_search_spaces
from nats_bench import create
# Ad-hoc for TuNAS
# Ad-hoc for RL algorithms.
class ExponentialMovingAverage(object):
"""Class that maintains an exponential moving average."""

View File

@ -44,8 +44,8 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suf
# alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix)
# alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix)
alg2name['channel-wise interpolation'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix)
alg2name['masking + Gumbel-Softmax'] = 'fbv2-affine0_BN0-AWD0.001{:}'.format(suffix)
alg2name['masking + sampling'] = 'tunas-affine0_BN0-AWD0.0{:}'.format(suffix)
alg2name['masking + Gumbel-Softmax'] = 'mask_gumbel-affine0_BN0-AWD0.001{:}'.format(suffix)
alg2name['masking + sampling'] = 'mask_rl-affine0_BN0-AWD0.0{:}'.format(suffix)
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth')
alg2data = OrderedDict()

View File

@ -3,8 +3,8 @@
#####################################################
# Here, we utilized three techniques to search for the number of channels:
# - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
# - masking + Gumbel-Softmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
# - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
# - masking + Gumbel-Softmax (mask_gumbel) from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
# - masking + sampling (mask_rl) from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
from typing import List, Text, Any
import random, torch
import torch.nn as nn
@ -52,10 +52,10 @@ class GenericNAS301Model(nn.Module):
def set_algo(self, algo: Text):
# used for searching
assert self._algo is None, 'This functioin can only be called once.'
assert algo in ['fbv2', 'tunas', 'tas'], 'invalid algo : {:}'.format(algo)
assert algo in ['mask_gumbel', 'mask_rl', 'tas'], 'invalid algo : {:}'.format(algo)
self._algo = algo
self._arch_parameters = nn.Parameter(1e-3*torch.randn(self._max_num_Cs, len(self._candidate_Cs)))
# if algo == 'fbv2' or algo == 'tunas':
# if algo == 'mask_gumbel' or algo == 'mask_rl':
self.register_buffer('_masks', torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs)))
for i in range(len(self._candidate_Cs)):
self._masks.data[i, :self._candidate_Cs[i]] = 1
@ -130,7 +130,7 @@ class GenericNAS301Model(nn.Module):
else:
mask = self._masks[random.randint(0, len(self._masks)-1)]
feature = feature * mask.view(1, -1, 1, 1)
elif self._algo == 'fbv2':
elif self._algo == 'mask_gumbel':
weights = nn.functional.gumbel_softmax(self._arch_parameters[idx:idx+1], tau=self.tau, dim=-1)
mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1)
feature = feature * mask
@ -148,7 +148,7 @@ class GenericNAS301Model(nn.Module):
else:
miss = torch.zeros(feature.shape[0], feature.shape[1]-out.shape[1], feature.shape[2], feature.shape[3], device=feature.device)
feature = torch.cat((out, miss), dim=1)
elif self._algo == 'tunas':
elif self._algo == 'mask_rl':
prob = nn.functional.softmax(self._arch_parameters[idx:idx+1], dim=-1)
dist = torch.distributions.Categorical(prob)
action = dist.sample()

View File

@ -939,9 +939,9 @@ class ArchResults(object):
x.load_state_dict(state_dict)
return x
# This function is used to clear the weights saved in each 'result'
# This can help reduce the memory footprint.
def clear_params(self):
"""Clear the weights saved in each 'result'."""
# NOTE(xuanyidong): This can help reduce the memory footprint.
for unused_key, result in self.all_results.items():
del result.net_state_dict
result.net_state_dict = None

View File

@ -23,11 +23,11 @@ CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --warmup_ratio ${ratio} --rand_seed ${seed}
#
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --warmup_ratio ${ratio} --rand_seed ${seed}
#
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}