fix bugs in RANDOM-NAS and BOHB

This commit is contained in:
D-X-Y 2019-12-29 20:17:26 +11:00
parent 4c144b7437
commit f8f44bfb31
8 changed files with 469 additions and 67 deletions

View File

@ -51,7 +51,7 @@ res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric
cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
# get the detailed information
results = api.query_by_index(1, 'cifar100') # a list of all trials on cifar100
results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed
print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1]))
print ('Latency : {:}'.format(results[0].get_latency()))
print ('Train Info : {:}'.format(results[0].get_train()))

View File

@ -9,5 +9,6 @@
"momentum" : ["float", "0.9"],
"nesterov" : ["bool", "1"],
"criterion": ["str", "Softmax"],
"batch_size": ["int", "64"]
"batch_size": ["int", "64"],
"test_batch_size": ["int", "512"]
}

View File

@ -0,0 +1,386 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
# python exps/NAS-Bench-102/visualize.py --api_path $HOME/.torch/NAS-Bench-102-v1_0-e61699.pth
##################################################
import os, sys, time, argparse, collections
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from pathlib import Path
from collections import defaultdict
import matplotlib
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
matplotlib.use('agg')
import matplotlib.pyplot as plt
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from log_utils import time_string
from nas_102_api import NASBench102API as API
def calculate_correlation(*vectors):
matrix = []
for i, vectori in enumerate(vectors):
x = []
for j, vectorj in enumerate(vectors):
x.append( np.corrcoef(vectori, vectorj)[0,1] )
matrix.append( x )
return np.array(matrix)
def visualize_relative_ranking(vis_save_dir):
print ('\n' + '-'*100)
cifar010_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('cifar10')
cifar100_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('cifar100')
imagenet_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('ImageNet16-120')
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info['params'])))
print ('{:} start to visualize relative ranking'.format(time_string()))
# maximum accuracy with ResNet-level params 11472
x_010_accs = [ cifar010_info['test_accs'][i] if cifar010_info['params'][i] <= cifar010_info['params'][11472] else -1 for i in indexes]
x_100_accs = [ cifar100_info['test_accs'][i] if cifar100_info['params'][i] <= cifar100_info['params'][11472] else -1 for i in indexes]
x_img_accs = [ imagenet_info['test_accs'][i] if imagenet_info['params'][i] <= imagenet_info['params'][11472] else -1 for i in indexes]
cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i])
cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i])
imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i])
cifar100_labels, imagenet_labels = [], []
for idx in cifar010_ord_indexes:
cifar100_labels.append( cifar100_ord_indexes.index(idx) )
imagenet_labels.append( imagenet_ord_indexes.index(idx) )
print ('{:} prepare data done.'.format(time_string()))
dpi, width, height = 300, 2600, 2600
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 18, 18
resnet_scale, resnet_alpha = 120, 0.5
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
#plt.ylabel('y').set_rotation(0)
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize, rotation='vertical')
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize)
#ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8, label='CIFAR-100')
#ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8, label='ImageNet-16-120')
#ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8, label='CIFAR-10')
ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8)
ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8)
ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8)
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10')
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100')
ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='ImageNet-16-120')
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc=0, fontsize=LegendFontsize)
ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize)
ax.set_ylabel('architecture ranking', fontsize=LabelSize)
save_path = (vis_save_dir / 'relative-rank.pdf').resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (vis_save_dir / 'relative-rank.png').resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
# calculate correlation
sns_size = 15
CoRelMatrix = calculate_correlation(cifar010_info['valid_accs'], cifar010_info['test_accs'], cifar100_info['valid_accs'], cifar100_info['test_accs'], imagenet_info['valid_accs'], imagenet_info['test_accs'])
fig = plt.figure(figsize=figsize)
plt.axis('off')
h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5)
save_path = (vis_save_dir / 'co-relation-all.pdf').resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
print ('{:} save into {:}'.format(time_string(), save_path))
# calculate correlation
acc_bars = [92, 93]
for acc_bar in acc_bars:
selected_indexes = []
for i, acc in enumerate(cifar010_info['test_accs']):
if acc > acc_bar: selected_indexes.append( i )
print ('select {:} architectures'.format(len(selected_indexes)))
cifar010_valid_accs = np.array(cifar010_info['valid_accs'])[ selected_indexes ]
cifar010_test_accs = np.array(cifar010_info['test_accs']) [ selected_indexes ]
cifar100_valid_accs = np.array(cifar100_info['valid_accs'])[ selected_indexes ]
cifar100_test_accs = np.array(cifar100_info['test_accs']) [ selected_indexes ]
imagenet_valid_accs = np.array(imagenet_info['valid_accs'])[ selected_indexes ]
imagenet_test_accs = np.array(imagenet_info['test_accs']) [ selected_indexes ]
CoRelMatrix = calculate_correlation(cifar010_valid_accs, cifar010_test_accs, cifar100_valid_accs, cifar100_test_accs, imagenet_valid_accs, imagenet_test_accs)
fig = plt.figure(figsize=figsize)
plt.axis('off')
h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5)
save_path = (vis_save_dir / 'co-relation-top-{:}.pdf'.format(len(selected_indexes))).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
print ('{:} save into {:}'.format(time_string(), save_path))
plt.close('all')
def visualize_info(meta_file, dataset, vis_save_dir):
print ('{:} start to visualize {:} information'.format(time_string(), dataset))
cache_file_path = vis_save_dir / '{:}-cache-info.pth'.format(dataset)
if not cache_file_path.exists():
print ('Do not find cache file : {:}'.format(cache_file_path))
nas_bench = API(str(meta_file))
params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], [], [], [], []
for index in range( len(nas_bench) ):
info = nas_bench.query_by_index(index, use_12epochs_result=False)
resx = info.get_comput_costs(dataset) ; flop, param = resx['flops'], resx['params']
if dataset == 'cifar10':
res = info.get_metrics('cifar10', 'train') ; train_acc = res['accuracy']
res = info.get_metrics('cifar10-valid', 'x-valid') ; valid_acc = res['accuracy']
res = info.get_metrics('cifar10', 'ori-test') ; test_acc = res['accuracy']
res = info.get_metrics('cifar10', 'ori-test') ; otest_acc = res['accuracy']
else:
res = info.get_metrics(dataset, 'train') ; train_acc = res['accuracy']
res = info.get_metrics(dataset, 'x-valid') ; valid_acc = res['accuracy']
res = info.get_metrics(dataset, 'x-test') ; test_acc = res['accuracy']
res = info.get_metrics(dataset, 'ori-test') ; otest_acc = res['accuracy']
if index == 11472: # resnet
resnet = {'params':param, 'flops': flop, 'index': 11472, 'train_acc': train_acc, 'valid_acc': valid_acc, 'test_acc': test_acc, 'otest_acc': otest_acc}
flops.append( flop )
params.append( param )
train_accs.append( train_acc )
valid_accs.append( valid_acc )
test_accs.append( test_acc )
otest_accs.append( otest_acc )
#resnet = {'params': 0.559, 'flops': 78.56, 'index': 11472, 'train_acc': 99.99, 'valid_acc': 90.84, 'test_acc': 93.97}
info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs, 'otest_accs': otest_accs}
info['resnet'] = resnet
torch.save(info, cache_file_path)
else:
print ('Find cache file : {:}'.format(cache_file_path))
info = torch.load(cache_file_path)
params, flops, train_accs, valid_accs, test_accs, otest_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'], info['otest_accs']
resnet = info['resnet']
print ('{:} collect data done.'.format(time_string()))
indexes = list(range(len(params)))
dpi, width, height = 300, 2600, 2600
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 22, 22
resnet_scale, resnet_alpha = 120, 0.5
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
if dataset == 'cifar10':
plt.ylim(50, 100)
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
elif dataset == 'cifar100':
plt.ylim(25, 75)
plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
else:
plt.ylim(0, 50)
plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
ax.scatter(params, valid_accs, marker='o', s=0.5, c='tab:blue')
ax.scatter([resnet['params']], [resnet['valid_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=0.4)
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc=4, fontsize=LegendFontsize)
ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
ax.set_ylabel('the validation accuracy (%)', fontsize=LabelSize)
save_path = (vis_save_dir / '{:}-param-vs-valid.pdf'.format(dataset)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (vis_save_dir / '{:}-param-vs-valid.png'.format(dataset)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
if dataset == 'cifar10':
plt.ylim(50, 100)
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
elif dataset == 'cifar100':
plt.ylim(25, 75)
plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
else:
plt.ylim(0, 50)
plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
ax.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue')
ax.scatter([resnet['params']], [resnet['test_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha)
plt.grid()
ax.set_axisbelow(True)
plt.legend(loc=4, fontsize=LegendFontsize)
ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize)
save_path = (vis_save_dir / '{:}-param-vs-test.pdf'.format(dataset)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (vis_save_dir / '{:}-param-vs-test.png'.format(dataset)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
if dataset == 'cifar10':
plt.ylim(50, 100)
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
elif dataset == 'cifar100':
plt.ylim(20, 100)
plt.yticks(np.arange(20, 101, 10), fontsize=LegendFontsize)
else:
plt.ylim(25, 76)
plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
ax.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue')
ax.scatter([resnet['params']], [resnet['train_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha)
plt.grid()
ax.set_axisbelow(True)
plt.legend(loc=4, fontsize=LegendFontsize)
ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
ax.set_ylabel('the trarining accuracy (%)', fontsize=LabelSize)
save_path = (vis_save_dir / '{:}-param-vs-train.pdf'.format(dataset)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (vis_save_dir / '{:}-param-vs-train.png'.format(dataset)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xlim(0, max(indexes))
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize)
if dataset == 'cifar10':
plt.ylim(50, 100)
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
elif dataset == 'cifar100':
plt.ylim(25, 75)
plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
else:
plt.ylim(0, 50)
plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
ax.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue')
ax.scatter([resnet['index']], [resnet['test_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha)
plt.grid()
ax.set_axisbelow(True)
plt.legend(loc=4, fontsize=LegendFontsize)
ax.set_xlabel('architecture ID', fontsize=LabelSize)
ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize)
save_path = (vis_save_dir / '{:}-test-over-ID.pdf'.format(dataset)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (vis_save_dir / '{:}-test-over-ID.png'.format(dataset)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
plt.close('all')
def visualize_rank_over_time(meta_file, vis_save_dir):
print ('\n' + '-'*150)
vis_save_dir.mkdir(parents=True, exist_ok=True)
print ('{:} start to visualize rank-over-time into {:}'.format(time_string(), vis_save_dir))
cache_file_path = vis_save_dir / 'rank-over-time-cache-info.pth'
if not cache_file_path.exists():
print ('Do not find cache file : {:}'.format(cache_file_path))
nas_bench = API(str(meta_file))
print ('{:} load nas_bench done'.format(time_string()))
params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list)
#for iepoch in range(200): for index in range( len(nas_bench) ):
for index in tqdm(range(len(nas_bench))):
info = nas_bench.query_by_index(index, use_12epochs_result=False)
for iepoch in range(200):
res = info.get_metrics('cifar10' , 'train' , iepoch) ; train_acc = res['accuracy']
res = info.get_metrics('cifar10-valid', 'x-valid' , iepoch) ; valid_acc = res['accuracy']
res = info.get_metrics('cifar10' , 'ori-test', iepoch) ; test_acc = res['accuracy']
res = info.get_metrics('cifar10' , 'ori-test', iepoch) ; otest_acc = res['accuracy']
train_accs[iepoch].append( train_acc )
valid_accs[iepoch].append( valid_acc )
test_accs [iepoch].append( test_acc )
otest_accs[iepoch].append( otest_acc )
if iepoch == 0:
res = info.get_comput_costs('cifar10') ; flop, param = res['flops'], res['params']
flops.append( flop )
params.append( param )
info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs, 'otest_accs': otest_accs}
torch.save(info, cache_file_path)
else:
print ('Find cache file : {:}'.format(cache_file_path))
info = torch.load(cache_file_path)
params, flops, train_accs, valid_accs, test_accs, otest_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'], info['otest_accs']
print ('{:} collect data done.'.format(time_string()))
#selected_epochs = [0, 100, 150, 180, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]
selected_epochs = list( range(200) )
x_xtests = test_accs[199]
indexes = list(range(len(x_xtests)))
ord_idxs = sorted(indexes, key=lambda i: x_xtests[i])
for sepoch in selected_epochs:
x_valids = valid_accs[sepoch]
valid_ord_idxs = sorted(indexes, key=lambda i: x_valids[i])
valid_ord_lbls = []
for idx in ord_idxs:
valid_ord_lbls.append( valid_ord_idxs.index(idx) )
# labeled data
dpi, width, height = 300, 2600, 2600
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 18, 18
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize, rotation='vertical')
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize)
ax.scatter(indexes, valid_ord_lbls, marker='^', s=0.5, c='tab:green', alpha=0.8)
ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8)
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-10 validation')
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10 test')
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc='upper left', fontsize=LegendFontsize)
ax.set_xlabel('architecture ranking in the final test accuracy', fontsize=LabelSize)
ax.set_ylabel('architecture ranking in the validation set', fontsize=LabelSize)
save_path = (vis_save_dir / 'time-{:03d}.pdf'.format(sepoch)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (vis_save_dir / 'time-{:03d}.png'.format(sepoch)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
plt.close('all')
def write_video(save_dir):
import cv2
video_save_path = save_dir / 'time.avi'
print ('{:} start create video for {:}'.format(time_string(), video_save_path))
images = sorted( list( save_dir.glob('time-*.png') ) )
ximage = cv2.imread(str(images[0]))
#shape = (ximage.shape[1], ximage.shape[0])
shape = (1000, 1000)
#writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 25, shape)
writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 5, shape)
for idx, image in enumerate(images):
ximage = cv2.imread(str(image))
_image = cv2.resize(ximage, shape)
writer.write(_image)
writer.release()
print ('write video [{:} frames] into {:}'.format(len(images), video_save_path))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-102/visual', help='The base-name of folder to save checkpoints and log.')
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-102 benchmark file.')
args = parser.parse_args()
vis_save_dir = Path(args.save_dir) / 'visuals'
vis_save_dir.mkdir(parents=True, exist_ok=True)
meta_file = Path(args.api_path)
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
visualize_rank_over_time(str(meta_file), vis_save_dir / 'over-time')
write_video(vis_save_dir / 'over-time')
visualize_info(str(meta_file), 'cifar10' , vis_save_dir)
visualize_info(str(meta_file), 'cifar100', vis_save_dir)
visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir)
visualize_relative_ranking(vis_save_dir)

View File

@ -53,43 +53,50 @@ def config2structure_func(max_nodes):
class MyWorker(Worker):
def __init__(self, *args, convert_func=None, nas_bench=None, time_scale=None, **kwargs):
def __init__(self, *args, convert_func=None, nas_bench=None, time_budget=None, **kwargs):
super().__init__(*args, **kwargs)
self.convert_func = convert_func
self.nas_bench = nas_bench
self.time_scale = time_scale
self.seen_arch = 0
self.time_budget = time_budget
self.seen_archs = []
self.sim_cost_time = 0
self.real_cost_time = 0
self.is_end = False
def get_the_best(self):
assert len(self.seen_archs) > 0
best_index, best_acc = -1, None
for arch_index in self.seen_archs:
info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True)
vacc = info['valid-accuracy']
if best_acc is None or best_acc < vacc:
best_acc = vacc
best_index = arch_index
assert best_index != -1
return best_index
def compute(self, config, budget, **kwargs):
start_time = time.time()
structure = self.convert_func( config )
arch_index = self.nas_bench.query_index_by_arch( structure )
iepoch = 0
while iepoch < 12:
info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch, True)
cur_time = info['train-all-time'] + info['valid-per-time']
cur_vacc = info['valid-accuracy']
if time.time() - start_time + cur_time / self.time_scale > budget:
break
else:
iepoch += 1
self.sim_cost_time += cur_time
self.seen_arch += 1
remaining_time = cur_time / self.time_scale - (time.time() - start_time)
if remaining_time > 0:
time.sleep(remaining_time)
else:
import pdb; pdb.set_trace()
info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True)
cur_time = info['train-all-time'] + info['valid-per-time']
cur_vacc = info['valid-accuracy']
self.real_cost_time += (time.time() - start_time)
return ({
'loss': 100 - float(cur_vacc),
'info': {'seen-arch' : self.seen_arch,
'sim-test-time' : self.sim_cost_time,
'real-test-time': self.real_cost_time,
'current-arch' : arch_index,
'current-budget': budget}
if self.sim_cost_time + cur_time <= self.time_budget and not self.is_end:
self.sim_cost_time += cur_time
self.seen_archs.append( arch_index )
return ({'loss': 100 - float(cur_vacc),
'info': {'seen-arch' : len(self.seen_archs),
'sim-test-time' : self.sim_cost_time,
'current-arch' : arch_index}
})
else:
self.is_end = True
return ({'loss': 100,
'info': {'seen-arch' : len(self.seen_archs),
'sim-test-time' : self.sim_cost_time,
'current-arch' : None}
})
@ -139,16 +146,14 @@ def main(xargs, nas_bench):
#logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string()))
workers = []
for i in range(num_workers):
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_scale=xargs.time_scale, run_id=hb_run_id, id=i)
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_budget=xargs.time_budget, run_id=hb_run_id, id=i)
w.run(background=True)
workers.append(w)
simulate_time_budge = xargs.time_budget // xargs.time_scale
start_time = time.time()
logger.log('simulate_time_budge : {:} (in seconds).'.format(simulate_time_budge))
bohb = BOHB(configspace=cs,
run_id=hb_run_id,
eta=3, min_budget=simulate_time_budge//3, max_budget=simulate_time_budge,
eta=3, min_budget=12, max_budget=200,
nameserver=ns_host,
nameserver_port=ns_port,
num_samples=xargs.num_samples,
@ -161,11 +166,9 @@ def main(xargs, nas_bench):
NS.shutdown()
real_cost_time = time.time() - start_time
import pdb; pdb.set_trace()
id2config = results.get_id2config_mapping()
incumbent = results.get_incumbent_id()
logger.log('Best found configuration: {:}'.format(id2config[incumbent]['config']))
best_arch = config2structure( id2config[incumbent]['config'] )
@ -174,7 +177,7 @@ def main(xargs, nas_bench):
else : logger.log('{:}'.format(info))
logger.log('-'*100)
logger.log('workers : {:}'.format(workers[0].test_time))
logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs)))
logger.close()
return logger.log_dir, nas_bench.query_index_by_arch( best_arch )
@ -190,14 +193,13 @@ if __name__ == '__main__':
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
parser.add_argument('--time_scale' , type=int, help='The time scale to accelerate the time budget.')
# BOHB
parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function')
parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function')
parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function')
parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function')
parser.add_argument('--random_fraction', default=.33, type=float, nargs='?', help='fraction of random configurations')
parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth')
parser.add_argument('--n_iters', default=100, type=int, nargs='?', help='number of iterations for optimization method')
parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth')
parser.add_argument('--n_iters', default=100, type=int, nargs='?', help='number of iterations for optimization method')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')

View File

@ -82,14 +82,29 @@ def valid_func(xloader, network, criterion):
return arch_losses.avg, arch_top1.avg, arch_top5.avg
def search_find_best(valid_loader, network, criterion, select_num):
best_arch, best_acc = None, -1
for iarch in range(select_num):
arch = network.module.random_genotype( True )
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion)
if best_arch is None or best_acc < valid_a_top1:
best_arch, best_acc = arch, valid_a_top1
return best_arch
def search_find_best(xloader, network, n_samples):
with torch.no_grad():
network.eval()
archs, valid_accs = [], []
#print ('obtain the top-{:} architectures'.format(n_samples))
loader_iter = iter(xloader)
for i in range(n_samples):
arch = network.module.random_genotype( True )
try:
inputs, targets = next(loader_iter)
except:
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
_, logits = network(inputs)
val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5))
archs.append( arch )
valid_accs.append( val_top1.item() )
best_idx = np.argmax(valid_accs)
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
return best_arch, best_valid_acc
def main(xargs):
@ -127,7 +142,7 @@ def main(xargs):
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
@ -177,7 +192,8 @@ def main(xargs):
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
cur_arch = search_find_best(valid_loader, network, criterion, xargs.select_num)
cur_arch, cur_valid_acc = search_find_best(valid_loader, network, xargs.select_num)
logger.log('[{:}] find-the-best : {:}, accuracy@1={:.2f}%'.format(epoch_str, cur_arch, cur_valid_acc))
genotypes[epoch] = cur_arch
# check the best accuracy
valid_accuracies[epoch] = valid_a_top1
@ -211,13 +227,7 @@ def main(xargs):
logger.log('\n' + '-'*200)
logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum))
start_time = time.time()
best_arch, best_acc = None, -1
for iarch in range(xargs.select_num):
arch = search_model.random_genotype( True )
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log('final evaluation [{:02d}/{:02d}] : {:} : accuracy={:.2f}%, loss={:.3f}'.format(iarch, xargs.select_num, arch, valid_a_top1, valid_a_loss))
if best_arch is None or best_acc < valid_a_top1:
best_arch, best_acc = arch, valid_a_top1
best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num)
search_time.update(time.time() - start_time)
logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum))
if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) ))

View File

@ -26,8 +26,6 @@ def get_depth_choices(nDepth, return_num):
else : return choices
def conv_forward(inputs, conv, choices):
iC = conv.in_channels
fill_size = list(inputs.size())

View File

@ -104,14 +104,19 @@ class NASBench102API(object):
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
return None
def query_by_index(self, arch_index, dataname, use_12epochs_result=False):
# query information with the training of 12 epochs or 200 epochs
# if dataname is None, return the ArchResults
# else, return a dict with all trials on that dataset (the key is the seed)
def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr)
archInfo = copy.deepcopy( arch2infos[ arch_index ] )
assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname)
info = archInfo.query(dataname)
return info
if dataname is None: return archInfo
else:
assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname)
info = archInfo.query(dataname)
return info
def query_meta_info_by_index(self, arch_index, use_12epochs_result=False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
@ -266,7 +271,7 @@ class ArchResults(object):
def query(self, dataset, seed=None):
if seed is None:
x_seeds = self.dataset_seed[dataset]
return [self.all_results[ (dataset, seed) ] for seed in x_seeds]
return {seed: self.all_results[ (dataset, seed) ] for seed in x_seeds}
else:
return self.all_results[ (dataset, seed) ]

View File

@ -34,6 +34,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \
--dataset ${dataset} --data_path ${data_path} \
--search_space_name ${space} \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
--time_budget 12000 --time_scale 200 \
--n_iters 64 --num_samples 4 --random_fraction 0 \
--time_budget 12000 \
--n_iters 28 --num_samples 64 --random_fraction .33 --bandwidth_factor 3 \
--workers 4 --print_freq 200 --rand_seed ${seed}