Update the query_by_arch function in API to be compatiable with the submission version of NAS-Bench-201
This commit is contained in:
		| @@ -53,7 +53,7 @@ def evaluate(api, weight_dir, data: str): | ||||
|     config = api.get_net_config(arch_index, data) | ||||
|     net = get_cell_based_tiny_net(config) | ||||
|     meta_info = api.query_meta_info_by_index(arch_index, hp='200' if isinstance(api, NASBench201API) else '90') | ||||
|     params = meta_info.get_net_param(data, 777) | ||||
|     params = meta_info.get_net_param(data, 888 if isinstance(api, NASBench201API) else 777) | ||||
|     with torch.no_grad(): | ||||
|       net.load_state_dict(params) | ||||
|       _, summary = weight_watcher.analyze(net, alphas=False) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user