Update NATS-Bench API to v1.1
This commit is contained in:
		| @@ -7,4 +7,4 @@ | |||||||
| - [2020.07.01] [a45808b] Upgrade NAS-API to the 2.0 version. | - [2020.07.01] [a45808b] Upgrade NAS-API to the 2.0 version. | ||||||
| - [2020.09.16] [7052265] Create NATS-BENCH. | - [2020.09.16] [7052265] Create NATS-BENCH. | ||||||
| - [2020.10.15] [446262a] Update NATS-BENCH to version 1.0 | - [2020.10.15] [446262a] Update NATS-BENCH to version 1.0 | ||||||
| - [2020.12.20] [59b5696] Update NATS-BENCH to version 1.1 | - [2020.12.20] [dae387a] Update NATS-BENCH to version 1.1 | ||||||
|   | |||||||
| @@ -17,7 +17,13 @@ from nats_bench.api_topology import ALL_BASE_NAMES as tss_base_names | |||||||
|  |  | ||||||
|  |  | ||||||
| def get_fake_torch_home_dir(): | def get_fake_torch_home_dir(): | ||||||
|  |   print('This file is {:}'.format(os.path.abspath(__file__))) | ||||||
|  |   print('The current directory is {:}'.format(os.path.abspath(os.getcwd()))) | ||||||
|  |   xname = 'FAKE_TORCH_HOME' | ||||||
|  |   if xname in os.environ: | ||||||
|     return os.environ['FAKE_TORCH_HOME'] |     return os.environ['FAKE_TORCH_HOME'] | ||||||
|  |   else: | ||||||
|  |     return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'fake_torch_dir') | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestNATSBench(object): | class TestNATSBench(object): | ||||||
| @@ -70,8 +76,10 @@ class TestNATSBench(object): | |||||||
|     print(xinfo) |     print(xinfo) | ||||||
|     print(data[777].train_acc1es) |     print(data[777].train_acc1es) | ||||||
|  |  | ||||||
|     info_012_epochs = api.get_more_info(284, 'cifar10', hp=200) |     info_012_epochs = api.get_more_info(284, 'cifar10', hp= 12) | ||||||
|     print(info_012_epochs['train-accuracy']) |     print('Train accuracy for  12 epochs is {:}'.format(info_012_epochs['train-accuracy'])) | ||||||
|  |     info_200_epochs = api.get_more_info(284, 'cifar10', hp=200) | ||||||
|  |     print('Train accuracy for 200 epochs is {:}'.format(info_200_epochs['train-accuracy'])) | ||||||
|   |   | ||||||
|  |  | ||||||
| def _test_nats_bench(benchmark_dir, is_tss, fake_random, verbose=False): | def _test_nats_bench(benchmark_dir, is_tss, fake_random, verbose=False): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user