simplify baselines
This commit is contained in:
		| @@ -41,8 +41,9 @@ class NASBench102API(object): | ||||
|       if verbose: print('try to create the NAS-Bench-102 api from {:}'.format(file_path_or_dict)) | ||||
|       assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict) | ||||
|       file_path_or_dict = torch.load(file_path_or_dict) | ||||
|     else: | ||||
|     elif isinstance(file_path_or_dict, dict): | ||||
|       file_path_or_dict = copy.deepcopy( file_path_or_dict ) | ||||
|     else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict))) | ||||
|     assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict)) | ||||
|     keys = ('meta_archs', 'arch2infos', 'evaluated_indexes') | ||||
|     for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key) | ||||
| @@ -152,26 +153,40 @@ class NASBench102API(object): | ||||
|     archresult = arch2infos[index] | ||||
|     return archresult.get_net_param(dataset, seed) | ||||
|  | ||||
|   def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False): | ||||
|   # obtain the metric for the `index`-th architecture | ||||
|   def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
|     archresult = arch2infos[index] | ||||
|     if dataset == 'cifar10-valid': | ||||
|       train_info = archresult.get_metrics(dataset, 'train'   , iepoch=iepoch, is_random=True) | ||||
|       valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=True) | ||||
|       test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=True) | ||||
|       train_info = archresult.get_metrics(dataset, 'train'   , iepoch=iepoch, is_random=is_random) | ||||
|       valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=is_random) | ||||
|       try: | ||||
|         test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) | ||||
|       except: | ||||
|         test__info = None | ||||
|       total      = train_info['iepoch'] + 1 | ||||
|       return {'train-loss'    : train_info['loss'], | ||||
|       xifo = {'train-loss'    : train_info['loss'], | ||||
|               'train-accuracy': train_info['accuracy'], | ||||
|               'train-all-time': train_info['all_time'], | ||||
|               'valid-loss'    : valid_info['loss'], | ||||
|               'valid-accuracy': valid_info['accuracy'], | ||||
|               'valid-all-time': valid_info['all_time'], | ||||
|               'valid-per-time': valid_info['all_time'] / total, | ||||
|               'valid-per-time': None if valid_info['all_time'] is None else valid_info['all_time'] / total} | ||||
|       if test__info is not None: | ||||
|         xifo['test-loss']     = test__info['loss'] | ||||
|         xifo['test-accuracy'] = test__info['accuracy'] | ||||
|       return xifo | ||||
|     else: | ||||
|       train_info = archresult.get_metrics(dataset, 'train'   , iepoch=iepoch, is_random=is_random) | ||||
|       if dataset == 'cifar10': | ||||
|         test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) | ||||
|       else: | ||||
|         test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) | ||||
|       return {'train-loss'    : train_info['loss'], | ||||
|               'train-accuracy': train_info['accuracy'], | ||||
|               'test-loss'     : test__info['loss'], | ||||
|               'test-accuracy' : test__info['accuracy']} | ||||
|     else: | ||||
|       raise ValueError('coming soon...') | ||||
|  | ||||
|   def show(self, index=-1): | ||||
|     if index < 0: # show all architectures | ||||
| @@ -369,7 +384,7 @@ class ResultsCount(object): | ||||
|   def update_latency(self, latency): | ||||
|     self.latency = copy.deepcopy( latency ) | ||||
|  | ||||
|   def update_eval(self, accs, losses, times): # old version | ||||
|   def update_eval(self, accs, losses, times):  # new version | ||||
|     data_names = set([x.split('@')[0] for x in accs.keys()]) | ||||
|     for data_name in data_names: | ||||
|       assert data_name not in self.eval_names, '{:} has already been added into eval-names'.format(data_name) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user