2020-02-23 00:30:37 +01:00
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
#####################################################
2020-01-14 14:52:06 +01:00
# python exps/NAS-Bench-201/visualize.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
2020-02-23 00:30:37 +01:00
#####################################################
2019-12-29 10:17:26 +01:00
import os , sys , time , argparse , collections
from tqdm import tqdm
2020-01-02 06:49:16 +01:00
from collections import OrderedDict
2019-12-29 10:17:26 +01:00
import numpy as np
import torch
from pathlib import Path
from collections import defaultdict
import matplotlib
import seaborn as sns
from mpl_toolkits . mplot3d import Axes3D
matplotlib . use ( ' agg ' )
import matplotlib . pyplot as plt
lib_dir = ( Path ( __file__ ) . parent / ' .. ' / ' .. ' / ' lib ' ) . resolve ( )
if str ( lib_dir ) not in sys . path : sys . path . insert ( 0 , str ( lib_dir ) )
from log_utils import time_string
2020-01-14 14:52:06 +01:00
from nas_201_api import NASBench201API as API
2019-12-29 10:17:26 +01:00
def calculate_correlation ( * vectors ) :
matrix = [ ]
for i , vectori in enumerate ( vectors ) :
x = [ ]
for j , vectorj in enumerate ( vectors ) :
x . append ( np . corrcoef ( vectori , vectorj ) [ 0 , 1 ] )
matrix . append ( x )
return np . array ( matrix )
def visualize_relative_ranking ( vis_save_dir ) :
print ( ' \n ' + ' - ' * 100 )
cifar010_cache_path = vis_save_dir / ' {:} -cache-info.pth ' . format ( ' cifar10 ' )
cifar100_cache_path = vis_save_dir / ' {:} -cache-info.pth ' . format ( ' cifar100 ' )
imagenet_cache_path = vis_save_dir / ' {:} -cache-info.pth ' . format ( ' ImageNet16-120 ' )
cifar010_info = torch . load ( cifar010_cache_path )
cifar100_info = torch . load ( cifar100_cache_path )
imagenet_info = torch . load ( imagenet_cache_path )
indexes = list ( range ( len ( cifar010_info [ ' params ' ] ) ) )
print ( ' {:} start to visualize relative ranking ' . format ( time_string ( ) ) )
# maximum accuracy with ResNet-level params 11472
x_010_accs = [ cifar010_info [ ' test_accs ' ] [ i ] if cifar010_info [ ' params ' ] [ i ] < = cifar010_info [ ' params ' ] [ 11472 ] else - 1 for i in indexes ]
x_100_accs = [ cifar100_info [ ' test_accs ' ] [ i ] if cifar100_info [ ' params ' ] [ i ] < = cifar100_info [ ' params ' ] [ 11472 ] else - 1 for i in indexes ]
x_img_accs = [ imagenet_info [ ' test_accs ' ] [ i ] if imagenet_info [ ' params ' ] [ i ] < = imagenet_info [ ' params ' ] [ 11472 ] else - 1 for i in indexes ]
cifar010_ord_indexes = sorted ( indexes , key = lambda i : cifar010_info [ ' test_accs ' ] [ i ] )
cifar100_ord_indexes = sorted ( indexes , key = lambda i : cifar100_info [ ' test_accs ' ] [ i ] )
imagenet_ord_indexes = sorted ( indexes , key = lambda i : imagenet_info [ ' test_accs ' ] [ i ] )
cifar100_labels , imagenet_labels = [ ] , [ ]
for idx in cifar010_ord_indexes :
cifar100_labels . append ( cifar100_ord_indexes . index ( idx ) )
imagenet_labels . append ( imagenet_ord_indexes . index ( idx ) )
print ( ' {:} prepare data done. ' . format ( time_string ( ) ) )
dpi , width , height = 300 , 2600 , 2600
figsize = width / float ( dpi ) , height / float ( dpi )
LabelSize , LegendFontsize = 18 , 18
resnet_scale , resnet_alpha = 120 , 0.5
fig = plt . figure ( figsize = figsize )
ax = fig . add_subplot ( 111 )
plt . xlim ( min ( indexes ) , max ( indexes ) )
plt . ylim ( min ( indexes ) , max ( indexes ) )
#plt.ylabel('y').set_rotation(0)
plt . yticks ( np . arange ( min ( indexes ) , max ( indexes ) , max ( indexes ) / / 6 ) , fontsize = LegendFontsize , rotation = ' vertical ' )
plt . xticks ( np . arange ( min ( indexes ) , max ( indexes ) , max ( indexes ) / / 6 ) , fontsize = LegendFontsize )
#ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8, label='CIFAR-100')
#ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8, label='ImageNet-16-120')
#ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8, label='CIFAR-10')
ax . scatter ( indexes , cifar100_labels , marker = ' ^ ' , s = 0.5 , c = ' tab:green ' , alpha = 0.8 )
ax . scatter ( indexes , imagenet_labels , marker = ' * ' , s = 0.5 , c = ' tab:red ' , alpha = 0.8 )
ax . scatter ( indexes , indexes , marker = ' o ' , s = 0.5 , c = ' tab:blue ' , alpha = 0.8 )
ax . scatter ( [ - 1 ] , [ - 1 ] , marker = ' o ' , s = 100 , c = ' tab:blue ' , label = ' CIFAR-10 ' )
ax . scatter ( [ - 1 ] , [ - 1 ] , marker = ' ^ ' , s = 100 , c = ' tab:green ' , label = ' CIFAR-100 ' )
ax . scatter ( [ - 1 ] , [ - 1 ] , marker = ' * ' , s = 100 , c = ' tab:red ' , label = ' ImageNet-16-120 ' )
plt . grid ( zorder = 0 )
ax . set_axisbelow ( True )
plt . legend ( loc = 0 , fontsize = LegendFontsize )
ax . set_xlabel ( ' architecture ranking in CIFAR-10 ' , fontsize = LabelSize )
ax . set_ylabel ( ' architecture ranking ' , fontsize = LabelSize )
save_path = ( vis_save_dir / ' relative-rank.pdf ' ) . resolve ( )
fig . savefig ( save_path , dpi = dpi , bbox_inches = ' tight ' , format = ' pdf ' )
save_path = ( vis_save_dir / ' relative-rank.png ' ) . resolve ( )
fig . savefig ( save_path , dpi = dpi , bbox_inches = ' tight ' , format = ' png ' )
print ( ' {:} save into {:} ' . format ( time_string ( ) , save_path ) )
# calculate correlation
sns_size = 15
CoRelMatrix = calculate_correlation ( cifar010_info [ ' valid_accs ' ] , cifar010_info [ ' test_accs ' ] , cifar100_info [ ' valid_accs ' ] , cifar100_info [ ' test_accs ' ] , imagenet_info [ ' valid_accs ' ] , imagenet_info [ ' test_accs ' ] )
fig = plt . figure ( figsize = figsize )
plt . axis ( ' off ' )
h = sns . heatmap ( CoRelMatrix , annot = True , annot_kws = { ' size ' : sns_size } , fmt = ' .3f ' , linewidths = 0.5 )
save_path = ( vis_save_dir / ' co-relation-all.pdf ' ) . resolve ( )
fig . savefig ( save_path , dpi = dpi , bbox_inches = ' tight ' , format = ' pdf ' )
print ( ' {:} save into {:} ' . format ( time_string ( ) , save_path ) )
# calculate correlation
acc_bars = [ 92 , 93 ]
for acc_bar in acc_bars :
selected_indexes = [ ]
for i , acc in enumerate ( cifar010_info [ ' test_accs ' ] ) :
if acc > acc_bar : selected_indexes . append ( i )
print ( ' select {:} architectures ' . format ( len ( selected_indexes ) ) )
cifar010_valid_accs = np . array ( cifar010_info [ ' valid_accs ' ] ) [ selected_indexes ]
cifar010_test_accs = np . array ( cifar010_info [ ' test_accs ' ] ) [ selected_indexes ]
cifar100_valid_accs = np . array ( cifar100_info [ ' valid_accs ' ] ) [ selected_indexes ]
cifar100_test_accs = np . array ( cifar100_info [ ' test_accs ' ] ) [ selected_indexes ]
imagenet_valid_accs = np . array ( imagenet_info [ ' valid_accs ' ] ) [ selected_indexes ]
imagenet_test_accs = np . array ( imagenet_info [ ' test_accs ' ] ) [ selected_indexes ]
CoRelMatrix = calculate_correlation ( cifar010_valid_accs , cifar010_test_accs , cifar100_valid_accs , cifar100_test_accs , imagenet_valid_accs , imagenet_test_accs )
fig = plt . figure ( figsize = figsize )
plt . axis ( ' off ' )
h = sns . heatmap ( CoRelMatrix , annot = True , annot_kws = { ' size ' : sns_size } , fmt = ' .3f ' , linewidths = 0.5 )
save_path = ( vis_save_dir / ' co-relation-top- {:} .pdf ' . format ( len ( selected_indexes ) ) ) . resolve ( )
fig . savefig ( save_path , dpi = dpi , bbox_inches = ' tight ' , format = ' pdf ' )
print ( ' {:} save into {:} ' . format ( time_string ( ) , save_path ) )
plt . close ( ' all ' )
def visualize_info ( meta_file , dataset , vis_save_dir ) :
print ( ' {:} start to visualize {:} information ' . format ( time_string ( ) , dataset ) )
cache_file_path = vis_save_dir / ' {:} -cache-info.pth ' . format ( dataset )
if not cache_file_path . exists ( ) :
print ( ' Do not find cache file : {:} ' . format ( cache_file_path ) )
nas_bench = API ( str ( meta_file ) )
params , flops , train_accs , valid_accs , test_accs , otest_accs = [ ] , [ ] , [ ] , [ ] , [ ] , [ ]
for index in range ( len ( nas_bench ) ) :
info = nas_bench . query_by_index ( index , use_12epochs_result = False )
resx = info . get_comput_costs ( dataset ) ; flop , param = resx [ ' flops ' ] , resx [ ' params ' ]
if dataset == ' cifar10 ' :
res = info . get_metrics ( ' cifar10 ' , ' train ' ) ; train_acc = res [ ' accuracy ' ]
res = info . get_metrics ( ' cifar10-valid ' , ' x-valid ' ) ; valid_acc = res [ ' accuracy ' ]
res = info . get_metrics ( ' cifar10 ' , ' ori-test ' ) ; test_acc = res [ ' accuracy ' ]
res = info . get_metrics ( ' cifar10 ' , ' ori-test ' ) ; otest_acc = res [ ' accuracy ' ]
else :
res = info . get_metrics ( dataset , ' train ' ) ; train_acc = res [ ' accuracy ' ]
res = info . get_metrics ( dataset , ' x-valid ' ) ; valid_acc = res [ ' accuracy ' ]
res = info . get_metrics ( dataset , ' x-test ' ) ; test_acc = res [ ' accuracy ' ]
res = info . get_metrics ( dataset , ' ori-test ' ) ; otest_acc = res [ ' accuracy ' ]
if index == 11472 : # resnet
resnet = { ' params ' : param , ' flops ' : flop , ' index ' : 11472 , ' train_acc ' : train_acc , ' valid_acc ' : valid_acc , ' test_acc ' : test_acc , ' otest_acc ' : otest_acc }
flops . append ( flop )
params . append ( param )
train_accs . append ( train_acc )
valid_accs . append ( valid_acc )
test_accs . append ( test_acc )
otest_accs . append ( otest_acc )
#resnet = {'params': 0.559, 'flops': 78.56, 'index': 11472, 'train_acc': 99.99, 'valid_acc': 90.84, 'test_acc': 93.97}
info = { ' params ' : params , ' flops ' : flops , ' train_accs ' : train_accs , ' valid_accs ' : valid_accs , ' test_accs ' : test_accs , ' otest_accs ' : otest_accs }
info [ ' resnet ' ] = resnet
torch . save ( info , cache_file_path )
else :
print ( ' Find cache file : {:} ' . format ( cache_file_path ) )
info = torch . load ( cache_file_path )
params , flops , train_accs , valid_accs , test_accs , otest_accs = info [ ' params ' ] , info [ ' flops ' ] , info [ ' train_accs ' ] , info [ ' valid_accs ' ] , info [ ' test_accs ' ] , info [ ' otest_accs ' ]
resnet = info [ ' resnet ' ]
print ( ' {:} collect data done. ' . format ( time_string ( ) ) )
indexes = list ( range ( len ( params ) ) )
dpi , width , height = 300 , 2600 , 2600
figsize = width / float ( dpi ) , height / float ( dpi )
LabelSize , LegendFontsize = 22 , 22
resnet_scale , resnet_alpha = 120 , 0.5
fig = plt . figure ( figsize = figsize )
ax = fig . add_subplot ( 111 )
plt . xticks ( np . arange ( 0 , 1.6 , 0.3 ) , fontsize = LegendFontsize )
if dataset == ' cifar10 ' :
plt . ylim ( 50 , 100 )
plt . yticks ( np . arange ( 50 , 101 , 10 ) , fontsize = LegendFontsize )
elif dataset == ' cifar100 ' :
plt . ylim ( 25 , 75 )
plt . yticks ( np . arange ( 25 , 76 , 10 ) , fontsize = LegendFontsize )
else :
plt . ylim ( 0 , 50 )
plt . yticks ( np . arange ( 0 , 51 , 10 ) , fontsize = LegendFontsize )
ax . scatter ( params , valid_accs , marker = ' o ' , s = 0.5 , c = ' tab:blue ' )
ax . scatter ( [ resnet [ ' params ' ] ] , [ resnet [ ' valid_acc ' ] ] , marker = ' * ' , s = resnet_scale , c = ' tab:orange ' , label = ' resnet ' , alpha = 0.4 )
plt . grid ( zorder = 0 )
ax . set_axisbelow ( True )
plt . legend ( loc = 4 , fontsize = LegendFontsize )
ax . set_xlabel ( ' #parameters (MB) ' , fontsize = LabelSize )
ax . set_ylabel ( ' the validation accuracy ( % ) ' , fontsize = LabelSize )
save_path = ( vis_save_dir / ' {:} -param-vs-valid.pdf ' . format ( dataset ) ) . resolve ( )
fig . savefig ( save_path , dpi = dpi , bbox_inches = ' tight ' , format = ' pdf ' )
save_path = ( vis_save_dir / ' {:} -param-vs-valid.png ' . format ( dataset ) ) . resolve ( )
fig . savefig ( save_path , dpi = dpi , bbox_inches = ' tight ' , format = ' png ' )
print ( ' {:} save into {:} ' . format ( time_string ( ) , save_path ) )
fig = plt . figure ( figsize = figsize )
ax = fig . add_subplot ( 111 )
plt . xticks ( np . arange ( 0 , 1.6 , 0.3 ) , fontsize = LegendFontsize )
if dataset == ' cifar10 ' :
plt . ylim ( 50 , 100 )
plt . yticks ( np . arange ( 50 , 101 , 10 ) , fontsize = LegendFontsize )
elif dataset == ' cifar100 ' :
plt . ylim ( 25 , 75 )
plt . yticks ( np . arange ( 25 , 76 , 10 ) , fontsize = LegendFontsize )
else :
plt . ylim ( 0 , 50 )
plt . yticks ( np . arange ( 0 , 51 , 10 ) , fontsize = LegendFontsize )
ax . scatter ( params , test_accs , marker = ' o ' , s = 0.5 , c = ' tab:blue ' )
ax . scatter ( [ resnet [ ' params ' ] ] , [ resnet [ ' test_acc ' ] ] , marker = ' * ' , s = resnet_scale , c = ' tab:orange ' , label = ' resnet ' , alpha = resnet_alpha )
plt . grid ( )
ax . set_axisbelow ( True )
plt . legend ( loc = 4 , fontsize = LegendFontsize )
ax . set_xlabel ( ' #parameters (MB) ' , fontsize = LabelSize )
ax . set_ylabel ( ' the test accuracy ( % ) ' , fontsize = LabelSize )
save_path = ( vis_save_dir / ' {:} -param-vs-test.pdf ' . format ( dataset ) ) . resolve ( )
fig . savefig ( save_path , dpi = dpi , bbox_inches = ' tight ' , format = ' pdf ' )
save_path = ( vis_save_dir / ' {:} -param-vs-test.png ' . format ( dataset ) ) . resolve ( )
fig . savefig ( save_path , dpi = dpi , bbox_inches = ' tight ' , format = ' png ' )
print ( ' {:} save into {:} ' . format ( time_string ( ) , save_path ) )
fig = plt . figure ( figsize = figsize )
ax = fig . add_subplot ( 111 )
plt . xticks ( np . arange ( 0 , 1.6 , 0.3 ) , fontsize = LegendFontsize )
if dataset == ' cifar10 ' :
plt . ylim ( 50 , 100 )
plt . yticks ( np . arange ( 50 , 101 , 10 ) , fontsize = LegendFontsize )
elif dataset == ' cifar100 ' :
plt . ylim ( 20 , 100 )
plt . yticks ( np . arange ( 20 , 101 , 10 ) , fontsize = LegendFontsize )
else :
plt . ylim ( 25 , 76 )
plt . yticks ( np . arange ( 25 , 76 , 10 ) , fontsize = LegendFontsize )
ax . scatter ( params , train_accs , marker = ' o ' , s = 0.5 , c = ' tab:blue ' )
ax . scatter ( [ resnet [ ' params ' ] ] , [ resnet [ ' train_acc ' ] ] , marker = ' * ' , s = resnet_scale , c = ' tab:orange ' , label = ' resnet ' , alpha = resnet_alpha )
plt . grid ( )
ax . set_axisbelow ( True )
plt . legend ( loc = 4 , fontsize = LegendFontsize )
ax . set_xlabel ( ' #parameters (MB) ' , fontsize = LabelSize )
ax . set_ylabel ( ' the trarining accuracy ( % ) ' , fontsize = LabelSize )
save_path = ( vis_save_dir / ' {:} -param-vs-train.pdf ' . format ( dataset ) ) . resolve ( )
fig . savefig ( save_path , dpi = dpi , bbox_inches = ' tight ' , format = ' pdf ' )
save_path = ( vis_save_dir / ' {:} -param-vs-train.png ' . format ( dataset ) ) . resolve ( )
fig . savefig ( save_path , dpi = dpi , bbox_inches = ' tight ' , format = ' png ' )
print ( ' {:} save into {:} ' . format ( time_string ( ) , save_path ) )
fig = plt . figure ( figsize = figsize )
ax = fig . add_subplot ( 111 )
plt . xlim ( 0 , max ( indexes ) )
plt . xticks ( np . arange ( min ( indexes ) , max ( indexes ) , max ( indexes ) / / 5 ) , fontsize = LegendFontsize )
if dataset == ' cifar10 ' :
plt . ylim ( 50 , 100 )
plt . yticks ( np . arange ( 50 , 101 , 10 ) , fontsize = LegendFontsize )
elif dataset == ' cifar100 ' :
plt . ylim ( 25 , 75 )
plt . yticks ( np . arange ( 25 , 76 , 10 ) , fontsize = LegendFontsize )
else :
plt . ylim ( 0 , 50 )
plt . yticks ( np . arange ( 0 , 51 , 10 ) , fontsize = LegendFontsize )
ax . scatter ( indexes , test_accs , marker = ' o ' , s = 0.5 , c = ' tab:blue ' )
ax . scatter ( [ resnet [ ' index ' ] ] , [ resnet [ ' test_acc ' ] ] , marker = ' * ' , s = resnet_scale , c = ' tab:orange ' , label = ' resnet ' , alpha = resnet_alpha )
plt . grid ( )
ax . set_axisbelow ( True )
plt . legend ( loc = 4 , fontsize = LegendFontsize )
ax . set_xlabel ( ' architecture ID ' , fontsize = LabelSize )
ax . set_ylabel ( ' the test accuracy ( % ) ' , fontsize = LabelSize )
save_path = ( vis_save_dir / ' {:} -test-over-ID.pdf ' . format ( dataset ) ) . resolve ( )
fig . savefig ( save_path , dpi = dpi , bbox_inches = ' tight ' , format = ' pdf ' )
save_path = ( vis_save_dir / ' {:} -test-over-ID.png ' . format ( dataset ) ) . resolve ( )
fig . savefig ( save_path , dpi = dpi , bbox_inches = ' tight ' , format = ' png ' )
print ( ' {:} save into {:} ' . format ( time_string ( ) , save_path ) )
plt . close ( ' all ' )
def visualize_rank_over_time ( meta_file , vis_save_dir ) :
print ( ' \n ' + ' - ' * 150 )
vis_save_dir . mkdir ( parents = True , exist_ok = True )
print ( ' {:} start to visualize rank-over-time into {:} ' . format ( time_string ( ) , vis_save_dir ) )
cache_file_path = vis_save_dir / ' rank-over-time-cache-info.pth '
if not cache_file_path . exists ( ) :
print ( ' Do not find cache file : {:} ' . format ( cache_file_path ) )
nas_bench = API ( str ( meta_file ) )
print ( ' {:} load nas_bench done ' . format ( time_string ( ) ) )
params , flops , train_accs , valid_accs , test_accs , otest_accs = [ ] , [ ] , defaultdict ( list ) , defaultdict ( list ) , defaultdict ( list ) , defaultdict ( list )
#for iepoch in range(200): for index in range( len(nas_bench) ):
for index in tqdm ( range ( len ( nas_bench ) ) ) :
info = nas_bench . query_by_index ( index , use_12epochs_result = False )
for iepoch in range ( 200 ) :
res = info . get_metrics ( ' cifar10 ' , ' train ' , iepoch ) ; train_acc = res [ ' accuracy ' ]
res = info . get_metrics ( ' cifar10-valid ' , ' x-valid ' , iepoch ) ; valid_acc = res [ ' accuracy ' ]
res = info . get_metrics ( ' cifar10 ' , ' ori-test ' , iepoch ) ; test_acc = res [ ' accuracy ' ]
res = info . get_metrics ( ' cifar10 ' , ' ori-test ' , iepoch ) ; otest_acc = res [ ' accuracy ' ]
train_accs [ iepoch ] . append ( train_acc )
valid_accs [ iepoch ] . append ( valid_acc )
test_accs [ iepoch ] . append ( test_acc )
otest_accs [ iepoch ] . append ( otest_acc )
if iepoch == 0 :
res = info . get_comput_costs ( ' cifar10 ' ) ; flop , param = res [ ' flops ' ] , res [ ' params ' ]
flops . append ( flop )
params . append ( param )
info = { ' params ' : params , ' flops ' : flops , ' train_accs ' : train_accs , ' valid_accs ' : valid_accs , ' test_accs ' : test_accs , ' otest_accs ' : otest_accs }
torch . save ( info , cache_file_path )
else :
print ( ' Find cache file : {:} ' . format ( cache_file_path ) )
info = torch . load ( cache_file_path )
params , flops , train_accs , valid_accs , test_accs , otest_accs = info [ ' params ' ] , info [ ' flops ' ] , info [ ' train_accs ' ] , info [ ' valid_accs ' ] , info [ ' test_accs ' ] , info [ ' otest_accs ' ]
print ( ' {:} collect data done. ' . format ( time_string ( ) ) )
#selected_epochs = [0, 100, 150, 180, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]
selected_epochs = list ( range ( 200 ) )
x_xtests = test_accs [ 199 ]
indexes = list ( range ( len ( x_xtests ) ) )
ord_idxs = sorted ( indexes , key = lambda i : x_xtests [ i ] )
for sepoch in selected_epochs :
x_valids = valid_accs [ sepoch ]
valid_ord_idxs = sorted ( indexes , key = lambda i : x_valids [ i ] )
valid_ord_lbls = [ ]
for idx in ord_idxs :
valid_ord_lbls . append ( valid_ord_idxs . index ( idx ) )
# labeled data
dpi , width , height = 300 , 2600 , 2600
figsize = width / float ( dpi ) , height / float ( dpi )
LabelSize , LegendFontsize = 18 , 18
fig = plt . figure ( figsize = figsize )
ax = fig . add_subplot ( 111 )
plt . xlim ( min ( indexes ) , max ( indexes ) )
plt . ylim ( min ( indexes ) , max ( indexes ) )
plt . yticks ( np . arange ( min ( indexes ) , max ( indexes ) , max ( indexes ) / / 6 ) , fontsize = LegendFontsize , rotation = ' vertical ' )
plt . xticks ( np . arange ( min ( indexes ) , max ( indexes ) , max ( indexes ) / / 6 ) , fontsize = LegendFontsize )
ax . scatter ( indexes , valid_ord_lbls , marker = ' ^ ' , s = 0.5 , c = ' tab:green ' , alpha = 0.8 )
ax . scatter ( indexes , indexes , marker = ' o ' , s = 0.5 , c = ' tab:blue ' , alpha = 0.8 )
ax . scatter ( [ - 1 ] , [ - 1 ] , marker = ' ^ ' , s = 100 , c = ' tab:green ' , label = ' CIFAR-10 validation ' )
ax . scatter ( [ - 1 ] , [ - 1 ] , marker = ' o ' , s = 100 , c = ' tab:blue ' , label = ' CIFAR-10 test ' )
plt . grid ( zorder = 0 )
ax . set_axisbelow ( True )
plt . legend ( loc = ' upper left ' , fontsize = LegendFontsize )
ax . set_xlabel ( ' architecture ranking in the final test accuracy ' , fontsize = LabelSize )
ax . set_ylabel ( ' architecture ranking in the validation set ' , fontsize = LabelSize )
save_path = ( vis_save_dir / ' time- {:03d} .pdf ' . format ( sepoch ) ) . resolve ( )
fig . savefig ( save_path , dpi = dpi , bbox_inches = ' tight ' , format = ' pdf ' )
save_path = ( vis_save_dir / ' time- {:03d} .png ' . format ( sepoch ) ) . resolve ( )
fig . savefig ( save_path , dpi = dpi , bbox_inches = ' tight ' , format = ' png ' )
print ( ' {:} save into {:} ' . format ( time_string ( ) , save_path ) )
plt . close ( ' all ' )
def write_video ( save_dir ) :
import cv2
video_save_path = save_dir / ' time.avi '
print ( ' {:} start create video for {:} ' . format ( time_string ( ) , video_save_path ) )
images = sorted ( list ( save_dir . glob ( ' time-*.png ' ) ) )
ximage = cv2 . imread ( str ( images [ 0 ] ) )
#shape = (ximage.shape[1], ximage.shape[0])
shape = ( 1000 , 1000 )
#writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 25, shape)
writer = cv2 . VideoWriter ( str ( video_save_path ) , cv2 . VideoWriter_fourcc ( * " MJPG " ) , 5 , shape )
for idx , image in enumerate ( images ) :
ximage = cv2 . imread ( str ( image ) )
_image = cv2 . resize ( ximage , shape )
writer . write ( _image )
writer . release ( )
print ( ' write video [ {:} frames] into {:} ' . format ( len ( images ) , video_save_path ) )
2020-01-14 14:52:06 +01:00
def plot_results_nas_v2 ( api , dataset_xset_a , dataset_xset_b , root , file_name , y_lims ) :
#print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset))
print ( ' root-path : {:} and {:} ' . format ( dataset_xset_a , dataset_xset_b ) )
checkpoints = [ ' ./output/search-cell-nas-bench-201/R-EA-cifar10/results.pth ' ,
' ./output/search-cell-nas-bench-201/REINFORCE-cifar10/results.pth ' ,
' ./output/search-cell-nas-bench-201/RAND-cifar10/results.pth ' ,
' ./output/search-cell-nas-bench-201/BOHB-cifar10/results.pth '
]
legends , indexes = [ ' REA ' , ' REINFORCE ' , ' RANDOM ' , ' BOHB ' ] , None
All_Accs_A , All_Accs_B = OrderedDict ( ) , OrderedDict ( )
for legend , checkpoint in zip ( legends , checkpoints ) :
all_indexes = torch . load ( checkpoint , map_location = ' cpu ' )
accuracies_A , accuracies_B = [ ] , [ ]
accuracies = [ ]
for x in all_indexes :
info = api . arch2infos_full [ x ]
metrics = info . get_metrics ( dataset_xset_a [ 0 ] , dataset_xset_a [ 1 ] , None , False )
accuracies_A . append ( metrics [ ' accuracy ' ] )
metrics = info . get_metrics ( dataset_xset_b [ 0 ] , dataset_xset_b [ 1 ] , None , False )
accuracies_B . append ( metrics [ ' accuracy ' ] )
accuracies . append ( ( accuracies_A [ - 1 ] , accuracies_B [ - 1 ] ) )
if indexes is None : indexes = list ( range ( len ( all_indexes ) ) )
accuracies = sorted ( accuracies )
All_Accs_A [ legend ] = [ x [ 0 ] for x in accuracies ]
All_Accs_B [ legend ] = [ x [ 1 ] for x in accuracies ]
color_set = [ ' r ' , ' b ' , ' g ' , ' c ' , ' m ' , ' y ' , ' k ' ]
dpi , width , height = 300 , 3400 , 2600
LabelSize , LegendFontsize = 28 , 28
figsize = width / float ( dpi ) , height / float ( dpi )
fig = plt . figure ( figsize = figsize )
x_axis = np . arange ( 0 , 600 )
plt . xlim ( 0 , max ( indexes ) )
plt . ylim ( y_lims [ 0 ] , y_lims [ 1 ] )
interval_x , interval_y = 100 , y_lims [ 2 ]
plt . xticks ( np . arange ( 0 , max ( indexes ) , interval_x ) , fontsize = LegendFontsize )
plt . yticks ( np . arange ( y_lims [ 0 ] , y_lims [ 1 ] , interval_y ) , fontsize = LegendFontsize )
plt . grid ( )
plt . xlabel ( ' The index of runs ' , fontsize = LabelSize )
plt . ylabel ( ' The accuracy ( % ) ' , fontsize = LabelSize )
for idx , legend in enumerate ( legends ) :
plt . plot ( indexes , All_Accs_B [ legend ] , color = color_set [ idx ] , linestyle = ' -- ' , label = ' {:} ' . format ( legend ) , lw = 1 , alpha = 0.5 )
plt . plot ( indexes , All_Accs_A [ legend ] , color = color_set [ idx ] , linestyle = ' - ' , lw = 1 )
for All_Accs in [ All_Accs_A , All_Accs_B ] :
print ( ' {:} : mean = {:} , std = {:} :: {:.2f} $ \\ pm$ {:.2f} ' . format ( legend , np . mean ( All_Accs [ legend ] ) , np . std ( All_Accs [ legend ] ) , np . mean ( All_Accs [ legend ] ) , np . std ( All_Accs [ legend ] ) ) )
plt . legend ( loc = 4 , fontsize = LegendFontsize )
save_path = root / ' {:} ' . format ( file_name )
print ( ' save figure into {:} \n ' . format ( save_path ) )
fig . savefig ( str ( save_path ) , dpi = dpi , bbox_inches = ' tight ' , format = ' pdf ' )
2020-01-01 12:51:00 +01:00
def plot_results_nas ( api , dataset , xset , root , file_name , y_lims ) :
print ( ' root-path= {:} , dataset= {:} , xset= {:} ' . format ( root , dataset , xset ) )
2020-01-14 14:52:06 +01:00
checkpoints = [ ' ./output/search-cell-nas-bench-201/R-EA-cifar10/results.pth ' ,
' ./output/search-cell-nas-bench-201/REINFORCE-cifar10/results.pth ' ,
' ./output/search-cell-nas-bench-201/RAND-cifar10/results.pth ' ,
' ./output/search-cell-nas-bench-201/BOHB-cifar10/results.pth '
2020-01-01 12:51:00 +01:00
]
legends , indexes = [ ' REA ' , ' REINFORCE ' , ' RANDOM ' , ' BOHB ' ] , None
All_Accs = OrderedDict ( )
for legend , checkpoint in zip ( legends , checkpoints ) :
all_indexes = torch . load ( checkpoint , map_location = ' cpu ' )
accuracies = [ ]
for x in all_indexes :
info = api . arch2infos_full [ x ]
metrics = info . get_metrics ( dataset , xset , None , False )
accuracies . append ( metrics [ ' accuracy ' ] )
if indexes is None : indexes = list ( range ( len ( all_indexes ) ) )
All_Accs [ legend ] = sorted ( accuracies )
color_set = [ ' r ' , ' b ' , ' g ' , ' c ' , ' m ' , ' y ' , ' k ' ]
dpi , width , height = 300 , 3400 , 2600
LabelSize , LegendFontsize = 28 , 28
figsize = width / float ( dpi ) , height / float ( dpi )
fig = plt . figure ( figsize = figsize )
x_axis = np . arange ( 0 , 600 )
plt . xlim ( 0 , max ( indexes ) )
plt . ylim ( y_lims [ 0 ] , y_lims [ 1 ] )
interval_x , interval_y = 100 , y_lims [ 2 ]
plt . xticks ( np . arange ( 0 , max ( indexes ) , interval_x ) , fontsize = LegendFontsize )
plt . yticks ( np . arange ( y_lims [ 0 ] , y_lims [ 1 ] , interval_y ) , fontsize = LegendFontsize )
plt . grid ( )
plt . xlabel ( ' The index of runs ' , fontsize = LabelSize )
plt . ylabel ( ' The accuracy ( % ) ' , fontsize = LabelSize )
for idx , legend in enumerate ( legends ) :
plt . plot ( indexes , All_Accs [ legend ] , color = color_set [ idx ] , linestyle = ' - ' , label = ' {:} ' . format ( legend ) , lw = 2 )
print ( ' {:} : mean = {:} , std = {:} :: {:.2f} $ \\ pm$ {:.2f} ' . format ( legend , np . mean ( All_Accs [ legend ] ) , np . std ( All_Accs [ legend ] ) , np . mean ( All_Accs [ legend ] ) , np . std ( All_Accs [ legend ] ) ) )
plt . legend ( loc = 4 , fontsize = LegendFontsize )
save_path = root / ' {:} - {:} - {:} ' . format ( dataset , xset , file_name )
print ( ' save figure into {:} \n ' . format ( save_path ) )
fig . savefig ( str ( save_path ) , dpi = dpi , bbox_inches = ' tight ' , format = ' pdf ' )
def just_show ( api ) :
2020-01-02 06:49:16 +01:00
xtimes = { ' RSPS ' : [ 8082.5 , 7794.2 , 8144.7 ] ,
2020-01-01 12:51:00 +01:00
' DARTS-V1 ' : [ 11582.1 , 11347.0 , 11948.2 ] ,
' DARTS-V2 ' : [ 35694.7 , 36132.7 , 35518.0 ] ,
' GDAS ' : [ 31334.1 , 31478.6 , 32016.7 ] ,
' SETN ' : [ 33528.8 , 33831.5 , 35058.3 ] ,
' ENAS ' : [ 14340.2 , 13817.3 , 14018.9 ] }
for xkey , xlist in xtimes . items ( ) :
xlist = np . array ( xlist )
2020-01-02 06:49:16 +01:00
print ( ' {:4s} : mean-time= {:.2f} s ' . format ( xkey , xlist . mean ( ) ) )
2020-01-01 12:51:00 +01:00
2020-01-14 14:52:06 +01:00
xpaths = { ' RSPS ' : ' output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/ ' ,
' DARTS-V1 ' : ' output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/ ' ,
' DARTS-V2 ' : ' output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/ ' ,
' GDAS ' : ' output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/ ' ,
' SETN ' : ' output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/ ' ,
' ENAS ' : ' output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/ ' ,
2020-01-01 12:51:00 +01:00
}
xseeds = { ' RSPS ' : [ 5349 , 59613 , 5983 ] ,
' DARTS-V1 ' : [ 11416 , 72873 , 81184 ] ,
' DARTS-V2 ' : [ 43330 , 79405 , 79423 ] ,
' GDAS ' : [ 19677 , 884 , 95950 ] ,
' SETN ' : [ 20518 , 61817 , 89144 ] ,
2020-01-14 14:52:06 +01:00
' ENAS ' : [ 3231 , 34238 , 96929 ] ,
2020-01-01 12:51:00 +01:00
}
def get_accs ( xdata , index = - 1 ) :
if index == - 1 :
epochs = xdata [ ' epoch ' ]
genotype = xdata [ ' genotypes ' ] [ epochs - 1 ]
index = api . query_index_by_arch ( genotype )
pairs = [ ( ' cifar10-valid ' , ' x-valid ' ) , ( ' cifar10 ' , ' ori-test ' ) , ( ' cifar100 ' , ' x-valid ' ) , ( ' cifar100 ' , ' x-test ' ) , ( ' ImageNet16-120 ' , ' x-valid ' ) , ( ' ImageNet16-120 ' , ' x-test ' ) ]
xresults = [ ]
for dataset , xset in pairs :
metrics = api . arch2infos_full [ index ] . get_metrics ( dataset , xset , None , False )
xresults . append ( metrics [ ' accuracy ' ] )
return xresults
for xkey in xpaths . keys ( ) :
all_paths = [ ' {:} /seed- {:} -basic.pth ' . format ( xpaths [ xkey ] , seed ) for seed in xseeds [ xkey ] ]
all_datas = [ torch . load ( xpath ) for xpath in all_paths ]
accyss = [ get_accs ( xdatas ) for xdatas in all_datas ]
accyss = np . array ( accyss )
print ( ' \n xkey = {:} ' . format ( xkey ) )
for i in range ( accyss . shape [ 1 ] ) : print ( ' ---->>>> {:.2f} $ \\ pm$ {:.2f} ' . format ( accyss [ : , i ] . mean ( ) , accyss [ : , i ] . std ( ) ) )
print ( ' \n {:} ' . format ( get_accs ( None , 11472 ) ) ) # resnet
pairs = [ ( ' cifar10-valid ' , ' x-valid ' ) , ( ' cifar10 ' , ' ori-test ' ) , ( ' cifar100 ' , ' x-valid ' ) , ( ' cifar100 ' , ' x-test ' ) , ( ' ImageNet16-120 ' , ' x-valid ' ) , ( ' ImageNet16-120 ' , ' x-test ' ) ]
for dataset , metric_on_set in pairs :
arch_index , highest_acc = api . find_best ( dataset , metric_on_set )
print ( ' [ {:10s} - {:10s} ::: index= {:5d} , accuracy= {:.2f} ' . format ( dataset , metric_on_set , arch_index , highest_acc ) )
2020-01-15 15:43:07 +01:00
def show_nas_sharing_w ( api , dataset , subset , vis_save_dir , sufix , file_name , y_lims , x_maxs ) :
2020-01-01 12:51:00 +01:00
color_set = [ ' r ' , ' b ' , ' g ' , ' c ' , ' m ' , ' y ' , ' k ' ]
dpi , width , height = 300 , 3400 , 2600
LabelSize , LegendFontsize = 28 , 28
figsize = width / float ( dpi ) , height / float ( dpi )
fig = plt . figure ( figsize = figsize )
2020-01-02 04:35:58 +01:00
#x_maxs = 250
plt . xlim ( 0 , x_maxs + 1 )
2020-01-01 12:51:00 +01:00
plt . ylim ( y_lims [ 0 ] , y_lims [ 1 ] )
interval_x , interval_y = x_maxs / / 5 , y_lims [ 2 ]
2020-01-02 04:35:58 +01:00
plt . xticks ( np . arange ( 0 , x_maxs + 1 , interval_x ) , fontsize = LegendFontsize )
2020-01-01 12:51:00 +01:00
plt . yticks ( np . arange ( y_lims [ 0 ] , y_lims [ 1 ] , interval_y ) , fontsize = LegendFontsize )
plt . grid ( )
plt . xlabel ( ' The searching epoch ' , fontsize = LabelSize )
plt . ylabel ( ' The accuracy ( % ) ' , fontsize = LabelSize )
2020-01-15 15:43:07 +01:00
xpaths = { ' RSPS ' : ' output/search-cell-nas-bench-201/RANDOM-NAS-cifar10- {:} /checkpoint/ ' . format ( sufix ) ,
' DARTS-V1 ' : ' output/search-cell-nas-bench-201/DARTS-V1-cifar10- {:} /checkpoint/ ' . format ( sufix ) ,
' DARTS-V2 ' : ' output/search-cell-nas-bench-201/DARTS-V2-cifar10- {:} /checkpoint/ ' . format ( sufix ) ,
' GDAS ' : ' output/search-cell-nas-bench-201/GDAS-cifar10- {:} /checkpoint/ ' . format ( sufix ) ,
' SETN ' : ' output/search-cell-nas-bench-201/SETN-cifar10- {:} /checkpoint/ ' . format ( sufix ) ,
' ENAS ' : ' output/search-cell-nas-bench-201/ENAS-cifar10- {:} /checkpoint/ ' . format ( sufix ) ,
2020-01-01 12:51:00 +01:00
}
2020-01-15 15:43:07 +01:00
"""
2020-01-01 12:51:00 +01:00
xseeds = { ' RSPS ' : [ 5349 , 59613 , 5983 ] ,
2020-01-14 14:52:06 +01:00
' DARTS-V1 ' : [ 11416 , 72873 , 81184 , 28640 ] ,
2020-01-01 12:51:00 +01:00
' DARTS-V2 ' : [ 43330 , 79405 , 79423 ] ,
' GDAS ' : [ 19677 , 884 , 95950 ] ,
' SETN ' : [ 20518 , 61817 , 89144 ] ,
2020-01-14 14:52:06 +01:00
' ENAS ' : [ 3231 , 34238 , 96929 ] ,
2020-01-01 12:51:00 +01:00
}
2020-01-15 15:43:07 +01:00
"""
xseeds = { ' RSPS ' : [ 23814 , 28015 , 95809 ] ,
' DARTS-V1 ' : [ 48349 , 80877 , 81920 ] ,
' DARTS-V2 ' : [ 61712 , 7941 , 87041 ] ,
' GDAS ' : [ 72818 , 72996 , 78877 ] ,
' SETN ' : [ 26985 , 55206 , 95404 ] ,
' ENAS ' : [ 21792 , 36605 , 45029 ]
}
2020-01-01 12:51:00 +01:00
def get_accs ( xdata ) :
epochs , xresults = xdata [ ' epoch ' ] , [ ]
2020-01-14 14:52:06 +01:00
if - 1 in xdata [ ' genotypes ' ] :
metrics = api . arch2infos_full [ api . query_index_by_arch ( xdata [ ' genotypes ' ] [ - 1 ] ) ] . get_metrics ( dataset , subset , None , False )
else :
metrics = api . arch2infos_full [ api . random ( ) ] . get_metrics ( dataset , subset , None , False )
2020-01-09 12:26:23 +01:00
xresults . append ( metrics [ ' accuracy ' ] )
2020-01-01 12:51:00 +01:00
for iepoch in range ( epochs ) :
genotype = xdata [ ' genotypes ' ] [ iepoch ]
index = api . query_index_by_arch ( genotype )
metrics = api . arch2infos_full [ index ] . get_metrics ( dataset , subset , None , False )
xresults . append ( metrics [ ' accuracy ' ] )
return xresults
2020-01-02 04:35:58 +01:00
if x_maxs == 50 :
xox , xxxstrs = ' v2 ' , [ ' DARTS-V1 ' , ' DARTS-V2 ' ]
elif x_maxs == 250 :
xox , xxxstrs = ' v1 ' , [ ' RSPS ' , ' GDAS ' , ' SETN ' , ' ENAS ' ]
else : raise ValueError ( ' invalid x_maxs= {:} ' . format ( x_maxs ) )
for idx , method in enumerate ( xxxstrs ) :
2020-01-01 12:51:00 +01:00
xkey = method
all_paths = [ ' {:} /seed- {:} -basic.pth ' . format ( xpaths [ xkey ] , seed ) for seed in xseeds [ xkey ] ]
2020-01-02 04:35:58 +01:00
all_datas = [ torch . load ( xpath , map_location = ' cpu ' ) for xpath in all_paths ]
2020-01-01 12:51:00 +01:00
accyss = [ get_accs ( xdatas ) for xdatas in all_datas ]
accyss = np . array ( accyss )
epochs = list ( range ( accyss . shape [ 1 ] ) )
plt . plot ( epochs , [ accyss [ : , i ] . mean ( ) for i in epochs ] , color = color_set [ idx ] , linestyle = ' - ' , label = ' {:} ' . format ( method ) , lw = 2 )
plt . fill_between ( epochs , [ accyss [ : , i ] . mean ( ) - accyss [ : , i ] . std ( ) for i in epochs ] , [ accyss [ : , i ] . mean ( ) + accyss [ : , i ] . std ( ) for i in epochs ] , alpha = 0.2 , color = color_set [ idx ] )
2020-01-02 04:35:58 +01:00
#plt.legend(loc=4, fontsize=LegendFontsize)
plt . legend ( loc = 0 , fontsize = LegendFontsize )
2020-01-15 15:43:07 +01:00
save_path = vis_save_dir / ' {:} .pdf ' . format ( file_name )
2020-01-01 12:51:00 +01:00
print ( ' save figure into {:} \n ' . format ( save_path ) )
fig . savefig ( str ( save_path ) , dpi = dpi , bbox_inches = ' tight ' , format = ' pdf ' )
2020-01-15 15:43:07 +01:00
def show_nas_sharing_w_v2 ( api , data_sub_a , data_sub_b , vis_save_dir , sufix , file_name , y_lims , x_maxs ) :
2020-01-14 14:52:06 +01:00
color_set = [ ' r ' , ' b ' , ' g ' , ' c ' , ' m ' , ' y ' , ' k ' ]
dpi , width , height = 300 , 3400 , 2600
LabelSize , LegendFontsize = 28 , 28
figsize = width / float ( dpi ) , height / float ( dpi )
fig = plt . figure ( figsize = figsize )
#x_maxs = 250
plt . xlim ( 0 , x_maxs + 1 )
plt . ylim ( y_lims [ 0 ] , y_lims [ 1 ] )
interval_x , interval_y = x_maxs / / 5 , y_lims [ 2 ]
plt . xticks ( np . arange ( 0 , x_maxs + 1 , interval_x ) , fontsize = LegendFontsize )
plt . yticks ( np . arange ( y_lims [ 0 ] , y_lims [ 1 ] , interval_y ) , fontsize = LegendFontsize )
plt . grid ( )
plt . xlabel ( ' The searching epoch ' , fontsize = LabelSize )
plt . ylabel ( ' The accuracy ( % ) ' , fontsize = LabelSize )
2020-01-15 15:43:07 +01:00
xpaths = { ' RSPS ' : ' output/search-cell-nas-bench-201/RANDOM-NAS-cifar10- {:} /checkpoint/ ' . format ( sufix ) ,
' DARTS-V1 ' : ' output/search-cell-nas-bench-201/DARTS-V1-cifar10- {:} /checkpoint/ ' . format ( sufix ) ,
' DARTS-V2 ' : ' output/search-cell-nas-bench-201/DARTS-V2-cifar10- {:} /checkpoint/ ' . format ( sufix ) ,
' GDAS ' : ' output/search-cell-nas-bench-201/GDAS-cifar10- {:} /checkpoint/ ' . format ( sufix ) ,
' SETN ' : ' output/search-cell-nas-bench-201/SETN-cifar10- {:} /checkpoint/ ' . format ( sufix ) ,
' ENAS ' : ' output/search-cell-nas-bench-201/ENAS-cifar10- {:} /checkpoint/ ' . format ( sufix ) ,
2020-01-14 14:52:06 +01:00
}
2020-01-15 15:43:07 +01:00
"""
2020-01-14 14:52:06 +01:00
xseeds = { ' RSPS ' : [ 5349 , 59613 , 5983 ] ,
' DARTS-V1 ' : [ 11416 , 72873 , 81184 , 28640 ] ,
' DARTS-V2 ' : [ 43330 , 79405 , 79423 ] ,
' GDAS ' : [ 19677 , 884 , 95950 ] ,
' SETN ' : [ 20518 , 61817 , 89144 ] ,
' ENAS ' : [ 3231 , 34238 , 96929 ] ,
}
2020-01-15 15:43:07 +01:00
"""
xseeds = { ' RSPS ' : [ 23814 , 28015 , 95809 ] ,
' DARTS-V1 ' : [ 48349 , 80877 , 81920 ] ,
' DARTS-V2 ' : [ 61712 , 7941 , 87041 ] ,
' GDAS ' : [ 72818 , 72996 , 78877 ] ,
' SETN ' : [ 26985 , 55206 , 95404 ] ,
' ENAS ' : [ 21792 , 36605 , 45029 ]
}
2020-01-14 14:52:06 +01:00
def get_accs ( xdata , dataset , subset ) :
epochs , xresults = xdata [ ' epoch ' ] , [ ]
if - 1 in xdata [ ' genotypes ' ] :
metrics = api . arch2infos_full [ api . query_index_by_arch ( xdata [ ' genotypes ' ] [ - 1 ] ) ] . get_metrics ( dataset , subset , None , False )
else :
metrics = api . arch2infos_full [ api . random ( ) ] . get_metrics ( dataset , subset , None , False )
xresults . append ( metrics [ ' accuracy ' ] )
for iepoch in range ( epochs ) :
genotype = xdata [ ' genotypes ' ] [ iepoch ]
index = api . query_index_by_arch ( genotype )
metrics = api . arch2infos_full [ index ] . get_metrics ( dataset , subset , None , False )
xresults . append ( metrics [ ' accuracy ' ] )
return xresults
if x_maxs == 50 :
xox , xxxstrs = ' v2 ' , [ ' DARTS-V1 ' , ' DARTS-V2 ' ]
elif x_maxs == 250 :
xox , xxxstrs = ' v1 ' , [ ' RSPS ' , ' GDAS ' , ' SETN ' , ' ENAS ' ]
else : raise ValueError ( ' invalid x_maxs= {:} ' . format ( x_maxs ) )
for idx , method in enumerate ( xxxstrs ) :
xkey = method
all_paths = [ ' {:} /seed- {:} -basic.pth ' . format ( xpaths [ xkey ] , seed ) for seed in xseeds [ xkey ] ]
all_datas = [ torch . load ( xpath , map_location = ' cpu ' ) for xpath in all_paths ]
accyss_A = np . array ( [ get_accs ( xdatas , data_sub_a [ 0 ] , data_sub_a [ 1 ] ) for xdatas in all_datas ] )
accyss_B = np . array ( [ get_accs ( xdatas , data_sub_b [ 0 ] , data_sub_b [ 1 ] ) for xdatas in all_datas ] )
epochs = list ( range ( accyss_A . shape [ 1 ] ) )
for j , accyss in enumerate ( [ accyss_A , accyss_B ] ) :
2020-01-15 15:43:07 +01:00
if x_maxs == 50 :
color , line = color_set [ idx * 2 + j ] , ' - ' if j == 0 else ' -- '
elif x_maxs == 250 :
color , line = color_set [ idx ] , ' - ' if j == 0 else ' -- '
else : raise ValueError ( ' invalid x-maxs= {:} ' . format ( x_maxs ) )
plt . plot ( epochs , [ accyss [ : , i ] . mean ( ) for i in epochs ] , color = color , linestyle = line , label = ' {:} ( {:} ) ' . format ( method , ' VALID ' if j == 0 else ' TEST ' ) , lw = 2 , alpha = 0.9 )
plt . fill_between ( epochs , [ accyss [ : , i ] . mean ( ) - accyss [ : , i ] . std ( ) for i in epochs ] , [ accyss [ : , i ] . mean ( ) + accyss [ : , i ] . std ( ) for i in epochs ] , alpha = 0.2 , color = color )
setname = data_sub_a if j == 0 else data_sub_b
print ( ' {:} -- {:} ---- {:.2f} $ \\ pm$ {:.2f} ' . format ( method , setname , accyss [ : , - 1 ] . mean ( ) , accyss [ : , - 1 ] . std ( ) ) )
2020-01-14 14:52:06 +01:00
#plt.legend(loc=4, fontsize=LegendFontsize)
plt . legend ( loc = 0 , fontsize = LegendFontsize )
save_path = vis_save_dir / ' {:} - {:} ' . format ( xox , file_name )
print ( ' save figure into {:} \n ' . format ( save_path ) )
fig . savefig ( str ( save_path ) , dpi = dpi , bbox_inches = ' tight ' , format = ' pdf ' )
def show_reinforce ( api , root , dataset , xset , file_name , y_lims ) :
print ( ' root-path= {:} , dataset= {:} , xset= {:} ' . format ( root , dataset , xset ) )
2020-01-15 15:43:07 +01:00
LRs = [ ' 0.01 ' , ' 0.02 ' , ' 0.1 ' , ' 0.2 ' , ' 0.5 ' ]
2020-01-14 14:52:06 +01:00
checkpoints = [ ' ./output/search-cell-nas-bench-201/REINFORCE-cifar10- {:} /results.pth ' . format ( x ) for x in LRs ]
acc_lr_dict , indexes = { } , None
for lr , checkpoint in zip ( LRs , checkpoints ) :
all_indexes , accuracies = torch . load ( checkpoint , map_location = ' cpu ' ) , [ ]
for x in all_indexes :
info = api . arch2infos_full [ x ]
metrics = info . get_metrics ( dataset , xset , None , False )
accuracies . append ( metrics [ ' accuracy ' ] )
if indexes is None : indexes = list ( range ( len ( accuracies ) ) )
acc_lr_dict [ lr ] = np . array ( sorted ( accuracies ) )
print ( ' LR= {:.3f} , mean= {:} , std= {:} ' . format ( float ( lr ) , acc_lr_dict [ lr ] . mean ( ) , acc_lr_dict [ lr ] . std ( ) ) )
color_set = [ ' r ' , ' b ' , ' g ' , ' c ' , ' m ' , ' y ' , ' k ' ]
dpi , width , height = 300 , 3400 , 2600
LabelSize , LegendFontsize = 28 , 22
figsize = width / float ( dpi ) , height / float ( dpi )
fig = plt . figure ( figsize = figsize )
x_axis = np . arange ( 0 , 600 )
plt . xlim ( 0 , max ( indexes ) )
plt . ylim ( y_lims [ 0 ] , y_lims [ 1 ] )
interval_x , interval_y = 100 , y_lims [ 2 ]
plt . xticks ( np . arange ( 0 , max ( indexes ) , interval_x ) , fontsize = LegendFontsize )
plt . yticks ( np . arange ( y_lims [ 0 ] , y_lims [ 1 ] , interval_y ) , fontsize = LegendFontsize )
plt . grid ( )
plt . xlabel ( ' The index of runs ' , fontsize = LabelSize )
plt . ylabel ( ' The accuracy ( % ) ' , fontsize = LabelSize )
for idx , LR in enumerate ( LRs ) :
legend = ' LR= {:.2f} ' . format ( float ( LR ) )
2020-01-15 15:43:07 +01:00
#color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.'
color , linestyle = color_set [ idx ] , ' - '
2020-01-14 14:52:06 +01:00
plt . plot ( indexes , acc_lr_dict [ LR ] , color = color , linestyle = linestyle , label = legend , lw = 2 , alpha = 0.8 )
print ( ' {:} : mean = {:} , std = {:} :: {:.2f} $ \\ pm$ {:.2f} ' . format ( legend , np . mean ( acc_lr_dict [ LR ] ) , np . std ( acc_lr_dict [ LR ] ) , np . mean ( acc_lr_dict [ LR ] ) , np . std ( acc_lr_dict [ LR ] ) ) )
plt . legend ( loc = 4 , fontsize = LegendFontsize )
save_path = root / ' {:} - {:} - {:} .pdf ' . format ( dataset , xset , file_name )
print ( ' save figure into {:} \n ' . format ( save_path ) )
fig . savefig ( str ( save_path ) , dpi = dpi , bbox_inches = ' tight ' , format = ' pdf ' )
2020-01-01 12:51:00 +01:00
2020-01-15 15:43:07 +01:00
def show_rea ( api , root , dataset , xset , file_name , y_lims ) :
print ( ' root-path= {:} , dataset= {:} , xset= {:} ' . format ( root , dataset , xset ) )
SSs = [ 3 , 5 , 10 ]
checkpoints = [ ' ./output/search-cell-nas-bench-201/R-EA-cifar10-SS {:} /results.pth ' . format ( x ) for x in SSs ]
acc_ss_dict , indexes = { } , None
for ss , checkpoint in zip ( SSs , checkpoints ) :
all_indexes , accuracies = torch . load ( checkpoint , map_location = ' cpu ' ) , [ ]
for x in all_indexes :
info = api . arch2infos_full [ x ]
metrics = info . get_metrics ( dataset , xset , None , False )
accuracies . append ( metrics [ ' accuracy ' ] )
if indexes is None : indexes = list ( range ( len ( accuracies ) ) )
acc_ss_dict [ ss ] = np . array ( sorted ( accuracies ) )
print ( ' Sample-Size= {:2d} , mean= {:} , std= {:} ' . format ( ss , acc_ss_dict [ ss ] . mean ( ) , acc_ss_dict [ ss ] . std ( ) ) )
color_set = [ ' r ' , ' b ' , ' g ' , ' c ' , ' m ' , ' y ' , ' k ' ]
dpi , width , height = 300 , 3400 , 2600
LabelSize , LegendFontsize = 28 , 22
figsize = width / float ( dpi ) , height / float ( dpi )
fig = plt . figure ( figsize = figsize )
x_axis = np . arange ( 0 , 600 )
plt . xlim ( 0 , max ( indexes ) )
plt . ylim ( y_lims [ 0 ] , y_lims [ 1 ] )
interval_x , interval_y = 100 , y_lims [ 2 ]
plt . xticks ( np . arange ( 0 , max ( indexes ) , interval_x ) , fontsize = LegendFontsize )
plt . yticks ( np . arange ( y_lims [ 0 ] , y_lims [ 1 ] , interval_y ) , fontsize = LegendFontsize )
plt . grid ( )
plt . xlabel ( ' The index of runs ' , fontsize = LabelSize )
plt . ylabel ( ' The accuracy ( % ) ' , fontsize = LabelSize )
for idx , ss in enumerate ( SSs ) :
legend = ' sample-size= {:2d} ' . format ( ss )
#color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.'
color , linestyle = color_set [ idx ] , ' - '
plt . plot ( indexes , acc_ss_dict [ ss ] , color = color , linestyle = linestyle , label = legend , lw = 2 , alpha = 0.8 )
print ( ' {:} : mean = {:} , std = {:} :: {:.2f} $ \\ pm$ {:.2f} ' . format ( legend , np . mean ( acc_ss_dict [ ss ] ) , np . std ( acc_ss_dict [ ss ] ) , np . mean ( acc_ss_dict [ ss ] ) , np . std ( acc_ss_dict [ ss ] ) ) )
plt . legend ( loc = 4 , fontsize = LegendFontsize )
save_path = root / ' {:} - {:} - {:} .pdf ' . format ( dataset , xset , file_name )
print ( ' save figure into {:} \n ' . format ( save_path ) )
fig . savefig ( str ( save_path ) , dpi = dpi , bbox_inches = ' tight ' , format = ' pdf ' )
2019-12-29 10:17:26 +01:00
if __name__ == ' __main__ ' :
2020-01-14 14:52:06 +01:00
parser = argparse . ArgumentParser ( description = ' NAS-Bench-201 ' , formatter_class = argparse . ArgumentDefaultsHelpFormatter )
parser . add_argument ( ' --save_dir ' , type = str , default = ' ./output/search-cell-nas-bench-201/visuals ' , help = ' The base-name of folder to save checkpoints and log. ' )
parser . add_argument ( ' --api_path ' , type = str , default = None , help = ' The path to the NAS-Bench-201 benchmark file. ' )
2019-12-29 10:17:26 +01:00
args = parser . parse_args ( )
2019-12-31 12:02:11 +01:00
vis_save_dir = Path ( args . save_dir )
2019-12-29 10:17:26 +01:00
vis_save_dir . mkdir ( parents = True , exist_ok = True )
meta_file = Path ( args . api_path )
assert meta_file . exists ( ) , ' invalid path for api : {:} ' . format ( meta_file )
2019-12-31 12:02:11 +01:00
#visualize_rank_over_time(str(meta_file), vis_save_dir / 'over-time')
#write_video(vis_save_dir / 'over-time')
#visualize_info(str(meta_file), 'cifar10' , vis_save_dir)
#visualize_info(str(meta_file), 'cifar100', vis_save_dir)
#visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir)
2020-01-01 12:18:42 +01:00
#visualize_relative_ranking(vis_save_dir)
2020-01-01 12:51:00 +01:00
api = API ( args . api_path )
2020-01-15 15:43:07 +01:00
#show_reinforce(api, vis_save_dir, 'cifar10-valid' , 'x-valid', 'REINFORCE-CIFAR-10', (85, 92, 2))
#show_rea (api, vis_save_dir, 'cifar10-valid' , 'x-valid', 'REA-CIFAR-10', (88, 92, 1))
#plot_results_nas_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test'), vis_save_dir, 'nas-com-v2-cifar010.pdf', (85,95, 1))
#plot_results_nas_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ), vis_save_dir, 'nas-com-v2-cifar100.pdf', (60,75, 3))
#plot_results_nas_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ), vis_save_dir, 'nas-com-v2-imagenet.pdf', (35,48, 2))
show_nas_sharing_w_v2 ( api , ( ' cifar10-valid ' , ' x-valid ' ) , ( ' cifar10 ' , ' ori-test ' ) , vis_save_dir , ' BN0 ' , ' BN0-DARTS-CIFAR010.pdf ' , ( 0 , 100 , 10 ) , 50 )
show_nas_sharing_w_v2 ( api , ( ' cifar100 ' , ' x-valid ' ) , ( ' cifar100 ' , ' x-test ' ) , vis_save_dir , ' BN0 ' , ' BN0-DARTS-CIFAR100.pdf ' , ( 0 , 100 , 10 ) , 50 )
show_nas_sharing_w_v2 ( api , ( ' ImageNet16-120 ' , ' x-valid ' ) , ( ' ImageNet16-120 ' , ' x-test ' ) , vis_save_dir , ' BN0 ' , ' BN0-DARTS-ImageNet.pdf ' , ( 0 , 100 , 10 ) , 50 )
show_nas_sharing_w_v2 ( api , ( ' cifar10-valid ' , ' x-valid ' ) , ( ' cifar10 ' , ' ori-test ' ) , vis_save_dir , ' BN0 ' , ' BN0-OTHER-CIFAR010.pdf ' , ( 0 , 100 , 10 ) , 250 )
show_nas_sharing_w_v2 ( api , ( ' cifar100 ' , ' x-valid ' ) , ( ' cifar100 ' , ' x-test ' ) , vis_save_dir , ' BN0 ' , ' BN0-OTHER-CIFAR100.pdf ' , ( 0 , 100 , 10 ) , 250 )
show_nas_sharing_w_v2 ( api , ( ' ImageNet16-120 ' , ' x-valid ' ) , ( ' ImageNet16-120 ' , ' x-test ' ) , vis_save_dir , ' BN0 ' , ' BN0-OTHER-ImageNet.pdf ' , ( 0 , 100 , 10 ) , 250 )
2020-01-14 14:52:06 +01:00
2020-01-15 15:43:07 +01:00
show_nas_sharing_w ( api , ' cifar10-valid ' , ' x-valid ' , vis_save_dir , ' BN0 ' , ' BN0-XX-CIFAR010-VALID.pdf ' , ( 0 , 100 , 10 ) , 250 )
show_nas_sharing_w ( api , ' cifar10 ' , ' ori-test ' , vis_save_dir , ' BN0 ' , ' BN0-XX-CIFAR010-TEST.pdf ' , ( 0 , 100 , 10 ) , 250 )
import pdb ; pdb . set_trace ( )
"""
2020-01-02 04:35:58 +01:00
for x_maxs in [ 50 , 250 ] :
show_nas_sharing_w ( api , ' cifar10-valid ' , ' x-valid ' , vis_save_dir , ' nas-plot.pdf ' , ( 0 , 100 , 10 ) , x_maxs )
show_nas_sharing_w ( api , ' cifar10 ' , ' ori-test ' , vis_save_dir , ' nas-plot.pdf ' , ( 0 , 100 , 10 ) , x_maxs )
show_nas_sharing_w ( api , ' cifar100 ' , ' x-valid ' , vis_save_dir , ' nas-plot.pdf ' , ( 0 , 100 , 10 ) , x_maxs )
show_nas_sharing_w ( api , ' cifar100 ' , ' x-test ' , vis_save_dir , ' nas-plot.pdf ' , ( 0 , 100 , 10 ) , x_maxs )
show_nas_sharing_w ( api , ' ImageNet16-120 ' , ' x-valid ' , vis_save_dir , ' nas-plot.pdf ' , ( 0 , 100 , 10 ) , x_maxs )
show_nas_sharing_w ( api , ' ImageNet16-120 ' , ' x-test ' , vis_save_dir , ' nas-plot.pdf ' , ( 0 , 100 , 10 ) , x_maxs )
2020-01-14 14:52:06 +01:00
show_nas_sharing_w_v2 ( api , ( ' cifar10-valid ' , ' x-valid ' ) , ( ' cifar10 ' , ' ori-test ' ) , vis_save_dir , ' DARTS-CIFAR010.pdf ' , ( 0 , 100 , 10 ) , 50 )
2020-01-15 15:43:07 +01:00
just_show ( api )
2020-01-01 12:51:00 +01:00
plot_results_nas ( api , ' cifar10-valid ' , ' x-valid ' , vis_save_dir , ' nas-com.pdf ' , ( 85 , 95 , 1 ) )
plot_results_nas ( api , ' cifar10 ' , ' ori-test ' , vis_save_dir , ' nas-com.pdf ' , ( 85 , 95 , 1 ) )
plot_results_nas ( api , ' cifar100 ' , ' x-valid ' , vis_save_dir , ' nas-com.pdf ' , ( 55 , 75 , 3 ) )
plot_results_nas ( api , ' cifar100 ' , ' x-test ' , vis_save_dir , ' nas-com.pdf ' , ( 55 , 75 , 3 ) )
plot_results_nas ( api , ' ImageNet16-120 ' , ' x-valid ' , vis_save_dir , ' nas-com.pdf ' , ( 35 , 50 , 3 ) )
plot_results_nas ( api , ' ImageNet16-120 ' , ' x-test ' , vis_save_dir , ' nas-com.pdf ' , ( 35 , 50 , 3 ) )
2020-01-09 12:26:23 +01:00
"""