77 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			77 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # This file is for experimental usage
 | |
| import torch, random
 | |
| import numpy as np
 | |
| from copy import deepcopy
 | |
| import torch.nn as nn
 | |
| 
 | |
| # modules in AutoDL
 | |
| from models import CellStructure
 | |
| from log_utils import time_string
 | |
| 
 | |
| 
 | |
| def evaluate_one_shot(model, xloader, api, cal_mode, seed=111):
 | |
|     print(
 | |
|         "This is an old version of codes to use NAS-Bench-API, and should be modified to align with the new version. Please contact me for more details if you use this function."
 | |
|     )
 | |
|     weights = deepcopy(model.state_dict())
 | |
|     model.train(cal_mode)
 | |
|     with torch.no_grad():
 | |
|         logits = nn.functional.log_softmax(model.arch_parameters, dim=-1)
 | |
|         archs = CellStructure.gen_all(model.op_names, model.max_nodes, False)
 | |
|         probs, accuracies, gt_accs_10_valid, gt_accs_10_test = [], [], [], []
 | |
|         loader_iter = iter(xloader)
 | |
|         random.seed(seed)
 | |
|         random.shuffle(archs)
 | |
|         for idx, arch in enumerate(archs):
 | |
|             arch_index = api.query_index_by_arch(arch)
 | |
|             metrics = api.get_more_info(arch_index, "cifar10-valid", None, False, False)
 | |
|             gt_accs_10_valid.append(metrics["valid-accuracy"])
 | |
|             metrics = api.get_more_info(arch_index, "cifar10", None, False, False)
 | |
|             gt_accs_10_test.append(metrics["test-accuracy"])
 | |
|             select_logits = []
 | |
|             for i, node_info in enumerate(arch.nodes):
 | |
|                 for op, xin in node_info:
 | |
|                     node_str = "{:}<-{:}".format(i + 1, xin)
 | |
|                     op_index = model.op_names.index(op)
 | |
|                     select_logits.append(logits[model.edge2index[node_str], op_index])
 | |
|             cur_prob = sum(select_logits).item()
 | |
|             probs.append(cur_prob)
 | |
|         cor_prob_valid = np.corrcoef(probs, gt_accs_10_valid)[0, 1]
 | |
|         cor_prob_test = np.corrcoef(probs, gt_accs_10_test)[0, 1]
 | |
|         print(
 | |
|             "{:} correlation for probabilities : {:.6f} on CIFAR-10 validation and {:.6f} on CIFAR-10 test".format(
 | |
|                 time_string(), cor_prob_valid, cor_prob_test
 | |
|             )
 | |
|         )
 | |
| 
 | |
|         for idx, arch in enumerate(archs):
 | |
|             model.set_cal_mode("dynamic", arch)
 | |
|             try:
 | |
|                 inputs, targets = next(loader_iter)
 | |
|             except:
 | |
|                 loader_iter = iter(xloader)
 | |
|                 inputs, targets = next(loader_iter)
 | |
|             _, logits = model(inputs.cuda())
 | |
|             _, preds = torch.max(logits, dim=-1)
 | |
|             correct = (preds == targets.cuda()).float()
 | |
|             accuracies.append(correct.mean().item())
 | |
|             if idx != 0 and (idx % 500 == 0 or idx + 1 == len(archs)):
 | |
|                 cor_accs_valid = np.corrcoef(accuracies, gt_accs_10_valid[: idx + 1])[
 | |
|                     0, 1
 | |
|                 ]
 | |
|                 cor_accs_test = np.corrcoef(accuracies, gt_accs_10_test[: idx + 1])[
 | |
|                     0, 1
 | |
|                 ]
 | |
|                 print(
 | |
|                     "{:} {:05d}/{:05d} mode={:5s}, correlation : accs={:.5f} for CIFAR-10 valid, {:.5f} for CIFAR-10 test.".format(
 | |
|                         time_string(),
 | |
|                         idx,
 | |
|                         len(archs),
 | |
|                         "Train" if cal_mode else "Eval",
 | |
|                         cor_accs_valid,
 | |
|                         cor_accs_test,
 | |
|                     )
 | |
|                 )
 | |
|     model.load_state_dict(weights)
 | |
|     return archs, probs, accuracies
 |