diff --git a/docs/NATS-Bench.md b/docs/NATS-Bench.md index 787a86b..4312bae 100644 --- a/docs/NATS-Bench.md +++ b/docs/NATS-Bench.md @@ -8,8 +8,6 @@ We also show the versatility of NATS-Bench by benchmarking 13 recent state-of-th This facilitates a much larger community of researchers to focus on developing better NAS algorithms in a more comparable and computationally effective environment. -**coming soon!** - The structure of this Markdown file: - [How to use NATS-Bench?](#How-to-Use-NATS-Bench) - [How to re-create NATS-Bench from scratch?](#how-to-re-create-nats-bench-from-scratch) @@ -33,14 +31,48 @@ To merge the chunks into the original full archive, you can use `cat file_name* | Date | benchmark file (tss) | archive (tss) | full archive (tss) | benchmark file (sss) | archive (sss) | full archive (sss) | |:-----------|:---------------------:|:-------------:|:------------------:|:-------------------------------:|:--------------------------:|:------------------:| -| 2020.08.31 | | | | [NATS-sss-v1_0-50262.pickle.pbz2](https://drive.google.com/file/d/1IabIvzWeDdDAWICBzFtTCMXxYWPIOIOX/view?usp=sharing) | [NATS-sss-v1_0-50262-simple.tar](https://drive.google.com/file/d/1scOMTUwcQhAMa_IMedp9lTzwmgqHLGgA/view?usp=sharing) | NATS-sss-v1_0-50262-full | +| 2020.08.31 | [NATS-tss-v1_0-3ffb9.pickle.pbz2](https://drive.google.com/file/d/1vzyK0UVH2D3fTpa1_dSWnp1gvGpAxRul/view?usp=sharing) | [NATS-tss-v1_0-3ffb9-simple.tar](https://drive.google.com/file/d/17_saCsj_krKjlCBLOJEpNtzPXArMCqxU/view?usp=sharing) | NATS-tss-v1_0-3ffb9-full | [NATS-sss-v1_0-50262.pickle.pbz2](https://drive.google.com/file/d/1IabIvzWeDdDAWICBzFtTCMXxYWPIOIOX/view?usp=sharing) | [NATS-sss-v1_0-50262-simple.tar](https://drive.google.com/file/d/1scOMTUwcQhAMa_IMedp9lTzwmgqHLGgA/view?usp=sharing) | [NATS-sss-v1_0-50262-full](api.reload(index=12)) | 1, create the benchmark instance: ``` +# Create the API instance for the size search space in NATS api = create(None, 'sss', fast_mode=True, verbose=True) + +# Create the API instance for the topology search space in NATS +api = create(None, 'tss', fast_mode=True, verbose=True) ``` +2, query the performance: +``` +# Query the loss / accuracy / time for 1234-th candidate architecture on CIFAR-10 +# info is a dict, where you can easily figure out the meaning by key +info = api.get_more_info(1234, 'cifar10') + +# Query the flops, params, latency. info is a dict. +info = api.get_cost_info(12, 'cifar10') + +# Simulate the training of the 1224-th candidate: +validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(1224, dataset='cifar10', hp='12') +``` + +3, others: +``` +# Clear the parameters of the 12-th candidate. +api.clear_params(12) + +# Reload all information of the 12-th candidate. +api.reload(index=12) + +# Create the instance of th 12-th candidate for CIFAR-10. +from models import get_cell_based_tiny_net +config = api.get_net_config(12, 'cifar10') +network = get_cell_based_tiny_net(config) + +# Load the pre-trained weights: params is a dict, where the key is the seed and value is the weights. +params = api.get_net_param(12, 'cifar10', None) +network.load_state_dict(next(iter(params.values()))) +``` ## How to Re-create NATS-Bench from Scratch @@ -53,6 +85,10 @@ bash ./scripts/NATS-Bench/train-shapes.sh 00000-32767 90 777 ``` The checkpoint of all candidates are located at `output/NATS-Bench-size` by default. +After training these candidate architectures, please use the following command to re-organize all checkpoints into the official benchmark file. +``` +python exps/NATS-Bench/sss-collect.py +``` ### The Topology Search Space @@ -63,7 +99,10 @@ bash scripts/NATS-Bench/train-topology.sh 00000-15624 200 '777 888 999' ``` The checkpoint of all candidates are located at `output/NATS-Bench-topology` by default. - +After training these candidate architectures, please use the following command to re-organize all checkpoints into the official benchmark file. +``` +python exps/NATS-Bench/tss-collect.py +``` ## To Reproduce 13 Baseline NAS Algorithms in NAS-Bench-201 diff --git a/exps/NAS-Bench-201/visualize.py b/exps/NAS-Bench-201/visualize.py index c04e57e..451e614 100644 --- a/exps/NAS-Bench-201/visualize.py +++ b/exps/NAS-Bench-201/visualize.py @@ -801,7 +801,6 @@ if __name__ == '__main__': show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'BN0', 'BN0-XX-CIFAR010-VALID.pdf', (0, 100,10), 250) show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'BN0', 'BN0-XX-CIFAR010-TEST.pdf' , (0, 100,10), 250) - import pdb; pdb.set_trace() """ for x_maxs in [50, 250]: show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) diff --git a/exps/NATS-Bench/test-nats-api.py b/exps/NATS-Bench/test-nats-api.py index 15a1826..91da061 100644 --- a/exps/NATS-Bench/test-nats-api.py +++ b/exps/NATS-Bench/test-nats-api.py @@ -48,7 +48,6 @@ def test_api(api, sss_or_tss=True): print('') params = api.get_net_param(12, 'cifar10', None) - import pdb; pdb.set_trace() # Obtain the config and create the network config = api.get_net_config(12, 'cifar10') print('{:}\n'.format(config)) diff --git a/exps/NATS-algos/bohb.py b/exps/NATS-algos/bohb.py index 980cdd0..50df9f9 100644 --- a/exps/NATS-algos/bohb.py +++ b/exps/NATS-algos/bohb.py @@ -95,7 +95,7 @@ def main(xargs, api): logger.log('{:} use api : {:}'.format(time_string(), api)) api.reset_time() - search_space = get_search_spaces(xargs.search_space, 'nas-bench-301') + search_space = get_search_spaces(xargs.search_space, 'nats-bench') if xargs.search_space == 'tss': cs = get_topology_config_space(search_space) config2structure = config2topology_func() diff --git a/exps/NATS-algos/random_wo_share.py b/exps/NATS-algos/random_wo_share.py index 5f4b46b..275eaa6 100644 --- a/exps/NATS-algos/random_wo_share.py +++ b/exps/NATS-algos/random_wo_share.py @@ -33,7 +33,7 @@ def main(xargs, api): logger.log('{:} use api : {:}'.format(time_string(), api)) api.reset_time() - search_space = get_search_spaces(xargs.search_space, 'nas-bench-301') + search_space = get_search_spaces(xargs.search_space, 'nats-bench') if xargs.search_space == 'tss': random_arch = random_topology_func(search_space) else: diff --git a/exps/NATS-algos/regularized_ea.py b/exps/NATS-algos/regularized_ea.py index c5effef..d01d7f1 100644 --- a/exps/NATS-algos/regularized_ea.py +++ b/exps/NATS-algos/regularized_ea.py @@ -160,7 +160,7 @@ def main(xargs, api): prepare_seed(xargs.rand_seed) logger = prepare_logger(args) - search_space = get_search_spaces(xargs.search_space, 'nas-bench-301') + search_space = get_search_spaces(xargs.search_space, 'nats-bench') if xargs.search_space == 'tss': random_arch = random_topology_func(search_space) mutate_arch = mutate_topology_func(search_space) diff --git a/exps/NATS-algos/reinforce.py b/exps/NATS-algos/reinforce.py index 10dfe76..da09e99 100644 --- a/exps/NATS-algos/reinforce.py +++ b/exps/NATS-algos/reinforce.py @@ -124,7 +124,7 @@ def main(xargs, api): prepare_seed(xargs.rand_seed) logger = prepare_logger(args) - search_space = get_search_spaces(xargs.search_space, 'nas-bench-301') + search_space = get_search_spaces(xargs.search_space, 'nats-bench') if xargs.search_space == 'tss': policy = PolicyTopology(search_space) else: diff --git a/exps/NATS-algos/search-cell.py b/exps/NATS-algos/search-cell.py index d057cab..cedd3cb 100644 --- a/exps/NATS-algos/search-cell.py +++ b/exps/NATS-algos/search-cell.py @@ -342,9 +342,8 @@ def main(xargs): logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) - search_space = get_search_spaces(xargs.search_space, 'nas-bench-301') + search_space = get_search_spaces(xargs.search_space, 'nats-bench') - model_config = dict2config( dict(name='generic', C=xargs.channel, N=xargs.num_cells, max_nodes=xargs.max_nodes, num_classes=class_num, space=search_space, affine=bool(xargs.affine), track_running_stats=bool(xargs.track_running_stats)), None) diff --git a/exps/NATS-algos/search-size.py b/exps/NATS-algos/search-size.py index ba836b4..5fe37c7 100644 --- a/exps/NATS-algos/search-size.py +++ b/exps/NATS-algos/search-size.py @@ -155,8 +155,8 @@ def main(xargs): logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) - search_space = get_search_spaces(xargs.search_space, 'nas-bench-301') - + search_space = get_search_spaces(xargs.search_space, 'nats-bench') + model_config = dict2config( dict(name='generic', super_type='search-shape', candidate_Cs=search_space['candidates'], max_num_Cs=search_space['numbers'], num_classes=class_num, genotype=args.genotype, affine=bool(xargs.affine), track_running_stats=bool(xargs.track_running_stats)), None) diff --git a/exps/experimental/test-ww-bench.py b/exps/experimental/test-ww-bench.py index c81e707..6cd3847 100644 --- a/exps/experimental/test-ww-bench.py +++ b/exps/experimental/test-ww-bench.py @@ -3,10 +3,10 @@ ########################################################################################################################################################### # Before run these commands, the files must be properly put. # -# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset cifar10 -# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset cifar100 -# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset ImageNet16-120 -# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space tss --base_path $HOME/.torch/NAS-Bench-201-v1_1 --dataset cifar10 +# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-tss-v1_0-3ffb9 --dataset cifar10 +# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-sss-v1_0-50262 --dataset cifar100 +# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-sss-v1_0-50262 --dataset ImageNet16-120 +# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space tss --base_path $HOME/.torch/NATS-tss-v1_0-3ffb9 --dataset cifar10 ########################################################################################################################################################### import os, gc, sys, math, argparse, psutil import numpy as np @@ -140,7 +140,7 @@ if __name__ == '__main__': save_dir = Path(args.save_dir) save_dir.mkdir(parents=True, exist_ok=True) meta_file = Path(args.base_path + '.pth') - weight_dir = Path(args.base_path + '-archive') + weight_dir = Path(args.base_path + '-full') assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) assert weight_dir.exists() and weight_dir.is_dir(), 'invalid path for weight dir : {:}'.format(weight_dir) diff --git a/exps/experimental/visualize-nas-bench-x.py b/exps/experimental/visualize-nas-bench-x.py index 03d8438..70412e6 100644 --- a/exps/experimental/visualize-nas-bench-x.py +++ b/exps/experimental/visualize-nas-bench-x.py @@ -395,9 +395,9 @@ if __name__ == '__main__': for xdata in datasets: visualize_tss_info(api201, xdata, to_save_dir) - api301 = create(None, 'size', verbose=True) + api_sss = create(None, 'size', verbose=True) for xdata in datasets: - visualize_sss_info(api301, xdata, to_save_dir) + visualize_sss_info(api_sss, xdata, to_save_dir) visualize_info(None, to_save_dir, 'tss') visualize_info(None, to_save_dir, 'sss') diff --git a/lib/nats_bench/api_size.py b/lib/nats_bench/api_size.py index 9eb604f..14cc5b5 100644 --- a/lib/nats_bench/api_size.py +++ b/lib/nats_bench/api_size.py @@ -15,9 +15,9 @@ from .api_utils import pickle_load from .api_utils import ArchResults from .api_utils import NASBenchMetaAPI from .api_utils import remap_dataset_set_names +from .api_utils import PICKLE_EXT -PICKLE_EXT = 'pickle.pbz2' ALL_BASE_NAMES = ['NATS-sss-v1_0-50262'] @@ -58,6 +58,7 @@ class NATSsize(NASBenchMetaAPI): """ The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """ def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, fast_mode: bool=False, verbose: bool=True): + self.ALL_BASE_NAMES = ALL_BASE_NAMES self.filename = None self._search_space_name = 'size' self._fast_mode = fast_mode @@ -120,39 +121,6 @@ class NATSsize(NASBenchMetaAPI): print('{:} Create NATS-Bench (size) done with {:}/{:} architectures avaliable.'.format( time_string(), len(self.evaluated_indexes), len(self.meta_archs))) - def reload(self, archive_root: Text = None, index: int = None): - """Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'. - If index is None, overwrite all ckps. - """ - if self.verbose: - print('{:} Call clear_params with archive_root={:} and index={:}'.format( - time_string(), archive_root, index)) - if archive_root is None: - archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(ALL_BASE_NAMES[-1])) - if not os.path.isdir(archive_root): - warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root)) - archive_root = self.archive_dir - if archive_root is None or not os.path.isdir(archive_root): - raise ValueError('Invalid archive_root : {:}'.format(archive_root)) - if index is None: - indexes = list(range(len(self))) - else: - indexes = [index] - for idx in indexes: - assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx) - xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT)) - if not os.path.isfile(xfile_path): - xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT)) - assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path) - xdata = pickle_load(xfile_path) - assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path) - self.evaluated_indexes.add(idx) - hp2archres = OrderedDict() - for hp_key, results in xdata.items(): - hp2archres[hp_key] = ArchResults.create_from_state_dict(results) - self._avaliable_hps.add(hp_key) - self.arch2infos_dict[idx] = hp2archres - def query_info_str_by_arch(self, arch, hp: Text='12'): """ This function is used to query the information of a specific architecture 'arch' can be an architecture index or an architecture string diff --git a/lib/nats_bench/api_topology.py b/lib/nats_bench/api_topology.py index 1413483..c8659fc 100644 --- a/lib/nats_bench/api_topology.py +++ b/lib/nats_bench/api_topology.py @@ -16,9 +16,9 @@ from .api_utils import pickle_load from .api_utils import ArchResults from .api_utils import NASBenchMetaAPI from .api_utils import remap_dataset_set_names +from .api_utils import PICKLE_EXT -PICKLE_EXT = 'pickle.pbz2' ALL_BASE_NAMES = ['NATS-tss-v1_0-3ffb9'] @@ -55,6 +55,7 @@ class NATStopology(NASBenchMetaAPI): """ The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """ def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, fast_mode: bool=False, verbose: bool=True): + self.ALL_BASE_NAMES = ALL_BASE_NAMES self.filename = None self._search_space_name = 'topology' self._fast_mode = fast_mode @@ -117,39 +118,6 @@ class NATStopology(NASBenchMetaAPI): print('{:} Create NATS-Bench (topology) done with {:}/{:} architectures avaliable.'.format( time_string(), len(self.evaluated_indexes), len(self.meta_archs))) - def reload(self, archive_root: Text = None, index: int = None): - """Overwrite all information of the 'index'-th architecture in the search space. - If index is None, overwrite all ckps. - """ - if self.verbose: - print('{:} Call clear_params with archive_root={:} and index={:}'.format( - time_string(), archive_root, index)) - if archive_root is None: - archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(ALL_BASE_NAMES[-1])) - if not os.path.isdir(archive_root): - warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root)) - archive_root = self.archive_dir - if archive_root is None or not os.path.isdir(archive_root): - raise ValueError('Invalid archive_root : {:}'.format(archive_root)) - if index is None: - indexes = list(range(len(self))) - else: - indexes = [index] - for idx in indexes: - assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx) - xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT)) - if not os.path.isfile(xfile_path): - xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT)) - assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path) - xdata = pickle_load(xfile_path) - assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path) - self.evaluated_indexes.add(idx) - hp2archres = OrderedDict() - for hp_key, results in xdata.items(): - hp2archres[hp_key] = ArchResults.create_from_state_dict(results) - self._avaliable_hps.add(hp_key) - self.arch2infos_dict[idx] = hp2archres - def query_info_str_by_arch(self, arch, hp: Text='12'): """ This function is used to query the information of a specific architecture 'arch' can be an architecture index or an architecture string diff --git a/lib/nats_bench/api_utils.py b/lib/nats_bench/api_utils.py index d7b4b79..4f7ca35 100644 --- a/lib/nats_bench/api_utils.py +++ b/lib/nats_bench/api_utils.py @@ -17,6 +17,9 @@ from typing import List, Text, Union, Dict, Optional from collections import OrderedDict, defaultdict +PICKLE_EXT = 'pickle.pbz2' + + def pickle_save(obj, file_path, ext='.pbz2', protocol=4): """Use pickle to save data (obj) into file_path. According to https://docs.python.org/3/library/pickle.html#data-stream-format, Protocol version 4 was added in Python 3.4. It adds support for very large objects, pickling more kinds of objects, and some data format optimizations. It is the default protocol starting with Python 3.8. @@ -132,6 +135,41 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): """Return a random index of all architectures.""" return random.randint(0, len(self.meta_archs)-1) + def reload(self, archive_root: Text = None, index: int = None): + """Overwrite all information of the 'index'-th architecture in the search space, + where the data will be loaded from 'archive_root'. + If archive_root is None, it will try to load from the default path os.environ['TORCH_HOME'] / 'BASE_NAME'-full. + If index is None, overwrite all ckps. + """ + if self.verbose: + print('{:} Call clear_params with archive_root={:} and index={:}'.format( + time_string(), archive_root, index)) + if archive_root is None: + archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(self.ALL_BASE_NAMES[-1])) + if not os.path.isdir(archive_root): + warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root)) + archive_root = self.archive_dir + if archive_root is None or not os.path.isdir(archive_root): + raise ValueError('Invalid archive_root : {:}'.format(archive_root)) + if index is None: + indexes = list(range(len(self))) + else: + indexes = [index] + for idx in indexes: + assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx) + xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT)) + if not os.path.isfile(xfile_path): + xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT)) + assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path) + xdata = pickle_load(xfile_path) + assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path) + self.evaluated_indexes.add(idx) + hp2archres = OrderedDict() + for hp_key, results in xdata.items(): + hp2archres[hp_key] = ArchResults.create_from_state_dict(results) + self._avaliable_hps.add(hp_key) + self.arch2infos_dict[idx] = hp2archres + def query_index_by_arch(self, arch): """ This function is used to query the index of an architecture in the search space. In the topology search space, the input arch can be an architecture string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|'; @@ -176,12 +214,6 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): if self.verbose: print('{:} Call _prepare_info with index={:} skip because it is in arch2infos_dict'.format(time_string(), index)) - @abc.abstractmethod - def reload(self, archive_root: Text = None, index: int = None): - """Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'. - If index is None, overwrite all ckps. - """ - def clear_params(self, index: int, hp: Optional[Text]=None): """Remove the architecture's weights to save memory. :arg