diff --git a/.gitmodules b/.gitmodules index 09d12b9..a9f9dad 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule ".latent-data/qlib"] path = .latent-data/qlib url = git@github.com:D-X-Y/qlib.git +[submodule ".latent-data/NATS-Bench"] + path = .latent-data/NATS-Bench + url = git@github.com:D-X-Y/NATS-Bench.git diff --git a/.latent-data/NATS-Bench b/.latent-data/NATS-Bench new file mode 160000 index 0000000..51187c1 --- /dev/null +++ b/.latent-data/NATS-Bench @@ -0,0 +1 @@ +Subproject commit 51187c1e9152ff79b02b11c80bca0b03b402a7e5 diff --git a/README.md b/README.md index 82d8b77..5456f2e 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,12 @@ Some visualization codes may require `opencv`. CIFAR and ImageNet should be downloaded and extracted into `$TORCH_HOME`. Some methods use knowledge distillation (KD), which require pre-trained models. Please download these models from [Google Drive](https://drive.google.com/open?id=1ANmiYEGX-IQZTfH8w0aSpj-Wypg-0DR-) (or train by yourself) and save into `.latent-data`. +Please use +``` +git clone --recurse-submodules git@github.com:D-X-Y/AutoDL-Projects.git +``` +to download this repo with submodules. + ## Citation If you find that this project helps your research, please consider citing the related paper: diff --git a/docs/NATS-Bench.md b/docs/NATS-Bench.md index 65cc5b9..a5cda8b 100644 --- a/docs/NATS-Bench.md +++ b/docs/NATS-Bench.md @@ -8,6 +8,7 @@ We also show the versatility of NATS-Bench by benchmarking 13 recent state-of-th This facilitates a much larger community of researchers to focus on developing better NAS algorithms in a more comparable and computationally effective environment. **You can use `pip install nats_bench` to install the library of NATS-Bench.** +or install from the [source codes](https://github.com/D-X-Y/NATS-Bench) via `python setup.py install`. The structure of this Markdown file: - [How to use NATS-Bench?](#How-to-Use-NATS-Bench) diff --git a/exps/NATS-algos/README.md b/exps/NATS-algos/README.md index fc90c36..361ce4f 100644 --- a/exps/NATS-algos/README.md +++ b/exps/NATS-algos/README.md @@ -9,6 +9,11 @@ The Python files in this folder are used to re-produce the results in ``NATS-Ben - [`regularized_ea.py`](https://github.com/D-X-Y/AutoDL-Projects/blob/main/exps/NATS-algos/regularized_ea.py) contains the REA algorithm for both search spaces. - [`reinforce.py`](https://github.com/D-X-Y/AutoDL-Projects/blob/main/exps/NATS-algos/reinforce.py) contains the REINFORCE algorithm for both search spaces. +## Requirements + +- `nats_bench`>=v1.1 : you can use `pip install nats_bench` to install or from [sources](https://github.com/D-X-Y/NATS-Bench) +- `hpbandster` : if you want to run BOHB + ## Citation If you find that this project helps your research, please consider citing the related paper: diff --git a/lib/layers/super_mlp.py b/lib/layers/super_mlp.py new file mode 100644 index 0000000..ffd3f50 --- /dev/null +++ b/lib/layers/super_mlp.py @@ -0,0 +1,24 @@ +import torch.nn as nn +from typing import Optional + +class MLP(nn.Module): + # MLP: FC -> Activation -> Drop -> FC -> Drop + def __init__(self, in_features, hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer=nn.GELU, + drop: Optional[float] = None): + super(MLP, self).__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop or 0) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/lib/layers/super_module.py b/lib/layers/super_module.py new file mode 100644 index 0000000..a5d3968 --- /dev/null +++ b/lib/layers/super_module.py @@ -0,0 +1,7 @@ +import torch.nn as nn + +class SuperModule(nn.Module): + def __init__(self): + super(SuperModule, self).__init__() + + diff --git a/lib/nats_bench/__init__.py b/lib/nats_bench/__init__.py deleted file mode 100644 index 0702d52..0000000 --- a/lib/nats_bench/__init__.py +++ /dev/null @@ -1,70 +0,0 @@ -############################################################################## -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 ########################## -############################################################################## -# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # -############################################################################## -"""The official Application Programming Interface (API) for NATS-Bench.""" -from nats_bench.api_size import NATSsize -from nats_bench.api_topology import NATStopology -from nats_bench.api_utils import ArchResults -from nats_bench.api_utils import pickle_load -from nats_bench.api_utils import pickle_save -from nats_bench.api_utils import ResultsCount - - -NATS_BENCH_API_VERSIONs = ['v1.0', # [2020.08.31] - 'v1.1'] # [2020.12.20] adding unit tests -NATS_BENCH_SSS_NAMEs = ('sss', 'size') -NATS_BENCH_TSS_NAMEs = ('tss', 'topology') - - -def version(): - return NATS_BENCH_API_VERSIONs[-1] - - -def create(file_path_or_dict, search_space, fast_mode=False, verbose=True): - """Create the instead for NATS API. - - Args: - file_path_or_dict: None or a file path or a directory path. - search_space: This is a string indicates the search space in NATS-Bench. - fast_mode: If True, we will not load all the data at initialization, - instead, the data for each candidate architecture will be loaded when - quering it; If False, we will load all the data during initialization. - verbose: This is a flag to indicate whether log additional information. - - Raises: - ValueError: If not find the matched serach space description. - - Returns: - The created NATS-Bench API. - """ - if search_space in NATS_BENCH_TSS_NAMEs: - return NATStopology(file_path_or_dict, fast_mode, verbose) - elif search_space in NATS_BENCH_SSS_NAMEs: - return NATSsize(file_path_or_dict, fast_mode, verbose) - else: - raise ValueError('invalid search space : {:}'.format(search_space)) - - -def search_space_info(main_tag, aux_tag): - """Obtain the search space information.""" - nats_sss = dict(candidates=[8, 16, 24, 32, 40, 48, 56, 64], - num_layers=5) - nats_tss = dict(op_names=['none', 'skip_connect', - 'nor_conv_1x1', 'nor_conv_3x3', - 'avg_pool_3x3'], - num_nodes=4) - if main_tag == 'nats-bench': - if aux_tag in NATS_BENCH_SSS_NAMEs: - return nats_sss - elif aux_tag in NATS_BENCH_TSS_NAMEs: - return nats_tss - else: - raise ValueError('Unknown auxiliary tag: {:}'.format(aux_tag)) - elif main_tag == 'nas-bench-201': - if aux_tag is not None: - raise ValueError('For NAS-Bench-201, the auxiliary tag should be None.') - return nats_tss - else: - raise ValueError('Unknown main tag: {:}'.format(main_tag)) diff --git a/lib/nats_bench/api_size.py b/lib/nats_bench/api_size.py deleted file mode 100644 index c446481..0000000 --- a/lib/nats_bench/api_size.py +++ /dev/null @@ -1,291 +0,0 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 # -############################################################################## -# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # -############################################################################## -# The history of benchmark files are as follows, # -# where the format is (the name is NATS-sss-[version]-[md5].pickle.pbz2) # -# [2020.08.31] NATS-sss-v1_0-50262.pickle.pbz2 # -############################################################################## -# pylint: disable=line-too-long -"""The API for size search space in NATS-Bench.""" -import collections -import copy -import os -import random -from typing import Dict, Optional, Text, Union, Any - -from nats_bench.api_utils import ArchResults -from nats_bench.api_utils import NASBenchMetaAPI -from nats_bench.api_utils import get_torch_home -from nats_bench.api_utils import nats_is_dir -from nats_bench.api_utils import nats_is_file -from nats_bench.api_utils import PICKLE_EXT -from nats_bench.api_utils import pickle_load -from nats_bench.api_utils import time_string - - -ALL_BASE_NAMES = ['NATS-sss-v1_0-50262'] - - -def print_information(information, extra_info=None, show=False): - """print out the information of a given ArchResults.""" - dataset_names = information.get_dataset_names() - strings = [ - information.arch_str, - 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info) - ] - - def metric2str(loss, acc): - return 'loss = {:.3f} & top1 = {:.2f}%'.format(loss, acc) - - for dataset in dataset_names: - metric = information.get_compute_costs(dataset) - flop, param, latency = metric['flops'], metric['params'], metric['latency'] - str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format( - dataset, flop, param, - '{:.2f}'.format(latency * - 1000) if latency is not None and latency > 0 else None) - train_info = information.get_metrics(dataset, 'train') - if dataset == 'cifar10-valid': - valid_info = information.get_metrics(dataset, 'x-valid') - test__info = information.get_metrics(dataset, 'ori-test') - str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format( - dataset, metric2str(train_info['loss'], train_info['accuracy']), - metric2str(valid_info['loss'], valid_info['accuracy']), - metric2str(test__info['loss'], test__info['accuracy'])) - elif dataset == 'cifar10': - test__info = information.get_metrics(dataset, 'ori-test') - str2 = '{:14s} train : [{:}], test : [{:}]'.format( - dataset, metric2str(train_info['loss'], train_info['accuracy']), - metric2str(test__info['loss'], test__info['accuracy'])) - else: - valid_info = information.get_metrics(dataset, 'x-valid') - test__info = information.get_metrics(dataset, 'x-test') - str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format( - dataset, metric2str(train_info['loss'], train_info['accuracy']), - metric2str(valid_info['loss'], valid_info['accuracy']), - metric2str(test__info['loss'], test__info['accuracy'])) - strings += [str1, str2] - if show: print('\n'.join(strings)) - return strings - - -class NATSsize(NASBenchMetaAPI): - """This is the class for the API of size search space in NATS-Bench.""" - - def __init__(self, - file_path_or_dict: Optional[Union[Text, Dict[Text, Any]]] = None, - fast_mode: bool = False, - verbose: bool = True): - """The initialization function that takes the dataset file path (or a dict loaded from that path) as input.""" - self._all_base_names = ALL_BASE_NAMES - self.filename = None - self._search_space_name = 'size' - self._fast_mode = fast_mode - self._archive_dir = None - self._full_train_epochs = 90 - self.reset_time() - if file_path_or_dict is None: - if self._fast_mode: - self._archive_dir = os.path.join( - get_torch_home(), '{:}-simple'.format(ALL_BASE_NAMES[-1])) - else: - file_path_or_dict = os.path.join( - get_torch_home(), '{:}.{:}'.format( - ALL_BASE_NAMES[-1], PICKLE_EXT)) - print('{:} Try to use the default NATS-Bench (size) path from ' - 'fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode, - file_path_or_dict)) - if isinstance(file_path_or_dict, str): - file_path_or_dict = str(file_path_or_dict) - if verbose: - print('{:} Try to create the NATS-Bench (size) api ' - 'from {:} with fast_mode={:}'.format( - time_string(), file_path_or_dict, fast_mode)) - if not nats_is_file(file_path_or_dict) and not nats_is_dir( - file_path_or_dict): - raise ValueError('{:} is neither a file or a dir.'.format( - file_path_or_dict)) - self.filename = os.path.basename(file_path_or_dict) - if fast_mode: - if nats_is_file(file_path_or_dict): - raise ValueError('fast_mode={:} must feed the path for directory ' - ': {:}'.format(fast_mode, file_path_or_dict)) - else: - self._archive_dir = file_path_or_dict - else: - if nats_is_dir(file_path_or_dict): - raise ValueError('fast_mode={:} must feed the path for file ' - ': {:}'.format(fast_mode, file_path_or_dict)) - else: - file_path_or_dict = pickle_load(file_path_or_dict) - elif isinstance(file_path_or_dict, dict): - file_path_or_dict = copy.deepcopy(file_path_or_dict) - self.verbose = verbose - if isinstance(file_path_or_dict, dict): - keys = ('meta_archs', 'arch2infos', 'evaluated_indexes') - for key in keys: - if key not in file_path_or_dict: - raise ValueError('Can not find key[{:}] in the dict'.format(key)) - self.meta_archs = copy.deepcopy(file_path_or_dict['meta_archs']) - # NOTE(xuanyidong): This is a dict mapping each architecture to a dict, - # where the key is #epochs and the value is ArchResults - self.arch2infos_dict = collections.OrderedDict() - self._avaliable_hps = set() - for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())): - all_infos = file_path_or_dict['arch2infos'][xkey] - hp2archres = collections.OrderedDict() - for hp_key, results in all_infos.items(): - hp2archres[hp_key] = ArchResults.create_from_state_dict(results) - self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter - self.arch2infos_dict[xkey] = hp2archres - self.evaluated_indexes = set(file_path_or_dict['evaluated_indexes']) - elif self.archive_dir is not None: - benchmark_meta = pickle_load('{:}/meta.{:}'.format( - self.archive_dir, PICKLE_EXT)) - self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs']) - self.arch2infos_dict = collections.OrderedDict() - self._avaliable_hps = set() - self.evaluated_indexes = set() - else: - raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir ' - 'must be set'.format(type(file_path_or_dict))) - self.archstr2index = {} - for idx, arch in enumerate(self.meta_archs): - if arch in self.archstr2index: - raise ValueError('This [{:}]-th arch {:} already in the ' - 'dict ({:}).'.format( - idx, arch, self.archstr2index[arch])) - self.archstr2index[arch] = idx - if self.verbose: - print('{:} Create NATS-Bench (size) done with {:}/{:} architectures ' - 'avaliable.'.format(time_string(), - len(self.evaluated_indexes), - len(self.meta_archs))) - - def query_info_str_by_arch(self, arch, hp: Text = '12'): - """Query the information of a specific architecture. - - Args: - arch: it can be an architecture index or an architecture string. - - hp: the hyperparamete indicator, could be 01, 12, or 90. The difference - between these three configurations are the number of training epochs. - - Returns: - ArchResults instance - """ - if self.verbose: - print('{:} Call query_info_str_by_arch with arch={:}' - 'and hp={:}'.format(time_string(), arch, hp)) - return self._query_info_str_by_arch(arch, hp, print_information) - - def get_more_info(self, - index, - dataset, - iepoch=None, - hp: Text = '12', - is_random: bool = True): - """Return the metric for the `index`-th architecture. - - Args: - index: the architecture index. - dataset: - 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set - 'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set - 'cifar100' : using the proposed train set of CIFAR-100 as the training set - 'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set - iepoch: the index of training epochs from 0 to 11/199. - When iepoch=None, it will return the metric for the last training epoch - When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0) - hp: indicates different hyper-parameters for training - When hp=01, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 01 epochs - When hp=12, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 12 epochs - When hp=90, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 90 epochs - is_random: - When is_random=True, the performance of a random architecture will be returned - When is_random=False, the performanceo of all trials will be averaged. - - Returns: - a dict, where key is the metric name and value is its value. - """ - if self.verbose: - print('{:} Call the get_more_info function with index={:}, dataset={:}, ' - 'iepoch={:}, hp={:}, and is_random={:}.'.format( - time_string(), index, dataset, iepoch, hp, is_random)) - index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object - self._prepare_info(index) - if index not in self.arch2infos_dict: - raise ValueError('Did not find {:} from arch2infos_dict.'.format(index)) - archresult = self.arch2infos_dict[index][str(hp)] - # if randomly select one trial, select the seed at first - if isinstance(is_random, bool) and is_random: - seeds = archresult.get_dataset_seeds(dataset) - is_random = random.choice(seeds) - # collect the training information - train_info = archresult.get_metrics( - dataset, 'train', iepoch=iepoch, is_random=is_random) - total = train_info['iepoch'] + 1 - xinfo = { - 'train-loss': train_info['loss'], - 'train-accuracy': train_info['accuracy'], - 'train-per-time': train_info['all_time'] / total, - 'train-all-time': train_info['all_time'] - } - # collect the evaluation information - if dataset == 'cifar10-valid': - valid_info = archresult.get_metrics( - dataset, 'x-valid', iepoch=iepoch, is_random=is_random) - try: - test_info = archresult.get_metrics( - dataset, 'ori-test', iepoch=iepoch, is_random=is_random) - except Exception as unused_e: # pylint: disable=broad-except - test_info = None - valtest_info = None - xinfo['comment'] = 'In this dict, train-loss/accuracy/time is the metric on the train set of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train set by {:} epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.'.format(hp) - else: - if dataset == 'cifar10': - xinfo['comment'] = 'In this dict, train-loss/accuracy/time is the metric on the train+valid sets of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train+valid sets by {:} epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.'.format(hp) - try: # collect results on the proposed test set - if dataset == 'cifar10': - test_info = archresult.get_metrics( - dataset, 'ori-test', iepoch=iepoch, is_random=is_random) - else: - test_info = archresult.get_metrics( - dataset, 'x-test', iepoch=iepoch, is_random=is_random) - except Exception as unused_e: # pylint: disable=broad-except - test_info = None - try: # collect results on the proposed validation set - valid_info = archresult.get_metrics( - dataset, 'x-valid', iepoch=iepoch, is_random=is_random) - except Exception as unused_e: # pylint: disable=broad-except - valid_info = None - try: - if dataset != 'cifar10': - valtest_info = archresult.get_metrics( - dataset, 'ori-test', iepoch=iepoch, is_random=is_random) - else: - valtest_info = None - except Exception as unused_e: # pylint: disable=broad-except - valtest_info = None - if valid_info is not None: - xinfo['valid-loss'] = valid_info['loss'] - xinfo['valid-accuracy'] = valid_info['accuracy'] - xinfo['valid-per-time'] = valid_info['all_time'] / total - xinfo['valid-all-time'] = valid_info['all_time'] - if test_info is not None: - xinfo['test-loss'] = test_info['loss'] - xinfo['test-accuracy'] = test_info['accuracy'] - xinfo['test-per-time'] = test_info['all_time'] / total - xinfo['test-all-time'] = test_info['all_time'] - if valtest_info is not None: - xinfo['valtest-loss'] = valtest_info['loss'] - xinfo['valtest-accuracy'] = valtest_info['accuracy'] - xinfo['valtest-per-time'] = valtest_info['all_time'] / total - xinfo['valtest-all-time'] = valtest_info['all_time'] - return xinfo - - def show(self, index: int = -1) -> None: - """Print the information of a specific (or all) architecture(s).""" - self._show(index, print_information) diff --git a/lib/nats_bench/api_test.py b/lib/nats_bench/api_test.py deleted file mode 100644 index 7bedfbf..0000000 --- a/lib/nats_bench/api_test.py +++ /dev/null @@ -1,131 +0,0 @@ -############################################################################## -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 ########################## -############################################################################## -# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # -############################################################################## -# pytest --capture=tee-sys # -############################################################################## -"""This file is used to quickly test the API.""" -import os -import pytest -import random - -from nats_bench.api_size import NATSsize -from nats_bench.api_size import ALL_BASE_NAMES as sss_base_names -from nats_bench.api_topology import NATStopology -from nats_bench.api_topology import ALL_BASE_NAMES as tss_base_names - - -def get_fake_torch_home_dir(): - print('This file is {:}'.format(os.path.abspath(__file__))) - print('The current directory is {:}'.format(os.path.abspath(os.getcwd()))) - xname = 'FAKE_TORCH_HOME' - if xname in os.environ: - return os.environ['FAKE_TORCH_HOME'] - else: - return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'fake_torch_dir') - - -class TestNATSBench(object): - - def test_nats_bench_tss(self, benchmark_dir=None, fake_random=True): - if benchmark_dir is None: - benchmark_dir = os.path.join(get_fake_torch_home_dir(), sss_base_names[-1] + '-simple') - return _test_nats_bench(benchmark_dir, True, fake_random) - - def test_nats_bench_sss(self, benchmark_dir=None, fake_random=True): - if benchmark_dir is None: - benchmark_dir = os.path.join(get_fake_torch_home_dir(), tss_base_names[-1] + '-simple') - return _test_nats_bench(benchmark_dir, False, fake_random) - - def prepare_fake_tss(self): - print('') - tss_benchmark_dir = os.path.join(get_fake_torch_home_dir(), tss_base_names[-1] + '-simple') - api = NATStopology(tss_benchmark_dir, True, False) - return api - - def test_01_th_issue(self): - # Link: https://github.com/D-X-Y/NATS-Bench/issues/1 - api = self.prepare_fake_tss() - # The performance of 0-th architecture on CIFAR-10 (trained by 12 epochs) - info = api.get_more_info(0, 'cifar10', hp=12) - # First of all, the data split in NATS-Bench is different from that in the official CIFAR paper. - # In NATS-Bench, we split the original CIFAR-10 training set into two parts, i.e., a training set and a validation set. - # In the following, we will use the splits of NATS-Bench to explain. - print(info['comment']) - print('The loss on the training + validation sets of CIFAR-10: {:}'.format(info['train-loss'])) - print('The total training time for 12 epochs on the training + validation sets of CIFAR-10: {:}'.format(info['train-all-time'])) - print('The per-epoch training time on CIFAR-10: {:}'.format(info['train-per-time'])) - print('The total evaluation time on the test set of CIFAR-10 for 12 times: {:}'.format(info['test-all-time'])) - print('The evaluation time on the test set of CIFAR-10: {:}'.format(info['test-per-time'])) - cost_info = api.get_cost_info(0, 'cifar10') - xkeys = ['T-train@epoch', # The per epoch training time on the training + validation sets of CIFAR-10. - 'T-train@total', - 'T-ori-test@epoch', # The time cost for the evaluation on CIFAR-10 test set. - 'T-ori-test@total'] # T-ori-test@epoch * 12 times. - for xkey in xkeys: - print('The cost info [{:}] for 0-th architecture on CIFAR-10 is {:}'.format(xkey, cost_info[xkey])) - - def test_02_th_issue(self): - # https://github.com/D-X-Y/NATS-Bench/issues/2 - api = self.prepare_fake_tss() - data = api.query_by_index(284, dataname='cifar10', hp=200) - for xkey, xvalue in data.items(): - print('{:} : {:}'.format(xkey, xvalue)) - xinfo = data[777].get_train() - print(xinfo) - print(data[777].train_acc1es) - - info_012_epochs = api.get_more_info(284, 'cifar10', hp= 12) - print('Train accuracy for 12 epochs is {:}'.format(info_012_epochs['train-accuracy'])) - info_200_epochs = api.get_more_info(284, 'cifar10', hp=200) - print('Train accuracy for 200 epochs is {:}'.format(info_200_epochs['train-accuracy'])) - - -def _test_nats_bench(benchmark_dir, is_tss, fake_random, verbose=False): - """The main test entry for NATS-Bench.""" - if is_tss: - api = NATStopology(benchmark_dir, True, verbose) - else: - api = NATSsize(benchmark_dir, True, verbose) - - if fake_random: - test_indexes = [0, 11, 284] - else: - test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)] - - key2dataset = {'cifar10': 'CIFAR-10', - 'cifar100': 'CIFAR-100', - 'ImageNet16-120': 'ImageNet16-120'} - - for index in test_indexes: - print('\n\nEvaluate the {:5d}-th architecture.'.format(index)) - - for key, dataset in key2dataset.items(): - # Query the loss / accuracy / time for the `index`-th candidate - # architecture on CIFAR-10 - # info is a dict, where you can easily figure out the meaning by key - info = api.get_more_info(index, key) - print(' -->> The performance on {:}: {:}'.format(dataset, info)) - - # Query the flops, params, latency. info is a dict. - info = api.get_cost_info(index, key) - print(' -->> The cost info on {:}: {:}'.format(dataset, info)) - - # Simulate the training of the `index`-th candidate: - validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval( - index, dataset=key, hp='12') - print(' -->> The validation accuracy={:}, latency={:}, ' - 'the current time cost={:} s, accumulated time cost={:} s' - .format(validation_accuracy, latency, time_cost, - current_total_time_cost)) - - # Print the configuration of the `index`-th architecture on CIFAR-10 - config = api.get_net_config(index, key) - print(' -->> The configuration on {:} is {:}'.format(dataset, config)) - - # Show the information of the `index`-th architecture - api.show(index) - - with pytest.raises(ValueError): - api.get_more_info(100000, 'cifar10') diff --git a/lib/nats_bench/api_topology.py b/lib/nats_bench/api_topology.py deleted file mode 100644 index 05a633d..0000000 --- a/lib/nats_bench/api_topology.py +++ /dev/null @@ -1,338 +0,0 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 # -############################################################################## -# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # -############################################################################## -# The history of benchmark files are as follows, # -# where the format is (the name is NATS-tss-[version]-[md5].pickle.pbz2) # -# [2020.08.31] NATS-tss-v1_0-3ffb9.pickle.pbz2 # -############################################################################## -# pylint: disable=line-too-long -"""The API for topology search space in NATS-Bench.""" -import collections -import copy -import os -import random -from typing import Any, Dict, List, Optional, Text, Union - -from nats_bench.api_utils import ArchResults -from nats_bench.api_utils import NASBenchMetaAPI -from nats_bench.api_utils import get_torch_home -from nats_bench.api_utils import nats_is_dir -from nats_bench.api_utils import nats_is_file -from nats_bench.api_utils import PICKLE_EXT -from nats_bench.api_utils import pickle_load -from nats_bench.api_utils import time_string - -import numpy as np - - -ALL_BASE_NAMES = ['NATS-tss-v1_0-3ffb9'] - - -def print_information(information, extra_info=None, show=False): - """print out the information of a given ArchResults.""" - dataset_names = information.get_dataset_names() - strings = [ - information.arch_str, - 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info) - ] - - def metric2str(loss, acc): - return 'loss = {:.3f} & top1 = {:.2f}%'.format(loss, acc) - - for dataset in dataset_names: - metric = information.get_compute_costs(dataset) - flop, param, latency = metric['flops'], metric['params'], metric['latency'] - str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format( - dataset, flop, param, - '{:.2f}'.format(latency * - 1000) if latency is not None and latency > 0 else None) - train_info = information.get_metrics(dataset, 'train') - if dataset == 'cifar10-valid': - valid_info = information.get_metrics(dataset, 'x-valid') - str2 = '{:14s} train : [{:}], valid : [{:}]'.format( - dataset, metric2str(train_info['loss'], train_info['accuracy']), - metric2str(valid_info['loss'], valid_info['accuracy'])) - elif dataset == 'cifar10': - test__info = information.get_metrics(dataset, 'ori-test') - str2 = '{:14s} train : [{:}], test : [{:}]'.format( - dataset, metric2str(train_info['loss'], train_info['accuracy']), - metric2str(test__info['loss'], test__info['accuracy'])) - else: - valid_info = information.get_metrics(dataset, 'x-valid') - test__info = information.get_metrics(dataset, 'x-test') - str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format( - dataset, metric2str(train_info['loss'], train_info['accuracy']), - metric2str(valid_info['loss'], valid_info['accuracy']), - metric2str(test__info['loss'], test__info['accuracy'])) - strings += [str1, str2] - if show: print('\n'.join(strings)) - return strings - - -class NATStopology(NASBenchMetaAPI): - """This is the class for the API of topology search space in NATS-Bench.""" - - def __init__(self, - file_path_or_dict: Optional[Union[Text, Dict[Text, Any]]] = None, - fast_mode: bool = False, - verbose: bool = True): - """The initialization function that takes the dataset file path (or a dict loaded from that path) as input.""" - self._all_base_names = ALL_BASE_NAMES - self.filename = None - self._search_space_name = 'topology' - self._fast_mode = fast_mode - self._archive_dir = None - self._full_train_epochs = 200 - self.reset_time() - if file_path_or_dict is None: - if self._fast_mode: - self._archive_dir = os.path.join( - get_torch_home(), '{:}-simple'.format(ALL_BASE_NAMES[-1])) - else: - file_path_or_dict = os.path.join( - get_torch_home(), '{:}.{:}'.format( - ALL_BASE_NAMES[-1], PICKLE_EXT)) - print('{:} Try to use the default NATS-Bench (topology) path from ' - 'fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode, file_path_or_dict)) - if isinstance(file_path_or_dict, str): - file_path_or_dict = str(file_path_or_dict) - if verbose: - print('{:} Try to create the NATS-Bench (topology) api ' - 'from {:} with fast_mode={:}'.format( - time_string(), file_path_or_dict, fast_mode)) - if not nats_is_file(file_path_or_dict) and not nats_is_dir( - file_path_or_dict): - raise ValueError('{:} is neither a file or a dir.'.format( - file_path_or_dict)) - self.filename = os.path.basename(file_path_or_dict) - if fast_mode: - if nats_is_file(file_path_or_dict): - raise ValueError('fast_mode={:} must feed the path for directory ' - ': {:}'.format(fast_mode, file_path_or_dict)) - else: - self._archive_dir = file_path_or_dict - else: - if nats_is_dir(file_path_or_dict): - raise ValueError('fast_mode={:} must feed the path for file ' - ': {:}'.format(fast_mode, file_path_or_dict)) - else: - file_path_or_dict = pickle_load(file_path_or_dict) - elif isinstance(file_path_or_dict, dict): - file_path_or_dict = copy.deepcopy(file_path_or_dict) - self.verbose = verbose - if isinstance(file_path_or_dict, dict): - keys = ('meta_archs', 'arch2infos', 'evaluated_indexes') - for key in keys: - if key not in file_path_or_dict: - raise ValueError('Can not find key[{:}] in the dict'.format(key)) - self.meta_archs = copy.deepcopy(file_path_or_dict['meta_archs']) - # NOTE(xuanyidong): This is a dict mapping each architecture to a dict, - # where the key is #epochs and the value is ArchResults - self.arch2infos_dict = collections.OrderedDict() - self._avaliable_hps = set() - for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())): - all_infos = file_path_or_dict['arch2infos'][xkey] - hp2archres = collections.OrderedDict() - for hp_key, results in all_infos.items(): - hp2archres[hp_key] = ArchResults.create_from_state_dict(results) - self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter - self.arch2infos_dict[xkey] = hp2archres - self.evaluated_indexes = set(file_path_or_dict['evaluated_indexes']) - elif self.archive_dir is not None: - benchmark_meta = pickle_load('{:}/meta.{:}'.format( - self.archive_dir, PICKLE_EXT)) - self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs']) - self.arch2infos_dict = collections.OrderedDict() - self._avaliable_hps = set() - self.evaluated_indexes = set() - else: - raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir ' - 'must be set'.format(type(file_path_or_dict))) - self.archstr2index = {} - for idx, arch in enumerate(self.meta_archs): - if arch in self.archstr2index: - raise ValueError('This [{:}]-th arch {:} already in the ' - 'dict ({:}).'.format( - idx, arch, self.archstr2index[arch])) - self.archstr2index[arch] = idx - if self.verbose: - print('{:} Create NATS-Bench (topology) done with {:}/{:} architectures ' - 'avaliable.'.format(time_string(), - len(self.evaluated_indexes), - len(self.meta_archs))) - - def query_info_str_by_arch(self, arch, hp: Text = '12'): - """Query the information of a specific architecture. - - Args: - arch: it can be an architecture index or an architecture string. - - hp: the hyperparamete indicator, could be 12 or 200. The difference - between these three configurations are the number of training epochs. - - Returns: - ArchResults instance - """ - if self.verbose: - print('{:} Call query_info_str_by_arch with arch={:}' - 'and hp={:}'.format(time_string(), arch, hp)) - return self._query_info_str_by_arch(arch, hp, print_information) - - def get_more_info(self, - index, - dataset, - iepoch=None, - hp: Text = '12', - is_random: bool = True): - """Return the metric for the `index`-th architecture.""" - if self.verbose: - print('{:} Call the get_more_info function with index={:}, dataset={:}, ' - 'iepoch={:}, hp={:}, and is_random={:}.'.format( - time_string(), index, dataset, iepoch, hp, is_random)) - index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object - self._prepare_info(index) - if index not in self.arch2infos_dict: - raise ValueError('Did not find {:} from arch2infos_dict.'.format(index)) - archresult = self.arch2infos_dict[index][str(hp)] - # if randomly select one trial, select the seed at first - if isinstance(is_random, bool) and is_random: - seeds = archresult.get_dataset_seeds(dataset) - is_random = random.choice(seeds) - # collect the training information - train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random) - total = train_info['iepoch'] + 1 - xinfo = { - 'train-loss': - train_info['loss'], - 'train-accuracy': - train_info['accuracy'], - 'train-per-time': - train_info['all_time'] / - total if train_info['all_time'] is not None else None, - 'train-all-time': - train_info['all_time'] - } - # collect the evaluation information - if dataset == 'cifar10-valid': - valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) - try: - test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) - except Exception as unused_e: # pylint: disable=broad-except - test_info = None - valtest_info = None - xinfo['comment'] = 'In this dict, train-loss/accuracy/time is the metric on the train set of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train set by {:} epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.'.format(hp) - else: - if dataset == 'cifar10': - xinfo['comment'] = 'In this dict, train-loss/accuracy/time is the metric on the train+valid sets of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train+valid sets by {:} epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.'.format(hp) - try: # collect results on the proposed test set - if dataset == 'cifar10': - test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) - else: - test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) - except Exception as unused_e: # pylint: disable=broad-except - test_info = None - try: # collect results on the proposed validation set - valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) - except Exception as unused_e: # pylint: disable=broad-except - valid_info = None - try: - if dataset != 'cifar10': - valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) - else: - valtest_info = None - except Exception as unused_e: # pylint: disable=broad-except - valtest_info = None - if valid_info is not None: - xinfo['valid-loss'] = valid_info['loss'] - xinfo['valid-accuracy'] = valid_info['accuracy'] - xinfo['valid-per-time'] = valid_info['all_time'] / total if valid_info['all_time'] is not None else None - xinfo['valid-all-time'] = valid_info['all_time'] - if test_info is not None: - xinfo['test-loss'] = test_info['loss'] - xinfo['test-accuracy'] = test_info['accuracy'] - xinfo['test-per-time'] = test_info['all_time'] / total if test_info['all_time'] is not None else None - xinfo['test-all-time'] = test_info['all_time'] - if valtest_info is not None: - xinfo['valtest-loss'] = valtest_info['loss'] - xinfo['valtest-accuracy'] = valtest_info['accuracy'] - xinfo['valtest-per-time'] = valtest_info['all_time'] / total if valtest_info['all_time'] is not None else None - xinfo['valtest-all-time'] = valtest_info['all_time'] - return xinfo - - def show(self, index: int = -1) -> None: - """This function will print the information of a specific (or all) architecture(s).""" - self._show(index, print_information) - - @staticmethod - def str2lists(arch_str: Text) -> List[Any]: - """Shows how to read the string-based architecture encoding. - - Args: - arch_str: the input is a string indicates the architecture topology, such as - |nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2| - Returns: - a list of tuple, contains multiple (op, input_node_index) pairs. - - [USAGE] - It is the same as the `str2structure` func in AutoDL-Projects: - `github.com/D-X-Y/AutoDL-Projects/lib/models/cell_searchs/genotypes.py` - ``` - arch = api.str2lists( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' ) - print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list - for i, node in enumerate(arch): - print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node)) - ``` - """ - node_strs = arch_str.split('+') - genotypes = [] - for unused_i, node_str in enumerate(node_strs): - inputs = list(filter(lambda x: x != '', node_str.split('|'))) # pylint: disable=g-explicit-bool-comparison - for xinput in inputs: - assert len( - xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) - inputs = (xi.split('~') for xi in inputs) - input_infos = tuple((op, int(idx)) for (op, idx) in inputs) - genotypes.append(input_infos) - return genotypes - - @staticmethod - def str2matrix(arch_str: Text, - search_space: List[Text] = ('none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3')) -> np.ndarray: - """Convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101. - - Args: - arch_str: the input is a string indicates the architecture topology, such as - |nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2| - search_space: a list of operation string, the default list is the topology search space for NATS-BENCH. - the default value should be be consistent with this line https://github.com/D-X-Y/AutoDL-Projects/blob/main/lib/models/cell_operations.py#L24 - - Returns: - the numpy matrix (2-D np.ndarray) representing the DAG of this architecture topology - - [USAGE] - matrix = api.str2matrix( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' ) - This matrix is 4-by-4 matrix representing a cell with 4 nodes (only the lower left triangle is useful). - [ [0, 0, 0, 0], # the first line represents the input (0-th) node - [2, 0, 0, 0], # the second line represents the 1-st node, is calculated by 2-th-op( 0-th-node ) - [0, 0, 0, 0], # the third line represents the 2-nd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) - [0, 0, 1, 0] ] # the fourth line represents the 3-rd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) + 1-th-op( 2-th-node ) - In the topology search space in NATS-BENCH, 0-th-op is 'none', 1-th-op is 'skip_connect', - 2-th-op is 'nor_conv_1x1', 3-th-op is 'nor_conv_3x3', 4-th-op is 'avg_pool_3x3'. - [NOTE] - If a node has two input-edges from the same node, this function does not work. One edge will be overlapped. - """ - node_strs = arch_str.split('+') - num_nodes = len(node_strs) + 1 - matrix = np.zeros((num_nodes, num_nodes)) - for i, node_str in enumerate(node_strs): - inputs = list(filter(lambda x: x != '', node_str.split('|'))) # pylint: disable=g-explicit-bool-comparison - for xinput in inputs: - assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) - for xi in inputs: - op, idx = xi.split('~') - if op not in search_space: raise ValueError('this op ({:}) is not in {:}'.format(op, search_space)) - op_idx, node_idx = search_space.index(op), int(idx) - matrix[i+1, node_idx] = op_idx - return matrix diff --git a/lib/nats_bench/api_utils.py b/lib/nats_bench/api_utils.py deleted file mode 100644 index 06b5617..0000000 --- a/lib/nats_bench/api_utils.py +++ /dev/null @@ -1,1217 +0,0 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 # -############################################################################## -# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # -############################################################################## -"""In this file, we define NASBenchMetaAPI, ArchResults, and ResultsCount. - - NASBenchMetaAPI is the abstract class for benchmark APIs. - We also define the class ArchResults, which contains all - information of a single architecture trained by one kind of hyper-parameters - on three datasets. We also define the class ResultsCount, which contains all - information of a single trial for a single architecture. -""" -import abc -import bz2 -import collections -import copy -import os -import pickle -import random -import time -from typing import Any, Dict, Optional, Text, Union -import warnings - -import numpy as np - - -_FILE_SYSTEM = 'default' -PICKLE_EXT = 'pickle.pbz2' - - -def time_string(): - iso_time_format = '%Y-%m-%d %X' - string = '[{:}]'.format( - time.strftime(iso_time_format, time.gmtime(time.time()))) - return string - - -def reset_file_system(lib: Text = 'default'): - global _FILE_SYSTEM - _FILE_SYSTEM = lib - - -def get_file_system(): - return _FILE_SYSTEM - - -def get_torch_home(): - if 'TORCH_HOME' in os.environ: - return os.environ['TORCH_HOME'] - elif 'HOME' in os.environ: - return os.path.join(os.environ['HOME'], '.torch') - else: - raise ValueError('Did not find HOME in os.environ. ' - 'Please at least setup the path of HOME or TORCH_HOME ' - 'in the environment.') - - -def nats_is_dir(file_path): - if _FILE_SYSTEM == 'default': - return os.path.isdir(file_path) - elif _FILE_SYSTEM == 'google': - import tensorflow as tf # pylint: disable=g-import-not-at-top - return tf.io.gfile.isdir(file_path) - else: - raise ValueError('Unknown file system lib: {:}'.format(_FILE_SYSTEM)) - - -def nats_is_file(file_path): - if _FILE_SYSTEM == 'default': - return os.path.isfile(file_path) - elif _FILE_SYSTEM == 'google': - import tensorflow as tf # pylint: disable=g-import-not-at-top - return tf.io.gfile.exists(file_path) and not tf.io.gfile.isdir(file_path) - else: - raise ValueError('Unknown file system lib: {:}'.format(_FILE_SYSTEM)) - - -def pickle_save(obj, file_path, ext='.pbz2', protocol=4): - """Use pickle to save data (obj) into file_path. - - Args: - obj: The object to be saved into a path. - file_path: The target saving path. - ext: The extension of file name. - protocol: The pickle protocol. According to this documentation - (https://docs.python.org/3/library/pickle.html#data-stream-format), - the protocol version 4 was added in Python 3.4. It adds support for very - large objects, pickling more kinds of objects, and some data format - optimizations. It is the default protocol starting with Python 3.8. - """ - # with open(file_path, 'wb') as cfile: - if _FILE_SYSTEM == 'default': - with bz2.BZ2File(str(file_path) + ext, 'wb') as cfile: - pickle.dump(obj, cfile, protocol=protocol) # pytype: disable=wrong-arg-types - else: - raise ValueError('Unknown file system lib: {:}'.format(_FILE_SYSTEM)) - - -def pickle_load(file_path, ext='.pbz2'): - """Use pickle to load the file on different systems.""" - # return pickle.load(open(file_path, "rb")) - if nats_is_file(str(file_path)): - xfile_path = str(file_path) - else: - xfile_path = str(file_path) + ext - if _FILE_SYSTEM == 'default': - with bz2.BZ2File(xfile_path, 'rb') as cfile: - return pickle.load(cfile) # pytype: disable=wrong-arg-types - elif _FILE_SYSTEM == 'google': - import tensorflow as tf # pylint: disable=g-import-not-at-top - file_content = tf.io.gfile.GFile(file_path, mode='rb').read() - byte_content = bz2.decompress(file_content) - return pickle.loads(byte_content) - else: - raise ValueError('Unknown file system lib: {:}'.format(_FILE_SYSTEM)) - - -def remap_dataset_set_names(dataset, metric_on_set, verbose=False): - """Re-map the metric_on_set to internal keys.""" - if verbose: - print('Call internal function _remap_dataset_set_names with dataset={:} ' - 'and metric_on_set={:}'.format(dataset, metric_on_set)) - if dataset == 'cifar10' and metric_on_set == 'valid': - dataset, metric_on_set = 'cifar10-valid', 'x-valid' - elif dataset == 'cifar10' and metric_on_set == 'test': - dataset, metric_on_set = 'cifar10', 'ori-test' - elif dataset == 'cifar10' and metric_on_set == 'train': - dataset, metric_on_set = 'cifar10', 'train' - elif (dataset == 'cifar100' or - dataset == 'ImageNet16-120') and metric_on_set == 'valid': - metric_on_set = 'x-valid' - elif (dataset == 'cifar100' or - dataset == 'ImageNet16-120') and metric_on_set == 'test': - metric_on_set = 'x-test' - if verbose: - print(' return dataset={:} and metric_on_set={:}'.format( - dataset, metric_on_set)) - return dataset, metric_on_set - - -class NASBenchMetaAPI(metaclass=abc.ABCMeta): - """The abstract class for NATS Bench API.""" - - @abc.abstractmethod - def __init__(self, - file_path_or_dict: Optional[Union[Text, Dict[Text, Any]]] = None, - fast_mode: bool = False, - verbose: bool = True): - """The initialization function that takes the dataset file path (or a dict loaded from that path) as input.""" - # NOTE(xuanyidong): the following attributes must be initilaized in subclass - self.meta_archs = None - self.verbose = None - self.evaluated_indexes = None - self.arch2infos_dict = None - self.filename = None - self._fast_mode = None - self._archive_dir = None - self._avaliable_hps = None - self.archstr2index = None - - def __getitem__(self, index: int): - return copy.deepcopy(self.meta_archs[index]) - - def arch(self, index: int): - """Return the topology structure of the `index`-th architecture.""" - if self.verbose: - print('Call the arch function with index={:}'.format(index)) - if index < 0 or index >= len(self.meta_archs): - raise ValueError('invalid index : {:} vs. {:}.'.format( - index, len(self.meta_archs))) - return copy.deepcopy(self.meta_archs[index]) - - def __len__(self): - return len(self.meta_archs) - - def __repr__(self): - return ('{name}({num}/{total} architectures, fast_mode={fast_mode}, ' - 'file={filename})'.format( - name=self.__class__.__name__, - num=len(self.evaluated_indexes), total=len(self.meta_archs), - fast_mode=self.fast_mode, filename=self.filename)) - - @property - def avaliable_hps(self): - return list(copy.deepcopy(self._avaliable_hps)) - - @property - def used_time(self): - return self._used_time - - @property - def search_space_name(self): - return self._search_space_name - - @property - def fast_mode(self): - return self._fast_mode - - @property - def archive_dir(self): - return self._archive_dir - - @property - def full_train_epochs(self): - return self._full_train_epochs - - def reset_archive_dir(self, archive_dir): - self._archive_dir = archive_dir - - def reset_fast_mode(self, fast_mode): - self._fast_mode = fast_mode - - def reset_time(self): - self._used_time = 0 - - @abc.abstractmethod - def get_more_info(self, - index, - dataset, - iepoch=None, - hp: Text = '12', - is_random: bool = True): - """Return the metric for the `index`-th architecture.""" - - def simulate_train_eval(self, - arch, - dataset, - iepoch=None, - hp='12', - account_time=True): - """This function is used to simulate training and evaluating an arch.""" - index = self.query_index_by_arch(arch) - all_names = ('cifar10', 'cifar100', 'ImageNet16-120') - if dataset not in all_names: - raise ValueError('Invalid dataset name : {:} vs {:}'.format( - dataset, all_names)) - if dataset == 'cifar10': - info = self.get_more_info( - index, 'cifar10-valid', iepoch=iepoch, hp=hp, is_random=True) - else: - info = self.get_more_info( - index, dataset, iepoch=iepoch, hp=hp, is_random=True) - valid_acc, time_cost = info[ - 'valid-accuracy'], info['train-all-time'] + info['valid-per-time'] - latency = self.get_latency(index, dataset) - if account_time: - self._used_time += time_cost - return valid_acc, latency, time_cost, self._used_time - - def random(self): - """Return a random index of all architectures.""" - return random.randint(0, len(self.meta_archs)-1) - - def reload(self, archive_root: Text = None, index: int = None): - """Overwrite all information of the 'index'-th architecture in search space. - - Args: - archive_root: If archive_root is None, it will try to load from the - default path os.environ['TORCH_HOME'] / 'BASE_NAME'-full. - index: If index is None, overwrite all ckps. - """ - if self.verbose: - print('{:} Call clear_params with archive_root={:} and index={:}'.format( - time_string(), archive_root, index)) - if archive_root is None: - archive_root = os.path.join(os.environ['TORCH_HOME'], - '{:}-full'.format(self._all_base_names[-1])) - if not nats_is_dir(archive_root): - warnings.warn('The input archive_root is None and the default ' - 'archive_root path ({:}) does not exist, try to use ' - 'self.archive_dir.'.format(archive_root)) - archive_root = self.archive_dir - if archive_root is None or not nats_is_dir(archive_root): - raise ValueError('Invalid archive_root : {:}'.format(archive_root)) - if index is None: - indexes = list(range(len(self))) - else: - indexes = [index] - for idx in indexes: - if not (0 <= idx < len(self.meta_archs)): # pylint: disable=superfluous-parens - raise ValueError('invalid index of {:}'.format(idx)) - xfile_path = os.path.join(archive_root, - '{:06d}.{:}'.format(idx, PICKLE_EXT)) - if not nats_is_file(xfile_path): - xfile_path = os.path.join(archive_root, - '{:d}.{:}'.format(idx, PICKLE_EXT)) - assert nats_is_file(xfile_path), 'invalid data path : {:}'.format( - xfile_path) - xdata = pickle_load(xfile_path) - assert isinstance(xdata, dict), 'invalid format of data in {:}'.format( - xfile_path) - self.evaluated_indexes.add(idx) - hp2archres = collections.OrderedDict() - for hp_key, results in xdata.items(): - hp2archres[hp_key] = ArchResults.create_from_state_dict(results) - self._avaliable_hps.add(hp_key) - self.arch2infos_dict[idx] = hp2archres - - def query_index_by_arch(self, arch): - """Query the index of an architecture in the search space. - - Args: - arch: For topology search space, the input arch can be an architecture - string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|'; # pylint: disable=line-too-long - or an instance that has the 'tostr' function that can - generate the architecture string; - or it is directly an architecture index, in this case, - we will check whether it is valid or not. - This function will return the index. - If return -1, it means this architecture is not in the search space. - Otherwise, it will return an intenger in - [0, the-number-of-candidates-in-the-search-space). - - Raises: - ValueError: If did not find the architecture in this benchmark. - - Returns: - The index of the architcture in this benchmark. - """ - if self.verbose: - print('{:} Call query_index_by_arch with arch={:}'.format( - time_string(), arch)) - if isinstance(arch, int): - if 0 <= arch < len(self): - return arch - else: - raise ValueError('Invalid architecture index {:} vs [{:}, {:}].'.format( - arch, 0, len(self))) - elif isinstance(arch, str): - if arch in self.archstr2index: - arch_index = self.archstr2index[arch] - else: - arch_index = -1 - elif hasattr(arch, 'tostr'): - if arch.tostr() in self.archstr2index: - arch_index = self.archstr2index[arch.tostr()] - else: - arch_index = -1 - else: - arch_index = -1 - return arch_index - - def query_by_arch(self, arch, hp): - """Make the current version be compatible with the old NAS-Bench-201 version.""" - return self.query_info_str_by_arch(arch, hp) - - def _prepare_info(self, index): - """This is a function to load the data from disk when using fast mode.""" - if index not in self.arch2infos_dict: - if self.fast_mode and self.archive_dir is not None: - self.reload(self.archive_dir, index) - elif not self.fast_mode: - if self.verbose: - print('{:} Call _prepare_info with index={:} skip because it is not' - 'the fast mode.'.format(time_string(), index)) - else: - raise ValueError('Invalid status: fast_mode={:} and ' - 'archive_dir={:}'.format( - self.fast_mode, self.archive_dir)) - else: - if index not in self.evaluated_indexes: - raise ValueError('The index of {:} is not in self.evaluated_indexes, ' - 'there must be something wrong.'.format(index)) - if self.verbose: - print('{:} Call _prepare_info with index={:} skip because it is in ' - 'arch2infos_dict'.format(time_string(), index)) - - def clear_params(self, index: int, hp: Optional[Text] = None): - """Remove the architecture's weights to save memory. - - Args: - index: the index of the target architecture - hp: a flag to controll how to clear the parameters. - -- None: clear all the weights in '01'/'12'/'90', which indicates - the number of training epochs. - -- '01' or '12' or '90': clear all the weights in - arch2infos_dict[index][hp]. - """ - if self.verbose: - print('{:} Call clear_params with index={:} and hp={:}'.format( - time_string(), index, hp)) - if index not in self.arch2infos_dict: - warnings.warn('The {:}-th architecture is not in the benchmark data yet, ' - 'no need to clear params.'.format(index)) - elif hp is None: - for key, result in self.arch2infos_dict[index].items(): - result.clear_params() - else: - if str(hp) not in self.arch2infos_dict[index]: - raise ValueError('The {:}-th architecture only has hyper-parameters ' - 'of {:} instead of {:}.'.format( - index, list(self.arch2infos_dict[index].keys()), - hp)) - self.arch2infos_dict[index][str(hp)].clear_params() - - @abc.abstractmethod - def query_info_str_by_arch(self, arch, hp: Text = '12'): - """This function is used to query the information of a specific architecture.""" - - def _query_info_str_by_arch(self, - arch, - hp: Text = '12', - print_information=None): - """Internal function to query the information of `arch` when using `hp`.""" - arch_index = self.query_index_by_arch(arch) - self._prepare_info(arch_index) - if arch_index in self.arch2infos_dict: - if hp not in self.arch2infos_dict[arch_index]: - raise ValueError('The {:}-th architecture only has hyper-parameters of ' - '{:} instead of {:}.'.format( - arch_index, - list(self.arch2infos_dict[arch_index].keys()), hp)) - info = self.arch2infos_dict[arch_index][hp] - strings = print_information(info, 'arch-index={:}'.format(arch_index)) - return '\n'.join(strings) - else: - warnings.warn('Find this arch-index : {:}, but this arch is not ' - 'evaluated.'.format(arch_index)) - return None - - def query_meta_info_by_index(self, arch_index, hp: Text = '12'): - """Return ArchResults for the 'arch_index'-th architecture.""" - if self.verbose: - print('Call query_meta_info_by_index with arch_index={:}, hp={:}'.format( - arch_index, hp)) - self._prepare_info(arch_index) - if arch_index in self.arch2infos_dict: - if str(hp) not in self.arch2infos_dict[arch_index]: - raise ValueError('The {:}-th architecture only has hyper-parameters of ' - '{:} instead of {:}.'.format( - arch_index, - list(self.arch2infos_dict[arch_index].keys()), - hp)) - info = self.arch2infos_dict[arch_index][str(hp)] - else: - raise ValueError('arch_index [{:}] does not in arch2infos'.format( - arch_index)) - return copy.deepcopy(info) - - def query_by_index(self, - arch_index: int, - dataname: Union[None, Text] = None, - hp: Text = '12'): - """Query the information with the training of 01/12/90/200 epochs. - - Args: - arch_index: The architecture index in this benchmark. - dataname: If dataname is None, return the ArchResults; otherwise, we will - return a dict with all trials on that dataset - (the key is the seed). - Options are 'cifar10-valid', 'cifar10', 'cifar100', - and 'ImageNet16-120'. - -- cifar10-valid : train the model on CIFAR-10 training set. - -- cifar10 : train the model on CIFAR-10 training + validation set. - -- cifar100 : train the model on CIFAR-100 training set. - -- ImageNet16-120 : train the model on ImageNet16-120 training set. - hp: The hyperparameters. - If hp=01, we train the model by 01 epochs. - If hp=12, we train the model by 01 epochs. - If hp=90, we train the model by 01 epochs. - If hp=200, we train the model by 01 epochs. - See github.com/D-X-Y/AutoDL-Projects/configs/nas-benchmark/hyper-opts - for more details. - - Raises: - ValueError: If not find the matched serach space description. - - Returns: - An instance fo ArchResults. - """ - if self.verbose: - print('{:} Call query_by_index with arch_index={:}, dataname={:}, ' - 'hp={:}'.format(time_string(), arch_index, dataname, hp)) - info = self.query_meta_info_by_index(arch_index, str(hp)) - if dataname is None: - return info - else: - if dataname not in info.get_dataset_names(): - raise ValueError('invalid dataset-name : {:} vs. {:}'.format( - dataname, info.get_dataset_names())) - return info.query(dataname) - - def find_best(self, - dataset, - metric_on_set, - flop_max=None, - param_max=None, - hp: Text = '12'): - """Find the architecture with the highest accuracy based on some constraints.""" - if self.verbose: - print('{:} Call find_best with dataset={:}, metric_on_set={:}, hp={:} ' - '| with #FLOPs < {:} and #Params < {:}'.format( - time_string(), dataset, metric_on_set, hp, flop_max, param_max)) - dataset, metric_on_set = remap_dataset_set_names( - dataset, metric_on_set, self.verbose) - best_index, highest_accuracy = -1, None - evaluated_indexes = sorted(list(self.evaluated_indexes)) - for arch_index in evaluated_indexes: - self._prepare_info(arch_index) - arch_info = self.arch2infos_dict[arch_index][hp] - info = arch_info.get_compute_costs(dataset) # the information of costs - flop, param, latency = info['flops'], info['params'], info['latency'] - if flop_max is not None and flop > flop_max: - continue - if param_max is not None and param > param_max: - continue - xinfo = arch_info.get_metrics( - dataset, metric_on_set) # the information of loss and accuracy - loss, accuracy = xinfo['loss'], xinfo['accuracy'] - if best_index == -1: - best_index, highest_accuracy = arch_index, accuracy - elif highest_accuracy < accuracy: - best_index, highest_accuracy = arch_index, accuracy - del latency, loss - if self.verbose: - print(' the best architecture : [{:}] {:} with accuracy={:.3f}%'.format( - best_index, self.arch(best_index), highest_accuracy)) - return best_index, highest_accuracy - - def get_net_param(self, index, dataset, seed: Optional[int], hp: Text = '12'): - """Obtain the trained weights of the `index`-th arch on `dataset`. - - Args: - index: The architecture index. - dataset: The training dataset name. - seed: - -- None : return a dict containing the trained weights of all trials, - where each key is a seed and its corresponding value - is the weights. - -- Interger : return the weights of a specific trial, whose seed - is this interger. - hp: - -- 01 : train the model by 01 epochs - -- 12 : train the model by 12 epochs - -- 90 : train the model by 90 epochs - -- 200 : train the model by 200 epochs - Returns: - PyTorch weights. - """ - if self.verbose: - print('{:} Call the get_net_param function with index={:}, dataset={:}, ' - 'seed={:}, hp={:}'.format(time_string(), index, dataset, seed, hp)) - info = self.query_meta_info_by_index(index, hp) - return info.get_net_param(dataset, seed) - - def get_net_config(self, index: int, dataset: Text): - """Obtain the configuration for the `index`-th architecture on `dataset`. - - Args: - index: The architecture index. - dataset: 4 possible options as follows, - -- cifar10-valid : train the model on the CIFAR-10 training set. - -- cifar10 : train the model on the CIFAR-10 training + validation set. - -- cifar100 : train the model on the CIFAR-100 training set. - -- ImageNet16-120 : train the model on the ImageNet16-120 training set. - Returns: - A dict. - - Note: some examlpes for using this function: - config = api.get_net_config(128, 'cifar10') - """ - if self.verbose: - print('{:} Call the get_net_config function with index={:}, ' - 'dataset={:}.'.format(time_string(), index, dataset)) - self._prepare_info(index) - if index in self.arch2infos_dict: - info = self.arch2infos_dict[index] - else: - raise ValueError( - 'The arch_index={:} is not in arch2infos_dict.'.format(index)) - info = next(iter(info.values())) - results = info.query(dataset, None) - results = next(iter(results.values())) - return results.get_config(None) - - def get_cost_info(self, - index: int, - dataset: Text, - hp: Text = '12') -> Dict[Text, float]: - """To obtain the cost metric for the `index`-th architecture on a dataset.""" - if self.verbose: - print('{:} Call the get_cost_info function with index={:}, ' - 'dataset={:}, and hp={:}.'.format( - time_string(), index, dataset, hp)) - self._prepare_info(index) - info = self.query_meta_info_by_index(index, hp) - return info.get_compute_costs(dataset) - - def get_latency(self, index: int, dataset: Text, hp: Text = '12') -> float: - """Obtain the latency of the network. - - Note: by default it will return the latency with the batch size of 256. - Args: - index: the index of the target architecture - dataset: the dataset name (cifar10-valid, cifar10, cifar100, - and ImageNet16-120) - hp: the hyperparamete indicator. - - Returns: - return a float value in seconds - """ - if self.verbose: - print('{:} Call the get_latency function with index={:}, ' - 'dataset={:}, and hp={:}.'.format( - time_string(), index, dataset, hp)) - cost_dict = self.get_cost_info(index, dataset, hp) - return cost_dict['latency'] - - @abc.abstractmethod - def show(self, index=-1): - """This function will print the information of a specific (or all) architecture(s).""" - - def _show(self, index=-1, print_information=None) -> None: - """Print the information of a specific (or all) architecture(s). - - Args: - index: If the index < 0: it will loop for all architectures and print - their information one by one. Else: it will print the information - of the 'index'-th architecture. - - print_information: A function to print result. - - Returns: None - """ - if index < 0: # show all architectures - print(self) - evaluated_indexes = sorted(list(self.evaluated_indexes)) - for i, idx in enumerate(evaluated_indexes): - print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th ' - 'architecture! '.format(i, len(evaluated_indexes), idx) + '-'*10) - print('arch : {:}'.format(self.meta_archs[idx])) - for unused_key, result in self.arch2infos_dict[index].items(): - strings = print_information(result) - print('>' * 40 + ' {:03d} epochs '.format( - result.get_total_epoch()) + '>' * 40) - print('\n'.join(strings)) - print('<' * 40 + '------------' + '<' * 40) - else: - if 0 <= index < len(self.meta_archs): - if index not in self.evaluated_indexes: - self._prepare_info(index) - if index not in self.evaluated_indexes: - print('The {:}-th architecture has not been evaluated ' - 'or not saved.'.format(index)) - else: - # arch_info = self.arch2infos_dict[index] - for unused_key, result in self.arch2infos_dict[index].items(): - strings = print_information(result) - print('>' * 40 + ' {:03d} epochs '.format( - result.get_total_epoch()) + '>' * 40) - print('\n'.join(strings)) - print('<' * 40 + '------------' + '<' * 40) - else: - print('This index ({:}) is out of range (0~{:}).'.format( - index, len(self.meta_archs))) - - def statistics(self, dataset: Text, hp: Union[Text, int]) -> Dict[int, int]: - """This function will count the number of total trials.""" - if self.verbose: - print('Call the statistics function with dataset={:} and hp={:}.'.format( - dataset, hp)) - valid_datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'] - if dataset not in valid_datasets: - raise ValueError('{:} not in {:}'.format(dataset, valid_datasets)) - nums, hp = collections.defaultdict(lambda: 0), str(hp) - # for index in range(len(self)): - for index in self.evaluated_indexes: - arch_info = self.arch2infos_dict[index][hp] - dataset_seed = arch_info.dataset_seed - if dataset not in dataset_seed: - nums[0] += 1 - else: - nums[len(dataset_seed[dataset])] += 1 - return dict(nums) - - -class ArchResults(object): - """A class to maintain the results of an architecture under different settings.""" - - def __init__(self, arch_index, arch_str): - self.arch_index = int(arch_index) - self.arch_str = copy.deepcopy(arch_str) - self.all_results = dict() - self.dataset_seed = dict() - self.clear_net_done = False - - def get_compute_costs(self, dataset): - """Return the computation cost on the input dataset.""" - x_seeds = self.dataset_seed[dataset] - results = [self.all_results[(dataset, seed)] for seed in x_seeds] - - flops = [result.flop for result in results] - params = [result.params for result in results] - latencies = [result.get_latency() for result in results] - latencies = [x for x in latencies if x > 0] - mean_latency = np.mean(latencies) if len(latencies) else None - time_infos = collections.defaultdict(list) - for result in results: - time_info = result.get_times() - for key, value in time_info.items(): - time_infos[key].append(value) - - info = { - 'flops': np.mean(flops), - 'params': np.mean(params), - 'latency': mean_latency - } - for key, value in time_infos.items(): - if len(value) and value[0] is not None: - info[key] = np.mean(value) - else: - info[key] = None - return info - - def get_metrics(self, dataset, setname, iepoch=None, is_random=False): - """Obtain the loss, accuracy, etc information on a specific dataset. - - If not specify, each set refer to the proposed split in NAS-Bench-201. - If some args return None or raise error, then it is not avaliable. - ======================================== - - Args: - dataset: 4 possible options as follows - -- cifar10-valid : train the model on the CIFAR-10 training set. - -- cifar10 : train the model on the CIFAR-10 training + validation set. - -- cifar100 : train the model on the CIFAR-100 training set. - -- ImageNet16-120 : train the model on the ImageNet16-120 training set. - setname: each dataset has different setnames - -- When dataset = cifar10-valid, you can use 'train', - 'x-valid', and 'ori-test' - ------ 'train' : the metric on the training set. - ------ 'x-valid' : the metric on the validation set. - ------ 'ori-test' : the metric on the test set. - -- When dataset = cifar10, you can use 'train', 'ori-test'. - ------ 'train' : the metric on the training + validation set. - ------ 'ori-test' : the metric on the test set. - -- When dataset = cifar100 or ImageNet16-120, you can use 'train', - 'ori-test', 'x-valid', and 'x-test' - ------ 'train' : the metric on the training set. - ------ 'x-valid' : the metric on the validation set. - ------ 'x-test' : the metric on the test set. - ------ 'ori-test' : the metric on the validation + test set. - iepoch: (None or an integer in [0, the-number-of-total-training-epochs) - ------ None : return the metric after the last training epoch. - ------ an integer i : return the metric after the i-th training epoch. - is_random: - ------ True : return the metric of a randomly selected trial. - ------ False : return the averaged metric of all avaliable trials. - ------ an integer indicating the 'seed' value : return the metric of a - specific trial (whose random seed is 'is_random'). - - Returns: - All the metrics given the input setting. - """ - x_seeds = self.dataset_seed[dataset] - results = [self.all_results[(dataset, seed)] for seed in x_seeds] - infos = collections.defaultdict(list) - for result in results: - if setname == 'train': - info = result.get_train(iepoch) - else: - info = result.get_eval(setname, iepoch) - for key, value in info.items(): - infos[key].append(value) - return_info = dict() - if isinstance(is_random, bool) and is_random: # randomly select one - index = random.randint(0, len(results)-1) - for key, value in infos.items(): - return_info[key] = value[index] - elif isinstance(is_random, bool) and not is_random: # average - for key, value in infos.items(): - if len(value) and value[0] is not None: - return_info[key] = np.mean(value) - else: - return_info[key] = None - elif isinstance(is_random, int): # specify the seed - if is_random not in x_seeds: - raise ValueError('can not find random seed ({:}) from {:}'.format( - is_random, x_seeds)) - index = x_seeds.index(is_random) - for key, value in infos.items(): - return_info[key] = value[index] - else: - raise ValueError('invalid value for is_random: {:}'.format(is_random)) - return return_info - - # def show(self, is_print=False): - # return print_information(self, None, is_print) - - def get_dataset_names(self): - return list(self.dataset_seed.keys()) - - def get_dataset_seeds(self, dataset): - return copy.deepcopy(self.dataset_seed[dataset]) - - def get_net_param(self, dataset: Text, seed: Union[None, int] = None): - """Return the trained network's weights on the 'dataset'. - - Args: - dataset: 'cifar10-valid', 'cifar10', 'cifar100', or 'ImageNet16-120'. - seed: an integer indicates the seed value - or None that indicates returing all trials. - - Returns: - The trained weights (parameters). - """ - if seed is None: - x_seeds = self.dataset_seed[dataset] - return { - seed: self.all_results[(dataset, seed)].get_net_param() - for seed in x_seeds - } - else: - xkey = (dataset, seed) - if xkey in self.all_results: - return self.all_results[xkey].get_net_param() - else: - raise ValueError('key={:} not in {:}'.format( - xkey, list(self.all_results.keys()))) - - def reset_latency(self, dataset: Text, seed: Union[None, Text], - latency: float) -> None: - """This function is used to reset the latency in all corresponding ResultsCount(s).""" - if seed is None: - for seed in self.dataset_seed[dataset]: - self.all_results[(dataset, seed)].update_latency([latency]) - else: - self.all_results[(dataset, seed)].update_latency([latency]) - - def reset_pseudo_train_times(self, dataset: Text, seed: Union[None, Text], - estimated_per_epoch_time: float) -> None: - """This function is used to reset the train-times in all corresponding ResultsCount(s).""" - if seed is None: - for seed in self.dataset_seed[dataset]: - self.all_results[( - dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time) - else: - self.all_results[( - dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time) - - def reset_pseudo_eval_times(self, dataset: Text, seed: Union[None, Text], - eval_name: Text, - estimated_per_epoch_time: float) -> None: - """This function is used to reset the eval-times in all corresponding ResultsCount(s).""" - if seed is None: - for seed in self.dataset_seed[dataset]: - self.all_results[(dataset, seed)].reset_pseudo_eval_times( - eval_name, estimated_per_epoch_time) - else: - self.all_results[(dataset, seed)].reset_pseudo_eval_times( - eval_name, estimated_per_epoch_time) - - def get_latency(self, dataset: Text) -> float: - """Get the latency of a model on the target dataset.""" - latencies = [] - for seed in self.dataset_seed[dataset]: - latency = self.all_results[(dataset, seed)].get_latency() - if not isinstance(latency, float) or latency <= 0: - raise ValueError('invalid latency of {:} with seed={:} : {:}'.format( - dataset, seed, latency)) - latencies.append(latency) - return sum(latencies) / len(latencies) - - def get_total_epoch(self, dataset=None): - """Return the total number of training epochs.""" - if dataset is None: - epochss = [] - for xdata, x_seeds in self.dataset_seed.items(): - epochss += [ - self.all_results[(xdata, seed)].get_total_epoch() - for seed in x_seeds - ] - elif isinstance(dataset, str): - x_seeds = self.dataset_seed[dataset] - epochss = [ - self.all_results[(dataset, seed)].get_total_epoch() - for seed in x_seeds - ] - else: - raise ValueError('invalid dataset={:}'.format(dataset)) - if len(set(epochss)) > 1: - raise ValueError( - 'Each trial mush have the same number of training epochs : {:}' - .format(epochss)) - return epochss[-1] - - def query(self, dataset, seed=None): - """Return the ResultsCount object (containing all information of a single trial) for 'dataset' and 'seed'.""" - if seed is None: - x_seeds = self.dataset_seed[dataset] - return {seed: self.all_results[(dataset, seed)] for seed in x_seeds} - else: - return self.all_results[(dataset, seed)] - - def arch_idx_str(self): - return '{:06d}'.format(self.arch_index) - - def update(self, dataset_name, seed, result): - """Update the result for the given dataset and seed.""" - if dataset_name not in self.dataset_seed: - self.dataset_seed[dataset_name] = [] - if seed in self.dataset_seed[dataset_name]: - raise ValueError('{:}-th arch alreadly has this seed ({:}) on {:}'.format( - self.arch_index, seed, dataset_name)) - self.dataset_seed[dataset_name].append(seed) - self.dataset_seed[dataset_name] = sorted(self.dataset_seed[dataset_name]) - assert (dataset_name, seed) not in self.all_results - self.all_results[(dataset_name, seed)] = result - self.clear_net_done = False - - def state_dict(self): - """Return a dict that can be used to re-create this instance.""" - state_dict = dict() - for key, value in self.__dict__.items(): - if key == 'all_results': # contain the class of ResultsCount - xvalue = dict() - if not isinstance(value, dict): - raise ValueError('invalid type of value for {:} : {:}'.format( - key, type(value))) - for cur_k, cur_v in value.items(): - if not isinstance(cur_v, ResultsCount): - raise ValueError('invalid type of value for {:}/{:} : {:}'.format( - key, cur_k, type(cur_v))) - xvalue[cur_k] = cur_v.state_dict() - else: - xvalue = value - state_dict[key] = xvalue - return state_dict - - def load_state_dict(self, state_dict): - """Update self based on the input dict.""" - new_state_dict = dict() - for key, value in state_dict.items(): - if key == 'all_results': # To convert to the class of ResultsCount - xvalue = dict() - if not isinstance(value, dict): - raise ValueError('invalid type of value for {:} : {:}'.format( - key, type(value))) - for cur_k, cur_v in value.items(): - xvalue[cur_k] = ResultsCount.create_from_state_dict(cur_v) - else: xvalue = value - new_state_dict[key] = xvalue - self.__dict__.update(new_state_dict) - - @staticmethod - def create_from_state_dict(state_dict_or_file): - """Create the ArchResults instance from a dict or a file.""" - x = ArchResults(-1, -1) - if isinstance(state_dict_or_file, str): # a file path - state_dict = pickle_load(state_dict_or_file) - elif isinstance(state_dict_or_file, dict): - state_dict = state_dict_or_file - else: - raise ValueError('invalid type of state_dict_or_file : {:}'.format( - type(state_dict_or_file))) - x.load_state_dict(state_dict) - return x - - def clear_params(self): - """Clear the weights saved in each 'result'.""" - # NOTE(xuanyidong): This can help reduce the memory footprint. - for unused_key, result in self.all_results.items(): - del result.net_state_dict - result.net_state_dict = None - self.clear_net_done = True - - def debug_test(self): - """Help debug and test, which will call most methods.""" - all_dataset = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'] - for dataset in all_dataset: - print('---->>>> {:}'.format(dataset)) - print('The latency on {:} is {:} s'.format( - dataset, self.get_latency(dataset))) - for seed in self.dataset_seed[dataset]: - result = self.all_results[(dataset, seed)] - print(' ==>> result = {:}'.format(result)) - print(' ==>> cost = {:}'.format(result.get_times())) - - def __repr__(self): - return ('{name}(arch-index={index}, arch={arch}, ' - '{num} runs, clear={clear})'.format( - name=self.__class__.__name__, - index=self.arch_index, - arch=self.arch_str, - num=len(self.all_results), - clear=self.clear_net_done)) - - -class ResultsCount(object): - """ResultsCount is to save the information of one trial for a single architecture.""" - - def __init__(self, name, state_dict, train_accs, train_losses, params, flop, - arch_config, seed, epochs, latency): - self.name = name - self.net_state_dict = state_dict - self.train_acc1es = copy.deepcopy(train_accs) - self.train_acc5es = None - self.train_losses = copy.deepcopy(train_losses) - self.train_times = None - self.arch_config = copy.deepcopy(arch_config) - self.params = params - self.flop = flop - self.seed = seed - self.epochs = epochs - self.latency = latency - # evaluation results - self.reset_eval() - - def update_train_info(self, train_acc1es, train_acc5es, train_losses, - train_times) -> None: - self.train_acc1es = train_acc1es - self.train_acc5es = train_acc5es - self.train_losses = train_losses - self.train_times = train_times - - def reset_pseudo_train_times(self, estimated_per_epoch_time: float) -> None: - """Assign the training times.""" - train_times = collections.OrderedDict() - for i in range(self.epochs): - train_times[i] = estimated_per_epoch_time - self.train_times = train_times - - def reset_pseudo_eval_times( - self, eval_name: Text, estimated_per_epoch_time: float) -> None: - """Assign the evaluation times.""" - if eval_name not in self.eval_names: - raise ValueError('invalid eval name : {:}'.format(eval_name)) - for i in range(self.epochs): - self.eval_times['{:}@{:}'.format(eval_name, i)] = estimated_per_epoch_time - - def reset_eval(self): - self.eval_names = [] - self.eval_acc1es = {} - self.eval_times = {} - self.eval_losses = {} - - def update_latency(self, latency): - self.latency = copy.deepcopy(latency) - - def get_latency(self) -> float: - """Return the latency value in seconds.""" - # NOTE(xuanyidong): -1 represents not avaliable, - # NOTE(xuanyidong): otherwise it should be a float value. - if self.latency is None: - return -1.0 - else: - return sum(self.latency) / len(self.latency) - - def update_eval(self, accs, losses, times): - """To update the evaluataion results.""" - data_names = set([x.split('@')[0] for x in accs.keys()]) - for data_name in data_names: - if data_name in self.eval_names: - raise ValueError('{:} has already been added into ' - 'eval-names'.format(data_name)) - self.eval_names.append(data_name) - for iepoch in range(self.epochs): - xkey = '{:}@{:}'.format(data_name, iepoch) - self.eval_acc1es[xkey] = accs[xkey] - self.eval_losses[xkey] = losses[xkey] - self.eval_times[xkey] = times[xkey] - - def update_OLD_eval(self, name, accs, losses): # pylint: disable=invalid-name - """To update the evaluataion results (old NAS-Bench-201 version).""" - assert name not in self.eval_names, '{:} has already added'.format(name) - self.eval_names.append(name) - for iepoch in range(self.epochs): - if iepoch in accs: - self.eval_acc1es['{:}@{:}'.format(name, iepoch)] = accs[iepoch] - self.eval_losses['{:}@{:}'.format(name, iepoch)] = losses[iepoch] - - def __repr__(self): - num_eval = len(self.eval_names) - set_name = '[' + ', '.join(self.eval_names) + ']' - return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, ' - 'Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: ' - '{set_name})'.format(name=self.__class__.__name__, xname=self.name, - arch=self.arch_config['arch_str'], - flop=self.flop, param=self.params, - seed=self.seed, num_eval=num_eval, - set_name=set_name)) - - def get_total_epoch(self): - return copy.deepcopy(self.epochs) - - def get_times(self): - """Obtain the information regarding both training and evaluation time.""" - if self.train_times is not None and isinstance(self.train_times, dict): - train_times = list(self.train_times.values()) - time_info = { - 'T-train@epoch': np.mean(train_times), - 'T-train@total': np.sum(train_times) - } - else: - time_info = {'T-train@epoch': None, 'T-train@total': None} - for name in self.eval_names: - try: - xtimes = [ - self.eval_times['{:}@{:}'.format(name, i)] - for i in range(self.epochs) - ] - time_info['T-{:}@epoch'.format(name)] = np.mean(xtimes) - time_info['T-{:}@total'.format(name)] = np.sum(xtimes) - except Exception as unused_e: # pylint: disable=broad-except - time_info['T-{:}@epoch'.format(name)] = None - time_info['T-{:}@total'.format(name)] = None - return time_info - - def get_eval_set(self): - return self.eval_names - - def judge_valid(self, iepoch): - if iepoch < 0 or iepoch >= self.epochs: - raise ValueError('invalid iepoch={:} < {:}'.format(iepoch, self.epochs)) - - def get_train(self, iepoch=None): - """Get the training information.""" - if iepoch is None: iepoch = self.epochs-1 - self.judge_valid(iepoch) - if self.train_times is not None: - xtime = self.train_times[iepoch] - atime = sum([self.train_times[i] for i in range(iepoch+1)]) - else: - xtime, atime = None, None - return { - 'iepoch': iepoch, - 'loss': self.train_losses[iepoch], - 'accuracy': self.train_acc1es[iepoch], - 'cur_time': xtime, - 'all_time': atime - } - - def get_eval(self, name, iepoch=None): - """Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument).""" - if iepoch is None: - iepoch = self.epochs-1 - self.judge_valid(iepoch) - - def _internal_query(xname): - if isinstance(self.eval_times, dict) and len(self.eval_times): - xtime = self.eval_times['{:}@{:}'.format(xname, iepoch)] - atime = sum([ - self.eval_times['{:}@{:}'.format(xname, i)] - for i in range(iepoch + 1) - ]) - else: - xtime, atime = None, None - return { - 'iepoch': iepoch, - 'loss': self.eval_losses['{:}@{:}'.format(xname, iepoch)], - 'accuracy': self.eval_acc1es['{:}@{:}'.format(xname, iepoch)], - 'cur_time': xtime, - 'all_time': atime - } - - if name == 'valid': - return _internal_query('x-valid') - else: - return _internal_query(name) - - def get_net_param(self, clone=False): - if clone: - return copy.deepcopy(self.net_state_dict) - else: - return self.net_state_dict - - def get_config(self, str2structure): - """This function is used to obtain the config dict for this architecture.""" - if str2structure is None: - # In this case, this is an arch in size search space of NATS-BENCH. - if 'name' in self.arch_config and self.arch_config[ - 'name'] == 'infer.shape.tiny': - return { - 'name': 'infer.shape.tiny', - 'channels': self.arch_config['channels'], - 'genotype': self.arch_config['genotype'], - 'num_classes': self.arch_config['class_num'] - } - else: # This is an arch in NATS-BENCH's topology search space. - return { - 'name': 'infer.tiny', - 'C': self.arch_config['channel'], - 'N': self.arch_config['num_cells'], - 'arch_str': self.arch_config['arch_str'], - 'num_classes': self.arch_config['class_num'] - } - else: # This is an arch in the size search space of NATS-BENCH. - if 'name' in self.arch_config and self.arch_config[ - 'name'] == 'infer.shape.tiny': - return { - 'name': 'infer.shape.tiny', - 'channels': self.arch_config['channels'], - 'genotype': str2structure(self.arch_config['genotype']), - 'num_classes': self.arch_config['class_num'] - } - else: # This is an arch in the topology search space of NATS-BENCH. - return { - 'name': 'infer.tiny', - 'C': self.arch_config['channel'], - 'N': self.arch_config['num_cells'], - 'genotype': str2structure(self.arch_config['arch_str']), - 'num_classes': self.arch_config['class_num'] - } - - def state_dict(self): - collected_state_dict = {key: value for key, value in self.__dict__.items()} - return collected_state_dict - - def load_state_dict(self, state_dict): - self.__dict__.update(state_dict) - - @staticmethod - def create_from_state_dict(state_dict): - x = ResultsCount(None, None, None, None, None, None, None, None, None, None) - x.load_state_dict(state_dict) - return x diff --git a/lib/spaces/__init__.py b/lib/spaces/__init__.py new file mode 100644 index 0000000..792d600 --- /dev/null +++ b/lib/spaces/__init__.py @@ -0,0 +1 @@ +#