Update NATS-Bench API to v1.1
This commit is contained in:
parent
dae387a97d
commit
ff989ba814
@ -7,4 +7,4 @@
|
|||||||
- [2020.07.01] [a45808b] Upgrade NAS-API to the 2.0 version.
|
- [2020.07.01] [a45808b] Upgrade NAS-API to the 2.0 version.
|
||||||
- [2020.09.16] [7052265] Create NATS-BENCH.
|
- [2020.09.16] [7052265] Create NATS-BENCH.
|
||||||
- [2020.10.15] [446262a] Update NATS-BENCH to version 1.0
|
- [2020.10.15] [446262a] Update NATS-BENCH to version 1.0
|
||||||
- [2020.12.20] [59b5696] Update NATS-BENCH to version 1.1
|
- [2020.12.20] [dae387a] Update NATS-BENCH to version 1.1
|
||||||
|
@ -17,7 +17,13 @@ from nats_bench.api_topology import ALL_BASE_NAMES as tss_base_names
|
|||||||
|
|
||||||
|
|
||||||
def get_fake_torch_home_dir():
|
def get_fake_torch_home_dir():
|
||||||
|
print('This file is {:}'.format(os.path.abspath(__file__)))
|
||||||
|
print('The current directory is {:}'.format(os.path.abspath(os.getcwd())))
|
||||||
|
xname = 'FAKE_TORCH_HOME'
|
||||||
|
if xname in os.environ:
|
||||||
return os.environ['FAKE_TORCH_HOME']
|
return os.environ['FAKE_TORCH_HOME']
|
||||||
|
else:
|
||||||
|
return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'fake_torch_dir')
|
||||||
|
|
||||||
|
|
||||||
class TestNATSBench(object):
|
class TestNATSBench(object):
|
||||||
@ -70,8 +76,10 @@ class TestNATSBench(object):
|
|||||||
print(xinfo)
|
print(xinfo)
|
||||||
print(data[777].train_acc1es)
|
print(data[777].train_acc1es)
|
||||||
|
|
||||||
info_012_epochs = api.get_more_info(284, 'cifar10', hp=200)
|
info_012_epochs = api.get_more_info(284, 'cifar10', hp= 12)
|
||||||
print(info_012_epochs['train-accuracy'])
|
print('Train accuracy for 12 epochs is {:}'.format(info_012_epochs['train-accuracy']))
|
||||||
|
info_200_epochs = api.get_more_info(284, 'cifar10', hp=200)
|
||||||
|
print('Train accuracy for 200 epochs is {:}'.format(info_200_epochs['train-accuracy']))
|
||||||
|
|
||||||
|
|
||||||
def _test_nats_bench(benchmark_dir, is_tss, fake_random, verbose=False):
|
def _test_nats_bench(benchmark_dir, is_tss, fake_random, verbose=False):
|
||||||
|
Loading…
Reference in New Issue
Block a user