update NAS-Bench-102
This commit is contained in:
		| @@ -78,6 +78,16 @@ class NASBench102API(object): | ||||
|       else                                 : arch_index = -1 | ||||
|     else: arch_index = -1 | ||||
|     return arch_index | ||||
|  | ||||
|   def reload(self, archive_root, index): | ||||
|     assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root) | ||||
|     xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index)) | ||||
|     assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index) | ||||
|     assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path) | ||||
|     xdata = torch.load(xfile_path) | ||||
|     assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path) | ||||
|     self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] ) | ||||
|     self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] ) | ||||
|    | ||||
|   def query_by_arch(self, arch, use_12epochs_result=False): | ||||
|     if isinstance(arch, int): | ||||
| @@ -125,10 +135,18 @@ class NASBench102API(object): | ||||
|         best_index, highest_accuracy = idx, accuracy | ||||
|     return best_index | ||||
|  | ||||
|   # return the topology structure of the `index`-th architecture | ||||
|   def arch(self, index): | ||||
|     assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) | ||||
|     return copy.deepcopy(self.meta_archs[index]) | ||||
|  | ||||
|   # obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed` | ||||
|   def get_net_param(self, index, dataset, seed, use_12epochs_result=False): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
|     archresult = arch2infos[index] | ||||
|     return archresult.get_net_param(dataset, seed) | ||||
|  | ||||
|   def get_more_info(self, index, dataset, use_12epochs_result=False): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
| @@ -238,6 +256,13 @@ class ArchResults(object): | ||||
|   def get_dataset_names(self): | ||||
|     return list(self.dataset_seed.keys()) | ||||
|  | ||||
|   def get_net_param(self, dataset, seed=None): | ||||
|     if seed is None: | ||||
|       x_seeds = self.dataset_seed[dataset] | ||||
|       return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds} | ||||
|     else: | ||||
|       return self.all_results[(dataset, seed)].get_net_param() | ||||
|  | ||||
|   def query(self, dataset, seed=None): | ||||
|     if seed is None: | ||||
|       x_seeds = self.dataset_seed[dataset] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user