fix bugs in RANDOM-NAS and BOHB
This commit is contained in:
		| @@ -104,14 +104,19 @@ class NASBench102API(object): | ||||
|       print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index)) | ||||
|       return None | ||||
|  | ||||
|   def query_by_index(self, arch_index, dataname, use_12epochs_result=False): | ||||
|   # query information with the training of 12 epochs or 200 epochs | ||||
|   # if dataname is None, return the ArchResults | ||||
|   # else, return a dict with all trials on that dataset (the key is the seed) | ||||
|   def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
|     assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr) | ||||
|     archInfo = copy.deepcopy( arch2infos[ arch_index ] ) | ||||
|     assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname) | ||||
|     info = archInfo.query(dataname) | ||||
|     return info | ||||
|     if dataname is None: return archInfo | ||||
|     else: | ||||
|       assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname) | ||||
|       info = archInfo.query(dataname) | ||||
|       return info | ||||
|  | ||||
|   def query_meta_info_by_index(self, arch_index, use_12epochs_result=False): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
| @@ -266,7 +271,7 @@ class ArchResults(object): | ||||
|   def query(self, dataset, seed=None): | ||||
|     if seed is None: | ||||
|       x_seeds = self.dataset_seed[dataset] | ||||
|       return [self.all_results[ (dataset, seed) ] for seed in x_seeds] | ||||
|       return {seed: self.all_results[ (dataset, seed) ] for seed in x_seeds} | ||||
|     else: | ||||
|       return self.all_results[ (dataset, seed) ] | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user