rm PD ; update NAS-Bench-102 baselines
This commit is contained in:
		| @@ -129,6 +129,27 @@ class NASBench102API(object): | ||||
|     assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) | ||||
|     return copy.deepcopy(self.meta_archs[index]) | ||||
|  | ||||
|   def get_more_info(self, index, dataset, use_12epochs_result=False): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
|     archresult = arch2infos[index] | ||||
|     if dataset == 'cifar10-valid': | ||||
|       train_info = archresult.get_metrics(dataset, 'train', is_random=True) | ||||
|       valid_info = archresult.get_metrics(dataset, 'x-valid', is_random=True) | ||||
|       test__info = archresult.get_metrics(dataset, 'ori-test', is_random=True) | ||||
|       total      = train_info['iepoch'] + 1 | ||||
|       return {'train-loss'    : train_info['loss'], | ||||
|               'train-accuracy': train_info['accuracy'], | ||||
|               'train-all-time': train_info['all_time'], | ||||
|               'valid-loss'    : valid_info['loss'], | ||||
|               'valid-accuracy': valid_info['accuracy'], | ||||
|               'valid-all-time': valid_info['all_time'], | ||||
|               'valid-per-time': valid_info['all_time'] / total, | ||||
|               'test-loss'     : test__info['loss'], | ||||
|               'test-accuracy' : test__info['accuracy']} | ||||
|     else: | ||||
|       raise ValueError('coming soon...') | ||||
|  | ||||
|   def show(self, index=-1): | ||||
|     if index < 0: # show all architectures | ||||
|       print(self) | ||||
| @@ -367,23 +388,28 @@ class ResultsCount(object): | ||||
|   def get_train(self, iepoch=None): | ||||
|     if iepoch is None: iepoch = self.epochs-1 | ||||
|     assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs) | ||||
|     if self.train_times is not None: xtime = self.train_times[iepoch] | ||||
|     else                           : xtime = None | ||||
|     if self.train_times is not None: | ||||
|       xtime = self.train_times[iepoch] | ||||
|       atime = sum([self.train_times[i] for i in range(iepoch+1)]) | ||||
|     else: xtime, atime = None, None | ||||
|     return {'iepoch'  : iepoch, | ||||
|             'loss'    : self.train_losses[iepoch], | ||||
|             'accuracy': self.train_acc1es[iepoch], | ||||
|             'time'    : xtime} | ||||
|             'cur_time': xtime, | ||||
|             'all_time': atime} | ||||
|  | ||||
|   def get_eval(self, name, iepoch=None): | ||||
|     if iepoch is None: iepoch = self.epochs-1 | ||||
|     assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs) | ||||
|     if isinstance(self.eval_times,dict) and len(self.eval_times) > 0: | ||||
|       xtime = self.eval_times['{:}@{:}'.format(name,iepoch)] | ||||
|     else: xtime = None | ||||
|       atime = sum([self.eval_times['{:}@{:}'.format(name,i)] for i in range(iepoch+1)]) | ||||
|     else: xtime, atime = None, None | ||||
|     return {'iepoch'  : iepoch, | ||||
|             'loss'    : self.eval_losses['{:}@{:}'.format(name,iepoch)], | ||||
|             'accuracy': self.eval_acc1es['{:}@{:}'.format(name,iepoch)], | ||||
|             'time'    : xtime} | ||||
|             'cur_time': xtime, | ||||
|             'all_time': atime} | ||||
|  | ||||
|   def get_net_param(self): | ||||
|     return self.net_state_dict | ||||
|   | ||||
		Reference in New Issue
	
	Block a user