diff --git a/NAS-Bench-102.md b/NAS-Bench-102.md index a3fb46f..a2e1503 100644 --- a/NAS-Bench-102.md +++ b/NAS-Bench-102.md @@ -51,7 +51,7 @@ res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency # get the detailed information -results = api.query_by_index(1, 'cifar100') # a list of all trials on cifar100 +results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1])) print ('Latency : {:}'.format(results[0].get_latency())) print ('Train Info : {:}'.format(results[0].get_train())) diff --git a/configs/nas-benchmark/algos/RANDOM.config b/configs/nas-benchmark/algos/RANDOM.config index 8fca4d7..e2d956d 100644 --- a/configs/nas-benchmark/algos/RANDOM.config +++ b/configs/nas-benchmark/algos/RANDOM.config @@ -9,5 +9,6 @@ "momentum" : ["float", "0.9"], "nesterov" : ["bool", "1"], "criterion": ["str", "Softmax"], - "batch_size": ["int", "64"] + "batch_size": ["int", "64"], + "test_batch_size": ["int", "512"] } diff --git a/exps/NAS-Bench-102/visualize.py b/exps/NAS-Bench-102/visualize.py new file mode 100644 index 0000000..cdd501e --- /dev/null +++ b/exps/NAS-Bench-102/visualize.py @@ -0,0 +1,386 @@ +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # +################################################## +# python exps/NAS-Bench-102/visualize.py --api_path $HOME/.torch/NAS-Bench-102-v1_0-e61699.pth +################################################## +import os, sys, time, argparse, collections +from tqdm import tqdm +import numpy as np +import torch +import torch.nn as nn +from pathlib import Path +from collections import defaultdict +import matplotlib +import seaborn as sns +from mpl_toolkits.mplot3d import Axes3D +matplotlib.use('agg') +import matplotlib.pyplot as plt + +lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() +if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +from log_utils import time_string +from nas_102_api import NASBench102API as API + + + +def calculate_correlation(*vectors): + matrix = [] + for i, vectori in enumerate(vectors): + x = [] + for j, vectorj in enumerate(vectors): + x.append( np.corrcoef(vectori, vectorj)[0,1] ) + matrix.append( x ) + return np.array(matrix) + + + +def visualize_relative_ranking(vis_save_dir): + print ('\n' + '-'*100) + cifar010_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('cifar10') + cifar100_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('cifar100') + imagenet_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('ImageNet16-120') + cifar010_info = torch.load(cifar010_cache_path) + cifar100_info = torch.load(cifar100_cache_path) + imagenet_info = torch.load(imagenet_cache_path) + indexes = list(range(len(cifar010_info['params']))) + + print ('{:} start to visualize relative ranking'.format(time_string())) + # maximum accuracy with ResNet-level params 11472 + x_010_accs = [ cifar010_info['test_accs'][i] if cifar010_info['params'][i] <= cifar010_info['params'][11472] else -1 for i in indexes] + x_100_accs = [ cifar100_info['test_accs'][i] if cifar100_info['params'][i] <= cifar100_info['params'][11472] else -1 for i in indexes] + x_img_accs = [ imagenet_info['test_accs'][i] if imagenet_info['params'][i] <= imagenet_info['params'][11472] else -1 for i in indexes] + + cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i]) + cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i]) + imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i]) + + cifar100_labels, imagenet_labels = [], [] + for idx in cifar010_ord_indexes: + cifar100_labels.append( cifar100_ord_indexes.index(idx) ) + imagenet_labels.append( imagenet_ord_indexes.index(idx) ) + print ('{:} prepare data done.'.format(time_string())) + + dpi, width, height = 300, 2600, 2600 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 18, 18 + resnet_scale, resnet_alpha = 120, 0.5 + + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111) + plt.xlim(min(indexes), max(indexes)) + plt.ylim(min(indexes), max(indexes)) + #plt.ylabel('y').set_rotation(0) + plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize, rotation='vertical') + plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize) + #ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8, label='CIFAR-100') + #ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8, label='ImageNet-16-120') + #ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8, label='CIFAR-10') + ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8) + ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8) + ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8) + ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10') + ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100') + ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='ImageNet-16-120') + plt.grid(zorder=0) + ax.set_axisbelow(True) + plt.legend(loc=0, fontsize=LegendFontsize) + ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize) + ax.set_ylabel('architecture ranking', fontsize=LabelSize) + save_path = (vis_save_dir / 'relative-rank.pdf').resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') + save_path = (vis_save_dir / 'relative-rank.png').resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') + print ('{:} save into {:}'.format(time_string(), save_path)) + + # calculate correlation + sns_size = 15 + CoRelMatrix = calculate_correlation(cifar010_info['valid_accs'], cifar010_info['test_accs'], cifar100_info['valid_accs'], cifar100_info['test_accs'], imagenet_info['valid_accs'], imagenet_info['test_accs']) + fig = plt.figure(figsize=figsize) + plt.axis('off') + h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5) + save_path = (vis_save_dir / 'co-relation-all.pdf').resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') + print ('{:} save into {:}'.format(time_string(), save_path)) + + # calculate correlation + acc_bars = [92, 93] + for acc_bar in acc_bars: + selected_indexes = [] + for i, acc in enumerate(cifar010_info['test_accs']): + if acc > acc_bar: selected_indexes.append( i ) + print ('select {:} architectures'.format(len(selected_indexes))) + cifar010_valid_accs = np.array(cifar010_info['valid_accs'])[ selected_indexes ] + cifar010_test_accs = np.array(cifar010_info['test_accs']) [ selected_indexes ] + cifar100_valid_accs = np.array(cifar100_info['valid_accs'])[ selected_indexes ] + cifar100_test_accs = np.array(cifar100_info['test_accs']) [ selected_indexes ] + imagenet_valid_accs = np.array(imagenet_info['valid_accs'])[ selected_indexes ] + imagenet_test_accs = np.array(imagenet_info['test_accs']) [ selected_indexes ] + CoRelMatrix = calculate_correlation(cifar010_valid_accs, cifar010_test_accs, cifar100_valid_accs, cifar100_test_accs, imagenet_valid_accs, imagenet_test_accs) + fig = plt.figure(figsize=figsize) + plt.axis('off') + h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5) + save_path = (vis_save_dir / 'co-relation-top-{:}.pdf'.format(len(selected_indexes))).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') + print ('{:} save into {:}'.format(time_string(), save_path)) + plt.close('all') + + + +def visualize_info(meta_file, dataset, vis_save_dir): + print ('{:} start to visualize {:} information'.format(time_string(), dataset)) + cache_file_path = vis_save_dir / '{:}-cache-info.pth'.format(dataset) + if not cache_file_path.exists(): + print ('Do not find cache file : {:}'.format(cache_file_path)) + nas_bench = API(str(meta_file)) + params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], [], [], [], [] + for index in range( len(nas_bench) ): + info = nas_bench.query_by_index(index, use_12epochs_result=False) + resx = info.get_comput_costs(dataset) ; flop, param = resx['flops'], resx['params'] + if dataset == 'cifar10': + res = info.get_metrics('cifar10', 'train') ; train_acc = res['accuracy'] + res = info.get_metrics('cifar10-valid', 'x-valid') ; valid_acc = res['accuracy'] + res = info.get_metrics('cifar10', 'ori-test') ; test_acc = res['accuracy'] + res = info.get_metrics('cifar10', 'ori-test') ; otest_acc = res['accuracy'] + else: + res = info.get_metrics(dataset, 'train') ; train_acc = res['accuracy'] + res = info.get_metrics(dataset, 'x-valid') ; valid_acc = res['accuracy'] + res = info.get_metrics(dataset, 'x-test') ; test_acc = res['accuracy'] + res = info.get_metrics(dataset, 'ori-test') ; otest_acc = res['accuracy'] + if index == 11472: # resnet + resnet = {'params':param, 'flops': flop, 'index': 11472, 'train_acc': train_acc, 'valid_acc': valid_acc, 'test_acc': test_acc, 'otest_acc': otest_acc} + flops.append( flop ) + params.append( param ) + train_accs.append( train_acc ) + valid_accs.append( valid_acc ) + test_accs.append( test_acc ) + otest_accs.append( otest_acc ) + #resnet = {'params': 0.559, 'flops': 78.56, 'index': 11472, 'train_acc': 99.99, 'valid_acc': 90.84, 'test_acc': 93.97} + info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs, 'otest_accs': otest_accs} + info['resnet'] = resnet + torch.save(info, cache_file_path) + else: + print ('Find cache file : {:}'.format(cache_file_path)) + info = torch.load(cache_file_path) + params, flops, train_accs, valid_accs, test_accs, otest_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'], info['otest_accs'] + resnet = info['resnet'] + print ('{:} collect data done.'.format(time_string())) + + indexes = list(range(len(params))) + dpi, width, height = 300, 2600, 2600 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 22, 22 + resnet_scale, resnet_alpha = 120, 0.5 + + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111) + plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize) + if dataset == 'cifar10': + plt.ylim(50, 100) + plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) + elif dataset == 'cifar100': + plt.ylim(25, 75) + plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize) + else: + plt.ylim(0, 50) + plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize) + ax.scatter(params, valid_accs, marker='o', s=0.5, c='tab:blue') + ax.scatter([resnet['params']], [resnet['valid_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=0.4) + plt.grid(zorder=0) + ax.set_axisbelow(True) + plt.legend(loc=4, fontsize=LegendFontsize) + ax.set_xlabel('#parameters (MB)', fontsize=LabelSize) + ax.set_ylabel('the validation accuracy (%)', fontsize=LabelSize) + save_path = (vis_save_dir / '{:}-param-vs-valid.pdf'.format(dataset)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') + save_path = (vis_save_dir / '{:}-param-vs-valid.png'.format(dataset)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') + print ('{:} save into {:}'.format(time_string(), save_path)) + + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111) + plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize) + if dataset == 'cifar10': + plt.ylim(50, 100) + plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) + elif dataset == 'cifar100': + plt.ylim(25, 75) + plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize) + else: + plt.ylim(0, 50) + plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize) + ax.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue') + ax.scatter([resnet['params']], [resnet['test_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha) + plt.grid() + ax.set_axisbelow(True) + plt.legend(loc=4, fontsize=LegendFontsize) + ax.set_xlabel('#parameters (MB)', fontsize=LabelSize) + ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize) + save_path = (vis_save_dir / '{:}-param-vs-test.pdf'.format(dataset)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') + save_path = (vis_save_dir / '{:}-param-vs-test.png'.format(dataset)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') + print ('{:} save into {:}'.format(time_string(), save_path)) + + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111) + plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize) + if dataset == 'cifar10': + plt.ylim(50, 100) + plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) + elif dataset == 'cifar100': + plt.ylim(20, 100) + plt.yticks(np.arange(20, 101, 10), fontsize=LegendFontsize) + else: + plt.ylim(25, 76) + plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize) + ax.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue') + ax.scatter([resnet['params']], [resnet['train_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha) + plt.grid() + ax.set_axisbelow(True) + plt.legend(loc=4, fontsize=LegendFontsize) + ax.set_xlabel('#parameters (MB)', fontsize=LabelSize) + ax.set_ylabel('the trarining accuracy (%)', fontsize=LabelSize) + save_path = (vis_save_dir / '{:}-param-vs-train.pdf'.format(dataset)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') + save_path = (vis_save_dir / '{:}-param-vs-train.png'.format(dataset)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') + print ('{:} save into {:}'.format(time_string(), save_path)) + + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111) + plt.xlim(0, max(indexes)) + plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize) + if dataset == 'cifar10': + plt.ylim(50, 100) + plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) + elif dataset == 'cifar100': + plt.ylim(25, 75) + plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize) + else: + plt.ylim(0, 50) + plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize) + ax.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue') + ax.scatter([resnet['index']], [resnet['test_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha) + plt.grid() + ax.set_axisbelow(True) + plt.legend(loc=4, fontsize=LegendFontsize) + ax.set_xlabel('architecture ID', fontsize=LabelSize) + ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize) + save_path = (vis_save_dir / '{:}-test-over-ID.pdf'.format(dataset)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') + save_path = (vis_save_dir / '{:}-test-over-ID.png'.format(dataset)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') + print ('{:} save into {:}'.format(time_string(), save_path)) + plt.close('all') + + + +def visualize_rank_over_time(meta_file, vis_save_dir): + print ('\n' + '-'*150) + vis_save_dir.mkdir(parents=True, exist_ok=True) + print ('{:} start to visualize rank-over-time into {:}'.format(time_string(), vis_save_dir)) + cache_file_path = vis_save_dir / 'rank-over-time-cache-info.pth' + if not cache_file_path.exists(): + print ('Do not find cache file : {:}'.format(cache_file_path)) + nas_bench = API(str(meta_file)) + print ('{:} load nas_bench done'.format(time_string())) + params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list) + #for iepoch in range(200): for index in range( len(nas_bench) ): + for index in tqdm(range(len(nas_bench))): + info = nas_bench.query_by_index(index, use_12epochs_result=False) + for iepoch in range(200): + res = info.get_metrics('cifar10' , 'train' , iepoch) ; train_acc = res['accuracy'] + res = info.get_metrics('cifar10-valid', 'x-valid' , iepoch) ; valid_acc = res['accuracy'] + res = info.get_metrics('cifar10' , 'ori-test', iepoch) ; test_acc = res['accuracy'] + res = info.get_metrics('cifar10' , 'ori-test', iepoch) ; otest_acc = res['accuracy'] + train_accs[iepoch].append( train_acc ) + valid_accs[iepoch].append( valid_acc ) + test_accs [iepoch].append( test_acc ) + otest_accs[iepoch].append( otest_acc ) + if iepoch == 0: + res = info.get_comput_costs('cifar10') ; flop, param = res['flops'], res['params'] + flops.append( flop ) + params.append( param ) + info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs, 'otest_accs': otest_accs} + torch.save(info, cache_file_path) + else: + print ('Find cache file : {:}'.format(cache_file_path)) + info = torch.load(cache_file_path) + params, flops, train_accs, valid_accs, test_accs, otest_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'], info['otest_accs'] + print ('{:} collect data done.'.format(time_string())) + #selected_epochs = [0, 100, 150, 180, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199] + selected_epochs = list( range(200) ) + x_xtests = test_accs[199] + indexes = list(range(len(x_xtests))) + ord_idxs = sorted(indexes, key=lambda i: x_xtests[i]) + for sepoch in selected_epochs: + x_valids = valid_accs[sepoch] + valid_ord_idxs = sorted(indexes, key=lambda i: x_valids[i]) + valid_ord_lbls = [] + for idx in ord_idxs: + valid_ord_lbls.append( valid_ord_idxs.index(idx) ) + # labeled data + dpi, width, height = 300, 2600, 2600 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 18, 18 + + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111) + plt.xlim(min(indexes), max(indexes)) + plt.ylim(min(indexes), max(indexes)) + plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize, rotation='vertical') + plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize) + ax.scatter(indexes, valid_ord_lbls, marker='^', s=0.5, c='tab:green', alpha=0.8) + ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8) + ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-10 validation') + ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10 test') + plt.grid(zorder=0) + ax.set_axisbelow(True) + plt.legend(loc='upper left', fontsize=LegendFontsize) + ax.set_xlabel('architecture ranking in the final test accuracy', fontsize=LabelSize) + ax.set_ylabel('architecture ranking in the validation set', fontsize=LabelSize) + save_path = (vis_save_dir / 'time-{:03d}.pdf'.format(sepoch)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') + save_path = (vis_save_dir / 'time-{:03d}.png'.format(sepoch)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') + print ('{:} save into {:}'.format(time_string(), save_path)) + plt.close('all') + + + +def write_video(save_dir): + import cv2 + video_save_path = save_dir / 'time.avi' + print ('{:} start create video for {:}'.format(time_string(), video_save_path)) + images = sorted( list( save_dir.glob('time-*.png') ) ) + ximage = cv2.imread(str(images[0])) + #shape = (ximage.shape[1], ximage.shape[0]) + shape = (1000, 1000) + #writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 25, shape) + writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 5, shape) + for idx, image in enumerate(images): + ximage = cv2.imread(str(image)) + _image = cv2.resize(ximage, shape) + writer.write(_image) + writer.release() + print ('write video [{:} frames] into {:}'.format(len(images), video_save_path)) + + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-102/visual', help='The base-name of folder to save checkpoints and log.') + parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-102 benchmark file.') + args = parser.parse_args() + + vis_save_dir = Path(args.save_dir) / 'visuals' + vis_save_dir.mkdir(parents=True, exist_ok=True) + meta_file = Path(args.api_path) + assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) + visualize_rank_over_time(str(meta_file), vis_save_dir / 'over-time') + write_video(vis_save_dir / 'over-time') + visualize_info(str(meta_file), 'cifar10' , vis_save_dir) + visualize_info(str(meta_file), 'cifar100', vis_save_dir) + visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir) + visualize_relative_ranking(vis_save_dir) diff --git a/exps/algos/BOHB.py b/exps/algos/BOHB.py index eea14bc..c267f6c 100644 --- a/exps/algos/BOHB.py +++ b/exps/algos/BOHB.py @@ -53,43 +53,50 @@ def config2structure_func(max_nodes): class MyWorker(Worker): - def __init__(self, *args, convert_func=None, nas_bench=None, time_scale=None, **kwargs): + def __init__(self, *args, convert_func=None, nas_bench=None, time_budget=None, **kwargs): super().__init__(*args, **kwargs) self.convert_func = convert_func self.nas_bench = nas_bench - self.time_scale = time_scale - self.seen_arch = 0 + self.time_budget = time_budget + self.seen_archs = [] self.sim_cost_time = 0 self.real_cost_time = 0 + self.is_end = False + + def get_the_best(self): + assert len(self.seen_archs) > 0 + best_index, best_acc = -1, None + for arch_index in self.seen_archs: + info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True) + vacc = info['valid-accuracy'] + if best_acc is None or best_acc < vacc: + best_acc = vacc + best_index = arch_index + assert best_index != -1 + return best_index def compute(self, config, budget, **kwargs): start_time = time.time() structure = self.convert_func( config ) arch_index = self.nas_bench.query_index_by_arch( structure ) - iepoch = 0 - while iepoch < 12: - info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch, True) - cur_time = info['train-all-time'] + info['valid-per-time'] - cur_vacc = info['valid-accuracy'] - if time.time() - start_time + cur_time / self.time_scale > budget: - break - else: - iepoch += 1 - self.sim_cost_time += cur_time - self.seen_arch += 1 - remaining_time = cur_time / self.time_scale - (time.time() - start_time) - if remaining_time > 0: - time.sleep(remaining_time) - else: - import pdb; pdb.set_trace() + info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True) + cur_time = info['train-all-time'] + info['valid-per-time'] + cur_vacc = info['valid-accuracy'] self.real_cost_time += (time.time() - start_time) - return ({ - 'loss': 100 - float(cur_vacc), - 'info': {'seen-arch' : self.seen_arch, - 'sim-test-time' : self.sim_cost_time, - 'real-test-time': self.real_cost_time, - 'current-arch' : arch_index, - 'current-budget': budget} + if self.sim_cost_time + cur_time <= self.time_budget and not self.is_end: + self.sim_cost_time += cur_time + self.seen_archs.append( arch_index ) + return ({'loss': 100 - float(cur_vacc), + 'info': {'seen-arch' : len(self.seen_archs), + 'sim-test-time' : self.sim_cost_time, + 'current-arch' : arch_index} + }) + else: + self.is_end = True + return ({'loss': 100, + 'info': {'seen-arch' : len(self.seen_archs), + 'sim-test-time' : self.sim_cost_time, + 'current-arch' : None} }) @@ -139,16 +146,14 @@ def main(xargs, nas_bench): #logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string())) workers = [] for i in range(num_workers): - w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_scale=xargs.time_scale, run_id=hb_run_id, id=i) + w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_budget=xargs.time_budget, run_id=hb_run_id, id=i) w.run(background=True) workers.append(w) - simulate_time_budge = xargs.time_budget // xargs.time_scale start_time = time.time() - logger.log('simulate_time_budge : {:} (in seconds).'.format(simulate_time_budge)) bohb = BOHB(configspace=cs, run_id=hb_run_id, - eta=3, min_budget=simulate_time_budge//3, max_budget=simulate_time_budge, + eta=3, min_budget=12, max_budget=200, nameserver=ns_host, nameserver_port=ns_port, num_samples=xargs.num_samples, @@ -161,11 +166,9 @@ def main(xargs, nas_bench): NS.shutdown() real_cost_time = time.time() - start_time - import pdb; pdb.set_trace() id2config = results.get_id2config_mapping() incumbent = results.get_incumbent_id() - logger.log('Best found configuration: {:}'.format(id2config[incumbent]['config'])) best_arch = config2structure( id2config[incumbent]['config'] ) @@ -174,7 +177,7 @@ def main(xargs, nas_bench): else : logger.log('{:}'.format(info)) logger.log('-'*100) - logger.log('workers : {:}'.format(workers[0].test_time)) + logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs))) logger.close() return logger.log_dir, nas_bench.query_index_by_arch( best_arch ) @@ -190,14 +193,13 @@ if __name__ == '__main__': parser.add_argument('--channel', type=int, help='The number of channels.') parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).') - parser.add_argument('--time_scale' , type=int, help='The time scale to accelerate the time budget.') # BOHB - parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function') - parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE') - parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function') + parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function') + parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE') + parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function') parser.add_argument('--random_fraction', default=.33, type=float, nargs='?', help='fraction of random configurations') - parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth') - parser.add_argument('--n_iters', default=100, type=int, nargs='?', help='number of iterations for optimization method') + parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth') + parser.add_argument('--n_iters', default=100, type=int, nargs='?', help='number of iterations for optimization method') # log parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') diff --git a/exps/algos/RANDOM-NAS.py b/exps/algos/RANDOM-NAS.py index dd9653c..5699427 100644 --- a/exps/algos/RANDOM-NAS.py +++ b/exps/algos/RANDOM-NAS.py @@ -82,14 +82,29 @@ def valid_func(xloader, network, criterion): return arch_losses.avg, arch_top1.avg, arch_top5.avg -def search_find_best(valid_loader, network, criterion, select_num): - best_arch, best_acc = None, -1 - for iarch in range(select_num): - arch = network.module.random_genotype( True ) - valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion) - if best_arch is None or best_acc < valid_a_top1: - best_arch, best_acc = arch, valid_a_top1 - return best_arch +def search_find_best(xloader, network, n_samples): + with torch.no_grad(): + network.eval() + archs, valid_accs = [], [] + #print ('obtain the top-{:} architectures'.format(n_samples)) + loader_iter = iter(xloader) + for i in range(n_samples): + arch = network.module.random_genotype( True ) + try: + inputs, targets = next(loader_iter) + except: + loader_iter = iter(xloader) + inputs, targets = next(loader_iter) + + _, logits = network(inputs) + val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) + + archs.append( arch ) + valid_accs.append( val_top1.item() ) + + best_idx = np.argmax(valid_accs) + best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] + return best_arch, best_valid_acc def main(xargs): @@ -127,7 +142,7 @@ def main(xargs): search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) # data loader search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) + valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) @@ -177,7 +192,8 @@ def main(xargs): logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) - cur_arch = search_find_best(valid_loader, network, criterion, xargs.select_num) + cur_arch, cur_valid_acc = search_find_best(valid_loader, network, xargs.select_num) + logger.log('[{:}] find-the-best : {:}, accuracy@1={:.2f}%'.format(epoch_str, cur_arch, cur_valid_acc)) genotypes[epoch] = cur_arch # check the best accuracy valid_accuracies[epoch] = valid_a_top1 @@ -211,13 +227,7 @@ def main(xargs): logger.log('\n' + '-'*200) logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum)) start_time = time.time() - best_arch, best_acc = None, -1 - for iarch in range(xargs.select_num): - arch = search_model.random_genotype( True ) - valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion) - logger.log('final evaluation [{:02d}/{:02d}] : {:} : accuracy={:.2f}%, loss={:.3f}'.format(iarch, xargs.select_num, arch, valid_a_top1, valid_a_loss)) - if best_arch is None or best_acc < valid_a_top1: - best_arch, best_acc = arch, valid_a_top1 + best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num) search_time.update(time.time() - start_time) logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum)) if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) )) diff --git a/lib/models/shape_searchs/SearchCifarResNet.py b/lib/models/shape_searchs/SearchCifarResNet.py index 6271616..e944058 100644 --- a/lib/models/shape_searchs/SearchCifarResNet.py +++ b/lib/models/shape_searchs/SearchCifarResNet.py @@ -26,8 +26,6 @@ def get_depth_choices(nDepth, return_num): else : return choices - - def conv_forward(inputs, conv, choices): iC = conv.in_channels fill_size = list(inputs.size()) diff --git a/lib/nas_102_api/api.py b/lib/nas_102_api/api.py index f459380..9005a18 100644 --- a/lib/nas_102_api/api.py +++ b/lib/nas_102_api/api.py @@ -104,14 +104,19 @@ class NASBench102API(object): print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index)) return None - def query_by_index(self, arch_index, dataname, use_12epochs_result=False): + # query information with the training of 12 epochs or 200 epochs + # if dataname is None, return the ArchResults + # else, return a dict with all trials on that dataset (the key is the seed) + def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False): if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less else : basestr, arch2infos = '200epochs', self.arch2infos_full assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr) archInfo = copy.deepcopy( arch2infos[ arch_index ] ) - assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname) - info = archInfo.query(dataname) - return info + if dataname is None: return archInfo + else: + assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname) + info = archInfo.query(dataname) + return info def query_meta_info_by_index(self, arch_index, use_12epochs_result=False): if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less @@ -266,7 +271,7 @@ class ArchResults(object): def query(self, dataset, seed=None): if seed is None: x_seeds = self.dataset_seed[dataset] - return [self.all_results[ (dataset, seed) ] for seed in x_seeds] + return {seed: self.all_results[ (dataset, seed) ] for seed in x_seeds} else: return self.all_results[ (dataset, seed) ] diff --git a/scripts-search/algos/BOHB.sh b/scripts-search/algos/BOHB.sh index 4d07f0a..9a0a15c 100644 --- a/scripts-search/algos/BOHB.sh +++ b/scripts-search/algos/BOHB.sh @@ -34,6 +34,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \ --dataset ${dataset} --data_path ${data_path} \ --search_space_name ${space} \ --arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ - --time_budget 12000 --time_scale 200 \ - --n_iters 64 --num_samples 4 --random_fraction 0 \ + --time_budget 12000 \ + --n_iters 28 --num_samples 64 --random_fraction .33 --bandwidth_factor 3 \ --workers 4 --print_freq 200 --rand_seed ${seed}