fix bugs in RANDOM-NAS and BOHB
This commit is contained in:
parent
4c144b7437
commit
f8f44bfb31
@ -51,7 +51,7 @@ res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric
|
||||
cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
|
||||
|
||||
# get the detailed information
|
||||
results = api.query_by_index(1, 'cifar100') # a list of all trials on cifar100
|
||||
results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed
|
||||
print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1]))
|
||||
print ('Latency : {:}'.format(results[0].get_latency()))
|
||||
print ('Train Info : {:}'.format(results[0].get_train()))
|
||||
|
@ -9,5 +9,6 @@
|
||||
"momentum" : ["float", "0.9"],
|
||||
"nesterov" : ["bool", "1"],
|
||||
"criterion": ["str", "Softmax"],
|
||||
"batch_size": ["int", "64"]
|
||||
"batch_size": ["int", "64"],
|
||||
"test_batch_size": ["int", "512"]
|
||||
}
|
||||
|
386
exps/NAS-Bench-102/visualize.py
Normal file
386
exps/NAS-Bench-102/visualize.py
Normal file
@ -0,0 +1,386 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
# python exps/NAS-Bench-102/visualize.py --api_path $HOME/.torch/NAS-Bench-102-v1_0-e61699.pth
|
||||
##################################################
|
||||
import os, sys, time, argparse, collections
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
import matplotlib
|
||||
import seaborn as sns
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
matplotlib.use('agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from log_utils import time_string
|
||||
from nas_102_api import NASBench102API as API
|
||||
|
||||
|
||||
|
||||
def calculate_correlation(*vectors):
|
||||
matrix = []
|
||||
for i, vectori in enumerate(vectors):
|
||||
x = []
|
||||
for j, vectorj in enumerate(vectors):
|
||||
x.append( np.corrcoef(vectori, vectorj)[0,1] )
|
||||
matrix.append( x )
|
||||
return np.array(matrix)
|
||||
|
||||
|
||||
|
||||
def visualize_relative_ranking(vis_save_dir):
|
||||
print ('\n' + '-'*100)
|
||||
cifar010_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('cifar10')
|
||||
cifar100_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('cifar100')
|
||||
imagenet_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('ImageNet16-120')
|
||||
cifar010_info = torch.load(cifar010_cache_path)
|
||||
cifar100_info = torch.load(cifar100_cache_path)
|
||||
imagenet_info = torch.load(imagenet_cache_path)
|
||||
indexes = list(range(len(cifar010_info['params'])))
|
||||
|
||||
print ('{:} start to visualize relative ranking'.format(time_string()))
|
||||
# maximum accuracy with ResNet-level params 11472
|
||||
x_010_accs = [ cifar010_info['test_accs'][i] if cifar010_info['params'][i] <= cifar010_info['params'][11472] else -1 for i in indexes]
|
||||
x_100_accs = [ cifar100_info['test_accs'][i] if cifar100_info['params'][i] <= cifar100_info['params'][11472] else -1 for i in indexes]
|
||||
x_img_accs = [ imagenet_info['test_accs'][i] if imagenet_info['params'][i] <= imagenet_info['params'][11472] else -1 for i in indexes]
|
||||
|
||||
cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i])
|
||||
cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i])
|
||||
imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i])
|
||||
|
||||
cifar100_labels, imagenet_labels = [], []
|
||||
for idx in cifar010_ord_indexes:
|
||||
cifar100_labels.append( cifar100_ord_indexes.index(idx) )
|
||||
imagenet_labels.append( imagenet_ord_indexes.index(idx) )
|
||||
print ('{:} prepare data done.'.format(time_string()))
|
||||
|
||||
dpi, width, height = 300, 2600, 2600
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
LabelSize, LegendFontsize = 18, 18
|
||||
resnet_scale, resnet_alpha = 120, 0.5
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
ax = fig.add_subplot(111)
|
||||
plt.xlim(min(indexes), max(indexes))
|
||||
plt.ylim(min(indexes), max(indexes))
|
||||
#plt.ylabel('y').set_rotation(0)
|
||||
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize, rotation='vertical')
|
||||
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize)
|
||||
#ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8, label='CIFAR-100')
|
||||
#ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8, label='ImageNet-16-120')
|
||||
#ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8, label='CIFAR-10')
|
||||
ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8)
|
||||
ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8)
|
||||
ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8)
|
||||
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10')
|
||||
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100')
|
||||
ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='ImageNet-16-120')
|
||||
plt.grid(zorder=0)
|
||||
ax.set_axisbelow(True)
|
||||
plt.legend(loc=0, fontsize=LegendFontsize)
|
||||
ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize)
|
||||
ax.set_ylabel('architecture ranking', fontsize=LabelSize)
|
||||
save_path = (vis_save_dir / 'relative-rank.pdf').resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
save_path = (vis_save_dir / 'relative-rank.png').resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
|
||||
# calculate correlation
|
||||
sns_size = 15
|
||||
CoRelMatrix = calculate_correlation(cifar010_info['valid_accs'], cifar010_info['test_accs'], cifar100_info['valid_accs'], cifar100_info['test_accs'], imagenet_info['valid_accs'], imagenet_info['test_accs'])
|
||||
fig = plt.figure(figsize=figsize)
|
||||
plt.axis('off')
|
||||
h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5)
|
||||
save_path = (vis_save_dir / 'co-relation-all.pdf').resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
|
||||
# calculate correlation
|
||||
acc_bars = [92, 93]
|
||||
for acc_bar in acc_bars:
|
||||
selected_indexes = []
|
||||
for i, acc in enumerate(cifar010_info['test_accs']):
|
||||
if acc > acc_bar: selected_indexes.append( i )
|
||||
print ('select {:} architectures'.format(len(selected_indexes)))
|
||||
cifar010_valid_accs = np.array(cifar010_info['valid_accs'])[ selected_indexes ]
|
||||
cifar010_test_accs = np.array(cifar010_info['test_accs']) [ selected_indexes ]
|
||||
cifar100_valid_accs = np.array(cifar100_info['valid_accs'])[ selected_indexes ]
|
||||
cifar100_test_accs = np.array(cifar100_info['test_accs']) [ selected_indexes ]
|
||||
imagenet_valid_accs = np.array(imagenet_info['valid_accs'])[ selected_indexes ]
|
||||
imagenet_test_accs = np.array(imagenet_info['test_accs']) [ selected_indexes ]
|
||||
CoRelMatrix = calculate_correlation(cifar010_valid_accs, cifar010_test_accs, cifar100_valid_accs, cifar100_test_accs, imagenet_valid_accs, imagenet_test_accs)
|
||||
fig = plt.figure(figsize=figsize)
|
||||
plt.axis('off')
|
||||
h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5)
|
||||
save_path = (vis_save_dir / 'co-relation-top-{:}.pdf'.format(len(selected_indexes))).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
plt.close('all')
|
||||
|
||||
|
||||
|
||||
def visualize_info(meta_file, dataset, vis_save_dir):
|
||||
print ('{:} start to visualize {:} information'.format(time_string(), dataset))
|
||||
cache_file_path = vis_save_dir / '{:}-cache-info.pth'.format(dataset)
|
||||
if not cache_file_path.exists():
|
||||
print ('Do not find cache file : {:}'.format(cache_file_path))
|
||||
nas_bench = API(str(meta_file))
|
||||
params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], [], [], [], []
|
||||
for index in range( len(nas_bench) ):
|
||||
info = nas_bench.query_by_index(index, use_12epochs_result=False)
|
||||
resx = info.get_comput_costs(dataset) ; flop, param = resx['flops'], resx['params']
|
||||
if dataset == 'cifar10':
|
||||
res = info.get_metrics('cifar10', 'train') ; train_acc = res['accuracy']
|
||||
res = info.get_metrics('cifar10-valid', 'x-valid') ; valid_acc = res['accuracy']
|
||||
res = info.get_metrics('cifar10', 'ori-test') ; test_acc = res['accuracy']
|
||||
res = info.get_metrics('cifar10', 'ori-test') ; otest_acc = res['accuracy']
|
||||
else:
|
||||
res = info.get_metrics(dataset, 'train') ; train_acc = res['accuracy']
|
||||
res = info.get_metrics(dataset, 'x-valid') ; valid_acc = res['accuracy']
|
||||
res = info.get_metrics(dataset, 'x-test') ; test_acc = res['accuracy']
|
||||
res = info.get_metrics(dataset, 'ori-test') ; otest_acc = res['accuracy']
|
||||
if index == 11472: # resnet
|
||||
resnet = {'params':param, 'flops': flop, 'index': 11472, 'train_acc': train_acc, 'valid_acc': valid_acc, 'test_acc': test_acc, 'otest_acc': otest_acc}
|
||||
flops.append( flop )
|
||||
params.append( param )
|
||||
train_accs.append( train_acc )
|
||||
valid_accs.append( valid_acc )
|
||||
test_accs.append( test_acc )
|
||||
otest_accs.append( otest_acc )
|
||||
#resnet = {'params': 0.559, 'flops': 78.56, 'index': 11472, 'train_acc': 99.99, 'valid_acc': 90.84, 'test_acc': 93.97}
|
||||
info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs, 'otest_accs': otest_accs}
|
||||
info['resnet'] = resnet
|
||||
torch.save(info, cache_file_path)
|
||||
else:
|
||||
print ('Find cache file : {:}'.format(cache_file_path))
|
||||
info = torch.load(cache_file_path)
|
||||
params, flops, train_accs, valid_accs, test_accs, otest_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'], info['otest_accs']
|
||||
resnet = info['resnet']
|
||||
print ('{:} collect data done.'.format(time_string()))
|
||||
|
||||
indexes = list(range(len(params)))
|
||||
dpi, width, height = 300, 2600, 2600
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
LabelSize, LegendFontsize = 22, 22
|
||||
resnet_scale, resnet_alpha = 120, 0.5
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
ax = fig.add_subplot(111)
|
||||
plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
|
||||
if dataset == 'cifar10':
|
||||
plt.ylim(50, 100)
|
||||
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
|
||||
elif dataset == 'cifar100':
|
||||
plt.ylim(25, 75)
|
||||
plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
|
||||
else:
|
||||
plt.ylim(0, 50)
|
||||
plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
|
||||
ax.scatter(params, valid_accs, marker='o', s=0.5, c='tab:blue')
|
||||
ax.scatter([resnet['params']], [resnet['valid_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=0.4)
|
||||
plt.grid(zorder=0)
|
||||
ax.set_axisbelow(True)
|
||||
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||
ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
|
||||
ax.set_ylabel('the validation accuracy (%)', fontsize=LabelSize)
|
||||
save_path = (vis_save_dir / '{:}-param-vs-valid.pdf'.format(dataset)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
save_path = (vis_save_dir / '{:}-param-vs-valid.png'.format(dataset)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
ax = fig.add_subplot(111)
|
||||
plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
|
||||
if dataset == 'cifar10':
|
||||
plt.ylim(50, 100)
|
||||
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
|
||||
elif dataset == 'cifar100':
|
||||
plt.ylim(25, 75)
|
||||
plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
|
||||
else:
|
||||
plt.ylim(0, 50)
|
||||
plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
|
||||
ax.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue')
|
||||
ax.scatter([resnet['params']], [resnet['test_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha)
|
||||
plt.grid()
|
||||
ax.set_axisbelow(True)
|
||||
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||
ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
|
||||
ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize)
|
||||
save_path = (vis_save_dir / '{:}-param-vs-test.pdf'.format(dataset)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
save_path = (vis_save_dir / '{:}-param-vs-test.png'.format(dataset)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
ax = fig.add_subplot(111)
|
||||
plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
|
||||
if dataset == 'cifar10':
|
||||
plt.ylim(50, 100)
|
||||
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
|
||||
elif dataset == 'cifar100':
|
||||
plt.ylim(20, 100)
|
||||
plt.yticks(np.arange(20, 101, 10), fontsize=LegendFontsize)
|
||||
else:
|
||||
plt.ylim(25, 76)
|
||||
plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
|
||||
ax.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue')
|
||||
ax.scatter([resnet['params']], [resnet['train_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha)
|
||||
plt.grid()
|
||||
ax.set_axisbelow(True)
|
||||
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||
ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
|
||||
ax.set_ylabel('the trarining accuracy (%)', fontsize=LabelSize)
|
||||
save_path = (vis_save_dir / '{:}-param-vs-train.pdf'.format(dataset)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
save_path = (vis_save_dir / '{:}-param-vs-train.png'.format(dataset)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
ax = fig.add_subplot(111)
|
||||
plt.xlim(0, max(indexes))
|
||||
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize)
|
||||
if dataset == 'cifar10':
|
||||
plt.ylim(50, 100)
|
||||
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
|
||||
elif dataset == 'cifar100':
|
||||
plt.ylim(25, 75)
|
||||
plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
|
||||
else:
|
||||
plt.ylim(0, 50)
|
||||
plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
|
||||
ax.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue')
|
||||
ax.scatter([resnet['index']], [resnet['test_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha)
|
||||
plt.grid()
|
||||
ax.set_axisbelow(True)
|
||||
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||
ax.set_xlabel('architecture ID', fontsize=LabelSize)
|
||||
ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize)
|
||||
save_path = (vis_save_dir / '{:}-test-over-ID.pdf'.format(dataset)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
save_path = (vis_save_dir / '{:}-test-over-ID.png'.format(dataset)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
plt.close('all')
|
||||
|
||||
|
||||
|
||||
def visualize_rank_over_time(meta_file, vis_save_dir):
|
||||
print ('\n' + '-'*150)
|
||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
print ('{:} start to visualize rank-over-time into {:}'.format(time_string(), vis_save_dir))
|
||||
cache_file_path = vis_save_dir / 'rank-over-time-cache-info.pth'
|
||||
if not cache_file_path.exists():
|
||||
print ('Do not find cache file : {:}'.format(cache_file_path))
|
||||
nas_bench = API(str(meta_file))
|
||||
print ('{:} load nas_bench done'.format(time_string()))
|
||||
params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list)
|
||||
#for iepoch in range(200): for index in range( len(nas_bench) ):
|
||||
for index in tqdm(range(len(nas_bench))):
|
||||
info = nas_bench.query_by_index(index, use_12epochs_result=False)
|
||||
for iepoch in range(200):
|
||||
res = info.get_metrics('cifar10' , 'train' , iepoch) ; train_acc = res['accuracy']
|
||||
res = info.get_metrics('cifar10-valid', 'x-valid' , iepoch) ; valid_acc = res['accuracy']
|
||||
res = info.get_metrics('cifar10' , 'ori-test', iepoch) ; test_acc = res['accuracy']
|
||||
res = info.get_metrics('cifar10' , 'ori-test', iepoch) ; otest_acc = res['accuracy']
|
||||
train_accs[iepoch].append( train_acc )
|
||||
valid_accs[iepoch].append( valid_acc )
|
||||
test_accs [iepoch].append( test_acc )
|
||||
otest_accs[iepoch].append( otest_acc )
|
||||
if iepoch == 0:
|
||||
res = info.get_comput_costs('cifar10') ; flop, param = res['flops'], res['params']
|
||||
flops.append( flop )
|
||||
params.append( param )
|
||||
info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs, 'otest_accs': otest_accs}
|
||||
torch.save(info, cache_file_path)
|
||||
else:
|
||||
print ('Find cache file : {:}'.format(cache_file_path))
|
||||
info = torch.load(cache_file_path)
|
||||
params, flops, train_accs, valid_accs, test_accs, otest_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'], info['otest_accs']
|
||||
print ('{:} collect data done.'.format(time_string()))
|
||||
#selected_epochs = [0, 100, 150, 180, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]
|
||||
selected_epochs = list( range(200) )
|
||||
x_xtests = test_accs[199]
|
||||
indexes = list(range(len(x_xtests)))
|
||||
ord_idxs = sorted(indexes, key=lambda i: x_xtests[i])
|
||||
for sepoch in selected_epochs:
|
||||
x_valids = valid_accs[sepoch]
|
||||
valid_ord_idxs = sorted(indexes, key=lambda i: x_valids[i])
|
||||
valid_ord_lbls = []
|
||||
for idx in ord_idxs:
|
||||
valid_ord_lbls.append( valid_ord_idxs.index(idx) )
|
||||
# labeled data
|
||||
dpi, width, height = 300, 2600, 2600
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
LabelSize, LegendFontsize = 18, 18
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
ax = fig.add_subplot(111)
|
||||
plt.xlim(min(indexes), max(indexes))
|
||||
plt.ylim(min(indexes), max(indexes))
|
||||
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize, rotation='vertical')
|
||||
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize)
|
||||
ax.scatter(indexes, valid_ord_lbls, marker='^', s=0.5, c='tab:green', alpha=0.8)
|
||||
ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8)
|
||||
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-10 validation')
|
||||
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10 test')
|
||||
plt.grid(zorder=0)
|
||||
ax.set_axisbelow(True)
|
||||
plt.legend(loc='upper left', fontsize=LegendFontsize)
|
||||
ax.set_xlabel('architecture ranking in the final test accuracy', fontsize=LabelSize)
|
||||
ax.set_ylabel('architecture ranking in the validation set', fontsize=LabelSize)
|
||||
save_path = (vis_save_dir / 'time-{:03d}.pdf'.format(sepoch)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
save_path = (vis_save_dir / 'time-{:03d}.png'.format(sepoch)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
plt.close('all')
|
||||
|
||||
|
||||
|
||||
def write_video(save_dir):
|
||||
import cv2
|
||||
video_save_path = save_dir / 'time.avi'
|
||||
print ('{:} start create video for {:}'.format(time_string(), video_save_path))
|
||||
images = sorted( list( save_dir.glob('time-*.png') ) )
|
||||
ximage = cv2.imread(str(images[0]))
|
||||
#shape = (ximage.shape[1], ximage.shape[0])
|
||||
shape = (1000, 1000)
|
||||
#writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 25, shape)
|
||||
writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 5, shape)
|
||||
for idx, image in enumerate(images):
|
||||
ximage = cv2.imread(str(image))
|
||||
_image = cv2.resize(ximage, shape)
|
||||
writer.write(_image)
|
||||
writer.release()
|
||||
print ('write video [{:} frames] into {:}'.format(len(images), video_save_path))
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-102/visual', help='The base-name of folder to save checkpoints and log.')
|
||||
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-102 benchmark file.')
|
||||
args = parser.parse_args()
|
||||
|
||||
vis_save_dir = Path(args.save_dir) / 'visuals'
|
||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
meta_file = Path(args.api_path)
|
||||
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
|
||||
visualize_rank_over_time(str(meta_file), vis_save_dir / 'over-time')
|
||||
write_video(vis_save_dir / 'over-time')
|
||||
visualize_info(str(meta_file), 'cifar10' , vis_save_dir)
|
||||
visualize_info(str(meta_file), 'cifar100', vis_save_dir)
|
||||
visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir)
|
||||
visualize_relative_ranking(vis_save_dir)
|
@ -53,43 +53,50 @@ def config2structure_func(max_nodes):
|
||||
|
||||
class MyWorker(Worker):
|
||||
|
||||
def __init__(self, *args, convert_func=None, nas_bench=None, time_scale=None, **kwargs):
|
||||
def __init__(self, *args, convert_func=None, nas_bench=None, time_budget=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.convert_func = convert_func
|
||||
self.nas_bench = nas_bench
|
||||
self.time_scale = time_scale
|
||||
self.seen_arch = 0
|
||||
self.time_budget = time_budget
|
||||
self.seen_archs = []
|
||||
self.sim_cost_time = 0
|
||||
self.real_cost_time = 0
|
||||
self.is_end = False
|
||||
|
||||
def get_the_best(self):
|
||||
assert len(self.seen_archs) > 0
|
||||
best_index, best_acc = -1, None
|
||||
for arch_index in self.seen_archs:
|
||||
info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True)
|
||||
vacc = info['valid-accuracy']
|
||||
if best_acc is None or best_acc < vacc:
|
||||
best_acc = vacc
|
||||
best_index = arch_index
|
||||
assert best_index != -1
|
||||
return best_index
|
||||
|
||||
def compute(self, config, budget, **kwargs):
|
||||
start_time = time.time()
|
||||
structure = self.convert_func( config )
|
||||
arch_index = self.nas_bench.query_index_by_arch( structure )
|
||||
iepoch = 0
|
||||
while iepoch < 12:
|
||||
info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch, True)
|
||||
info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True)
|
||||
cur_time = info['train-all-time'] + info['valid-per-time']
|
||||
cur_vacc = info['valid-accuracy']
|
||||
if time.time() - start_time + cur_time / self.time_scale > budget:
|
||||
break
|
||||
else:
|
||||
iepoch += 1
|
||||
self.sim_cost_time += cur_time
|
||||
self.seen_arch += 1
|
||||
remaining_time = cur_time / self.time_scale - (time.time() - start_time)
|
||||
if remaining_time > 0:
|
||||
time.sleep(remaining_time)
|
||||
else:
|
||||
import pdb; pdb.set_trace()
|
||||
self.real_cost_time += (time.time() - start_time)
|
||||
return ({
|
||||
'loss': 100 - float(cur_vacc),
|
||||
'info': {'seen-arch' : self.seen_arch,
|
||||
if self.sim_cost_time + cur_time <= self.time_budget and not self.is_end:
|
||||
self.sim_cost_time += cur_time
|
||||
self.seen_archs.append( arch_index )
|
||||
return ({'loss': 100 - float(cur_vacc),
|
||||
'info': {'seen-arch' : len(self.seen_archs),
|
||||
'sim-test-time' : self.sim_cost_time,
|
||||
'real-test-time': self.real_cost_time,
|
||||
'current-arch' : arch_index,
|
||||
'current-budget': budget}
|
||||
'current-arch' : arch_index}
|
||||
})
|
||||
else:
|
||||
self.is_end = True
|
||||
return ({'loss': 100,
|
||||
'info': {'seen-arch' : len(self.seen_archs),
|
||||
'sim-test-time' : self.sim_cost_time,
|
||||
'current-arch' : None}
|
||||
})
|
||||
|
||||
|
||||
@ -139,16 +146,14 @@ def main(xargs, nas_bench):
|
||||
#logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string()))
|
||||
workers = []
|
||||
for i in range(num_workers):
|
||||
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_scale=xargs.time_scale, run_id=hb_run_id, id=i)
|
||||
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_budget=xargs.time_budget, run_id=hb_run_id, id=i)
|
||||
w.run(background=True)
|
||||
workers.append(w)
|
||||
|
||||
simulate_time_budge = xargs.time_budget // xargs.time_scale
|
||||
start_time = time.time()
|
||||
logger.log('simulate_time_budge : {:} (in seconds).'.format(simulate_time_budge))
|
||||
bohb = BOHB(configspace=cs,
|
||||
run_id=hb_run_id,
|
||||
eta=3, min_budget=simulate_time_budge//3, max_budget=simulate_time_budge,
|
||||
eta=3, min_budget=12, max_budget=200,
|
||||
nameserver=ns_host,
|
||||
nameserver_port=ns_port,
|
||||
num_samples=xargs.num_samples,
|
||||
@ -161,11 +166,9 @@ def main(xargs, nas_bench):
|
||||
NS.shutdown()
|
||||
|
||||
real_cost_time = time.time() - start_time
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
id2config = results.get_id2config_mapping()
|
||||
incumbent = results.get_incumbent_id()
|
||||
|
||||
logger.log('Best found configuration: {:}'.format(id2config[incumbent]['config']))
|
||||
best_arch = config2structure( id2config[incumbent]['config'] )
|
||||
|
||||
@ -174,7 +177,7 @@ def main(xargs, nas_bench):
|
||||
else : logger.log('{:}'.format(info))
|
||||
logger.log('-'*100)
|
||||
|
||||
logger.log('workers : {:}'.format(workers[0].test_time))
|
||||
logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs)))
|
||||
logger.close()
|
||||
return logger.log_dir, nas_bench.query_index_by_arch( best_arch )
|
||||
|
||||
@ -190,7 +193,6 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
|
||||
parser.add_argument('--time_scale' , type=int, help='The time scale to accelerate the time budget.')
|
||||
# BOHB
|
||||
parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function')
|
||||
parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
|
||||
|
@ -82,14 +82,29 @@ def valid_func(xloader, network, criterion):
|
||||
return arch_losses.avg, arch_top1.avg, arch_top5.avg
|
||||
|
||||
|
||||
def search_find_best(valid_loader, network, criterion, select_num):
|
||||
best_arch, best_acc = None, -1
|
||||
for iarch in range(select_num):
|
||||
def search_find_best(xloader, network, n_samples):
|
||||
with torch.no_grad():
|
||||
network.eval()
|
||||
archs, valid_accs = [], []
|
||||
#print ('obtain the top-{:} architectures'.format(n_samples))
|
||||
loader_iter = iter(xloader)
|
||||
for i in range(n_samples):
|
||||
arch = network.module.random_genotype( True )
|
||||
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
if best_arch is None or best_acc < valid_a_top1:
|
||||
best_arch, best_acc = arch, valid_a_top1
|
||||
return best_arch
|
||||
try:
|
||||
inputs, targets = next(loader_iter)
|
||||
except:
|
||||
loader_iter = iter(xloader)
|
||||
inputs, targets = next(loader_iter)
|
||||
|
||||
_, logits = network(inputs)
|
||||
val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5))
|
||||
|
||||
archs.append( arch )
|
||||
valid_accs.append( val_top1.item() )
|
||||
|
||||
best_idx = np.argmax(valid_accs)
|
||||
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
|
||||
return best_arch, best_valid_acc
|
||||
|
||||
|
||||
def main(xargs):
|
||||
@ -127,7 +142,7 @@ def main(xargs):
|
||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
@ -177,7 +192,8 @@ def main(xargs):
|
||||
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
|
||||
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
||||
cur_arch = search_find_best(valid_loader, network, criterion, xargs.select_num)
|
||||
cur_arch, cur_valid_acc = search_find_best(valid_loader, network, xargs.select_num)
|
||||
logger.log('[{:}] find-the-best : {:}, accuracy@1={:.2f}%'.format(epoch_str, cur_arch, cur_valid_acc))
|
||||
genotypes[epoch] = cur_arch
|
||||
# check the best accuracy
|
||||
valid_accuracies[epoch] = valid_a_top1
|
||||
@ -211,13 +227,7 @@ def main(xargs):
|
||||
logger.log('\n' + '-'*200)
|
||||
logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum))
|
||||
start_time = time.time()
|
||||
best_arch, best_acc = None, -1
|
||||
for iarch in range(xargs.select_num):
|
||||
arch = search_model.random_genotype( True )
|
||||
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
logger.log('final evaluation [{:02d}/{:02d}] : {:} : accuracy={:.2f}%, loss={:.3f}'.format(iarch, xargs.select_num, arch, valid_a_top1, valid_a_loss))
|
||||
if best_arch is None or best_acc < valid_a_top1:
|
||||
best_arch, best_acc = arch, valid_a_top1
|
||||
best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num)
|
||||
search_time.update(time.time() - start_time)
|
||||
logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum))
|
||||
if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) ))
|
||||
|
@ -26,8 +26,6 @@ def get_depth_choices(nDepth, return_num):
|
||||
else : return choices
|
||||
|
||||
|
||||
|
||||
|
||||
def conv_forward(inputs, conv, choices):
|
||||
iC = conv.in_channels
|
||||
fill_size = list(inputs.size())
|
||||
|
@ -104,11 +104,16 @@ class NASBench102API(object):
|
||||
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
|
||||
return None
|
||||
|
||||
def query_by_index(self, arch_index, dataname, use_12epochs_result=False):
|
||||
# query information with the training of 12 epochs or 200 epochs
|
||||
# if dataname is None, return the ArchResults
|
||||
# else, return a dict with all trials on that dataset (the key is the seed)
|
||||
def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False):
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||
assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr)
|
||||
archInfo = copy.deepcopy( arch2infos[ arch_index ] )
|
||||
if dataname is None: return archInfo
|
||||
else:
|
||||
assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname)
|
||||
info = archInfo.query(dataname)
|
||||
return info
|
||||
@ -266,7 +271,7 @@ class ArchResults(object):
|
||||
def query(self, dataset, seed=None):
|
||||
if seed is None:
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
return [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
||||
return {seed: self.all_results[ (dataset, seed) ] for seed in x_seeds}
|
||||
else:
|
||||
return self.all_results[ (dataset, seed) ]
|
||||
|
||||
|
@ -34,6 +34,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \
|
||||
--dataset ${dataset} --data_path ${data_path} \
|
||||
--search_space_name ${space} \
|
||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
||||
--time_budget 12000 --time_scale 200 \
|
||||
--n_iters 64 --num_samples 4 --random_fraction 0 \
|
||||
--time_budget 12000 \
|
||||
--n_iters 28 --num_samples 64 --random_fraction .33 --bandwidth_factor 3 \
|
||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||
|
Loading…
Reference in New Issue
Block a user