xautodl/lib/nats_bench/api_test.py

104 lines
4.9 KiB
Python
Raw Normal View History

2020-10-15 10:56:20 +02:00
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 ##########################
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##############################################################################
2020-12-19 16:42:21 +01:00
# pytest --capture=tee-sys #
##############################################################################
2020-10-15 10:56:20 +02:00
"""This file is used to quickly test the API."""
2020-12-19 16:42:21 +01:00
import os
import pytest
2020-10-15 10:56:20 +02:00
import random
from nats_bench.api_size import NATSsize
2020-12-19 16:42:21 +01:00
from nats_bench.api_size import ALL_BASE_NAMES as sss_base_names
2020-10-15 10:56:20 +02:00
from nats_bench.api_topology import NATStopology
2020-12-19 16:42:21 +01:00
from nats_bench.api_topology import ALL_BASE_NAMES as tss_base_names
def get_fake_torch_home_dir():
return os.environ['FAKE_TORCH_HOME']
2020-10-15 10:56:20 +02:00
2020-12-19 16:42:21 +01:00
class TestNATSBench(object):
2020-10-15 10:56:20 +02:00
2020-12-19 16:42:21 +01:00
def test_nats_bench_tss(self, benchmark_dir=None, fake_random=True):
if benchmark_dir is None:
benchmark_dir = os.path.join(get_fake_torch_home_dir(), sss_base_names[-1] + '-simple')
return _test_nats_bench(benchmark_dir, True, fake_random)
2020-10-15 10:56:20 +02:00
2020-12-19 16:42:21 +01:00
def test_nats_bench_sss(self, benchmark_dir=None, fake_random=True):
if benchmark_dir is None:
benchmark_dir = os.path.join(get_fake_torch_home_dir(), tss_base_names[-1] + '-simple')
return _test_nats_bench(benchmark_dir, False, fake_random)
2020-10-15 10:56:20 +02:00
2020-12-19 16:42:21 +01:00
def test_01_th_issue(self):
# Link: https://github.com/D-X-Y/NATS-Bench/issues/1
print('')
tss_benchmark_dir = os.path.join(get_fake_torch_home_dir(), sss_base_names[-1] + '-simple')
api = NATStopology(tss_benchmark_dir, True, False)
# The performance of 0-th architecture on CIFAR-10 (trained by 12 epochs)
info = api.get_more_info(0, 'cifar10', hp=12)
print('The loss on the training set of CIFAR-10: {:}'.format(info['train-loss']))
print('The total training time for 12 epochs on CIFAR-10: {:}'.format(info['train-all-time']))
print('The per-epoch training time on CIFAR-10: {:}'.format(info['train-per-time']))
print('The total evaluation time on the test set of CIFAR-10 for 12 times: {:}'.format(info['test-all-time']))
print('The evaluation time on the test set of CIFAR-10: {:}'.format(info['test-per-time']))
# Please note that the splits of train/validation/test on CIFAR-10 in our NATS-Bench paper is different from the original CIFAR paper.
cost_info = api.get_cost_info(0, 'cifar10')
xkeys = ['T-train@epoch', # The per epoch training cost for CIFAR-10. Note that the training set of CIFAR-10 in NATS-Bench is a subset of the original training set in CIFAR paper.
'T-train@total',
'T-ori-test@epoch', # The time cost for the evaluation on the original test split of CIFAR-10, which is the validation + test sets of CIFAR-10 on NATS-Bench.
'T-ori-test@total'] # T-ori-test@epoch * 12 times.
for xkey in xkeys:
print('The cost info [{:}] for 0-th architecture on CIFAR-10 is {:}'.format(xkey, cost_info[xkey]))
2020-10-15 10:56:20 +02:00
2020-12-19 16:42:21 +01:00
def _test_nats_bench(benchmark_dir, is_tss, fake_random, verbose=False):
"""The main test entry for NATS-Bench."""
2020-10-15 10:56:20 +02:00
if is_tss:
api = NATStopology(benchmark_dir, True, verbose)
else:
api = NATSsize(benchmark_dir, True, verbose)
2020-12-19 16:42:21 +01:00
if fake_random:
test_indexes = [0, 11, 241]
else:
test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)]
2020-10-15 10:56:20 +02:00
key2dataset = {'cifar10': 'CIFAR-10',
'cifar100': 'CIFAR-100',
'ImageNet16-120': 'ImageNet16-120'}
for index in test_indexes:
print('\n\nEvaluate the {:5d}-th architecture.'.format(index))
for key, dataset in key2dataset.items():
# Query the loss / accuracy / time for the `index`-th candidate
# architecture on CIFAR-10
# info is a dict, where you can easily figure out the meaning by key
info = api.get_more_info(index, key)
print(' -->> The performance on {:}: {:}'.format(dataset, info))
# Query the flops, params, latency. info is a dict.
info = api.get_cost_info(index, key)
print(' -->> The cost info on {:}: {:}'.format(dataset, info))
# Simulate the training of the `index`-th candidate:
validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(
index, dataset=key, hp='12')
print(' -->> The validation accuracy={:}, latency={:}, '
'the current time cost={:} s, accumulated time cost={:} s'
.format(validation_accuracy, latency, time_cost,
current_total_time_cost))
# Print the configuration of the `index`-th architecture on CIFAR-10
config = api.get_net_config(index, key)
print(' -->> The configuration on {:} is {:}'.format(dataset, config))
# Show the information of the `index`-th architecture
api.show(index)
2020-12-19 16:42:21 +01:00
with pytest.raises(ValueError):
api.get_more_info(100000, 'cifar10')