Create NATS
This commit is contained in:
		| @@ -1,9 +1,11 @@ | ||||
| ############################################################### | ||||
| # NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) # | ||||
| ############################################################### | ||||
| # NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size | ||||
| ############################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06           # | ||||
| ############################################################### | ||||
| # Usage: python exps/NAS-Bench-201/test-nas-api.py | ||||
| # Usage: python exps/NAS-Bench-201/test-nas-api.py            # | ||||
| ############################################################### | ||||
| import os, sys, time, torch, argparse | ||||
| import numpy as np | ||||
| @@ -21,7 +23,7 @@ import matplotlib.ticker as ticker | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import dict2config, load_config | ||||
| from nas_201_api import NASBench201API, NASBench301API | ||||
| from nats_bench import create | ||||
| from log_utils import time_string | ||||
| from models import get_cell_based_tiny_net, CellStructure | ||||
|  | ||||
| @@ -97,15 +99,14 @@ def test_issue_81_82(api): | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|  | ||||
|   api201 = NASBench201API(os.path.join(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_0-e61699.pth'), verbose=True) | ||||
|   api201 = create(os.path.join(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_0-e61699.pth'), 'topology', True) | ||||
|   test_issue_81_82(api201) | ||||
|   # test_api(api201, False) | ||||
|   print ('Test {:} done'.format(api201)) | ||||
|  | ||||
|   api201 = NASBench201API(None, verbose=True) | ||||
|   api201 = create(None, 'topology', True)  # use the default file path | ||||
|   test_issue_81_82(api201) | ||||
|   test_api(api201, False) | ||||
|   print ('Test {:} done'.format(api201)) | ||||
|  | ||||
|   # api301 = NASBench301API(None, verbose=True) | ||||
|   # test_api(api301, True) | ||||
|   api301 = create(None, 'size', True) | ||||
|   test_api(api301, True) | ||||
|   | ||||
| @@ -16,7 +16,7 @@ from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from config_utils import dict2config | ||||
| # NAS-Bench-201 related module or function | ||||
| from models       import CellStructure, get_cell_based_tiny_net | ||||
| from nas_201_api  import NASBench301API, ArchResults, ResultsCount | ||||
| from nas_201_api  import ArchResults, ResultsCount | ||||
| from procedures   import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user