rm PD ; update NAS-Bench-102 baselines
This commit is contained in:
		| @@ -36,6 +36,7 @@ def get_cell_based_tiny_net(config): | ||||
| def get_search_spaces(xtype, name): | ||||
|   if xtype == 'cell': | ||||
|     from .cell_operations import SearchSpaceNames | ||||
|     assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys()) | ||||
|     return SearchSpaceNames[name] | ||||
|   else: | ||||
|     raise ValueError('invalid search-space type is {:}'.format(xtype)) | ||||
|   | ||||
| @@ -16,12 +16,13 @@ OPS = { | ||||
|   'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine), | ||||
| } | ||||
|  | ||||
| CONNECT_NAS_BENCHMARK  = ['none', 'skip_connect', 'nor_conv_3x3'] | ||||
| AA_NAS_BENCHMARK       = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] | ||||
| CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3'] | ||||
| NAS_BENCH_102         = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] | ||||
|  | ||||
| SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK, | ||||
|                     'aa-nas'      : AA_NAS_BENCHMARK, | ||||
|                     'full'        : sorted(list(OPS.keys()))} | ||||
| SearchSpaceNames = {'connect-nas'  : CONNECT_NAS_BENCHMARK, | ||||
|                     'aa-nas'       : NAS_BENCH_102, | ||||
|                     'nas-bench-102': NAS_BENCH_102, | ||||
|                     'full'         : sorted(list(OPS.keys()))} | ||||
|  | ||||
|  | ||||
| class ReLUConvBN(nn.Module): | ||||
|   | ||||
| @@ -129,6 +129,27 @@ class NASBench102API(object): | ||||
|     assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) | ||||
|     return copy.deepcopy(self.meta_archs[index]) | ||||
|  | ||||
|   def get_more_info(self, index, dataset, use_12epochs_result=False): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
|     archresult = arch2infos[index] | ||||
|     if dataset == 'cifar10-valid': | ||||
|       train_info = archresult.get_metrics(dataset, 'train', is_random=True) | ||||
|       valid_info = archresult.get_metrics(dataset, 'x-valid', is_random=True) | ||||
|       test__info = archresult.get_metrics(dataset, 'ori-test', is_random=True) | ||||
|       total      = train_info['iepoch'] + 1 | ||||
|       return {'train-loss'    : train_info['loss'], | ||||
|               'train-accuracy': train_info['accuracy'], | ||||
|               'train-all-time': train_info['all_time'], | ||||
|               'valid-loss'    : valid_info['loss'], | ||||
|               'valid-accuracy': valid_info['accuracy'], | ||||
|               'valid-all-time': valid_info['all_time'], | ||||
|               'valid-per-time': valid_info['all_time'] / total, | ||||
|               'test-loss'     : test__info['loss'], | ||||
|               'test-accuracy' : test__info['accuracy']} | ||||
|     else: | ||||
|       raise ValueError('coming soon...') | ||||
|  | ||||
|   def show(self, index=-1): | ||||
|     if index < 0: # show all architectures | ||||
|       print(self) | ||||
| @@ -367,23 +388,28 @@ class ResultsCount(object): | ||||
|   def get_train(self, iepoch=None): | ||||
|     if iepoch is None: iepoch = self.epochs-1 | ||||
|     assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs) | ||||
|     if self.train_times is not None: xtime = self.train_times[iepoch] | ||||
|     else                           : xtime = None | ||||
|     if self.train_times is not None: | ||||
|       xtime = self.train_times[iepoch] | ||||
|       atime = sum([self.train_times[i] for i in range(iepoch+1)]) | ||||
|     else: xtime, atime = None, None | ||||
|     return {'iepoch'  : iepoch, | ||||
|             'loss'    : self.train_losses[iepoch], | ||||
|             'accuracy': self.train_acc1es[iepoch], | ||||
|             'time'    : xtime} | ||||
|             'cur_time': xtime, | ||||
|             'all_time': atime} | ||||
|  | ||||
|   def get_eval(self, name, iepoch=None): | ||||
|     if iepoch is None: iepoch = self.epochs-1 | ||||
|     assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs) | ||||
|     if isinstance(self.eval_times,dict) and len(self.eval_times) > 0: | ||||
|       xtime = self.eval_times['{:}@{:}'.format(name,iepoch)] | ||||
|     else: xtime = None | ||||
|       atime = sum([self.eval_times['{:}@{:}'.format(name,i)] for i in range(iepoch+1)]) | ||||
|     else: xtime, atime = None, None | ||||
|     return {'iepoch'  : iepoch, | ||||
|             'loss'    : self.eval_losses['{:}@{:}'.format(name,iepoch)], | ||||
|             'accuracy': self.eval_acc1es['{:}@{:}'.format(name,iepoch)], | ||||
|             'time'    : xtime} | ||||
|             'cur_time': xtime, | ||||
|             'all_time': atime} | ||||
|  | ||||
|   def get_net_param(self): | ||||
|     return self.net_state_dict | ||||
|   | ||||
		Reference in New Issue
	
	Block a user