2020-02-23 00:30:37 +01:00
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
2019-12-31 12:02:11 +01:00
########################################################
2020-01-14 14:52:06 +01:00
# python exps/NAS-Bench-201/test-correlation.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
2019-12-31 12:02:11 +01:00
########################################################
2020-03-13 22:00:54 +01:00
import sys , argparse
2019-12-31 12:02:11 +01:00
import numpy as np
from copy import deepcopy
2020-01-01 12:18:42 +01:00
from tqdm import tqdm
2019-12-31 12:02:11 +01:00
import torch
from pathlib import Path
lib_dir = ( Path ( __file__ ) . parent / ' .. ' / ' .. ' / ' lib ' ) . resolve ( )
if str ( lib_dir ) not in sys . path : sys . path . insert ( 0 , str ( lib_dir ) )
2020-03-13 22:00:54 +01:00
from log_utils import time_string
from models import CellStructure
2020-01-14 14:52:06 +01:00
from nas_201_api import NASBench201API as API
2019-12-31 12:02:11 +01:00
def check_unique_arch ( meta_file ) :
api = API ( str ( meta_file ) )
arch_strs = deepcopy ( api . meta_archs )
xarchs = [ CellStructure . str2structure ( x ) for x in arch_strs ]
def get_unique_matrix ( archs , consider_zero ) :
UniquStrs = [ arch . to_unique_str ( consider_zero ) for arch in archs ]
print ( ' {:} create unique-string ( {:} / {:} ) done ' . format ( time_string ( ) , len ( set ( UniquStrs ) ) , len ( UniquStrs ) ) )
Unique2Index = dict ( )
for index , xstr in enumerate ( UniquStrs ) :
if xstr not in Unique2Index : Unique2Index [ xstr ] = list ( )
Unique2Index [ xstr ] . append ( index )
sm_matrix = torch . eye ( len ( archs ) ) . bool ( )
for _ , xlist in Unique2Index . items ( ) :
for i in xlist :
for j in xlist :
sm_matrix [ i , j ] = True
unique_ids , unique_num = [ - 1 for _ in archs ] , 0
for i in range ( len ( unique_ids ) ) :
if unique_ids [ i ] > - 1 : continue
neighbours = sm_matrix [ i ] . nonzero ( ) . view ( - 1 ) . tolist ( )
for nghb in neighbours :
assert unique_ids [ nghb ] == - 1 , ' impossible '
unique_ids [ nghb ] = unique_num
unique_num + = 1
return sm_matrix , unique_ids , unique_num
print ( ' There are {:} valid-archs ' . format ( sum ( arch . check_valid ( ) for arch in xarchs ) ) )
sm_matrix , uniqueIDs , unique_num = get_unique_matrix ( xarchs , None )
print ( ' {:} There are {:} unique architectures (considering nothing). ' . format ( time_string ( ) , unique_num ) )
sm_matrix , uniqueIDs , unique_num = get_unique_matrix ( xarchs , False )
print ( ' {:} There are {:} unique architectures (not considering zero). ' . format ( time_string ( ) , unique_num ) )
sm_matrix , uniqueIDs , unique_num = get_unique_matrix ( xarchs , True )
print ( ' {:} There are {:} unique architectures (considering zero). ' . format ( time_string ( ) , unique_num ) )
2020-01-01 12:18:42 +01:00
def check_cor_for_bandit ( meta_file , test_epoch , use_less_or_not , is_rand = True , need_print = False ) :
2019-12-31 12:02:11 +01:00
if isinstance ( meta_file , API ) :
api = meta_file
else :
api = API ( str ( meta_file ) )
2020-01-02 06:49:16 +01:00
cifar10_currs = [ ]
2019-12-31 12:02:11 +01:00
cifar10_valid = [ ]
cifar10_test = [ ]
2020-01-01 12:18:42 +01:00
cifar100_valid = [ ]
2019-12-31 12:02:11 +01:00
cifar100_test = [ ]
imagenet_test = [ ]
2020-01-01 12:18:42 +01:00
imagenet_valid = [ ]
2019-12-31 12:02:11 +01:00
for idx , arch in enumerate ( api ) :
results = api . get_more_info ( idx , ' cifar10-valid ' , test_epoch - 1 , use_less_or_not , is_rand )
2020-01-02 06:49:16 +01:00
cifar10_currs . append ( results [ ' valid-accuracy ' ] )
# --->>>>>
results = api . get_more_info ( idx , ' cifar10-valid ' , None , False , is_rand )
2019-12-31 12:02:11 +01:00
cifar10_valid . append ( results [ ' valid-accuracy ' ] )
results = api . get_more_info ( idx , ' cifar10 ' , None , False , is_rand )
cifar10_test . append ( results [ ' test-accuracy ' ] )
results = api . get_more_info ( idx , ' cifar100 ' , None , False , is_rand )
cifar100_test . append ( results [ ' test-accuracy ' ] )
2020-01-01 12:18:42 +01:00
cifar100_valid . append ( results [ ' valid-accuracy ' ] )
2019-12-31 12:02:11 +01:00
results = api . get_more_info ( idx , ' ImageNet16-120 ' , None , False , is_rand )
imagenet_test . append ( results [ ' test-accuracy ' ] )
2020-01-01 12:18:42 +01:00
imagenet_valid . append ( results [ ' valid-accuracy ' ] )
2019-12-31 12:02:11 +01:00
def get_cor ( A , B ) :
return float ( np . corrcoef ( A , B ) [ 0 , 1 ] )
cors = [ ]
2020-01-02 06:49:16 +01:00
for basestr , xlist in zip ( [ ' C-010-V ' , ' C-010-T ' , ' C-100-V ' , ' C-100-T ' , ' I16-V ' , ' I16-T ' ] , [ cifar10_valid , cifar10_test , cifar100_valid , cifar100_test , imagenet_valid , imagenet_test ] ) :
correlation = get_cor ( cifar10_currs , xlist )
2020-01-01 12:18:42 +01:00
if need_print : print ( ' With {:3d} / {:} -epochs-training, the correlation between cifar10-valid and {:} is : {:} ' . format ( test_epoch , ' 012 ' if use_less_or_not else ' 200 ' , basestr , correlation ) )
2019-12-31 12:02:11 +01:00
cors . append ( correlation )
#print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist)))
#print('-'*200)
#print('*'*230)
return cors
2020-01-01 12:18:42 +01:00
def check_cor_for_bandit_v2 ( meta_file , test_epoch , use_less_or_not , is_rand ) :
corrs = [ ]
for i in tqdm ( range ( 100 ) ) :
x = check_cor_for_bandit ( meta_file , test_epoch , use_less_or_not , is_rand , False )
corrs . append ( x )
2020-01-02 06:49:16 +01:00
#xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
xstrs = [ ' C-010-V ' , ' C-010-T ' , ' C-100-V ' , ' C-100-T ' , ' I16-V ' , ' I16-T ' ]
2020-01-01 12:18:42 +01:00
correlations = np . array ( corrs )
print ( ' ------>>>>>>>> {:03d} / {:} >>>>>>>> ------ ' . format ( test_epoch , ' 012 ' if use_less_or_not else ' 200 ' ) )
for idx , xstr in enumerate ( xstrs ) :
print ( ' {:8s} ::: mean= {:.4f} , std= {:.4f} :: {:.4f} \\ pm {:.4f} ' . format ( xstr , correlations [ : , idx ] . mean ( ) , correlations [ : , idx ] . std ( ) , correlations [ : , idx ] . mean ( ) , correlations [ : , idx ] . std ( ) ) )
print ( ' ' )
2019-12-31 12:02:11 +01:00
if __name__ == ' __main__ ' :
2020-01-14 14:52:06 +01:00
parser = argparse . ArgumentParser ( " Analysis of NAS-Bench-201 " )
parser . add_argument ( ' --save_dir ' , type = str , default = ' ./output/search-cell-nas-bench-201/visuals ' , help = ' The base-name of folder to save checkpoints and log. ' )
parser . add_argument ( ' --api_path ' , type = str , default = None , help = ' The path to the NAS-Bench-201 benchmark file. ' )
2019-12-31 12:02:11 +01:00
args = parser . parse_args ( )
vis_save_dir = Path ( args . save_dir )
vis_save_dir . mkdir ( parents = True , exist_ok = True )
meta_file = Path ( args . api_path )
assert meta_file . exists ( ) , ' invalid path for api : {:} ' . format ( meta_file )
#check_unique_arch(meta_file)
api = API ( str ( meta_file ) )
#for iepoch in [11, 25, 50, 100, 150, 175, 200]:
# check_cor_for_bandit(api, 6, iepoch)
# check_cor_for_bandit(api, 12, iepoch)
2020-01-01 12:18:42 +01:00
check_cor_for_bandit_v2 ( api , 6 , True , True )
check_cor_for_bandit_v2 ( api , 12 , True , True )
check_cor_for_bandit_v2 ( api , 12 , False , True )
check_cor_for_bandit_v2 ( api , 24 , False , True )
check_cor_for_bandit_v2 ( api , 100 , False , True )
check_cor_for_bandit_v2 ( api , 150 , False , True )
2020-01-02 06:49:16 +01:00
check_cor_for_bandit_v2 ( api , 175 , False , True )
2020-01-01 12:18:42 +01:00
check_cor_for_bandit_v2 ( api , 200 , False , True )
print ( ' ---- ' )