update NAS-Bench-102
This commit is contained in:
		| @@ -16,7 +16,8 @@ Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`. | |||||||
|  |  | ||||||
| The benchmark file of NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w). | The benchmark file of NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w). | ||||||
| You can move it to anywhere you want and send its path to our API for initialization. | You can move it to anywhere you want and send its path to our API for initialization. | ||||||
| - v1.0: `NAS-Bench-102-v1_0-e61699.pth`, where `e61699` is the last six digits for this file. | - v1.0: `NAS-Bench-102-v1_0-e61699.pth`, where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial. | ||||||
|  | - v1.0: The full data of each architecture can be download from [Google Drive](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights. | ||||||
|  |  | ||||||
| The training and evaluation data used in NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan (code:4fg7)](https://pan.baidu.com/s/1XAzavPKq3zcat1yBA1L2tQ). | The training and evaluation data used in NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan (code:4fg7)](https://pan.baidu.com/s/1XAzavPKq3zcat1yBA1L2tQ). | ||||||
| It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default). If you want to generate NAS-Bench-102 or similar NAS datasets or training models by yourself, you need these data. | It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default). If you want to generate NAS-Bench-102 or similar NAS datasets or training models by yourself, you need these data. | ||||||
| @@ -108,8 +109,12 @@ print(archRes.get_metrics('cifar10-valid', 'x-valid', None,  True)) # print loss | |||||||
| `NASBench102API` is the topest level api. Please see the following usages: | `NASBench102API` is the topest level api. Please see the following usages: | ||||||
| ``` | ``` | ||||||
| from nas_102_api import NASBench102API as API | from nas_102_api import NASBench102API as API | ||||||
| api = API('NAS-Bench-102-v1_0-e61699.pth') | api = API('NAS-Bench-102-v1_0-e61699.pth') # This will load all the information of NAS-Bench-102 except the trained weights | ||||||
|  | api = API('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-102-v1_0-e61699.pth')) # The same as the above line while I usually save NAS-Bench-102-v1_0-e61699.pth in ~/.torch/. | ||||||
| api.show(-1)  # show info of all architectures | api.show(-1)  # show info of all architectures | ||||||
|  | api.reload('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-BENCH-102-4-v1.0-archive'), 3) # This code will reload the information 3-th architecture with the trained weights | ||||||
|  |  | ||||||
|  | weights = api.get_net_param(3, 'cifar10', None) # Obtaining the weights of all trials for the 3-th architecture on cifar10. It will returns a dict, where the key is the seed and the value is the trained weights. | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										84
									
								
								exps/NAS-Bench-102/check.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								exps/NAS-Bench-102/check.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,84 @@ | |||||||
|  | ################################################## | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||||
|  | ################################################## | ||||||
|  | # python exps/NAS-Bench-102/check.py --base_save_dir  | ||||||
|  | ################################################## | ||||||
|  | import os, sys, time, argparse, collections | ||||||
|  | from shutil import copyfile | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | from pathlib import Path | ||||||
|  | from collections import defaultdict | ||||||
|  | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
|  | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
|  | from log_utils    import AverageMeter, time_string, convert_secs2time | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def check_files(save_dir, meta_file, basestr): | ||||||
|  |   meta_infos     = torch.load(meta_file, map_location='cpu') | ||||||
|  |   meta_archs     = meta_infos['archs'] | ||||||
|  |   meta_num_archs = meta_infos['total'] | ||||||
|  |   meta_max_node  = meta_infos['max_node'] | ||||||
|  |   assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs)) | ||||||
|  |  | ||||||
|  |   sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr)))) | ||||||
|  |   print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs))) | ||||||
|  |    | ||||||
|  |   subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0 | ||||||
|  |   num_seeds = defaultdict(lambda: 0) | ||||||
|  |   for index, sub_dir in enumerate(sub_model_dirs): | ||||||
|  |     xcheckpoints = list(sub_dir.glob('arch-*-seed-*.pth')) | ||||||
|  |     #xcheckpoints = list(sub_dir.glob('arch-*-seed-0777.pth')) + list(sub_dir.glob('arch-*-seed-0888.pth')) + list(sub_dir.glob('arch-*-seed-0999.pth')) | ||||||
|  |     arch_indexes = set() | ||||||
|  |     for checkpoint in xcheckpoints: | ||||||
|  |       temp_names = checkpoint.name.split('-') | ||||||
|  |       assert len(temp_names) == 4 and temp_names[0] == 'arch' and temp_names[2] == 'seed', 'invalid checkpoint name : {:}'.format(checkpoint.name) | ||||||
|  |       arch_indexes.add( temp_names[1] ) | ||||||
|  |     subdir2archs[sub_dir] = sorted(list(arch_indexes)) | ||||||
|  |     num_evaluated_arch   += len(arch_indexes) | ||||||
|  |     # count number of seeds for each architecture | ||||||
|  |     for arch_index in arch_indexes: | ||||||
|  |       num_seeds[ len(list(sub_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))) ] += 1 | ||||||
|  |   print('There are {:5d} architectures that have been evaluated ({:} in total, {:} ckps in total).'.format(num_evaluated_arch, meta_num_archs, sum(k*v for k, v in num_seeds.items()))) | ||||||
|  |   for key in sorted( list( num_seeds.keys() ) ): print ('There are {:5d} architectures that are evaluated {:} times.'.format(num_seeds[key], key)) | ||||||
|  |  | ||||||
|  |   dir2ckps, dir2ckp_exists = dict(), dict() | ||||||
|  |   start_time, epoch_time = time.time(), AverageMeter() | ||||||
|  |   for IDX, (sub_dir, arch_indexes) in enumerate(subdir2archs.items()): | ||||||
|  |     seeds = [777, 888, 999] | ||||||
|  |     numrs = defaultdict(lambda: 0) | ||||||
|  |     all_checkpoints, all_ckp_exists = [], [] | ||||||
|  |     for arch_index in arch_indexes: | ||||||
|  |       checkpoints = ['arch-{:}-seed-{:04d}.pth'.format(arch_index, seed) for seed in seeds] | ||||||
|  |       ckp_exists  = [(sub_dir/x).exists() for x in checkpoints] | ||||||
|  |       arch_index  = int(arch_index) | ||||||
|  |       assert 0 <= arch_index < len(meta_archs), 'invalid arch-index {:} (not found in meta_archs)'.format(arch_index) | ||||||
|  |       all_checkpoints += checkpoints | ||||||
|  |       all_ckp_exists  += ckp_exists | ||||||
|  |       numrs[sum(ckp_exists)] += 1 | ||||||
|  |     dir2ckps[ str(sub_dir) ]       = all_checkpoints | ||||||
|  |     dir2ckp_exists[ str(sub_dir) ] = all_ckp_exists | ||||||
|  |     # measure time | ||||||
|  |     epoch_time.update(time.time() - start_time) | ||||||
|  |     start_time = time.time() | ||||||
|  |     numrstr = ', '.join( ['{:}: {:03d}'.format(x, numrs[x]) for x in sorted(numrs.keys())] ) | ||||||
|  |     print('{:} load [{:2d}/{:2d}] [{:03d} archs] [{:04d}->{:04d} ckps] {:} done, need {:}. {:}'.format(time_string(), IDX+1, len(subdir2archs), len(arch_indexes), len(all_checkpoints), sum(all_ckp_exists), sub_dir, convert_secs2time(epoch_time.avg * (len(subdir2archs)-IDX-1), True), numrstr)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |  | ||||||
|  |   parser = argparse.ArgumentParser(description='NAS Benchmark 102', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||||
|  |   parser.add_argument('--base_save_dir',  type=str, default='./output/NAS-BENCH-102-4',     help='The base-name of folder to save checkpoints and log.') | ||||||
|  |   parser.add_argument('--max_node',       type=int, default=4,                                 help='The maximum node in a cell.') | ||||||
|  |   parser.add_argument('--channel',        type=int, default=16,                                help='The number of channels.') | ||||||
|  |   parser.add_argument('--num_cells',      type=int, default=5,                                 help='The number of cells in one stage.') | ||||||
|  |   args = parser.parse_args() | ||||||
|  |    | ||||||
|  |   save_dir  = Path( args.base_save_dir ) | ||||||
|  |   meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node) | ||||||
|  |   assert save_dir.exists(),  'invalid save dir path : {:}'.format(save_dir) | ||||||
|  |   assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path) | ||||||
|  |   print ('check NAS-Bench-102 in {:}'.format(save_dir)) | ||||||
|  |  | ||||||
|  |   basestr = 'C{:}-N{:}'.format(args.channel, args.num_cells) | ||||||
|  |   check_files(save_dir, meta_path, basestr) | ||||||
| @@ -78,6 +78,16 @@ class NASBench102API(object): | |||||||
|       else                                 : arch_index = -1 |       else                                 : arch_index = -1 | ||||||
|     else: arch_index = -1 |     else: arch_index = -1 | ||||||
|     return arch_index |     return arch_index | ||||||
|  |  | ||||||
|  |   def reload(self, archive_root, index): | ||||||
|  |     assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root) | ||||||
|  |     xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index)) | ||||||
|  |     assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index) | ||||||
|  |     assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path) | ||||||
|  |     xdata = torch.load(xfile_path) | ||||||
|  |     assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path) | ||||||
|  |     self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] ) | ||||||
|  |     self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] ) | ||||||
|    |    | ||||||
|   def query_by_arch(self, arch, use_12epochs_result=False): |   def query_by_arch(self, arch, use_12epochs_result=False): | ||||||
|     if isinstance(arch, int): |     if isinstance(arch, int): | ||||||
| @@ -125,10 +135,18 @@ class NASBench102API(object): | |||||||
|         best_index, highest_accuracy = idx, accuracy |         best_index, highest_accuracy = idx, accuracy | ||||||
|     return best_index |     return best_index | ||||||
|  |  | ||||||
|  |   # return the topology structure of the `index`-th architecture | ||||||
|   def arch(self, index): |   def arch(self, index): | ||||||
|     assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) |     assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) | ||||||
|     return copy.deepcopy(self.meta_archs[index]) |     return copy.deepcopy(self.meta_archs[index]) | ||||||
|  |  | ||||||
|  |   # obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed` | ||||||
|  |   def get_net_param(self, index, dataset, seed, use_12epochs_result=False): | ||||||
|  |     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||||
|  |     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||||
|  |     archresult = arch2infos[index] | ||||||
|  |     return archresult.get_net_param(dataset, seed) | ||||||
|  |  | ||||||
|   def get_more_info(self, index, dataset, use_12epochs_result=False): |   def get_more_info(self, index, dataset, use_12epochs_result=False): | ||||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less |     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full |     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||||
| @@ -238,6 +256,13 @@ class ArchResults(object): | |||||||
|   def get_dataset_names(self): |   def get_dataset_names(self): | ||||||
|     return list(self.dataset_seed.keys()) |     return list(self.dataset_seed.keys()) | ||||||
|  |  | ||||||
|  |   def get_net_param(self, dataset, seed=None): | ||||||
|  |     if seed is None: | ||||||
|  |       x_seeds = self.dataset_seed[dataset] | ||||||
|  |       return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds} | ||||||
|  |     else: | ||||||
|  |       return self.all_results[(dataset, seed)].get_net_param() | ||||||
|  |  | ||||||
|   def query(self, dataset, seed=None): |   def query(self, dataset, seed=None): | ||||||
|     if seed is None: |     if seed is None: | ||||||
|       x_seeds = self.dataset_seed[dataset] |       x_seeds = self.dataset_seed[dataset] | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user