Update the query_by_arch function in API to be compatiable with the submission version of NAS-Bench-201
This commit is contained in:
		| @@ -53,7 +53,7 @@ def evaluate(api, weight_dir, data: str): | ||||
|     config = api.get_net_config(arch_index, data) | ||||
|     net = get_cell_based_tiny_net(config) | ||||
|     meta_info = api.query_meta_info_by_index(arch_index, hp='200' if isinstance(api, NASBench201API) else '90') | ||||
|     params = meta_info.get_net_param(data, 777) | ||||
|     params = meta_info.get_net_param(data, 888 if isinstance(api, NASBench201API) else 777) | ||||
|     with torch.no_grad(): | ||||
|       net.load_state_dict(params) | ||||
|       _, summary = weight_watcher.analyze(net, alphas=False) | ||||
|   | ||||
| @@ -90,6 +90,10 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): | ||||
|     else: arch_index = -1 | ||||
|     return arch_index | ||||
|  | ||||
|   def query_by_arch(self, arch, hp): | ||||
|     # This is to make the current version be compatible with the old version. | ||||
|     return self.query_info_str_by_arch(arch, hp) | ||||
|  | ||||
|   @abc.abstractmethod | ||||
|   def reload(self, archive_root: Text = None, index: int = None): | ||||
|     """Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'. | ||||
|   | ||||
		Reference in New Issue
	
	Block a user