Update docs of NATS-Bench

This commit is contained in:
D-X-Y 2020-09-16 09:04:22 +00:00
parent 9db28392c2
commit 7052265501
14 changed files with 99 additions and 95 deletions

View File

@ -8,8 +8,6 @@ We also show the versatility of NATS-Bench by benchmarking 13 recent state-of-th
This facilitates a much larger community of researchers to focus on developing better NAS algorithms in a more comparable and computationally effective environment.
**coming soon!**
The structure of this Markdown file:
- [How to use NATS-Bench?](#How-to-Use-NATS-Bench)
- [How to re-create NATS-Bench from scratch?](#how-to-re-create-nats-bench-from-scratch)
@ -33,14 +31,48 @@ To merge the chunks into the original full archive, you can use `cat file_name*
| Date | benchmark file (tss) | archive (tss) | full archive (tss) | benchmark file (sss) | archive (sss) | full archive (sss) |
|:-----------|:---------------------:|:-------------:|:------------------:|:-------------------------------:|:--------------------------:|:------------------:|
| 2020.08.31 | | | | [NATS-sss-v1_0-50262.pickle.pbz2](https://drive.google.com/file/d/1IabIvzWeDdDAWICBzFtTCMXxYWPIOIOX/view?usp=sharing) | [NATS-sss-v1_0-50262-simple.tar](https://drive.google.com/file/d/1scOMTUwcQhAMa_IMedp9lTzwmgqHLGgA/view?usp=sharing) | NATS-sss-v1_0-50262-full |
| 2020.08.31 | [NATS-tss-v1_0-3ffb9.pickle.pbz2](https://drive.google.com/file/d/1vzyK0UVH2D3fTpa1_dSWnp1gvGpAxRul/view?usp=sharing) | [NATS-tss-v1_0-3ffb9-simple.tar](https://drive.google.com/file/d/17_saCsj_krKjlCBLOJEpNtzPXArMCqxU/view?usp=sharing) | NATS-tss-v1_0-3ffb9-full | [NATS-sss-v1_0-50262.pickle.pbz2](https://drive.google.com/file/d/1IabIvzWeDdDAWICBzFtTCMXxYWPIOIOX/view?usp=sharing) | [NATS-sss-v1_0-50262-simple.tar](https://drive.google.com/file/d/1scOMTUwcQhAMa_IMedp9lTzwmgqHLGgA/view?usp=sharing) | [NATS-sss-v1_0-50262-full](api.reload(index=12)) |
1, create the benchmark instance:
```
# Create the API instance for the size search space in NATS
api = create(None, 'sss', fast_mode=True, verbose=True)
# Create the API instance for the topology search space in NATS
api = create(None, 'tss', fast_mode=True, verbose=True)
```
2, query the performance:
```
# Query the loss / accuracy / time for 1234-th candidate architecture on CIFAR-10
# info is a dict, where you can easily figure out the meaning by key
info = api.get_more_info(1234, 'cifar10')
# Query the flops, params, latency. info is a dict.
info = api.get_cost_info(12, 'cifar10')
# Simulate the training of the 1224-th candidate:
validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(1224, dataset='cifar10', hp='12')
```
3, others:
```
# Clear the parameters of the 12-th candidate.
api.clear_params(12)
# Reload all information of the 12-th candidate.
api.reload(index=12)
# Create the instance of th 12-th candidate for CIFAR-10.
from models import get_cell_based_tiny_net
config = api.get_net_config(12, 'cifar10')
network = get_cell_based_tiny_net(config)
# Load the pre-trained weights: params is a dict, where the key is the seed and value is the weights.
params = api.get_net_param(12, 'cifar10', None)
network.load_state_dict(next(iter(params.values())))
```
## How to Re-create NATS-Bench from Scratch
@ -53,6 +85,10 @@ bash ./scripts/NATS-Bench/train-shapes.sh 00000-32767 90 777
```
The checkpoint of all candidates are located at `output/NATS-Bench-size` by default.
After training these candidate architectures, please use the following command to re-organize all checkpoints into the official benchmark file.
```
python exps/NATS-Bench/sss-collect.py
```
### The Topology Search Space
@ -63,7 +99,10 @@ bash scripts/NATS-Bench/train-topology.sh 00000-15624 200 '777 888 999'
```
The checkpoint of all candidates are located at `output/NATS-Bench-topology` by default.
After training these candidate architectures, please use the following command to re-organize all checkpoints into the official benchmark file.
```
python exps/NATS-Bench/tss-collect.py
```
## To Reproduce 13 Baseline NAS Algorithms in NAS-Bench-201

View File

@ -801,7 +801,6 @@ if __name__ == '__main__':
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'BN0', 'BN0-XX-CIFAR010-VALID.pdf', (0, 100,10), 250)
show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'BN0', 'BN0-XX-CIFAR010-TEST.pdf' , (0, 100,10), 250)
import pdb; pdb.set_trace()
"""
for x_maxs in [50, 250]:
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)

View File

@ -48,7 +48,6 @@ def test_api(api, sss_or_tss=True):
print('')
params = api.get_net_param(12, 'cifar10', None)
import pdb; pdb.set_trace()
# Obtain the config and create the network
config = api.get_net_config(12, 'cifar10')
print('{:}\n'.format(config))

View File

@ -95,7 +95,7 @@ def main(xargs, api):
logger.log('{:} use api : {:}'.format(time_string(), api))
api.reset_time()
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
search_space = get_search_spaces(xargs.search_space, 'nats-bench')
if xargs.search_space == 'tss':
cs = get_topology_config_space(search_space)
config2structure = config2topology_func()

View File

@ -33,7 +33,7 @@ def main(xargs, api):
logger.log('{:} use api : {:}'.format(time_string(), api))
api.reset_time()
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
search_space = get_search_spaces(xargs.search_space, 'nats-bench')
if xargs.search_space == 'tss':
random_arch = random_topology_func(search_space)
else:

View File

@ -160,7 +160,7 @@ def main(xargs, api):
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
search_space = get_search_spaces(xargs.search_space, 'nats-bench')
if xargs.search_space == 'tss':
random_arch = random_topology_func(search_space)
mutate_arch = mutate_topology_func(search_space)

View File

@ -124,7 +124,7 @@ def main(xargs, api):
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
search_space = get_search_spaces(xargs.search_space, 'nats-bench')
if xargs.search_space == 'tss':
policy = PolicyTopology(search_space)
else:

View File

@ -342,9 +342,8 @@ def main(xargs):
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
search_space = get_search_spaces(xargs.search_space, 'nats-bench')
model_config = dict2config(
dict(name='generic', C=xargs.channel, N=xargs.num_cells, max_nodes=xargs.max_nodes, num_classes=class_num,
space=search_space, affine=bool(xargs.affine), track_running_stats=bool(xargs.track_running_stats)), None)

View File

@ -155,8 +155,8 @@ def main(xargs):
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
search_space = get_search_spaces(xargs.search_space, 'nats-bench')
model_config = dict2config(
dict(name='generic', super_type='search-shape', candidate_Cs=search_space['candidates'], max_num_Cs=search_space['numbers'], num_classes=class_num,
genotype=args.genotype, affine=bool(xargs.affine), track_running_stats=bool(xargs.track_running_stats)), None)

View File

@ -3,10 +3,10 @@
###########################################################################################################################################################
# Before run these commands, the files must be properly put.
#
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset cifar10
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset cifar100
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset ImageNet16-120
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space tss --base_path $HOME/.torch/NAS-Bench-201-v1_1 --dataset cifar10
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-tss-v1_0-3ffb9 --dataset cifar10
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-sss-v1_0-50262 --dataset cifar100
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-sss-v1_0-50262 --dataset ImageNet16-120
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space tss --base_path $HOME/.torch/NATS-tss-v1_0-3ffb9 --dataset cifar10
###########################################################################################################################################################
import os, gc, sys, math, argparse, psutil
import numpy as np
@ -140,7 +140,7 @@ if __name__ == '__main__':
save_dir = Path(args.save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
meta_file = Path(args.base_path + '.pth')
weight_dir = Path(args.base_path + '-archive')
weight_dir = Path(args.base_path + '-full')
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
assert weight_dir.exists() and weight_dir.is_dir(), 'invalid path for weight dir : {:}'.format(weight_dir)

View File

@ -395,9 +395,9 @@ if __name__ == '__main__':
for xdata in datasets:
visualize_tss_info(api201, xdata, to_save_dir)
api301 = create(None, 'size', verbose=True)
api_sss = create(None, 'size', verbose=True)
for xdata in datasets:
visualize_sss_info(api301, xdata, to_save_dir)
visualize_sss_info(api_sss, xdata, to_save_dir)
visualize_info(None, to_save_dir, 'tss')
visualize_info(None, to_save_dir, 'sss')

View File

@ -15,9 +15,9 @@ from .api_utils import pickle_load
from .api_utils import ArchResults
from .api_utils import NASBenchMetaAPI
from .api_utils import remap_dataset_set_names
from .api_utils import PICKLE_EXT
PICKLE_EXT = 'pickle.pbz2'
ALL_BASE_NAMES = ['NATS-sss-v1_0-50262']
@ -58,6 +58,7 @@ class NATSsize(NASBenchMetaAPI):
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, fast_mode: bool=False, verbose: bool=True):
self.ALL_BASE_NAMES = ALL_BASE_NAMES
self.filename = None
self._search_space_name = 'size'
self._fast_mode = fast_mode
@ -120,39 +121,6 @@ class NATSsize(NASBenchMetaAPI):
print('{:} Create NATS-Bench (size) done with {:}/{:} architectures avaliable.'.format(
time_string(), len(self.evaluated_indexes), len(self.meta_archs)))
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
If index is None, overwrite all ckps.
"""
if self.verbose:
print('{:} Call clear_params with archive_root={:} and index={:}'.format(
time_string(), archive_root, index))
if archive_root is None:
archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(ALL_BASE_NAMES[-1]))
if not os.path.isdir(archive_root):
warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root))
archive_root = self.archive_dir
if archive_root is None or not os.path.isdir(archive_root):
raise ValueError('Invalid archive_root : {:}'.format(archive_root))
if index is None:
indexes = list(range(len(self)))
else:
indexes = [index]
for idx in indexes:
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT))
if not os.path.isfile(xfile_path):
xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT))
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = pickle_load(xfile_path)
assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path)
self.evaluated_indexes.add(idx)
hp2archres = OrderedDict()
for hp_key, results in xdata.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key)
self.arch2infos_dict[idx] = hp2archres
def query_info_str_by_arch(self, arch, hp: Text='12'):
""" This function is used to query the information of a specific architecture
'arch' can be an architecture index or an architecture string

View File

@ -16,9 +16,9 @@ from .api_utils import pickle_load
from .api_utils import ArchResults
from .api_utils import NASBenchMetaAPI
from .api_utils import remap_dataset_set_names
from .api_utils import PICKLE_EXT
PICKLE_EXT = 'pickle.pbz2'
ALL_BASE_NAMES = ['NATS-tss-v1_0-3ffb9']
@ -55,6 +55,7 @@ class NATStopology(NASBenchMetaAPI):
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, fast_mode: bool=False, verbose: bool=True):
self.ALL_BASE_NAMES = ALL_BASE_NAMES
self.filename = None
self._search_space_name = 'topology'
self._fast_mode = fast_mode
@ -117,39 +118,6 @@ class NATStopology(NASBenchMetaAPI):
print('{:} Create NATS-Bench (topology) done with {:}/{:} architectures avaliable.'.format(
time_string(), len(self.evaluated_indexes), len(self.meta_archs)))
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space.
If index is None, overwrite all ckps.
"""
if self.verbose:
print('{:} Call clear_params with archive_root={:} and index={:}'.format(
time_string(), archive_root, index))
if archive_root is None:
archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(ALL_BASE_NAMES[-1]))
if not os.path.isdir(archive_root):
warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root))
archive_root = self.archive_dir
if archive_root is None or not os.path.isdir(archive_root):
raise ValueError('Invalid archive_root : {:}'.format(archive_root))
if index is None:
indexes = list(range(len(self)))
else:
indexes = [index]
for idx in indexes:
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT))
if not os.path.isfile(xfile_path):
xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT))
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = pickle_load(xfile_path)
assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path)
self.evaluated_indexes.add(idx)
hp2archres = OrderedDict()
for hp_key, results in xdata.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key)
self.arch2infos_dict[idx] = hp2archres
def query_info_str_by_arch(self, arch, hp: Text='12'):
""" This function is used to query the information of a specific architecture
'arch' can be an architecture index or an architecture string

View File

@ -17,6 +17,9 @@ from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
PICKLE_EXT = 'pickle.pbz2'
def pickle_save(obj, file_path, ext='.pbz2', protocol=4):
"""Use pickle to save data (obj) into file_path.
According to https://docs.python.org/3/library/pickle.html#data-stream-format, Protocol version 4 was added in Python 3.4. It adds support for very large objects, pickling more kinds of objects, and some data format optimizations. It is the default protocol starting with Python 3.8.
@ -132,6 +135,41 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
"""Return a random index of all architectures."""
return random.randint(0, len(self.meta_archs)-1)
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space,
where the data will be loaded from 'archive_root'.
If archive_root is None, it will try to load from the default path os.environ['TORCH_HOME'] / 'BASE_NAME'-full.
If index is None, overwrite all ckps.
"""
if self.verbose:
print('{:} Call clear_params with archive_root={:} and index={:}'.format(
time_string(), archive_root, index))
if archive_root is None:
archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(self.ALL_BASE_NAMES[-1]))
if not os.path.isdir(archive_root):
warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root))
archive_root = self.archive_dir
if archive_root is None or not os.path.isdir(archive_root):
raise ValueError('Invalid archive_root : {:}'.format(archive_root))
if index is None:
indexes = list(range(len(self)))
else:
indexes = [index]
for idx in indexes:
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT))
if not os.path.isfile(xfile_path):
xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT))
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = pickle_load(xfile_path)
assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path)
self.evaluated_indexes.add(idx)
hp2archres = OrderedDict()
for hp_key, results in xdata.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key)
self.arch2infos_dict[idx] = hp2archres
def query_index_by_arch(self, arch):
""" This function is used to query the index of an architecture in the search space.
In the topology search space, the input arch can be an architecture string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|';
@ -176,12 +214,6 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
if self.verbose:
print('{:} Call _prepare_info with index={:} skip because it is in arch2infos_dict'.format(time_string(), index))
@abc.abstractmethod
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
If index is None, overwrite all ckps.
"""
def clear_params(self, index: int, hp: Optional[Text]=None):
"""Remove the architecture's weights to save memory.
:arg