Update the query_by_arch function in API to be compatiable with the submission version of NAS-Bench-201
This commit is contained in:
		| @@ -23,7 +23,7 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import dict2config, load_config | ||||
| from nas_201_api import NASBench201API, NASBench301API | ||||
| from log_utils import time_string | ||||
| from models import get_cell_based_tiny_net | ||||
| from models import get_cell_based_tiny_net, CellStructure | ||||
|  | ||||
|  | ||||
| def test_api(api, is_301=True): | ||||
| @@ -80,6 +80,11 @@ def test_issue_81_82(api): | ||||
|   print(results[888].get_eval('valid')) | ||||
|   print(results[888].get_eval('x-valid')) | ||||
|   result_dict = api.get_more_info(index=0, dataset='cifar10-valid', iepoch=11, hp='200', is_random=False) | ||||
|   info = api.query_by_arch('|nor_conv_3x3~0|+|skip_connect~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|', '200') | ||||
|   print(info) | ||||
|   structure = CellStructure.str2structure('|nor_conv_3x3~0|+|skip_connect~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|') | ||||
|   info = api.query_by_arch(structure, '200') | ||||
|   print(info) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   | ||||
		Reference in New Issue
	
	Block a user