2020-02-23 00:30:37 +01:00
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
2019-12-31 12:02:11 +01:00
########################################################
2020-01-14 14:52:06 +01:00
# python exps/NAS-Bench-201/test-correlation.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
2019-12-31 12:02:11 +01:00
########################################################
2020-03-13 22:00:54 +01:00
import sys , argparse
2019-12-31 12:02:11 +01:00
import numpy as np
from copy import deepcopy
2020-01-01 12:18:42 +01:00
from tqdm import tqdm
2019-12-31 12:02:11 +01:00
import torch
from pathlib import Path
2021-03-17 10:25:58 +01:00
2021-05-24 05:04:18 +02:00
from xautodl . log_utils import time_string
from xautodl . models import CellStructure
2021-03-17 10:25:58 +01:00
from nas_201_api import NASBench201API as API
2019-12-31 12:02:11 +01:00
def check_unique_arch ( meta_file ) :
2021-03-17 10:25:58 +01:00
api = API ( str ( meta_file ) )
arch_strs = deepcopy ( api . meta_archs )
xarchs = [ CellStructure . str2structure ( x ) for x in arch_strs ]
2019-12-31 12:02:11 +01:00
2021-03-17 10:25:58 +01:00
def get_unique_matrix ( archs , consider_zero ) :
UniquStrs = [ arch . to_unique_str ( consider_zero ) for arch in archs ]
2021-03-18 09:02:55 +01:00
print (
" {:} create unique-string ( {:} / {:} ) done " . format (
time_string ( ) , len ( set ( UniquStrs ) ) , len ( UniquStrs )
)
)
2021-03-17 10:25:58 +01:00
Unique2Index = dict ( )
for index , xstr in enumerate ( UniquStrs ) :
if xstr not in Unique2Index :
Unique2Index [ xstr ] = list ( )
Unique2Index [ xstr ] . append ( index )
sm_matrix = torch . eye ( len ( archs ) ) . bool ( )
for _ , xlist in Unique2Index . items ( ) :
for i in xlist :
for j in xlist :
sm_matrix [ i , j ] = True
unique_ids , unique_num = [ - 1 for _ in archs ] , 0
for i in range ( len ( unique_ids ) ) :
if unique_ids [ i ] > - 1 :
continue
neighbours = sm_matrix [ i ] . nonzero ( ) . view ( - 1 ) . tolist ( )
for nghb in neighbours :
assert unique_ids [ nghb ] == - 1 , " impossible "
unique_ids [ nghb ] = unique_num
unique_num + = 1
return sm_matrix , unique_ids , unique_num
2021-03-18 09:02:55 +01:00
print (
" There are {:} valid-archs " . format ( sum ( arch . check_valid ( ) for arch in xarchs ) )
)
2021-03-17 10:25:58 +01:00
sm_matrix , uniqueIDs , unique_num = get_unique_matrix ( xarchs , None )
2021-03-18 09:02:55 +01:00
print (
" {:} There are {:} unique architectures (considering nothing). " . format (
time_string ( ) , unique_num
)
)
2021-03-17 10:25:58 +01:00
sm_matrix , uniqueIDs , unique_num = get_unique_matrix ( xarchs , False )
2021-03-18 09:02:55 +01:00
print (
" {:} There are {:} unique architectures (not considering zero). " . format (
time_string ( ) , unique_num
)
)
2021-03-17 10:25:58 +01:00
sm_matrix , uniqueIDs , unique_num = get_unique_matrix ( xarchs , True )
2021-03-18 09:02:55 +01:00
print (
" {:} There are {:} unique architectures (considering zero). " . format (
time_string ( ) , unique_num
)
)
2019-12-31 12:02:11 +01:00
2021-03-18 09:02:55 +01:00
def check_cor_for_bandit (
meta_file , test_epoch , use_less_or_not , is_rand = True , need_print = False
) :
2021-03-17 10:25:58 +01:00
if isinstance ( meta_file , API ) :
api = meta_file
else :
api = API ( str ( meta_file ) )
cifar10_currs = [ ]
cifar10_valid = [ ]
cifar10_test = [ ]
cifar100_valid = [ ]
cifar100_test = [ ]
imagenet_test = [ ]
imagenet_valid = [ ]
for idx , arch in enumerate ( api ) :
2021-03-18 09:02:55 +01:00
results = api . get_more_info (
idx , " cifar10-valid " , test_epoch - 1 , use_less_or_not , is_rand
)
2021-03-17 10:25:58 +01:00
cifar10_currs . append ( results [ " valid-accuracy " ] )
# --->>>>>
results = api . get_more_info ( idx , " cifar10-valid " , None , False , is_rand )
cifar10_valid . append ( results [ " valid-accuracy " ] )
results = api . get_more_info ( idx , " cifar10 " , None , False , is_rand )
cifar10_test . append ( results [ " test-accuracy " ] )
results = api . get_more_info ( idx , " cifar100 " , None , False , is_rand )
cifar100_test . append ( results [ " test-accuracy " ] )
cifar100_valid . append ( results [ " valid-accuracy " ] )
results = api . get_more_info ( idx , " ImageNet16-120 " , None , False , is_rand )
imagenet_test . append ( results [ " test-accuracy " ] )
imagenet_valid . append ( results [ " valid-accuracy " ] )
def get_cor ( A , B ) :
return float ( np . corrcoef ( A , B ) [ 0 , 1 ] )
cors = [ ]
for basestr , xlist in zip (
[ " C-010-V " , " C-010-T " , " C-100-V " , " C-100-T " , " I16-V " , " I16-T " ] ,
2021-03-18 09:02:55 +01:00
[
cifar10_valid ,
cifar10_test ,
cifar100_valid ,
cifar100_test ,
imagenet_valid ,
imagenet_test ,
] ,
2021-03-17 10:25:58 +01:00
) :
correlation = get_cor ( cifar10_currs , xlist )
if need_print :
print (
" With {:3d} / {:} -epochs-training, the correlation between cifar10-valid and {:} is : {:} " . format (
2021-03-18 09:02:55 +01:00
test_epoch ,
" 012 " if use_less_or_not else " 200 " ,
basestr ,
correlation ,
2021-03-17 10:25:58 +01:00
)
)
cors . append ( correlation )
# print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist)))
# print('-'*200)
# print('*'*230)
return cors
2019-12-31 12:02:11 +01:00
2020-01-01 12:18:42 +01:00
def check_cor_for_bandit_v2 ( meta_file , test_epoch , use_less_or_not , is_rand ) :
2021-03-17 10:25:58 +01:00
corrs = [ ]
for i in tqdm ( range ( 100 ) ) :
x = check_cor_for_bandit ( meta_file , test_epoch , use_less_or_not , is_rand , False )
corrs . append ( x )
# xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
xstrs = [ " C-010-V " , " C-010-T " , " C-100-V " , " C-100-T " , " I16-V " , " I16-T " ]
correlations = np . array ( corrs )
2021-03-18 09:02:55 +01:00
print (
" ------>>>>>>>> {:03d} / {:} >>>>>>>> ------ " . format (
test_epoch , " 012 " if use_less_or_not else " 200 "
)
)
2021-03-17 10:25:58 +01:00
for idx , xstr in enumerate ( xstrs ) :
print (
" {:8s} ::: mean= {:.4f} , std= {:.4f} :: {:.4f} \\ pm {:.4f} " . format (
xstr ,
correlations [ : , idx ] . mean ( ) ,
correlations [ : , idx ] . std ( ) ,
correlations [ : , idx ] . mean ( ) ,
correlations [ : , idx ] . std ( ) ,
)
)
print ( " " )
2020-01-01 12:18:42 +01:00
2021-03-17 10:25:58 +01:00
if __name__ == " __main__ " :
parser = argparse . ArgumentParser ( " Analysis of NAS-Bench-201 " )
parser . add_argument (
" --save_dir " ,
type = str ,
default = " ./output/search-cell-nas-bench-201/visuals " ,
help = " The base-name of folder to save checkpoints and log. " ,
)
2021-03-18 09:02:55 +01:00
parser . add_argument (
" --api_path " ,
type = str ,
default = None ,
help = " The path to the NAS-Bench-201 benchmark file. " ,
)
2021-03-17 10:25:58 +01:00
args = parser . parse_args ( )
2019-12-31 12:02:11 +01:00
2021-03-17 10:25:58 +01:00
vis_save_dir = Path ( args . save_dir )
vis_save_dir . mkdir ( parents = True , exist_ok = True )
meta_file = Path ( args . api_path )
assert meta_file . exists ( ) , " invalid path for api : {:} " . format ( meta_file )
2019-12-31 12:02:11 +01:00
2021-03-17 10:25:58 +01:00
# check_unique_arch(meta_file)
api = API ( str ( meta_file ) )
# for iepoch in [11, 25, 50, 100, 150, 175, 200]:
# check_cor_for_bandit(api, 6, iepoch)
# check_cor_for_bandit(api, 12, iepoch)
check_cor_for_bandit_v2 ( api , 6 , True , True )
check_cor_for_bandit_v2 ( api , 12 , True , True )
check_cor_for_bandit_v2 ( api , 12 , False , True )
check_cor_for_bandit_v2 ( api , 24 , False , True )
check_cor_for_bandit_v2 ( api , 100 , False , True )
check_cor_for_bandit_v2 ( api , 150 , False , True )
check_cor_for_bandit_v2 ( api , 175 , False , True )
check_cor_for_bandit_v2 ( api , 200 , False , True )
print ( " ---- " )