diff --git a/CHANGE-LOG.md b/CHANGE-LOG.md index b3d4150..e19e656 100644 --- a/CHANGE-LOG.md +++ b/CHANGE-LOG.md @@ -7,4 +7,4 @@ - [2020.07.01] [a45808b] Upgrade NAS-API to the 2.0 version. - [2020.09.16] [7052265] Create NATS-BENCH. - [2020.10.15] [446262a] Update NATS-BENCH to version 1.0 -- [2020.12.20] [59b5696] Update NATS-BENCH to version 1.1 +- [2020.12.20] [dae387a] Update NATS-BENCH to version 1.1 diff --git a/lib/nats_bench/api_test.py b/lib/nats_bench/api_test.py index 12f9a80..7bedfbf 100644 --- a/lib/nats_bench/api_test.py +++ b/lib/nats_bench/api_test.py @@ -17,7 +17,13 @@ from nats_bench.api_topology import ALL_BASE_NAMES as tss_base_names def get_fake_torch_home_dir(): - return os.environ['FAKE_TORCH_HOME'] + print('This file is {:}'.format(os.path.abspath(__file__))) + print('The current directory is {:}'.format(os.path.abspath(os.getcwd()))) + xname = 'FAKE_TORCH_HOME' + if xname in os.environ: + return os.environ['FAKE_TORCH_HOME'] + else: + return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'fake_torch_dir') class TestNATSBench(object): @@ -70,8 +76,10 @@ class TestNATSBench(object): print(xinfo) print(data[777].train_acc1es) - info_012_epochs = api.get_more_info(284, 'cifar10', hp=200) - print(info_012_epochs['train-accuracy']) + info_012_epochs = api.get_more_info(284, 'cifar10', hp= 12) + print('Train accuracy for 12 epochs is {:}'.format(info_012_epochs['train-accuracy'])) + info_200_epochs = api.get_more_info(284, 'cifar10', hp=200) + print('Train accuracy for 200 epochs is {:}'.format(info_200_epochs['train-accuracy'])) def _test_nats_bench(benchmark_dir, is_tss, fake_random, verbose=False):