Fix minor bugs in test-ww.py
This commit is contained in:
		| @@ -1,7 +1,9 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||
| ##################################################### | ||||
| # [2020.03.09] Upgrade to v1.2 | ||||
| # [2020.02.25] Initialize the API as v1.1 | ||||
| # [2020.03.09] Upgrade the API to v1.2 | ||||
| # [2020.03.16] Upgrade the API to v1.3 | ||||
| import os | ||||
| from setuptools import setup | ||||
|  | ||||
| @@ -13,7 +15,7 @@ def read(fname='README.md'): | ||||
|  | ||||
| setup( | ||||
|     name = "nas_bench_201", | ||||
|     version = "1.2", | ||||
|     version = "1.3", | ||||
|     author = "Xuanyi Dong", | ||||
|     author_email = "dongxuanyi888@gmail.com", | ||||
|     description = "API for NAS-Bench-201 (a benchmark for neural architecture search).", | ||||
|   | ||||
| @@ -37,7 +37,8 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool): | ||||
|   final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) | ||||
|   final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) | ||||
|   for idx in range(len(api)): | ||||
|     info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False) | ||||
|     # info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False) | ||||
|     # import pdb; pdb.set_trace() | ||||
|     for key in ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']: | ||||
|       info = api.get_more_info(idx, key, use_12epochs_result=False, is_random=False) | ||||
|       if key == 'cifar10-valid': | ||||
| @@ -50,7 +51,7 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool): | ||||
|     config = api.get_net_config(idx, data) | ||||
|     net = get_cell_based_tiny_net(config) | ||||
|     api.reload(weight_dir, idx) | ||||
|     params = api.get_net_param(idx, data, None) | ||||
|     params = api.get_net_param(idx, data, None, use_12epochs_result=use_12epochs_result) | ||||
|     cur_norms = [] | ||||
|     for seed, param in params.items(): | ||||
|       with torch.no_grad(): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user