Support accumulate and reset time function for API
This commit is contained in:
		| @@ -58,6 +58,7 @@ class NASBench201API(NASBenchMetaAPI): | ||||
|   def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, | ||||
|                verbose: bool=True): | ||||
|     self.filename = None | ||||
|     self.reset_time() | ||||
|     if file_path_or_dict is None: | ||||
|       file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1]) | ||||
|       print ('Try to use the default NAS-Bench-201 path from {:}.'.format(file_path_or_dict)) | ||||
|   | ||||
| @@ -57,6 +57,7 @@ class NASBench301API(NASBenchMetaAPI): | ||||
|   """ The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """ | ||||
|   def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, verbose: bool=True): | ||||
|     self.filename = None | ||||
|     self.reset_time() | ||||
|     if file_path_or_dict is None: | ||||
|       file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1]) | ||||
|     if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path): | ||||
| @@ -128,7 +129,7 @@ class NASBench301API(NASBenchMetaAPI): | ||||
|     """ | ||||
|     if self.verbose: | ||||
|       print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp)) | ||||
|     self._query_info_str_by_arch(arch, hp, print_information) | ||||
|     return self._query_info_str_by_arch(arch, hp, print_information) | ||||
|  | ||||
|   def get_more_info(self, index: int, dataset: Text, iepoch=None, hp='12', is_random=True): | ||||
|     """This function will return the metric for the `index`-th architecture | ||||
|   | ||||
| @@ -61,6 +61,25 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): | ||||
|   def avaliable_hps(self): | ||||
|     return list(copy.deepcopy(self._avaliable_hps)) | ||||
|  | ||||
|   @property | ||||
|   def used_time(self): | ||||
|     return self._used_time | ||||
|  | ||||
|   def reset_time(self): | ||||
|     self._used_time = 0 | ||||
|  | ||||
|   def simulate_train_eval(self, arch, dataset, hp='12'): | ||||
|     index = self.query_index_by_arch(arch) | ||||
|     all_names = ('cifar10', 'cifar100', 'ImageNet16-120') | ||||
|     assert dataset in all_names, 'Invalid dataset name : {:} vs {:}'.format(dataset, all_names) | ||||
|     if dataset == 'cifar10': | ||||
|       info = self.get_more_info(index, 'cifar10-valid', iepoch=None, hp=hp, is_random=True) | ||||
|     else: | ||||
|       info = self.get_more_info(index, dataset, iepoch=None, hp=hp, is_random=True) | ||||
|     valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] | ||||
|     self._used_time += time_cost | ||||
|     return valid_acc, time_cost, self._used_time | ||||
|  | ||||
|   def random(self): | ||||
|     """Return a random index of all architectures.""" | ||||
|     return random.randint(0, len(self.meta_archs)-1) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user