Upgrade NAS-Bench-201 to APIv1.3/FILEv1.1
This commit is contained in:
		| @@ -5,4 +5,5 @@ from .api import NASBench201API | ||||
| from .api import ArchResults, ResultsCount | ||||
|  | ||||
| # NAS_BENCH_201_API_VERSION="v1.1"  # [2020.02.25] | ||||
| NAS_BENCH_201_API_VERSION="v1.2"  # [2020.03.09] | ||||
| # NAS_BENCH_201_API_VERSION="v1.2"  # [2020.03.09] | ||||
| NAS_BENCH_201_API_VERSION="v1.3"  # [2020.03.16] | ||||
|   | ||||
| @@ -3,11 +3,14 @@ | ||||
| ############################################################################################ | ||||
| # NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 # | ||||
| ############################################################################################ | ||||
| # The history of benchmark files: | ||||
| # [2020.02.25] NAS-Bench-201-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID. | ||||
| # [2020.03.08] Next version (coming soon) | ||||
| # [2020.03.16] NAS-Bench-201-v1_1-096897.pth : 2225 architectures are trained once, 5439 archiitectures are trained twice, 7961 architectures are trained three times on all training sets. For the hyper-parameters with the total epochs of 12, each model is trained on CIFAR-10, CIFAR-100, ImageNet16-120 once, and is trained on CIFAR-10-VALID twice. | ||||
| # | ||||
| # I'm still actively enhancing this benchmark. Please feel free to contact me if you have any question w.r.t. NAS-Bench-201. | ||||
| # | ||||
| import os, copy, random, torch, numpy as np | ||||
| from pathlib import Path | ||||
| from typing import List, Text, Union, Dict | ||||
| from collections import OrderedDict, defaultdict | ||||
|  | ||||
| @@ -44,9 +47,12 @@ class NASBench201API(object): | ||||
|  | ||||
|   """ The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """ | ||||
|   def __init__(self, file_path_or_dict: Union[Text, Dict], verbose: bool=True): | ||||
|     if isinstance(file_path_or_dict, str): | ||||
|     self.filename = None | ||||
|     if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path): | ||||
|       file_path_or_dict = str(file_path_or_dict) | ||||
|       if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict)) | ||||
|       assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict) | ||||
|       self.filename = Path(file_path_or_dict).name | ||||
|       file_path_or_dict = torch.load(file_path_or_dict) | ||||
|     elif isinstance(file_path_or_dict, dict): | ||||
|       file_path_or_dict = copy.deepcopy( file_path_or_dict ) | ||||
| @@ -76,7 +82,7 @@ class NASBench201API(object): | ||||
|     return len(self.meta_archs) | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs))) | ||||
|     return ('{name}({num}/{total} architectures, file={filename})'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs), filename=self.filename)) | ||||
|  | ||||
|   def random(self): | ||||
|     """Return a random index of all architectures.""" | ||||
| @@ -98,9 +104,10 @@ class NASBench201API(object): | ||||
|     else: arch_index = -1 | ||||
|     return arch_index | ||||
|  | ||||
|   # Overwrite all information of the 'index'-th architecture in the search space. | ||||
|   # It will load its data from 'archive_root'. | ||||
|   def reload(self, archive_root: Text, index: int): | ||||
|     """Overwrite all information of the 'index'-th architecture in the search space. | ||||
|          It will load its data from 'archive_root'. | ||||
|     """ | ||||
|     assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root) | ||||
|     xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index)) | ||||
|     assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index) | ||||
| @@ -109,6 +116,13 @@ class NASBench201API(object): | ||||
|     assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path) | ||||
|     self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] ) | ||||
|     self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] ) | ||||
|  | ||||
|   def clear_params(self, index: int, use_12epochs_result: bool): | ||||
|     """Remove the architecture's weights to save memory.""" | ||||
|     if use_12epochs_result: arch2infos = self.arch2infos_less | ||||
|     else                  : arch2infos = self.arch2infos_full | ||||
|     archresult = arch2infos[index] | ||||
|     archresult.clear_params() | ||||
|    | ||||
|   # This function is used to query the information of a specific archiitecture | ||||
|   # 'arch' can be an architecture index or an architecture string | ||||
| @@ -162,6 +176,7 @@ class NASBench201API(object): | ||||
|     return archInfo | ||||
|  | ||||
|   def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, use_12epochs_result=False): | ||||
|     """Find the architecture with the highest accuracy based on some constraints.""" | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
|     best_index, highest_accuracy = -1, None | ||||
| @@ -255,6 +270,65 @@ class NASBench201API(object): | ||||
|   # `is_random` | ||||
|   #   When is_random=True, the performance of a random architecture will be returned | ||||
|   #   When is_random=False, the performanceo of all trials will be averaged. | ||||
|   def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
|     archresult = arch2infos[index] | ||||
|     # if randomly select one trial, select the seed at first | ||||
|     if isinstance(is_random, bool) and is_random: | ||||
|       seeds = archresult.get_dataset_seeds(dataset) | ||||
|       is_random = random.choice(seeds) | ||||
|     # collect the training information | ||||
|     train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random) | ||||
|     total = train_info['iepoch'] + 1 | ||||
|     xinfo = {'train-loss'    : train_info['loss'], | ||||
|              'train-accuracy': train_info['accuracy'], | ||||
|              'train-per-time': train_info['all_time'] / total, | ||||
|              'train-all-time': train_info['all_time']} | ||||
|     # collect the evaluation information | ||||
|     if dataset == 'cifar10-valid': | ||||
|       valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) | ||||
|       try: | ||||
|         test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) | ||||
|       except: | ||||
|         test_info = None | ||||
|       valtest_info = None | ||||
|     else: | ||||
|       try: # collect results on the proposed test set | ||||
|         if dataset == 'cifar10': | ||||
|           test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) | ||||
|         else: | ||||
|           test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) | ||||
|       except: | ||||
|         test_info = None | ||||
|       try: # collect results on the proposed validation set | ||||
|         valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) | ||||
|       except: | ||||
|         valid_info = None | ||||
|       try: | ||||
|         if dataset != 'cifar10': | ||||
|           valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) | ||||
|         else: | ||||
|           valtest_info = None | ||||
|       except: | ||||
|         valtest_info = None | ||||
|     if valid_info is not None: | ||||
|       xinfo['valid-loss'] = valid_info['loss'] | ||||
|       xinfo['valid-accuracy'] = valid_info['accuracy'] | ||||
|       xinfo['valid-per-time'] = valid_info['all_time'] / total | ||||
|       xinfo['valid-all-time'] = valid_info['all_time'] | ||||
|     if test_info is not None: | ||||
|       xinfo['test-loss'] = test_info['loss'] | ||||
|       xinfo['test-accuracy'] = test_info['accuracy'] | ||||
|       xinfo['test-per-time'] = test_info['all_time'] / total | ||||
|       xinfo['test-all-time'] = test_info['all_time'] | ||||
|     if valtest_info is not None: | ||||
|       xinfo['valtest-loss'] = valtest_info['loss'] | ||||
|       xinfo['valtest-accuracy'] = valtest_info['accuracy'] | ||||
|       xinfo['valtest-per-time'] = valtest_info['all_time'] / total | ||||
|       xinfo['valtest-all-time'] = valtest_info['all_time'] | ||||
|     return xinfo | ||||
|   """ # The following logic is deprecated after March 15 2020, where the benchmark file upgrades from NAS-Bench-201-v1_0-e61699.pth to NAS-Bench-201-v1_1-096897.pth. | ||||
|   def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
| @@ -312,6 +386,7 @@ class NASBench201API(object): | ||||
|         xifo['est-valid-loss'] = est_valid_info['loss'] | ||||
|         xifo['est-valid-accuracy'] = est_valid_info['accuracy'] | ||||
|       return xifo | ||||
|   """ | ||||
|  | ||||
|  | ||||
|   def show(self, index: int = -1) -> None: | ||||
| @@ -349,6 +424,26 @@ class NASBench201API(object): | ||||
|         print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs))) | ||||
|  | ||||
|  | ||||
|   def statistics(self, dataset: Text, use_12epochs_result: bool) -> Dict[int, int]: | ||||
|     """ | ||||
|     This function will count the number of total trials. | ||||
|     """ | ||||
|     valid_datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'] | ||||
|     if dataset not in valid_datasets: | ||||
|       raise ValueError('{:} not in {:}'.format(dataset, valid_datasets)) | ||||
|     if use_12epochs_result: arch2infos = self.arch2infos_less | ||||
|     else                  : arch2infos = self.arch2infos_full | ||||
|     nums = defaultdict(lambda: 0) | ||||
|     for index in range(len(self)): | ||||
|       archInfo = arch2infos[index] | ||||
|       dataset_seed = archInfo.dataset_seed | ||||
|       if dataset not in dataset_seed: | ||||
|         nums[0] += 1 | ||||
|       else: | ||||
|         nums[len(dataset_seed[dataset])] += 1 | ||||
|     return dict(nums) | ||||
|  | ||||
|  | ||||
|   @staticmethod | ||||
|   def str2lists(arch_str: Text) -> List[tuple]: | ||||
|     """ | ||||
|   | ||||
		Reference in New Issue
	
	Block a user