update code styles
This commit is contained in:
		| @@ -1,6 +1,6 @@ | |||||||
| MIT License | MIT License | ||||||
|  |  | ||||||
| Copyright (c) 2019 Xuanyi Dong [GitHub: https://github.com/D-X-Y] | Copyright (c) 2019 Xuanyi Dong (GitHub: https://github.com/D-X-Y) | ||||||
|  |  | ||||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | Permission is hereby granted, free of charge, to any person obtaining a copy | ||||||
| of this software and associated documentation files (the "Software"), to deal | of this software and associated documentation files (the "Software"), to deal | ||||||
|   | |||||||
| @@ -87,7 +87,8 @@ def test_one_shot_model(ckpath, use_train): | |||||||
|   ckp = torch.load(ckpath) |   ckp = torch.load(ckpath) | ||||||
|   xargs = ckp['args'] |   xargs = ckp['args'] | ||||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) |   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None) |   #config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None) | ||||||
|  |   config = load_config('./configs/nas-benchmark/algos/DARTS.config', {'class_num': class_num, 'xshape': xshape}, None) | ||||||
|   if xargs.dataset == 'cifar10': |   if xargs.dataset == 'cifar10': | ||||||
|     cifar_split = load_config('configs/nas-benchmark/cifar-split.txt', None, None) |     cifar_split = load_config('configs/nas-benchmark/cifar-split.txt', None, None) | ||||||
|     xvalid_data = deepcopy(train_data) |     xvalid_data = deepcopy(train_data) | ||||||
|   | |||||||
| @@ -15,14 +15,16 @@ def evaluate_one_shot(model, xloader, api, cal_mode, seed=111): | |||||||
|   with torch.no_grad(): |   with torch.no_grad(): | ||||||
|     logits = nn.functional.log_softmax(model.arch_parameters, dim=-1) |     logits = nn.functional.log_softmax(model.arch_parameters, dim=-1) | ||||||
|     archs = CellStructure.gen_all(model.op_names, model.max_nodes, False) |     archs = CellStructure.gen_all(model.op_names, model.max_nodes, False) | ||||||
|     probs, accuracies, gt_accs = [], [], [] |     probs, accuracies, gt_accs_10_valid, gt_accs_10_test = [], [], [], [] | ||||||
|     loader_iter = iter(xloader) |     loader_iter = iter(xloader) | ||||||
|     random.seed(seed) |     random.seed(seed) | ||||||
|     random.shuffle(archs) |     random.shuffle(archs) | ||||||
|     for idx, arch in enumerate(archs): |     for idx, arch in enumerate(archs): | ||||||
|       arch_index = api.query_index_by_arch( arch ) |       arch_index = api.query_index_by_arch( arch ) | ||||||
|       metrics = api.get_more_info(arch_index, 'cifar10-valid', None, False, False) |       metrics = api.get_more_info(arch_index, 'cifar10-valid', None, False, False) | ||||||
|       gt_accs.append( metrics['valid-accuracy'] ) |       gt_accs_10_valid.append( metrics['valid-accuracy'] ) | ||||||
|  |       metrics = api.get_more_info(arch_index, 'cifar10', None, False, False) | ||||||
|  |       gt_accs_10_test.append( metrics['test-accuracy'] ) | ||||||
|       select_logits = [] |       select_logits = [] | ||||||
|       for i, node_info in enumerate(arch.nodes): |       for i, node_info in enumerate(arch.nodes): | ||||||
|         for op, xin in node_info: |         for op, xin in node_info: | ||||||
| @@ -31,8 +33,9 @@ def evaluate_one_shot(model, xloader, api, cal_mode, seed=111): | |||||||
|           select_logits.append( logits[model.edge2index[node_str], op_index] ) |           select_logits.append( logits[model.edge2index[node_str], op_index] ) | ||||||
|       cur_prob = sum(select_logits).item() |       cur_prob = sum(select_logits).item() | ||||||
|       probs.append( cur_prob ) |       probs.append( cur_prob ) | ||||||
|     cor_prob = np.corrcoef(probs, gt_accs)[0,1] |     cor_prob_valid = np.corrcoef(probs, gt_accs_10_valid)[0,1] | ||||||
|     print ('correlation for probabilities : {:}'.format(cor_prob)) |     cor_prob_test  = np.corrcoef(probs, gt_accs_10_test )[0,1] | ||||||
|  |     print ('{:} correlation for probabilities : {:.6f} on CIFAR-10 validation and {:.6f} on CIFAR-10 test'.format(time_string(), cor_prob_valid, cor_prob_test)) | ||||||
|        |        | ||||||
|     for idx, arch in enumerate(archs): |     for idx, arch in enumerate(archs): | ||||||
|       model.set_cal_mode('dynamic', arch) |       model.set_cal_mode('dynamic', arch) | ||||||
| @@ -45,8 +48,9 @@ def evaluate_one_shot(model, xloader, api, cal_mode, seed=111): | |||||||
|       _, preds  = torch.max(logits, dim=-1) |       _, preds  = torch.max(logits, dim=-1) | ||||||
|       correct = (preds == targets.cuda() ).float() |       correct = (preds == targets.cuda() ).float() | ||||||
|       accuracies.append( correct.mean().item() ) |       accuracies.append( correct.mean().item() ) | ||||||
|       if idx != 0 and (idx % 300 == 0 or idx + 1 == len(archs) or idx == 10): |       if idx != 0 and (idx % 500 == 0 or idx + 1 == len(archs)): | ||||||
|         cor_accs = np.corrcoef(accuracies, gt_accs[:idx+1])[0,1] |         cor_accs_valid = np.corrcoef(accuracies, gt_accs_10_valid[:idx+1])[0,1] | ||||||
|         print ('{:} {:03d}/{:03d} mode={:5s}, correlation : accs={:.4f}, arch={:}'.format(time_string(), idx, len(archs), 'Train' if cal_mode else 'Eval', cor_accs, arch)) |         cor_accs_test  = np.corrcoef(accuracies, gt_accs_10_test [:idx+1])[0,1] | ||||||
|  |         print ('{:} {:05d}/{:05d} mode={:5s}, correlation : accs={:.5f} for CIFAR-10 valid, {:.5f} for CIFAR-10 test.'.format(time_string(), idx, len(archs), 'Train' if cal_mode else 'Eval', cor_accs_valid, cor_accs_test)) | ||||||
|   model.load_state_dict(weights) |   model.load_state_dict(weights) | ||||||
|   return archs, probs, accuracies |   return archs, probs, accuracies | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user