update NAS-Bench-102
This commit is contained in:
		
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -115,3 +115,6 @@ GPU-*.sh | ||||
| cal.sh | ||||
| aaa | ||||
| cx.sh | ||||
|  | ||||
| NAS-Bench-102-v1_0.pth | ||||
| lib/NAS-Bench-102-v1_0.pth | ||||
|   | ||||
| @@ -6,11 +6,16 @@ Each edge here is associated with an operation selected from a predefined operat | ||||
| For it to be applicable for all NAS algorithms, the search space defined in NAS-Bench-102 includes 4 nodes and 5 associated operation options, which generates 15,625 neural cell candidates in total. | ||||
|  | ||||
| In this Markdown file, we provide: | ||||
| - Detailed instruction to reproduce NAS-Bench-102. | ||||
| - 10 NAS algorithms evaluated in our paper. | ||||
| - [How to Use NAS-Bench-102](#how-to-use-nas-bench-102) | ||||
| - [Instruction to re-generate NAS-Bench-102](#instruction-to-re-generate-nas-bench-102) | ||||
| - [10 NAS algorithms evaluated in our paper](#to-reproduce-10-baseline-nas-algorithms-in-nas-bench-102) | ||||
|  | ||||
| Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`. | ||||
|  | ||||
| The data file of NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan]. | ||||
|  | ||||
| The training and evaluation data used in NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan]. | ||||
|  | ||||
| ## How to Use NAS-Bench-102 | ||||
|  | ||||
| 1. Creating an API instance from a file: | ||||
| @@ -35,8 +40,8 @@ api.show(2) | ||||
|  | ||||
| # show the mean loss and accuracy of an architecture | ||||
| info = api.query_meta_info_by_index(1) | ||||
| loss, accuracy = info.get_metrics('cifar10', 'train') | ||||
| flops, params, latency = info.get_comput_costs('cifar100') | ||||
| res_metrics = info.get_metrics('cifar10', 'train') | ||||
| cost_metrics = info.get_comput_costs('cifar100') | ||||
|  | ||||
| # get the detailed information | ||||
| results = api.query_by_index(1, 'cifar100') | ||||
| @@ -55,7 +60,8 @@ index = api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1 | ||||
| api.show(index) | ||||
| ``` | ||||
|  | ||||
| 5. For other usages, please see `lib/aa_nas_api/api.py` | ||||
| 5. For other usages, please see `lib/nas_102_api/api.py` | ||||
|  | ||||
|  | ||||
| ### Detailed Instruction | ||||
|  | ||||
| @@ -98,8 +104,10 @@ print(archRes.get_metrics('cifar10-valid', 'x-valid', None,  True)) # print loss | ||||
| ``` | ||||
| from nas_102_api import NASBench102API as API | ||||
| api = API('NAS-Bench-102-v1_0.pth') | ||||
| api.show(-1)  # show info of all architectures | ||||
| ``` | ||||
|  | ||||
|  | ||||
| ## Instruction to Re-Generate NAS-Bench-102 | ||||
|  | ||||
| 1. generate the meta file for NAS-Bench-102 using the following script, where `NAS-BENCH-102` indicates the name and `4` indicates the maximum number of nodes in a cell. | ||||
| @@ -139,6 +147,7 @@ CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NAS-Bench-102/train-a-net.sh resnet | ||||
| CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NAS-Bench-102/train-a-net.sh '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|skip_connect~1|skip_connect~2|' 16 5 | ||||
| ``` | ||||
|  | ||||
|  | ||||
| ## To Reproduce 10 Baseline NAS Algorithms in NAS-Bench-102 | ||||
|  | ||||
| We have tried our best to implement each method. However, still, some algorithms might obtain non-optimal results since their hyper-parameters might not fit our NAS-Bench-102. | ||||
|   | ||||
| @@ -1,6 +1,8 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| ################################################################################# | ||||
| # NAS-Bench-102: Extending the Scope of Reproducible Neural Architecture Search # | ||||
| ################################################################################# | ||||
| import os, sys, copy, random, torch, numpy as np | ||||
| from collections import OrderedDict, defaultdict | ||||
|  | ||||
| @@ -12,19 +14,21 @@ def print_information(information, extra_info=None, show=False): | ||||
|     return 'loss = {:.3f}, top1 = {:.2f}%'.format(loss, acc) | ||||
|  | ||||
|   for ida, dataset in enumerate(dataset_names): | ||||
|     flop, param, latency = information.get_comput_costs(dataset) | ||||
|     str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency > 0 else None) | ||||
|     train_loss, train_acc = information.get_metrics(dataset, 'train') | ||||
|     #flop, param, latency = information.get_comput_costs(dataset) | ||||
|     metric = information.get_comput_costs(dataset) | ||||
|     flop, param, latency = metric['flops'], metric['params'], metric['latency'] | ||||
|     str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None) | ||||
|     train_info = information.get_metrics(dataset, 'train') | ||||
|     if dataset == 'cifar10-valid': | ||||
|       valid_loss, valid_acc = information.get_metrics(dataset, 'x-valid') | ||||
|       str2 = '{:14s} train : [{:}], valid : [{:}]'.format(dataset, metric2str(train_loss, train_acc), metric2str(valid_loss, valid_acc)) | ||||
|       valid_info = information.get_metrics(dataset, 'x-valid') | ||||
|       str2 = '{:14s} train : [{:}], valid : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy'])) | ||||
|     elif dataset == 'cifar10': | ||||
|       test__loss, test__acc = information.get_metrics(dataset, 'ori-test') | ||||
|       str2 = '{:14s} train : [{:}], test  : [{:}]'.format(dataset, metric2str(train_loss, train_acc), metric2str(test__loss, test__acc)) | ||||
|       test__info = information.get_metrics(dataset, 'ori-test') | ||||
|       str2 = '{:14s} train : [{:}], test  : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy'])) | ||||
|     else: | ||||
|       valid_loss, valid_acc = information.get_metrics(dataset, 'x-valid') | ||||
|       test__loss, test__acc = information.get_metrics(dataset, 'x-test') | ||||
|       str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_loss, train_acc), metric2str(valid_loss, valid_acc), metric2str(test__loss, test__acc)) | ||||
|       valid_info = information.get_metrics(dataset, 'x-valid') | ||||
|       test__info = information.get_metrics(dataset, 'x-test') | ||||
|       str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy'])) | ||||
|     strings += [str1, str2] | ||||
|   if show: print('\n'.join(strings)) | ||||
|   return strings | ||||
| @@ -34,19 +38,21 @@ class NASBench102API(object): | ||||
|  | ||||
|   def __init__(self, file_path_or_dict, verbose=True): | ||||
|     if isinstance(file_path_or_dict, str): | ||||
|       if verbose: print('try to create NAS-Bench-102 api from {:}'.format(file_path_or_dict)) | ||||
|       if verbose: print('try to create the NAS-Bench-102 api from {:}'.format(file_path_or_dict)) | ||||
|       assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict) | ||||
|       file_path_or_dict = torch.load(file_path_or_dict) | ||||
|     else: | ||||
|       file_path_or_dict = copy.deepcopy( file_path_or_dict ) | ||||
|     assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict)) | ||||
|     import pdb; pdb.set_trace() # we will update this api soon | ||||
|     keys = ('meta_archs', 'arch2infos', 'evaluated_indexes') | ||||
|     for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key) | ||||
|     self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] ) | ||||
|     self.arch2infos = OrderedDict() | ||||
|     self.arch2infos_less = OrderedDict() | ||||
|     self.arch2infos_full = OrderedDict() | ||||
|     for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())): | ||||
|       self.arch2infos[xkey] = ArchResults.create_from_state_dict( file_path_or_dict['arch2infos'][xkey] ) | ||||
|       all_info = file_path_or_dict['arch2infos'][xkey] | ||||
|       self.arch2infos_less[xkey] = ArchResults.create_from_state_dict( all_info['less'] ) | ||||
|       self.arch2infos_full[xkey] = ArchResults.create_from_state_dict( all_info['full'] ) | ||||
|     self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes'])) | ||||
|     self.archstr2index = {} | ||||
|     for idx, arch in enumerate(self.meta_archs): | ||||
| @@ -73,35 +79,46 @@ class NASBench102API(object): | ||||
|     else: arch_index = -1 | ||||
|     return arch_index | ||||
|    | ||||
|   def query_by_arch(self, arch): | ||||
|     arch_index = self.query_index_by_arch(arch) | ||||
|     if arch_index == -1: return None | ||||
|     if arch_index in self.arch2infos: | ||||
|       strings = print_information(self.arch2infos[ arch_index ], 'arch-index={:}'.format(arch_index)) | ||||
|   def query_by_arch(self, arch, use_12epochs_result=False): | ||||
|     if isinstance(arch, int): | ||||
|       arch_index = arch | ||||
|     else: | ||||
|       arch_index = self.query_index_by_arch(arch) | ||||
|     if arch_index == -1: return None # the following two lines are used to support few training epochs | ||||
|     if use_12epochs_result: arch2infos = self.arch2infos_less | ||||
|     else                  : arch2infos = self.arch2infos_full | ||||
|     if arch_index in arch2infos: | ||||
|       strings = print_information(arch2infos[ arch_index ], 'arch-index={:}'.format(arch_index)) | ||||
|       return '\n'.join(strings) | ||||
|     else: | ||||
|       print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index)) | ||||
|       return None | ||||
|  | ||||
|   def query_by_index(self, arch_index, dataname): | ||||
|     assert arch_index in self.arch2infos, 'arch_index [{:}] does not in arch2info'.format(arch_index) | ||||
|     archInfo = copy.deepcopy( self.arch2infos[ arch_index ] ) | ||||
|   def query_by_index(self, arch_index, dataname, use_12epochs_result=False): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
|     assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr) | ||||
|     archInfo = copy.deepcopy( arch2infos[ arch_index ] ) | ||||
|     assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname) | ||||
|     info = archInfo.query(dataname) | ||||
|     return info | ||||
|  | ||||
|   def query_meta_info_by_index(self, arch_index): | ||||
|     assert arch_index in self.arch2infos, 'arch_index [{:}] does not in arch2info'.format(arch_index) | ||||
|     archInfo = copy.deepcopy( self.arch2infos[ arch_index ] ) | ||||
|   def query_meta_info_by_index(self, arch_index, use_12epochs_result=False): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
|     assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr) | ||||
|     archInfo = copy.deepcopy( arch2infos[ arch_index ] ) | ||||
|     return archInfo | ||||
|  | ||||
|   def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None): | ||||
|   def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, use_12epochs_result=False): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
|     best_index, highest_accuracy = -1, None | ||||
|     for i, idx in enumerate(self.evaluated_indexes): | ||||
|       flop, param, latency = self.arch2infos[idx].get_comput_costs(dataset) | ||||
|       flop, param, latency = arch2infos[idx].get_comput_costs(dataset) | ||||
|       if FLOP_max  is not None and flop  > FLOP_max : continue | ||||
|       if Param_max is not None and param > Param_max: continue | ||||
|       loss, accuracy = self.arch2infos[idx].get_metrics(dataset, metric_on_set) | ||||
|       loss, accuracy = arch2infos[idx].get_metrics(dataset, metric_on_set) | ||||
|       if best_index == -1: | ||||
|         best_index, highest_accuracy = idx, accuracy | ||||
|       elif highest_accuracy < accuracy: | ||||
| @@ -113,21 +130,29 @@ class NASBench102API(object): | ||||
|     return copy.deepcopy(self.meta_archs[index]) | ||||
|  | ||||
|   def show(self, index=-1): | ||||
|     if index == -1: # show all architectures | ||||
|     if index < 0: # show all architectures | ||||
|       print(self) | ||||
|       for i, idx in enumerate(self.evaluated_indexes): | ||||
|         print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10) | ||||
|         print('arch : {:}'.format(self.meta_archs[idx])) | ||||
|         strings = print_information(self.arch2infos[idx]) | ||||
|         print('>' * 20) | ||||
|         strings = print_information(self.arch2infos_full[idx]) | ||||
|         print('>' * 40 + ' 200 epochs ' + '>' * 40) | ||||
|         print('\n'.join(strings)) | ||||
|         print('<' * 20) | ||||
|         strings = print_information(self.arch2infos_less[idx]) | ||||
|         print('>' * 40 + '  12 epochs ' + '>' * 40) | ||||
|         print('\n'.join(strings)) | ||||
|         print('<' * 40 + '------------' + '<' * 40) | ||||
|     else: | ||||
|       if 0 <= index < len(self.meta_archs): | ||||
|         if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated or not saved.'.format(index)) | ||||
|         else: | ||||
|           strings = print_information(self.arch2infos[index]) | ||||
|           strings = print_information(self.arch2infos_full[index]) | ||||
|           print('>' * 40 + ' 200 epochs ' + '>' * 40) | ||||
|           print('\n'.join(strings)) | ||||
|           strings = print_information(self.arch2infos_less[index]) | ||||
|           print('>' * 40 + '  12 epochs ' + '>' * 40) | ||||
|           print('\n'.join(strings)) | ||||
|           print('<' * 40 + '------------' + '<' * 40) | ||||
|       else: | ||||
|         print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs))) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user