fix bugs in RANDOM-NAS and BOHB
This commit is contained in:
parent
4c144b7437
commit
f8f44bfb31
@ -51,7 +51,7 @@ res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric
|
|||||||
cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
|
cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
|
||||||
|
|
||||||
# get the detailed information
|
# get the detailed information
|
||||||
results = api.query_by_index(1, 'cifar100') # a list of all trials on cifar100
|
results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed
|
||||||
print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1]))
|
print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1]))
|
||||||
print ('Latency : {:}'.format(results[0].get_latency()))
|
print ('Latency : {:}'.format(results[0].get_latency()))
|
||||||
print ('Train Info : {:}'.format(results[0].get_train()))
|
print ('Train Info : {:}'.format(results[0].get_train()))
|
||||||
|
@ -9,5 +9,6 @@
|
|||||||
"momentum" : ["float", "0.9"],
|
"momentum" : ["float", "0.9"],
|
||||||
"nesterov" : ["bool", "1"],
|
"nesterov" : ["bool", "1"],
|
||||||
"criterion": ["str", "Softmax"],
|
"criterion": ["str", "Softmax"],
|
||||||
"batch_size": ["int", "64"]
|
"batch_size": ["int", "64"],
|
||||||
|
"test_batch_size": ["int", "512"]
|
||||||
}
|
}
|
||||||
|
386
exps/NAS-Bench-102/visualize.py
Normal file
386
exps/NAS-Bench-102/visualize.py
Normal file
@ -0,0 +1,386 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
##################################################
|
||||||
|
# python exps/NAS-Bench-102/visualize.py --api_path $HOME/.torch/NAS-Bench-102-v1_0-e61699.pth
|
||||||
|
##################################################
|
||||||
|
import os, sys, time, argparse, collections
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from pathlib import Path
|
||||||
|
from collections import defaultdict
|
||||||
|
import matplotlib
|
||||||
|
import seaborn as sns
|
||||||
|
from mpl_toolkits.mplot3d import Axes3D
|
||||||
|
matplotlib.use('agg')
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||||
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||||
|
from log_utils import time_string
|
||||||
|
from nas_102_api import NASBench102API as API
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_correlation(*vectors):
|
||||||
|
matrix = []
|
||||||
|
for i, vectori in enumerate(vectors):
|
||||||
|
x = []
|
||||||
|
for j, vectorj in enumerate(vectors):
|
||||||
|
x.append( np.corrcoef(vectori, vectorj)[0,1] )
|
||||||
|
matrix.append( x )
|
||||||
|
return np.array(matrix)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_relative_ranking(vis_save_dir):
|
||||||
|
print ('\n' + '-'*100)
|
||||||
|
cifar010_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('cifar10')
|
||||||
|
cifar100_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('cifar100')
|
||||||
|
imagenet_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('ImageNet16-120')
|
||||||
|
cifar010_info = torch.load(cifar010_cache_path)
|
||||||
|
cifar100_info = torch.load(cifar100_cache_path)
|
||||||
|
imagenet_info = torch.load(imagenet_cache_path)
|
||||||
|
indexes = list(range(len(cifar010_info['params'])))
|
||||||
|
|
||||||
|
print ('{:} start to visualize relative ranking'.format(time_string()))
|
||||||
|
# maximum accuracy with ResNet-level params 11472
|
||||||
|
x_010_accs = [ cifar010_info['test_accs'][i] if cifar010_info['params'][i] <= cifar010_info['params'][11472] else -1 for i in indexes]
|
||||||
|
x_100_accs = [ cifar100_info['test_accs'][i] if cifar100_info['params'][i] <= cifar100_info['params'][11472] else -1 for i in indexes]
|
||||||
|
x_img_accs = [ imagenet_info['test_accs'][i] if imagenet_info['params'][i] <= imagenet_info['params'][11472] else -1 for i in indexes]
|
||||||
|
|
||||||
|
cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i])
|
||||||
|
cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i])
|
||||||
|
imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i])
|
||||||
|
|
||||||
|
cifar100_labels, imagenet_labels = [], []
|
||||||
|
for idx in cifar010_ord_indexes:
|
||||||
|
cifar100_labels.append( cifar100_ord_indexes.index(idx) )
|
||||||
|
imagenet_labels.append( imagenet_ord_indexes.index(idx) )
|
||||||
|
print ('{:} prepare data done.'.format(time_string()))
|
||||||
|
|
||||||
|
dpi, width, height = 300, 2600, 2600
|
||||||
|
figsize = width / float(dpi), height / float(dpi)
|
||||||
|
LabelSize, LegendFontsize = 18, 18
|
||||||
|
resnet_scale, resnet_alpha = 120, 0.5
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=figsize)
|
||||||
|
ax = fig.add_subplot(111)
|
||||||
|
plt.xlim(min(indexes), max(indexes))
|
||||||
|
plt.ylim(min(indexes), max(indexes))
|
||||||
|
#plt.ylabel('y').set_rotation(0)
|
||||||
|
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize, rotation='vertical')
|
||||||
|
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize)
|
||||||
|
#ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8, label='CIFAR-100')
|
||||||
|
#ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8, label='ImageNet-16-120')
|
||||||
|
#ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8, label='CIFAR-10')
|
||||||
|
ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8)
|
||||||
|
ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8)
|
||||||
|
ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8)
|
||||||
|
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10')
|
||||||
|
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100')
|
||||||
|
ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='ImageNet-16-120')
|
||||||
|
plt.grid(zorder=0)
|
||||||
|
ax.set_axisbelow(True)
|
||||||
|
plt.legend(loc=0, fontsize=LegendFontsize)
|
||||||
|
ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize)
|
||||||
|
ax.set_ylabel('architecture ranking', fontsize=LabelSize)
|
||||||
|
save_path = (vis_save_dir / 'relative-rank.pdf').resolve()
|
||||||
|
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||||
|
save_path = (vis_save_dir / 'relative-rank.png').resolve()
|
||||||
|
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||||
|
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||||
|
|
||||||
|
# calculate correlation
|
||||||
|
sns_size = 15
|
||||||
|
CoRelMatrix = calculate_correlation(cifar010_info['valid_accs'], cifar010_info['test_accs'], cifar100_info['valid_accs'], cifar100_info['test_accs'], imagenet_info['valid_accs'], imagenet_info['test_accs'])
|
||||||
|
fig = plt.figure(figsize=figsize)
|
||||||
|
plt.axis('off')
|
||||||
|
h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5)
|
||||||
|
save_path = (vis_save_dir / 'co-relation-all.pdf').resolve()
|
||||||
|
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||||
|
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||||
|
|
||||||
|
# calculate correlation
|
||||||
|
acc_bars = [92, 93]
|
||||||
|
for acc_bar in acc_bars:
|
||||||
|
selected_indexes = []
|
||||||
|
for i, acc in enumerate(cifar010_info['test_accs']):
|
||||||
|
if acc > acc_bar: selected_indexes.append( i )
|
||||||
|
print ('select {:} architectures'.format(len(selected_indexes)))
|
||||||
|
cifar010_valid_accs = np.array(cifar010_info['valid_accs'])[ selected_indexes ]
|
||||||
|
cifar010_test_accs = np.array(cifar010_info['test_accs']) [ selected_indexes ]
|
||||||
|
cifar100_valid_accs = np.array(cifar100_info['valid_accs'])[ selected_indexes ]
|
||||||
|
cifar100_test_accs = np.array(cifar100_info['test_accs']) [ selected_indexes ]
|
||||||
|
imagenet_valid_accs = np.array(imagenet_info['valid_accs'])[ selected_indexes ]
|
||||||
|
imagenet_test_accs = np.array(imagenet_info['test_accs']) [ selected_indexes ]
|
||||||
|
CoRelMatrix = calculate_correlation(cifar010_valid_accs, cifar010_test_accs, cifar100_valid_accs, cifar100_test_accs, imagenet_valid_accs, imagenet_test_accs)
|
||||||
|
fig = plt.figure(figsize=figsize)
|
||||||
|
plt.axis('off')
|
||||||
|
h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5)
|
||||||
|
save_path = (vis_save_dir / 'co-relation-top-{:}.pdf'.format(len(selected_indexes))).resolve()
|
||||||
|
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||||
|
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||||
|
plt.close('all')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_info(meta_file, dataset, vis_save_dir):
|
||||||
|
print ('{:} start to visualize {:} information'.format(time_string(), dataset))
|
||||||
|
cache_file_path = vis_save_dir / '{:}-cache-info.pth'.format(dataset)
|
||||||
|
if not cache_file_path.exists():
|
||||||
|
print ('Do not find cache file : {:}'.format(cache_file_path))
|
||||||
|
nas_bench = API(str(meta_file))
|
||||||
|
params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], [], [], [], []
|
||||||
|
for index in range( len(nas_bench) ):
|
||||||
|
info = nas_bench.query_by_index(index, use_12epochs_result=False)
|
||||||
|
resx = info.get_comput_costs(dataset) ; flop, param = resx['flops'], resx['params']
|
||||||
|
if dataset == 'cifar10':
|
||||||
|
res = info.get_metrics('cifar10', 'train') ; train_acc = res['accuracy']
|
||||||
|
res = info.get_metrics('cifar10-valid', 'x-valid') ; valid_acc = res['accuracy']
|
||||||
|
res = info.get_metrics('cifar10', 'ori-test') ; test_acc = res['accuracy']
|
||||||
|
res = info.get_metrics('cifar10', 'ori-test') ; otest_acc = res['accuracy']
|
||||||
|
else:
|
||||||
|
res = info.get_metrics(dataset, 'train') ; train_acc = res['accuracy']
|
||||||
|
res = info.get_metrics(dataset, 'x-valid') ; valid_acc = res['accuracy']
|
||||||
|
res = info.get_metrics(dataset, 'x-test') ; test_acc = res['accuracy']
|
||||||
|
res = info.get_metrics(dataset, 'ori-test') ; otest_acc = res['accuracy']
|
||||||
|
if index == 11472: # resnet
|
||||||
|
resnet = {'params':param, 'flops': flop, 'index': 11472, 'train_acc': train_acc, 'valid_acc': valid_acc, 'test_acc': test_acc, 'otest_acc': otest_acc}
|
||||||
|
flops.append( flop )
|
||||||
|
params.append( param )
|
||||||
|
train_accs.append( train_acc )
|
||||||
|
valid_accs.append( valid_acc )
|
||||||
|
test_accs.append( test_acc )
|
||||||
|
otest_accs.append( otest_acc )
|
||||||
|
#resnet = {'params': 0.559, 'flops': 78.56, 'index': 11472, 'train_acc': 99.99, 'valid_acc': 90.84, 'test_acc': 93.97}
|
||||||
|
info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs, 'otest_accs': otest_accs}
|
||||||
|
info['resnet'] = resnet
|
||||||
|
torch.save(info, cache_file_path)
|
||||||
|
else:
|
||||||
|
print ('Find cache file : {:}'.format(cache_file_path))
|
||||||
|
info = torch.load(cache_file_path)
|
||||||
|
params, flops, train_accs, valid_accs, test_accs, otest_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'], info['otest_accs']
|
||||||
|
resnet = info['resnet']
|
||||||
|
print ('{:} collect data done.'.format(time_string()))
|
||||||
|
|
||||||
|
indexes = list(range(len(params)))
|
||||||
|
dpi, width, height = 300, 2600, 2600
|
||||||
|
figsize = width / float(dpi), height / float(dpi)
|
||||||
|
LabelSize, LegendFontsize = 22, 22
|
||||||
|
resnet_scale, resnet_alpha = 120, 0.5
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=figsize)
|
||||||
|
ax = fig.add_subplot(111)
|
||||||
|
plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
|
||||||
|
if dataset == 'cifar10':
|
||||||
|
plt.ylim(50, 100)
|
||||||
|
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
|
||||||
|
elif dataset == 'cifar100':
|
||||||
|
plt.ylim(25, 75)
|
||||||
|
plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
|
||||||
|
else:
|
||||||
|
plt.ylim(0, 50)
|
||||||
|
plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
|
||||||
|
ax.scatter(params, valid_accs, marker='o', s=0.5, c='tab:blue')
|
||||||
|
ax.scatter([resnet['params']], [resnet['valid_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=0.4)
|
||||||
|
plt.grid(zorder=0)
|
||||||
|
ax.set_axisbelow(True)
|
||||||
|
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||||
|
ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
|
||||||
|
ax.set_ylabel('the validation accuracy (%)', fontsize=LabelSize)
|
||||||
|
save_path = (vis_save_dir / '{:}-param-vs-valid.pdf'.format(dataset)).resolve()
|
||||||
|
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||||
|
save_path = (vis_save_dir / '{:}-param-vs-valid.png'.format(dataset)).resolve()
|
||||||
|
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||||
|
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=figsize)
|
||||||
|
ax = fig.add_subplot(111)
|
||||||
|
plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
|
||||||
|
if dataset == 'cifar10':
|
||||||
|
plt.ylim(50, 100)
|
||||||
|
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
|
||||||
|
elif dataset == 'cifar100':
|
||||||
|
plt.ylim(25, 75)
|
||||||
|
plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
|
||||||
|
else:
|
||||||
|
plt.ylim(0, 50)
|
||||||
|
plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
|
||||||
|
ax.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue')
|
||||||
|
ax.scatter([resnet['params']], [resnet['test_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha)
|
||||||
|
plt.grid()
|
||||||
|
ax.set_axisbelow(True)
|
||||||
|
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||||
|
ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
|
||||||
|
ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize)
|
||||||
|
save_path = (vis_save_dir / '{:}-param-vs-test.pdf'.format(dataset)).resolve()
|
||||||
|
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||||
|
save_path = (vis_save_dir / '{:}-param-vs-test.png'.format(dataset)).resolve()
|
||||||
|
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||||
|
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=figsize)
|
||||||
|
ax = fig.add_subplot(111)
|
||||||
|
plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
|
||||||
|
if dataset == 'cifar10':
|
||||||
|
plt.ylim(50, 100)
|
||||||
|
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
|
||||||
|
elif dataset == 'cifar100':
|
||||||
|
plt.ylim(20, 100)
|
||||||
|
plt.yticks(np.arange(20, 101, 10), fontsize=LegendFontsize)
|
||||||
|
else:
|
||||||
|
plt.ylim(25, 76)
|
||||||
|
plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
|
||||||
|
ax.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue')
|
||||||
|
ax.scatter([resnet['params']], [resnet['train_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha)
|
||||||
|
plt.grid()
|
||||||
|
ax.set_axisbelow(True)
|
||||||
|
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||||
|
ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
|
||||||
|
ax.set_ylabel('the trarining accuracy (%)', fontsize=LabelSize)
|
||||||
|
save_path = (vis_save_dir / '{:}-param-vs-train.pdf'.format(dataset)).resolve()
|
||||||
|
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||||
|
save_path = (vis_save_dir / '{:}-param-vs-train.png'.format(dataset)).resolve()
|
||||||
|
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||||
|
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=figsize)
|
||||||
|
ax = fig.add_subplot(111)
|
||||||
|
plt.xlim(0, max(indexes))
|
||||||
|
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize)
|
||||||
|
if dataset == 'cifar10':
|
||||||
|
plt.ylim(50, 100)
|
||||||
|
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
|
||||||
|
elif dataset == 'cifar100':
|
||||||
|
plt.ylim(25, 75)
|
||||||
|
plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
|
||||||
|
else:
|
||||||
|
plt.ylim(0, 50)
|
||||||
|
plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
|
||||||
|
ax.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue')
|
||||||
|
ax.scatter([resnet['index']], [resnet['test_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha)
|
||||||
|
plt.grid()
|
||||||
|
ax.set_axisbelow(True)
|
||||||
|
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||||
|
ax.set_xlabel('architecture ID', fontsize=LabelSize)
|
||||||
|
ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize)
|
||||||
|
save_path = (vis_save_dir / '{:}-test-over-ID.pdf'.format(dataset)).resolve()
|
||||||
|
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||||
|
save_path = (vis_save_dir / '{:}-test-over-ID.png'.format(dataset)).resolve()
|
||||||
|
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||||
|
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||||
|
plt.close('all')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_rank_over_time(meta_file, vis_save_dir):
|
||||||
|
print ('\n' + '-'*150)
|
||||||
|
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
print ('{:} start to visualize rank-over-time into {:}'.format(time_string(), vis_save_dir))
|
||||||
|
cache_file_path = vis_save_dir / 'rank-over-time-cache-info.pth'
|
||||||
|
if not cache_file_path.exists():
|
||||||
|
print ('Do not find cache file : {:}'.format(cache_file_path))
|
||||||
|
nas_bench = API(str(meta_file))
|
||||||
|
print ('{:} load nas_bench done'.format(time_string()))
|
||||||
|
params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list)
|
||||||
|
#for iepoch in range(200): for index in range( len(nas_bench) ):
|
||||||
|
for index in tqdm(range(len(nas_bench))):
|
||||||
|
info = nas_bench.query_by_index(index, use_12epochs_result=False)
|
||||||
|
for iepoch in range(200):
|
||||||
|
res = info.get_metrics('cifar10' , 'train' , iepoch) ; train_acc = res['accuracy']
|
||||||
|
res = info.get_metrics('cifar10-valid', 'x-valid' , iepoch) ; valid_acc = res['accuracy']
|
||||||
|
res = info.get_metrics('cifar10' , 'ori-test', iepoch) ; test_acc = res['accuracy']
|
||||||
|
res = info.get_metrics('cifar10' , 'ori-test', iepoch) ; otest_acc = res['accuracy']
|
||||||
|
train_accs[iepoch].append( train_acc )
|
||||||
|
valid_accs[iepoch].append( valid_acc )
|
||||||
|
test_accs [iepoch].append( test_acc )
|
||||||
|
otest_accs[iepoch].append( otest_acc )
|
||||||
|
if iepoch == 0:
|
||||||
|
res = info.get_comput_costs('cifar10') ; flop, param = res['flops'], res['params']
|
||||||
|
flops.append( flop )
|
||||||
|
params.append( param )
|
||||||
|
info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs, 'otest_accs': otest_accs}
|
||||||
|
torch.save(info, cache_file_path)
|
||||||
|
else:
|
||||||
|
print ('Find cache file : {:}'.format(cache_file_path))
|
||||||
|
info = torch.load(cache_file_path)
|
||||||
|
params, flops, train_accs, valid_accs, test_accs, otest_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'], info['otest_accs']
|
||||||
|
print ('{:} collect data done.'.format(time_string()))
|
||||||
|
#selected_epochs = [0, 100, 150, 180, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]
|
||||||
|
selected_epochs = list( range(200) )
|
||||||
|
x_xtests = test_accs[199]
|
||||||
|
indexes = list(range(len(x_xtests)))
|
||||||
|
ord_idxs = sorted(indexes, key=lambda i: x_xtests[i])
|
||||||
|
for sepoch in selected_epochs:
|
||||||
|
x_valids = valid_accs[sepoch]
|
||||||
|
valid_ord_idxs = sorted(indexes, key=lambda i: x_valids[i])
|
||||||
|
valid_ord_lbls = []
|
||||||
|
for idx in ord_idxs:
|
||||||
|
valid_ord_lbls.append( valid_ord_idxs.index(idx) )
|
||||||
|
# labeled data
|
||||||
|
dpi, width, height = 300, 2600, 2600
|
||||||
|
figsize = width / float(dpi), height / float(dpi)
|
||||||
|
LabelSize, LegendFontsize = 18, 18
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=figsize)
|
||||||
|
ax = fig.add_subplot(111)
|
||||||
|
plt.xlim(min(indexes), max(indexes))
|
||||||
|
plt.ylim(min(indexes), max(indexes))
|
||||||
|
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize, rotation='vertical')
|
||||||
|
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize)
|
||||||
|
ax.scatter(indexes, valid_ord_lbls, marker='^', s=0.5, c='tab:green', alpha=0.8)
|
||||||
|
ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8)
|
||||||
|
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-10 validation')
|
||||||
|
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10 test')
|
||||||
|
plt.grid(zorder=0)
|
||||||
|
ax.set_axisbelow(True)
|
||||||
|
plt.legend(loc='upper left', fontsize=LegendFontsize)
|
||||||
|
ax.set_xlabel('architecture ranking in the final test accuracy', fontsize=LabelSize)
|
||||||
|
ax.set_ylabel('architecture ranking in the validation set', fontsize=LabelSize)
|
||||||
|
save_path = (vis_save_dir / 'time-{:03d}.pdf'.format(sepoch)).resolve()
|
||||||
|
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||||
|
save_path = (vis_save_dir / 'time-{:03d}.png'.format(sepoch)).resolve()
|
||||||
|
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||||
|
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||||
|
plt.close('all')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def write_video(save_dir):
|
||||||
|
import cv2
|
||||||
|
video_save_path = save_dir / 'time.avi'
|
||||||
|
print ('{:} start create video for {:}'.format(time_string(), video_save_path))
|
||||||
|
images = sorted( list( save_dir.glob('time-*.png') ) )
|
||||||
|
ximage = cv2.imread(str(images[0]))
|
||||||
|
#shape = (ximage.shape[1], ximage.shape[0])
|
||||||
|
shape = (1000, 1000)
|
||||||
|
#writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 25, shape)
|
||||||
|
writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 5, shape)
|
||||||
|
for idx, image in enumerate(images):
|
||||||
|
ximage = cv2.imread(str(image))
|
||||||
|
_image = cv2.resize(ximage, shape)
|
||||||
|
writer.write(_image)
|
||||||
|
writer.release()
|
||||||
|
print ('write video [{:} frames] into {:}'.format(len(images), video_save_path))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-102/visual', help='The base-name of folder to save checkpoints and log.')
|
||||||
|
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-102 benchmark file.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
vis_save_dir = Path(args.save_dir) / 'visuals'
|
||||||
|
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
meta_file = Path(args.api_path)
|
||||||
|
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
|
||||||
|
visualize_rank_over_time(str(meta_file), vis_save_dir / 'over-time')
|
||||||
|
write_video(vis_save_dir / 'over-time')
|
||||||
|
visualize_info(str(meta_file), 'cifar10' , vis_save_dir)
|
||||||
|
visualize_info(str(meta_file), 'cifar100', vis_save_dir)
|
||||||
|
visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir)
|
||||||
|
visualize_relative_ranking(vis_save_dir)
|
@ -53,43 +53,50 @@ def config2structure_func(max_nodes):
|
|||||||
|
|
||||||
class MyWorker(Worker):
|
class MyWorker(Worker):
|
||||||
|
|
||||||
def __init__(self, *args, convert_func=None, nas_bench=None, time_scale=None, **kwargs):
|
def __init__(self, *args, convert_func=None, nas_bench=None, time_budget=None, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.convert_func = convert_func
|
self.convert_func = convert_func
|
||||||
self.nas_bench = nas_bench
|
self.nas_bench = nas_bench
|
||||||
self.time_scale = time_scale
|
self.time_budget = time_budget
|
||||||
self.seen_arch = 0
|
self.seen_archs = []
|
||||||
self.sim_cost_time = 0
|
self.sim_cost_time = 0
|
||||||
self.real_cost_time = 0
|
self.real_cost_time = 0
|
||||||
|
self.is_end = False
|
||||||
|
|
||||||
|
def get_the_best(self):
|
||||||
|
assert len(self.seen_archs) > 0
|
||||||
|
best_index, best_acc = -1, None
|
||||||
|
for arch_index in self.seen_archs:
|
||||||
|
info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True)
|
||||||
|
vacc = info['valid-accuracy']
|
||||||
|
if best_acc is None or best_acc < vacc:
|
||||||
|
best_acc = vacc
|
||||||
|
best_index = arch_index
|
||||||
|
assert best_index != -1
|
||||||
|
return best_index
|
||||||
|
|
||||||
def compute(self, config, budget, **kwargs):
|
def compute(self, config, budget, **kwargs):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
structure = self.convert_func( config )
|
structure = self.convert_func( config )
|
||||||
arch_index = self.nas_bench.query_index_by_arch( structure )
|
arch_index = self.nas_bench.query_index_by_arch( structure )
|
||||||
iepoch = 0
|
info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True)
|
||||||
while iepoch < 12:
|
cur_time = info['train-all-time'] + info['valid-per-time']
|
||||||
info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch, True)
|
cur_vacc = info['valid-accuracy']
|
||||||
cur_time = info['train-all-time'] + info['valid-per-time']
|
|
||||||
cur_vacc = info['valid-accuracy']
|
|
||||||
if time.time() - start_time + cur_time / self.time_scale > budget:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
iepoch += 1
|
|
||||||
self.sim_cost_time += cur_time
|
|
||||||
self.seen_arch += 1
|
|
||||||
remaining_time = cur_time / self.time_scale - (time.time() - start_time)
|
|
||||||
if remaining_time > 0:
|
|
||||||
time.sleep(remaining_time)
|
|
||||||
else:
|
|
||||||
import pdb; pdb.set_trace()
|
|
||||||
self.real_cost_time += (time.time() - start_time)
|
self.real_cost_time += (time.time() - start_time)
|
||||||
return ({
|
if self.sim_cost_time + cur_time <= self.time_budget and not self.is_end:
|
||||||
'loss': 100 - float(cur_vacc),
|
self.sim_cost_time += cur_time
|
||||||
'info': {'seen-arch' : self.seen_arch,
|
self.seen_archs.append( arch_index )
|
||||||
'sim-test-time' : self.sim_cost_time,
|
return ({'loss': 100 - float(cur_vacc),
|
||||||
'real-test-time': self.real_cost_time,
|
'info': {'seen-arch' : len(self.seen_archs),
|
||||||
'current-arch' : arch_index,
|
'sim-test-time' : self.sim_cost_time,
|
||||||
'current-budget': budget}
|
'current-arch' : arch_index}
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
self.is_end = True
|
||||||
|
return ({'loss': 100,
|
||||||
|
'info': {'seen-arch' : len(self.seen_archs),
|
||||||
|
'sim-test-time' : self.sim_cost_time,
|
||||||
|
'current-arch' : None}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
@ -139,16 +146,14 @@ def main(xargs, nas_bench):
|
|||||||
#logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string()))
|
#logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string()))
|
||||||
workers = []
|
workers = []
|
||||||
for i in range(num_workers):
|
for i in range(num_workers):
|
||||||
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_scale=xargs.time_scale, run_id=hb_run_id, id=i)
|
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_budget=xargs.time_budget, run_id=hb_run_id, id=i)
|
||||||
w.run(background=True)
|
w.run(background=True)
|
||||||
workers.append(w)
|
workers.append(w)
|
||||||
|
|
||||||
simulate_time_budge = xargs.time_budget // xargs.time_scale
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
logger.log('simulate_time_budge : {:} (in seconds).'.format(simulate_time_budge))
|
|
||||||
bohb = BOHB(configspace=cs,
|
bohb = BOHB(configspace=cs,
|
||||||
run_id=hb_run_id,
|
run_id=hb_run_id,
|
||||||
eta=3, min_budget=simulate_time_budge//3, max_budget=simulate_time_budge,
|
eta=3, min_budget=12, max_budget=200,
|
||||||
nameserver=ns_host,
|
nameserver=ns_host,
|
||||||
nameserver_port=ns_port,
|
nameserver_port=ns_port,
|
||||||
num_samples=xargs.num_samples,
|
num_samples=xargs.num_samples,
|
||||||
@ -161,11 +166,9 @@ def main(xargs, nas_bench):
|
|||||||
NS.shutdown()
|
NS.shutdown()
|
||||||
|
|
||||||
real_cost_time = time.time() - start_time
|
real_cost_time = time.time() - start_time
|
||||||
import pdb; pdb.set_trace()
|
|
||||||
|
|
||||||
id2config = results.get_id2config_mapping()
|
id2config = results.get_id2config_mapping()
|
||||||
incumbent = results.get_incumbent_id()
|
incumbent = results.get_incumbent_id()
|
||||||
|
|
||||||
logger.log('Best found configuration: {:}'.format(id2config[incumbent]['config']))
|
logger.log('Best found configuration: {:}'.format(id2config[incumbent]['config']))
|
||||||
best_arch = config2structure( id2config[incumbent]['config'] )
|
best_arch = config2structure( id2config[incumbent]['config'] )
|
||||||
|
|
||||||
@ -174,7 +177,7 @@ def main(xargs, nas_bench):
|
|||||||
else : logger.log('{:}'.format(info))
|
else : logger.log('{:}'.format(info))
|
||||||
logger.log('-'*100)
|
logger.log('-'*100)
|
||||||
|
|
||||||
logger.log('workers : {:}'.format(workers[0].test_time))
|
logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs)))
|
||||||
logger.close()
|
logger.close()
|
||||||
return logger.log_dir, nas_bench.query_index_by_arch( best_arch )
|
return logger.log_dir, nas_bench.query_index_by_arch( best_arch )
|
||||||
|
|
||||||
@ -190,14 +193,13 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||||
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||||
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
|
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
|
||||||
parser.add_argument('--time_scale' , type=int, help='The time scale to accelerate the time budget.')
|
|
||||||
# BOHB
|
# BOHB
|
||||||
parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function')
|
parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function')
|
||||||
parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
|
parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
|
||||||
parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function')
|
parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function')
|
||||||
parser.add_argument('--random_fraction', default=.33, type=float, nargs='?', help='fraction of random configurations')
|
parser.add_argument('--random_fraction', default=.33, type=float, nargs='?', help='fraction of random configurations')
|
||||||
parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth')
|
parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth')
|
||||||
parser.add_argument('--n_iters', default=100, type=int, nargs='?', help='number of iterations for optimization method')
|
parser.add_argument('--n_iters', default=100, type=int, nargs='?', help='number of iterations for optimization method')
|
||||||
# log
|
# log
|
||||||
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||||
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
||||||
|
@ -82,14 +82,29 @@ def valid_func(xloader, network, criterion):
|
|||||||
return arch_losses.avg, arch_top1.avg, arch_top5.avg
|
return arch_losses.avg, arch_top1.avg, arch_top5.avg
|
||||||
|
|
||||||
|
|
||||||
def search_find_best(valid_loader, network, criterion, select_num):
|
def search_find_best(xloader, network, n_samples):
|
||||||
best_arch, best_acc = None, -1
|
with torch.no_grad():
|
||||||
for iarch in range(select_num):
|
network.eval()
|
||||||
arch = network.module.random_genotype( True )
|
archs, valid_accs = [], []
|
||||||
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion)
|
#print ('obtain the top-{:} architectures'.format(n_samples))
|
||||||
if best_arch is None or best_acc < valid_a_top1:
|
loader_iter = iter(xloader)
|
||||||
best_arch, best_acc = arch, valid_a_top1
|
for i in range(n_samples):
|
||||||
return best_arch
|
arch = network.module.random_genotype( True )
|
||||||
|
try:
|
||||||
|
inputs, targets = next(loader_iter)
|
||||||
|
except:
|
||||||
|
loader_iter = iter(xloader)
|
||||||
|
inputs, targets = next(loader_iter)
|
||||||
|
|
||||||
|
_, logits = network(inputs)
|
||||||
|
val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5))
|
||||||
|
|
||||||
|
archs.append( arch )
|
||||||
|
valid_accs.append( val_top1.item() )
|
||||||
|
|
||||||
|
best_idx = np.argmax(valid_accs)
|
||||||
|
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
|
||||||
|
return best_arch, best_valid_acc
|
||||||
|
|
||||||
|
|
||||||
def main(xargs):
|
def main(xargs):
|
||||||
@ -127,7 +142,7 @@ def main(xargs):
|
|||||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||||
# data loader
|
# data loader
|
||||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||||
|
|
||||||
@ -177,7 +192,8 @@ def main(xargs):
|
|||||||
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
|
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
|
||||||
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||||
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
||||||
cur_arch = search_find_best(valid_loader, network, criterion, xargs.select_num)
|
cur_arch, cur_valid_acc = search_find_best(valid_loader, network, xargs.select_num)
|
||||||
|
logger.log('[{:}] find-the-best : {:}, accuracy@1={:.2f}%'.format(epoch_str, cur_arch, cur_valid_acc))
|
||||||
genotypes[epoch] = cur_arch
|
genotypes[epoch] = cur_arch
|
||||||
# check the best accuracy
|
# check the best accuracy
|
||||||
valid_accuracies[epoch] = valid_a_top1
|
valid_accuracies[epoch] = valid_a_top1
|
||||||
@ -211,13 +227,7 @@ def main(xargs):
|
|||||||
logger.log('\n' + '-'*200)
|
logger.log('\n' + '-'*200)
|
||||||
logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum))
|
logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum))
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
best_arch, best_acc = None, -1
|
best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num)
|
||||||
for iarch in range(xargs.select_num):
|
|
||||||
arch = search_model.random_genotype( True )
|
|
||||||
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion)
|
|
||||||
logger.log('final evaluation [{:02d}/{:02d}] : {:} : accuracy={:.2f}%, loss={:.3f}'.format(iarch, xargs.select_num, arch, valid_a_top1, valid_a_loss))
|
|
||||||
if best_arch is None or best_acc < valid_a_top1:
|
|
||||||
best_arch, best_acc = arch, valid_a_top1
|
|
||||||
search_time.update(time.time() - start_time)
|
search_time.update(time.time() - start_time)
|
||||||
logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum))
|
logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum))
|
||||||
if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) ))
|
if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) ))
|
||||||
|
@ -26,8 +26,6 @@ def get_depth_choices(nDepth, return_num):
|
|||||||
else : return choices
|
else : return choices
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def conv_forward(inputs, conv, choices):
|
def conv_forward(inputs, conv, choices):
|
||||||
iC = conv.in_channels
|
iC = conv.in_channels
|
||||||
fill_size = list(inputs.size())
|
fill_size = list(inputs.size())
|
||||||
|
@ -104,14 +104,19 @@ class NASBench102API(object):
|
|||||||
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
|
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def query_by_index(self, arch_index, dataname, use_12epochs_result=False):
|
# query information with the training of 12 epochs or 200 epochs
|
||||||
|
# if dataname is None, return the ArchResults
|
||||||
|
# else, return a dict with all trials on that dataset (the key is the seed)
|
||||||
|
def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False):
|
||||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||||
assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr)
|
assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr)
|
||||||
archInfo = copy.deepcopy( arch2infos[ arch_index ] )
|
archInfo = copy.deepcopy( arch2infos[ arch_index ] )
|
||||||
assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname)
|
if dataname is None: return archInfo
|
||||||
info = archInfo.query(dataname)
|
else:
|
||||||
return info
|
assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname)
|
||||||
|
info = archInfo.query(dataname)
|
||||||
|
return info
|
||||||
|
|
||||||
def query_meta_info_by_index(self, arch_index, use_12epochs_result=False):
|
def query_meta_info_by_index(self, arch_index, use_12epochs_result=False):
|
||||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||||
@ -266,7 +271,7 @@ class ArchResults(object):
|
|||||||
def query(self, dataset, seed=None):
|
def query(self, dataset, seed=None):
|
||||||
if seed is None:
|
if seed is None:
|
||||||
x_seeds = self.dataset_seed[dataset]
|
x_seeds = self.dataset_seed[dataset]
|
||||||
return [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
return {seed: self.all_results[ (dataset, seed) ] for seed in x_seeds}
|
||||||
else:
|
else:
|
||||||
return self.all_results[ (dataset, seed) ]
|
return self.all_results[ (dataset, seed) ]
|
||||||
|
|
||||||
|
@ -34,6 +34,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \
|
|||||||
--dataset ${dataset} --data_path ${data_path} \
|
--dataset ${dataset} --data_path ${data_path} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
||||||
--time_budget 12000 --time_scale 200 \
|
--time_budget 12000 \
|
||||||
--n_iters 64 --num_samples 4 --random_fraction 0 \
|
--n_iters 28 --num_samples 64 --random_fraction .33 --bandwidth_factor 3 \
|
||||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||||
|
Loading…
Reference in New Issue
Block a user