update code styles
This commit is contained in:
parent
ad34af9913
commit
96152a9904
@ -1,6 +1,6 @@
|
|||||||
MIT License
|
MIT License
|
||||||
|
|
||||||
Copyright (c) 2019 Xuanyi Dong [GitHub: https://github.com/D-X-Y]
|
Copyright (c) 2019 Xuanyi Dong (GitHub: https://github.com/D-X-Y)
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
@ -6,9 +6,9 @@ Each edge here is associated with an operation selected from a predefined operat
|
|||||||
For it to be applicable for all NAS algorithms, the search space defined in NAS-Bench-102 includes 4 nodes and 5 associated operation options, which generates 15,625 neural cell candidates in total.
|
For it to be applicable for all NAS algorithms, the search space defined in NAS-Bench-102 includes 4 nodes and 5 associated operation options, which generates 15,625 neural cell candidates in total.
|
||||||
|
|
||||||
In this Markdown file, we provide:
|
In this Markdown file, we provide:
|
||||||
- [How to Use NAS-Bench-102](#how-to-use-nas-bench-102)
|
- [How to Use NAS-Bench-102](#how-to-use-nas-bench-102)
|
||||||
- [Instruction to re-generate NAS-Bench-102](#instruction-to-re-generate-nas-bench-102)
|
- [Instruction to re-generate NAS-Bench-102](#instruction-to-re-generate-nas-bench-102)
|
||||||
- [10 NAS algorithms evaluated in our paper](#to-reproduce-10-baseline-nas-algorithms-in-nas-bench-102)
|
- [10 NAS algorithms evaluated in our paper](#to-reproduce-10-baseline-nas-algorithms-in-nas-bench-102)
|
||||||
|
|
||||||
Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`.
|
Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`.
|
||||||
|
|
||||||
|
@ -87,7 +87,8 @@ def test_one_shot_model(ckpath, use_train):
|
|||||||
ckp = torch.load(ckpath)
|
ckp = torch.load(ckpath)
|
||||||
xargs = ckp['args']
|
xargs = ckp['args']
|
||||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
|
#config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
|
||||||
|
config = load_config('./configs/nas-benchmark/algos/DARTS.config', {'class_num': class_num, 'xshape': xshape}, None)
|
||||||
if xargs.dataset == 'cifar10':
|
if xargs.dataset == 'cifar10':
|
||||||
cifar_split = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
|
cifar_split = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
|
||||||
xvalid_data = deepcopy(train_data)
|
xvalid_data = deepcopy(train_data)
|
||||||
|
@ -15,14 +15,16 @@ def evaluate_one_shot(model, xloader, api, cal_mode, seed=111):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = nn.functional.log_softmax(model.arch_parameters, dim=-1)
|
logits = nn.functional.log_softmax(model.arch_parameters, dim=-1)
|
||||||
archs = CellStructure.gen_all(model.op_names, model.max_nodes, False)
|
archs = CellStructure.gen_all(model.op_names, model.max_nodes, False)
|
||||||
probs, accuracies, gt_accs = [], [], []
|
probs, accuracies, gt_accs_10_valid, gt_accs_10_test = [], [], [], []
|
||||||
loader_iter = iter(xloader)
|
loader_iter = iter(xloader)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
random.shuffle(archs)
|
random.shuffle(archs)
|
||||||
for idx, arch in enumerate(archs):
|
for idx, arch in enumerate(archs):
|
||||||
arch_index = api.query_index_by_arch( arch )
|
arch_index = api.query_index_by_arch( arch )
|
||||||
metrics = api.get_more_info(arch_index, 'cifar10-valid', None, False, False)
|
metrics = api.get_more_info(arch_index, 'cifar10-valid', None, False, False)
|
||||||
gt_accs.append( metrics['valid-accuracy'] )
|
gt_accs_10_valid.append( metrics['valid-accuracy'] )
|
||||||
|
metrics = api.get_more_info(arch_index, 'cifar10', None, False, False)
|
||||||
|
gt_accs_10_test.append( metrics['test-accuracy'] )
|
||||||
select_logits = []
|
select_logits = []
|
||||||
for i, node_info in enumerate(arch.nodes):
|
for i, node_info in enumerate(arch.nodes):
|
||||||
for op, xin in node_info:
|
for op, xin in node_info:
|
||||||
@ -31,8 +33,9 @@ def evaluate_one_shot(model, xloader, api, cal_mode, seed=111):
|
|||||||
select_logits.append( logits[model.edge2index[node_str], op_index] )
|
select_logits.append( logits[model.edge2index[node_str], op_index] )
|
||||||
cur_prob = sum(select_logits).item()
|
cur_prob = sum(select_logits).item()
|
||||||
probs.append( cur_prob )
|
probs.append( cur_prob )
|
||||||
cor_prob = np.corrcoef(probs, gt_accs)[0,1]
|
cor_prob_valid = np.corrcoef(probs, gt_accs_10_valid)[0,1]
|
||||||
print ('correlation for probabilities : {:}'.format(cor_prob))
|
cor_prob_test = np.corrcoef(probs, gt_accs_10_test )[0,1]
|
||||||
|
print ('{:} correlation for probabilities : {:.6f} on CIFAR-10 validation and {:.6f} on CIFAR-10 test'.format(time_string(), cor_prob_valid, cor_prob_test))
|
||||||
|
|
||||||
for idx, arch in enumerate(archs):
|
for idx, arch in enumerate(archs):
|
||||||
model.set_cal_mode('dynamic', arch)
|
model.set_cal_mode('dynamic', arch)
|
||||||
@ -45,8 +48,9 @@ def evaluate_one_shot(model, xloader, api, cal_mode, seed=111):
|
|||||||
_, preds = torch.max(logits, dim=-1)
|
_, preds = torch.max(logits, dim=-1)
|
||||||
correct = (preds == targets.cuda() ).float()
|
correct = (preds == targets.cuda() ).float()
|
||||||
accuracies.append( correct.mean().item() )
|
accuracies.append( correct.mean().item() )
|
||||||
if idx != 0 and (idx % 300 == 0 or idx + 1 == len(archs) or idx == 10):
|
if idx != 0 and (idx % 500 == 0 or idx + 1 == len(archs)):
|
||||||
cor_accs = np.corrcoef(accuracies, gt_accs[:idx+1])[0,1]
|
cor_accs_valid = np.corrcoef(accuracies, gt_accs_10_valid[:idx+1])[0,1]
|
||||||
print ('{:} {:03d}/{:03d} mode={:5s}, correlation : accs={:.4f}, arch={:}'.format(time_string(), idx, len(archs), 'Train' if cal_mode else 'Eval', cor_accs, arch))
|
cor_accs_test = np.corrcoef(accuracies, gt_accs_10_test [:idx+1])[0,1]
|
||||||
|
print ('{:} {:05d}/{:05d} mode={:5s}, correlation : accs={:.5f} for CIFAR-10 valid, {:.5f} for CIFAR-10 test.'.format(time_string(), idx, len(archs), 'Train' if cal_mode else 'Eval', cor_accs_valid, cor_accs_test))
|
||||||
model.load_state_dict(weights)
|
model.load_state_dict(weights)
|
||||||
return archs, probs, accuracies
|
return archs, probs, accuracies
|
||||||
|
Loading…
Reference in New Issue
Block a user