wrote the get_nasbench201_idx_score
This commit is contained in:
		| @@ -58,15 +58,9 @@ def get_batch_jacobian(net, x, target, device, args=None): | |||||||
|     return jacob, target.detach(), y.detach(), out.detach() |     return jacob, target.detach(), y.detach(), out.detach() | ||||||
|  |  | ||||||
| def get_nasbench201_nodes_score(nodes, train_loader, searchspace, args, device): | def get_nasbench201_nodes_score(nodes, train_loader, searchspace, args, device): | ||||||
|     op_type = { |     num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] | ||||||
|     'input': 0, |      | ||||||
|     'nor_conv_1x1': 1, |  | ||||||
|     'nor_conv_3x3': 2, |  | ||||||
|     'avg_pool_3x3': 3, |  | ||||||
|     'skip_connect': 4, |  | ||||||
|     'none': 5, |  | ||||||
|     'output': 6, |  | ||||||
|     } |  | ||||||
|  |  | ||||||
| def get_nasbench201_idx_score(idx, train_loader, searchspace, args, device): | def get_nasbench201_idx_score(idx, train_loader, searchspace, args, device): | ||||||
|     # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |     # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user