Update the query_by_arch function in API to be compatiable with the submission version of NAS-Bench-201
This commit is contained in:
		| @@ -23,7 +23,7 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | |||||||
| from config_utils import dict2config, load_config | from config_utils import dict2config, load_config | ||||||
| from nas_201_api import NASBench201API, NASBench301API | from nas_201_api import NASBench201API, NASBench301API | ||||||
| from log_utils import time_string | from log_utils import time_string | ||||||
| from models import get_cell_based_tiny_net | from models import get_cell_based_tiny_net, CellStructure | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_api(api, is_301=True): | def test_api(api, is_301=True): | ||||||
| @@ -80,6 +80,11 @@ def test_issue_81_82(api): | |||||||
|   print(results[888].get_eval('valid')) |   print(results[888].get_eval('valid')) | ||||||
|   print(results[888].get_eval('x-valid')) |   print(results[888].get_eval('x-valid')) | ||||||
|   result_dict = api.get_more_info(index=0, dataset='cifar10-valid', iepoch=11, hp='200', is_random=False) |   result_dict = api.get_more_info(index=0, dataset='cifar10-valid', iepoch=11, hp='200', is_random=False) | ||||||
|  |   info = api.query_by_arch('|nor_conv_3x3~0|+|skip_connect~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|', '200') | ||||||
|  |   print(info) | ||||||
|  |   structure = CellStructure.str2structure('|nor_conv_3x3~0|+|skip_connect~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|') | ||||||
|  |   info = api.query_by_arch(structure, '200') | ||||||
|  |   print(info) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   | |||||||
| @@ -4,7 +4,6 @@ | |||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_combination(space, num): | def get_combination(space, num): | ||||||
|   combs = [] |   combs = [] | ||||||
|   for i in range(num): |   for i in range(num): | ||||||
| @@ -21,7 +20,6 @@ def get_combination(space, num): | |||||||
|   return combs |   return combs | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Structure: | class Structure: | ||||||
|  |  | ||||||
|   def __init__(self, genotype): |   def __init__(self, genotype): | ||||||
|   | |||||||
| @@ -123,7 +123,7 @@ class NASBench201API(NASBenchMetaAPI): | |||||||
|     """ |     """ | ||||||
|     if self.verbose: |     if self.verbose: | ||||||
|       print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp)) |       print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp)) | ||||||
|     self._query_info_str_by_arch(arch, hp, print_information) |     return self._query_info_str_by_arch(arch, hp, print_information) | ||||||
|  |  | ||||||
|   # obtain the metric for the `index`-th architecture |   # obtain the metric for the `index`-th architecture | ||||||
|   # `dataset` indicates the dataset: |   # `dataset` indicates the dataset: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user