2020-10-15 10:56:20 +02:00
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 ##########################
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##############################################################################
2020-12-19 16:42:21 +01:00
# pytest --capture=tee-sys #
##############################################################################
2020-10-15 10:56:20 +02:00
""" This file is used to quickly test the API. """
2020-12-19 16:42:21 +01:00
import os
import pytest
2020-10-15 10:56:20 +02:00
import random
from nats_bench . api_size import NATSsize
2020-12-19 16:42:21 +01:00
from nats_bench . api_size import ALL_BASE_NAMES as sss_base_names
2020-10-15 10:56:20 +02:00
from nats_bench . api_topology import NATStopology
2020-12-19 16:42:21 +01:00
from nats_bench . api_topology import ALL_BASE_NAMES as tss_base_names
def get_fake_torch_home_dir ( ) :
return os . environ [ ' FAKE_TORCH_HOME ' ]
2020-10-15 10:56:20 +02:00
2020-12-19 16:42:21 +01:00
class TestNATSBench ( object ) :
2020-10-15 10:56:20 +02:00
2020-12-19 16:42:21 +01:00
def test_nats_bench_tss ( self , benchmark_dir = None , fake_random = True ) :
if benchmark_dir is None :
benchmark_dir = os . path . join ( get_fake_torch_home_dir ( ) , sss_base_names [ - 1 ] + ' -simple ' )
return _test_nats_bench ( benchmark_dir , True , fake_random )
2020-10-15 10:56:20 +02:00
2020-12-19 16:42:21 +01:00
def test_nats_bench_sss ( self , benchmark_dir = None , fake_random = True ) :
if benchmark_dir is None :
benchmark_dir = os . path . join ( get_fake_torch_home_dir ( ) , tss_base_names [ - 1 ] + ' -simple ' )
return _test_nats_bench ( benchmark_dir , False , fake_random )
2020-10-15 10:56:20 +02:00
2020-12-19 16:42:21 +01:00
def test_01_th_issue ( self ) :
# Link: https://github.com/D-X-Y/NATS-Bench/issues/1
print ( ' ' )
tss_benchmark_dir = os . path . join ( get_fake_torch_home_dir ( ) , sss_base_names [ - 1 ] + ' -simple ' )
api = NATStopology ( tss_benchmark_dir , True , False )
# The performance of 0-th architecture on CIFAR-10 (trained by 12 epochs)
info = api . get_more_info ( 0 , ' cifar10 ' , hp = 12 )
print ( ' The loss on the training set of CIFAR-10: {:} ' . format ( info [ ' train-loss ' ] ) )
print ( ' The total training time for 12 epochs on CIFAR-10: {:} ' . format ( info [ ' train-all-time ' ] ) )
print ( ' The per-epoch training time on CIFAR-10: {:} ' . format ( info [ ' train-per-time ' ] ) )
print ( ' The total evaluation time on the test set of CIFAR-10 for 12 times: {:} ' . format ( info [ ' test-all-time ' ] ) )
print ( ' The evaluation time on the test set of CIFAR-10: {:} ' . format ( info [ ' test-per-time ' ] ) )
# Please note that the splits of train/validation/test on CIFAR-10 in our NATS-Bench paper is different from the original CIFAR paper.
cost_info = api . get_cost_info ( 0 , ' cifar10 ' )
xkeys = [ ' T-train@epoch ' , # The per epoch training cost for CIFAR-10. Note that the training set of CIFAR-10 in NATS-Bench is a subset of the original training set in CIFAR paper.
' T-train@total ' ,
' T-ori-test@epoch ' , # The time cost for the evaluation on the original test split of CIFAR-10, which is the validation + test sets of CIFAR-10 on NATS-Bench.
' T-ori-test@total ' ] # T-ori-test@epoch * 12 times.
for xkey in xkeys :
print ( ' The cost info [ {:} ] for 0-th architecture on CIFAR-10 is {:} ' . format ( xkey , cost_info [ xkey ] ) )
2020-10-15 10:56:20 +02:00
2020-12-19 16:42:21 +01:00
def _test_nats_bench ( benchmark_dir , is_tss , fake_random , verbose = False ) :
""" The main test entry for NATS-Bench. """
2020-10-15 10:56:20 +02:00
if is_tss :
api = NATStopology ( benchmark_dir , True , verbose )
else :
api = NATSsize ( benchmark_dir , True , verbose )
2020-12-19 16:42:21 +01:00
if fake_random :
test_indexes = [ 0 , 11 , 241 ]
else :
test_indexes = [ random . randint ( 0 , len ( api ) - 1 ) for _ in range ( 10 ) ]
2020-10-15 10:56:20 +02:00
key2dataset = { ' cifar10 ' : ' CIFAR-10 ' ,
' cifar100 ' : ' CIFAR-100 ' ,
' ImageNet16-120 ' : ' ImageNet16-120 ' }
for index in test_indexes :
print ( ' \n \n Evaluate the {:5d} -th architecture. ' . format ( index ) )
for key , dataset in key2dataset . items ( ) :
# Query the loss / accuracy / time for the `index`-th candidate
# architecture on CIFAR-10
# info is a dict, where you can easily figure out the meaning by key
info = api . get_more_info ( index , key )
print ( ' -->> The performance on {:} : {:} ' . format ( dataset , info ) )
# Query the flops, params, latency. info is a dict.
info = api . get_cost_info ( index , key )
print ( ' -->> The cost info on {:} : {:} ' . format ( dataset , info ) )
# Simulate the training of the `index`-th candidate:
validation_accuracy , latency , time_cost , current_total_time_cost = api . simulate_train_eval (
index , dataset = key , hp = ' 12 ' )
print ( ' -->> The validation accuracy= {:} , latency= {:} , '
' the current time cost= {:} s, accumulated time cost= {:} s '
. format ( validation_accuracy , latency , time_cost ,
current_total_time_cost ) )
# Print the configuration of the `index`-th architecture on CIFAR-10
config = api . get_net_config ( index , key )
print ( ' -->> The configuration on {:} is {:} ' . format ( dataset , config ) )
# Show the information of the `index`-th architecture
api . show ( index )
2020-12-19 16:42:21 +01:00
with pytest . raises ( ValueError ) :
api . get_more_info ( 100000 , ' cifar10 ' )