Fix minor bugs in test-ww.py
This commit is contained in:
		| @@ -1,7 +1,9 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||||
| ##################################################### | ##################################################### | ||||||
| # [2020.03.09] Upgrade to v1.2 | # [2020.02.25] Initialize the API as v1.1 | ||||||
|  | # [2020.03.09] Upgrade the API to v1.2 | ||||||
|  | # [2020.03.16] Upgrade the API to v1.3 | ||||||
| import os | import os | ||||||
| from setuptools import setup | from setuptools import setup | ||||||
|  |  | ||||||
| @@ -13,7 +15,7 @@ def read(fname='README.md'): | |||||||
|  |  | ||||||
| setup( | setup( | ||||||
|     name = "nas_bench_201", |     name = "nas_bench_201", | ||||||
|     version = "1.2", |     version = "1.3", | ||||||
|     author = "Xuanyi Dong", |     author = "Xuanyi Dong", | ||||||
|     author_email = "dongxuanyi888@gmail.com", |     author_email = "dongxuanyi888@gmail.com", | ||||||
|     description = "API for NAS-Bench-201 (a benchmark for neural architecture search).", |     description = "API for NAS-Bench-201 (a benchmark for neural architecture search).", | ||||||
|   | |||||||
| @@ -37,7 +37,8 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool): | |||||||
|   final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) |   final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) | ||||||
|   final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) |   final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) | ||||||
|   for idx in range(len(api)): |   for idx in range(len(api)): | ||||||
|     info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False) |     # info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False) | ||||||
|  |     # import pdb; pdb.set_trace() | ||||||
|     for key in ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']: |     for key in ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']: | ||||||
|       info = api.get_more_info(idx, key, use_12epochs_result=False, is_random=False) |       info = api.get_more_info(idx, key, use_12epochs_result=False, is_random=False) | ||||||
|       if key == 'cifar10-valid': |       if key == 'cifar10-valid': | ||||||
| @@ -50,7 +51,7 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool): | |||||||
|     config = api.get_net_config(idx, data) |     config = api.get_net_config(idx, data) | ||||||
|     net = get_cell_based_tiny_net(config) |     net = get_cell_based_tiny_net(config) | ||||||
|     api.reload(weight_dir, idx) |     api.reload(weight_dir, idx) | ||||||
|     params = api.get_net_param(idx, data, None) |     params = api.get_net_param(idx, data, None, use_12epochs_result=use_12epochs_result) | ||||||
|     cur_norms = [] |     cur_norms = [] | ||||||
|     for seed, param in params.items(): |     for seed, param in params.items(): | ||||||
|       with torch.no_grad(): |       with torch.no_grad(): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user