fix bugs in RANDOM-NAS and BOHB
This commit is contained in:
		| @@ -51,7 +51,7 @@ res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric | |||||||
| cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency | cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency | ||||||
|  |  | ||||||
| # get the detailed information | # get the detailed information | ||||||
| results = api.query_by_index(1, 'cifar100') # a list of all trials on cifar100 | results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed | ||||||
| print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1])) | print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1])) | ||||||
| print ('Latency : {:}'.format(results[0].get_latency())) | print ('Latency : {:}'.format(results[0].get_latency())) | ||||||
| print ('Train Info : {:}'.format(results[0].get_train())) | print ('Train Info : {:}'.format(results[0].get_train())) | ||||||
|   | |||||||
| @@ -9,5 +9,6 @@ | |||||||
|   "momentum" : ["float", "0.9"], |   "momentum" : ["float", "0.9"], | ||||||
|   "nesterov" : ["bool",  "1"], |   "nesterov" : ["bool",  "1"], | ||||||
|   "criterion": ["str",   "Softmax"], |   "criterion": ["str",   "Softmax"], | ||||||
|   "batch_size": ["int",  "64"] |   "batch_size": ["int",  "64"], | ||||||
|  |   "test_batch_size": ["int",  "512"] | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										386
									
								
								exps/NAS-Bench-102/visualize.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										386
									
								
								exps/NAS-Bench-102/visualize.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,386 @@ | |||||||
|  | ################################################## | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||||
|  | ################################################## | ||||||
|  | # python exps/NAS-Bench-102/visualize.py --api_path $HOME/.torch/NAS-Bench-102-v1_0-e61699.pth | ||||||
|  | ################################################## | ||||||
|  | import os, sys, time, argparse, collections | ||||||
|  | from tqdm import tqdm | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | from pathlib import Path | ||||||
|  | from collections import defaultdict | ||||||
|  | import matplotlib | ||||||
|  | import seaborn as sns | ||||||
|  | from mpl_toolkits.mplot3d import Axes3D | ||||||
|  | matplotlib.use('agg') | ||||||
|  | import matplotlib.pyplot as plt | ||||||
|  |  | ||||||
|  | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
|  | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
|  | from log_utils    import time_string | ||||||
|  | from nas_102_api  import NASBench102API as API | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def calculate_correlation(*vectors): | ||||||
|  |   matrix = [] | ||||||
|  |   for i, vectori in enumerate(vectors): | ||||||
|  |     x = [] | ||||||
|  |     for j, vectorj in enumerate(vectors): | ||||||
|  |       x.append( np.corrcoef(vectori, vectorj)[0,1] ) | ||||||
|  |     matrix.append( x ) | ||||||
|  |   return np.array(matrix) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def visualize_relative_ranking(vis_save_dir): | ||||||
|  |   print ('\n' + '-'*100) | ||||||
|  |   cifar010_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('cifar10') | ||||||
|  |   cifar100_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('cifar100') | ||||||
|  |   imagenet_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('ImageNet16-120') | ||||||
|  |   cifar010_info = torch.load(cifar010_cache_path) | ||||||
|  |   cifar100_info = torch.load(cifar100_cache_path) | ||||||
|  |   imagenet_info = torch.load(imagenet_cache_path) | ||||||
|  |   indexes       = list(range(len(cifar010_info['params']))) | ||||||
|  |  | ||||||
|  |   print ('{:} start to visualize relative ranking'.format(time_string())) | ||||||
|  |   # maximum accuracy with ResNet-level params 11472 | ||||||
|  |   x_010_accs    = [ cifar010_info['test_accs'][i] if cifar010_info['params'][i] <= cifar010_info['params'][11472] else -1 for i in indexes] | ||||||
|  |   x_100_accs    = [ cifar100_info['test_accs'][i] if cifar100_info['params'][i] <= cifar100_info['params'][11472] else -1 for i in indexes] | ||||||
|  |   x_img_accs    = [ imagenet_info['test_accs'][i] if imagenet_info['params'][i] <= imagenet_info['params'][11472] else -1 for i in indexes] | ||||||
|  |   | ||||||
|  |   cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i]) | ||||||
|  |   cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i]) | ||||||
|  |   imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i]) | ||||||
|  |  | ||||||
|  |   cifar100_labels, imagenet_labels = [], [] | ||||||
|  |   for idx in cifar010_ord_indexes: | ||||||
|  |     cifar100_labels.append( cifar100_ord_indexes.index(idx) ) | ||||||
|  |     imagenet_labels.append( imagenet_ord_indexes.index(idx) ) | ||||||
|  |   print ('{:} prepare data done.'.format(time_string())) | ||||||
|  |  | ||||||
|  |   dpi, width, height = 300, 2600, 2600 | ||||||
|  |   figsize = width / float(dpi), height / float(dpi) | ||||||
|  |   LabelSize, LegendFontsize = 18, 18 | ||||||
|  |   resnet_scale, resnet_alpha = 120, 0.5 | ||||||
|  |  | ||||||
|  |   fig = plt.figure(figsize=figsize) | ||||||
|  |   ax  = fig.add_subplot(111) | ||||||
|  |   plt.xlim(min(indexes), max(indexes)) | ||||||
|  |   plt.ylim(min(indexes), max(indexes)) | ||||||
|  |   #plt.ylabel('y').set_rotation(0) | ||||||
|  |   plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize, rotation='vertical') | ||||||
|  |   plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize) | ||||||
|  |   #ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8, label='CIFAR-100') | ||||||
|  |   #ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red'  , alpha=0.8, label='ImageNet-16-120') | ||||||
|  |   #ax.scatter(indexes, indexes        , marker='o', s=0.5, c='tab:blue' , alpha=0.8, label='CIFAR-10') | ||||||
|  |   ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8) | ||||||
|  |   ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red'  , alpha=0.8) | ||||||
|  |   ax.scatter(indexes, indexes        , marker='o', s=0.5, c='tab:blue' , alpha=0.8) | ||||||
|  |   ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10') | ||||||
|  |   ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100') | ||||||
|  |   ax.scatter([-1], [-1], marker='*', s=100, c='tab:red'  , label='ImageNet-16-120') | ||||||
|  |   plt.grid(zorder=0) | ||||||
|  |   ax.set_axisbelow(True) | ||||||
|  |   plt.legend(loc=0, fontsize=LegendFontsize) | ||||||
|  |   ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize) | ||||||
|  |   ax.set_ylabel('architecture ranking', fontsize=LabelSize) | ||||||
|  |   save_path = (vis_save_dir / 'relative-rank.pdf').resolve() | ||||||
|  |   fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') | ||||||
|  |   save_path = (vis_save_dir / 'relative-rank.png').resolve() | ||||||
|  |   fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') | ||||||
|  |   print ('{:} save into {:}'.format(time_string(), save_path)) | ||||||
|  |  | ||||||
|  |   # calculate correlation | ||||||
|  |   sns_size = 15 | ||||||
|  |   CoRelMatrix = calculate_correlation(cifar010_info['valid_accs'], cifar010_info['test_accs'], cifar100_info['valid_accs'], cifar100_info['test_accs'], imagenet_info['valid_accs'], imagenet_info['test_accs']) | ||||||
|  |   fig = plt.figure(figsize=figsize) | ||||||
|  |   plt.axis('off') | ||||||
|  |   h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5)   | ||||||
|  |   save_path = (vis_save_dir / 'co-relation-all.pdf').resolve() | ||||||
|  |   fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') | ||||||
|  |   print ('{:} save into {:}'.format(time_string(), save_path)) | ||||||
|  |  | ||||||
|  |   # calculate correlation | ||||||
|  |   acc_bars = [92, 93] | ||||||
|  |   for acc_bar in acc_bars: | ||||||
|  |     selected_indexes = [] | ||||||
|  |     for i, acc in enumerate(cifar010_info['test_accs']): | ||||||
|  |       if acc > acc_bar: selected_indexes.append( i ) | ||||||
|  |     print ('select {:} architectures'.format(len(selected_indexes))) | ||||||
|  |     cifar010_valid_accs = np.array(cifar010_info['valid_accs'])[ selected_indexes ] | ||||||
|  |     cifar010_test_accs  = np.array(cifar010_info['test_accs']) [ selected_indexes ] | ||||||
|  |     cifar100_valid_accs = np.array(cifar100_info['valid_accs'])[ selected_indexes ] | ||||||
|  |     cifar100_test_accs  = np.array(cifar100_info['test_accs']) [ selected_indexes ] | ||||||
|  |     imagenet_valid_accs = np.array(imagenet_info['valid_accs'])[ selected_indexes ] | ||||||
|  |     imagenet_test_accs  = np.array(imagenet_info['test_accs']) [ selected_indexes ] | ||||||
|  |     CoRelMatrix = calculate_correlation(cifar010_valid_accs, cifar010_test_accs, cifar100_valid_accs, cifar100_test_accs, imagenet_valid_accs, imagenet_test_accs) | ||||||
|  |     fig = plt.figure(figsize=figsize) | ||||||
|  |     plt.axis('off') | ||||||
|  |     h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5) | ||||||
|  |     save_path = (vis_save_dir / 'co-relation-top-{:}.pdf'.format(len(selected_indexes))).resolve() | ||||||
|  |     fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') | ||||||
|  |     print ('{:} save into {:}'.format(time_string(), save_path)) | ||||||
|  |   plt.close('all') | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def visualize_info(meta_file, dataset, vis_save_dir): | ||||||
|  |   print ('{:} start to visualize {:} information'.format(time_string(), dataset)) | ||||||
|  |   cache_file_path = vis_save_dir / '{:}-cache-info.pth'.format(dataset) | ||||||
|  |   if not cache_file_path.exists(): | ||||||
|  |     print ('Do not find cache file : {:}'.format(cache_file_path)) | ||||||
|  |     nas_bench = API(str(meta_file)) | ||||||
|  |     params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], [], [], [], [] | ||||||
|  |     for index in range( len(nas_bench) ): | ||||||
|  |       info = nas_bench.query_by_index(index, use_12epochs_result=False) | ||||||
|  |       resx = info.get_comput_costs(dataset) ; flop, param = resx['flops'], resx['params'] | ||||||
|  |       if dataset == 'cifar10': | ||||||
|  |         res = info.get_metrics('cifar10', 'train')         ; train_acc = res['accuracy'] | ||||||
|  |         res = info.get_metrics('cifar10-valid', 'x-valid') ; valid_acc = res['accuracy'] | ||||||
|  |         res = info.get_metrics('cifar10', 'ori-test')      ; test_acc  = res['accuracy'] | ||||||
|  |         res = info.get_metrics('cifar10', 'ori-test')      ; otest_acc = res['accuracy'] | ||||||
|  |       else: | ||||||
|  |         res = info.get_metrics(dataset, 'train')    ; train_acc = res['accuracy'] | ||||||
|  |         res = info.get_metrics(dataset, 'x-valid')  ; valid_acc = res['accuracy'] | ||||||
|  |         res = info.get_metrics(dataset, 'x-test')   ; test_acc  = res['accuracy'] | ||||||
|  |         res = info.get_metrics(dataset, 'ori-test') ; otest_acc = res['accuracy'] | ||||||
|  |       if index == 11472: # resnet | ||||||
|  |         resnet = {'params':param, 'flops': flop, 'index': 11472, 'train_acc': train_acc, 'valid_acc': valid_acc, 'test_acc': test_acc, 'otest_acc': otest_acc} | ||||||
|  |       flops.append( flop ) | ||||||
|  |       params.append( param ) | ||||||
|  |       train_accs.append( train_acc ) | ||||||
|  |       valid_accs.append( valid_acc ) | ||||||
|  |       test_accs.append( test_acc ) | ||||||
|  |       otest_accs.append( otest_acc ) | ||||||
|  |     #resnet = {'params': 0.559, 'flops': 78.56, 'index': 11472, 'train_acc': 99.99, 'valid_acc': 90.84, 'test_acc': 93.97} | ||||||
|  |     info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs, 'otest_accs': otest_accs} | ||||||
|  |     info['resnet'] = resnet | ||||||
|  |     torch.save(info, cache_file_path) | ||||||
|  |   else: | ||||||
|  |     print ('Find cache file : {:}'.format(cache_file_path)) | ||||||
|  |     info = torch.load(cache_file_path) | ||||||
|  |     params, flops, train_accs, valid_accs, test_accs, otest_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'], info['otest_accs'] | ||||||
|  |     resnet = info['resnet'] | ||||||
|  |   print ('{:} collect data done.'.format(time_string())) | ||||||
|  |  | ||||||
|  |   indexes = list(range(len(params))) | ||||||
|  |   dpi, width, height = 300, 2600, 2600 | ||||||
|  |   figsize = width / float(dpi), height / float(dpi) | ||||||
|  |   LabelSize, LegendFontsize = 22, 22 | ||||||
|  |   resnet_scale, resnet_alpha = 120, 0.5 | ||||||
|  |  | ||||||
|  |   fig = plt.figure(figsize=figsize) | ||||||
|  |   ax  = fig.add_subplot(111) | ||||||
|  |   plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize) | ||||||
|  |   if dataset == 'cifar10': | ||||||
|  |     plt.ylim(50, 100) | ||||||
|  |     plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) | ||||||
|  |   elif dataset == 'cifar100': | ||||||
|  |     plt.ylim(25,  75) | ||||||
|  |     plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize) | ||||||
|  |   else: | ||||||
|  |     plt.ylim(0, 50) | ||||||
|  |     plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize) | ||||||
|  |   ax.scatter(params, valid_accs, marker='o', s=0.5, c='tab:blue')  | ||||||
|  |   ax.scatter([resnet['params']], [resnet['valid_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=0.4)  | ||||||
|  |   plt.grid(zorder=0) | ||||||
|  |   ax.set_axisbelow(True) | ||||||
|  |   plt.legend(loc=4, fontsize=LegendFontsize) | ||||||
|  |   ax.set_xlabel('#parameters (MB)', fontsize=LabelSize) | ||||||
|  |   ax.set_ylabel('the validation accuracy (%)', fontsize=LabelSize) | ||||||
|  |   save_path = (vis_save_dir / '{:}-param-vs-valid.pdf'.format(dataset)).resolve() | ||||||
|  |   fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') | ||||||
|  |   save_path = (vis_save_dir / '{:}-param-vs-valid.png'.format(dataset)).resolve() | ||||||
|  |   fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') | ||||||
|  |   print ('{:} save into {:}'.format(time_string(), save_path)) | ||||||
|  |  | ||||||
|  |   fig = plt.figure(figsize=figsize) | ||||||
|  |   ax  = fig.add_subplot(111) | ||||||
|  |   plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize) | ||||||
|  |   if dataset == 'cifar10': | ||||||
|  |     plt.ylim(50, 100) | ||||||
|  |     plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) | ||||||
|  |   elif dataset == 'cifar100': | ||||||
|  |     plt.ylim(25,  75) | ||||||
|  |     plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize) | ||||||
|  |   else: | ||||||
|  |     plt.ylim(0, 50) | ||||||
|  |     plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize) | ||||||
|  |   ax.scatter(params,  test_accs, marker='o', s=0.5, c='tab:blue') | ||||||
|  |   ax.scatter([resnet['params']], [resnet['test_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha) | ||||||
|  |   plt.grid() | ||||||
|  |   ax.set_axisbelow(True) | ||||||
|  |   plt.legend(loc=4, fontsize=LegendFontsize) | ||||||
|  |   ax.set_xlabel('#parameters (MB)', fontsize=LabelSize) | ||||||
|  |   ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize) | ||||||
|  |   save_path = (vis_save_dir / '{:}-param-vs-test.pdf'.format(dataset)).resolve() | ||||||
|  |   fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') | ||||||
|  |   save_path = (vis_save_dir / '{:}-param-vs-test.png'.format(dataset)).resolve() | ||||||
|  |   fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') | ||||||
|  |   print ('{:} save into {:}'.format(time_string(), save_path)) | ||||||
|  |  | ||||||
|  |   fig = plt.figure(figsize=figsize) | ||||||
|  |   ax  = fig.add_subplot(111) | ||||||
|  |   plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize) | ||||||
|  |   if dataset == 'cifar10': | ||||||
|  |     plt.ylim(50, 100) | ||||||
|  |     plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) | ||||||
|  |   elif dataset == 'cifar100': | ||||||
|  |     plt.ylim(20, 100) | ||||||
|  |     plt.yticks(np.arange(20, 101, 10), fontsize=LegendFontsize) | ||||||
|  |   else: | ||||||
|  |     plt.ylim(25,  76) | ||||||
|  |     plt.yticks(np.arange(25,  76, 10), fontsize=LegendFontsize) | ||||||
|  |   ax.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue') | ||||||
|  |   ax.scatter([resnet['params']], [resnet['train_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha) | ||||||
|  |   plt.grid() | ||||||
|  |   ax.set_axisbelow(True) | ||||||
|  |   plt.legend(loc=4, fontsize=LegendFontsize) | ||||||
|  |   ax.set_xlabel('#parameters (MB)', fontsize=LabelSize) | ||||||
|  |   ax.set_ylabel('the trarining accuracy (%)', fontsize=LabelSize) | ||||||
|  |   save_path = (vis_save_dir / '{:}-param-vs-train.pdf'.format(dataset)).resolve() | ||||||
|  |   fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') | ||||||
|  |   save_path = (vis_save_dir / '{:}-param-vs-train.png'.format(dataset)).resolve() | ||||||
|  |   fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') | ||||||
|  |   print ('{:} save into {:}'.format(time_string(), save_path)) | ||||||
|  |  | ||||||
|  |   fig = plt.figure(figsize=figsize) | ||||||
|  |   ax  = fig.add_subplot(111) | ||||||
|  |   plt.xlim(0, max(indexes)) | ||||||
|  |   plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize) | ||||||
|  |   if dataset == 'cifar10': | ||||||
|  |     plt.ylim(50, 100) | ||||||
|  |     plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) | ||||||
|  |   elif dataset == 'cifar100': | ||||||
|  |     plt.ylim(25,  75) | ||||||
|  |     plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize) | ||||||
|  |   else: | ||||||
|  |     plt.ylim(0, 50) | ||||||
|  |     plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize) | ||||||
|  |   ax.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue') | ||||||
|  |   ax.scatter([resnet['index']], [resnet['test_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha) | ||||||
|  |   plt.grid() | ||||||
|  |   ax.set_axisbelow(True) | ||||||
|  |   plt.legend(loc=4, fontsize=LegendFontsize) | ||||||
|  |   ax.set_xlabel('architecture ID', fontsize=LabelSize) | ||||||
|  |   ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize) | ||||||
|  |   save_path = (vis_save_dir / '{:}-test-over-ID.pdf'.format(dataset)).resolve() | ||||||
|  |   fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') | ||||||
|  |   save_path = (vis_save_dir / '{:}-test-over-ID.png'.format(dataset)).resolve() | ||||||
|  |   fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') | ||||||
|  |   print ('{:} save into {:}'.format(time_string(), save_path)) | ||||||
|  |   plt.close('all') | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def visualize_rank_over_time(meta_file, vis_save_dir): | ||||||
|  |   print ('\n' + '-'*150) | ||||||
|  |   vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |   print ('{:} start to visualize rank-over-time into {:}'.format(time_string(), vis_save_dir)) | ||||||
|  |   cache_file_path = vis_save_dir / 'rank-over-time-cache-info.pth' | ||||||
|  |   if not cache_file_path.exists(): | ||||||
|  |     print ('Do not find cache file : {:}'.format(cache_file_path)) | ||||||
|  |     nas_bench = API(str(meta_file)) | ||||||
|  |     print ('{:} load nas_bench done'.format(time_string())) | ||||||
|  |     params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list) | ||||||
|  |     #for iepoch in range(200): for index in range( len(nas_bench) ): | ||||||
|  |     for index in tqdm(range(len(nas_bench))): | ||||||
|  |       info = nas_bench.query_by_index(index, use_12epochs_result=False) | ||||||
|  |       for iepoch in range(200): | ||||||
|  |         res = info.get_metrics('cifar10'      , 'train'   , iepoch) ; train_acc = res['accuracy'] | ||||||
|  |         res = info.get_metrics('cifar10-valid', 'x-valid' , iepoch) ; valid_acc = res['accuracy'] | ||||||
|  |         res = info.get_metrics('cifar10'      , 'ori-test', iepoch) ; test_acc  = res['accuracy'] | ||||||
|  |         res = info.get_metrics('cifar10'      , 'ori-test', iepoch) ; otest_acc = res['accuracy'] | ||||||
|  |         train_accs[iepoch].append( train_acc ) | ||||||
|  |         valid_accs[iepoch].append( valid_acc ) | ||||||
|  |         test_accs [iepoch].append( test_acc ) | ||||||
|  |         otest_accs[iepoch].append( otest_acc ) | ||||||
|  |         if iepoch == 0: | ||||||
|  |           res = info.get_comput_costs('cifar10') ; flop, param = res['flops'], res['params'] | ||||||
|  |           flops.append( flop ) | ||||||
|  |           params.append( param ) | ||||||
|  |     info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs, 'otest_accs': otest_accs} | ||||||
|  |     torch.save(info, cache_file_path) | ||||||
|  |   else: | ||||||
|  |     print ('Find cache file : {:}'.format(cache_file_path)) | ||||||
|  |     info = torch.load(cache_file_path) | ||||||
|  |     params, flops, train_accs, valid_accs, test_accs, otest_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'], info['otest_accs'] | ||||||
|  |   print ('{:} collect data done.'.format(time_string())) | ||||||
|  |   #selected_epochs = [0, 100, 150, 180, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199] | ||||||
|  |   selected_epochs = list( range(200) ) | ||||||
|  |   x_xtests = test_accs[199] | ||||||
|  |   indexes  = list(range(len(x_xtests))) | ||||||
|  |   ord_idxs = sorted(indexes, key=lambda i: x_xtests[i]) | ||||||
|  |   for sepoch in selected_epochs: | ||||||
|  |     x_valids = valid_accs[sepoch] | ||||||
|  |     valid_ord_idxs = sorted(indexes, key=lambda i: x_valids[i]) | ||||||
|  |     valid_ord_lbls = [] | ||||||
|  |     for idx in ord_idxs: | ||||||
|  |       valid_ord_lbls.append( valid_ord_idxs.index(idx) ) | ||||||
|  |     # labeled data | ||||||
|  |     dpi, width, height = 300, 2600, 2600 | ||||||
|  |     figsize = width / float(dpi), height / float(dpi) | ||||||
|  |     LabelSize, LegendFontsize = 18, 18 | ||||||
|  |  | ||||||
|  |     fig = plt.figure(figsize=figsize) | ||||||
|  |     ax  = fig.add_subplot(111) | ||||||
|  |     plt.xlim(min(indexes), max(indexes)) | ||||||
|  |     plt.ylim(min(indexes), max(indexes)) | ||||||
|  |     plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize, rotation='vertical') | ||||||
|  |     plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize) | ||||||
|  |     ax.scatter(indexes, valid_ord_lbls, marker='^', s=0.5, c='tab:green', alpha=0.8) | ||||||
|  |     ax.scatter(indexes, indexes       , marker='o', s=0.5, c='tab:blue' , alpha=0.8) | ||||||
|  |     ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-10 validation') | ||||||
|  |     ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10 test') | ||||||
|  |     plt.grid(zorder=0) | ||||||
|  |     ax.set_axisbelow(True) | ||||||
|  |     plt.legend(loc='upper left', fontsize=LegendFontsize) | ||||||
|  |     ax.set_xlabel('architecture ranking in the final test accuracy', fontsize=LabelSize) | ||||||
|  |     ax.set_ylabel('architecture ranking in the validation set', fontsize=LabelSize) | ||||||
|  |     save_path = (vis_save_dir / 'time-{:03d}.pdf'.format(sepoch)).resolve() | ||||||
|  |     fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') | ||||||
|  |     save_path = (vis_save_dir / 'time-{:03d}.png'.format(sepoch)).resolve() | ||||||
|  |     fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') | ||||||
|  |     print ('{:} save into {:}'.format(time_string(), save_path)) | ||||||
|  |     plt.close('all') | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def write_video(save_dir): | ||||||
|  |   import cv2 | ||||||
|  |   video_save_path = save_dir / 'time.avi' | ||||||
|  |   print ('{:} start create video for {:}'.format(time_string(), video_save_path)) | ||||||
|  |   images = sorted( list( save_dir.glob('time-*.png') ) ) | ||||||
|  |   ximage = cv2.imread(str(images[0])) | ||||||
|  |   #shape  = (ximage.shape[1], ximage.shape[0]) | ||||||
|  |   shape  = (1000, 1000) | ||||||
|  |   #writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 25, shape) | ||||||
|  |   writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 5, shape) | ||||||
|  |   for idx, image in enumerate(images): | ||||||
|  |     ximage = cv2.imread(str(image)) | ||||||
|  |     _image = cv2.resize(ximage, shape) | ||||||
|  |     writer.write(_image) | ||||||
|  |   writer.release() | ||||||
|  |   print ('write video [{:} frames] into {:}'.format(len(images), video_save_path)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |  | ||||||
|  |   parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||||
|  |   parser.add_argument('--save_dir',  type=str, default='./output/search-cell-nas-bench-102/visual', help='The base-name of folder to save checkpoints and log.') | ||||||
|  |   parser.add_argument('--api_path',  type=str, default=None,                                        help='The path to the NAS-Bench-102 benchmark file.') | ||||||
|  |   args = parser.parse_args() | ||||||
|  |    | ||||||
|  |   vis_save_dir = Path(args.save_dir) / 'visuals' | ||||||
|  |   vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |   meta_file = Path(args.api_path) | ||||||
|  |   assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) | ||||||
|  |   visualize_rank_over_time(str(meta_file), vis_save_dir / 'over-time') | ||||||
|  |   write_video(vis_save_dir / 'over-time') | ||||||
|  |   visualize_info(str(meta_file), 'cifar10' , vis_save_dir) | ||||||
|  |   visualize_info(str(meta_file), 'cifar100', vis_save_dir) | ||||||
|  |   visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir) | ||||||
|  |   visualize_relative_ranking(vis_save_dir) | ||||||
| @@ -53,43 +53,50 @@ def config2structure_func(max_nodes): | |||||||
|  |  | ||||||
| class MyWorker(Worker): | class MyWorker(Worker): | ||||||
|  |  | ||||||
|   def __init__(self, *args, convert_func=None, nas_bench=None, time_scale=None, **kwargs): |   def __init__(self, *args, convert_func=None, nas_bench=None, time_budget=None, **kwargs): | ||||||
|     super().__init__(*args, **kwargs) |     super().__init__(*args, **kwargs) | ||||||
|     self.convert_func   = convert_func |     self.convert_func   = convert_func | ||||||
|     self.nas_bench      = nas_bench |     self.nas_bench      = nas_bench | ||||||
|     self.time_scale     = time_scale |     self.time_budget    = time_budget | ||||||
|     self.seen_arch      = 0 |     self.seen_archs     = [] | ||||||
|     self.sim_cost_time  = 0 |     self.sim_cost_time  = 0 | ||||||
|     self.real_cost_time = 0 |     self.real_cost_time = 0 | ||||||
|  |     self.is_end         = False | ||||||
|  |  | ||||||
|  |   def get_the_best(self): | ||||||
|  |     assert len(self.seen_archs) > 0 | ||||||
|  |     best_index, best_acc = -1, None | ||||||
|  |     for arch_index in self.seen_archs: | ||||||
|  |       info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True) | ||||||
|  |       vacc = info['valid-accuracy'] | ||||||
|  |       if best_acc is None or best_acc < vacc: | ||||||
|  |         best_acc = vacc | ||||||
|  |         best_index = arch_index | ||||||
|  |     assert best_index != -1 | ||||||
|  |     return best_index | ||||||
|  |  | ||||||
|   def compute(self, config, budget, **kwargs): |   def compute(self, config, budget, **kwargs): | ||||||
|     start_time = time.time() |     start_time = time.time() | ||||||
|     structure  = self.convert_func( config ) |     structure  = self.convert_func( config ) | ||||||
|     arch_index = self.nas_bench.query_index_by_arch( structure ) |     arch_index = self.nas_bench.query_index_by_arch( structure ) | ||||||
|     iepoch     = 0 |     info       = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True) | ||||||
|     while iepoch < 12: |     cur_time   = info['train-all-time'] + info['valid-per-time'] | ||||||
|       info     = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch, True) |     cur_vacc   = info['valid-accuracy'] | ||||||
|       cur_time = info['train-all-time'] + info['valid-per-time'] |  | ||||||
|       cur_vacc = info['valid-accuracy'] |  | ||||||
|       if time.time() - start_time + cur_time / self.time_scale > budget: |  | ||||||
|         break |  | ||||||
|       else: |  | ||||||
|         iepoch += 1 |  | ||||||
|     self.sim_cost_time += cur_time |  | ||||||
|     self.seen_arch += 1 |  | ||||||
|     remaining_time = cur_time / self.time_scale - (time.time() - start_time) |  | ||||||
|     if remaining_time > 0: |  | ||||||
|       time.sleep(remaining_time) |  | ||||||
|     else: |  | ||||||
|       import pdb; pdb.set_trace() |  | ||||||
|     self.real_cost_time += (time.time() - start_time) |     self.real_cost_time += (time.time() - start_time) | ||||||
|     return ({ |     if self.sim_cost_time + cur_time <= self.time_budget and not self.is_end: | ||||||
|             'loss': 100 - float(cur_vacc), |       self.sim_cost_time += cur_time | ||||||
|             'info': {'seen-arch'     : self.seen_arch, |       self.seen_archs.append( arch_index ) | ||||||
|                      'sim-test-time' : self.sim_cost_time, |       return ({'loss': 100 - float(cur_vacc), | ||||||
|                      'real-test-time': self.real_cost_time, |                'info': {'seen-arch'     : len(self.seen_archs), | ||||||
|                      'current-arch'  : arch_index, |                         'sim-test-time' : self.sim_cost_time, | ||||||
|                      'current-budget': budget} |                         'current-arch'  : arch_index} | ||||||
|  |             }) | ||||||
|  |     else: | ||||||
|  |       self.is_end = True | ||||||
|  |       return ({'loss': 100, | ||||||
|  |                'info': {'seen-arch'     : len(self.seen_archs), | ||||||
|  |                         'sim-test-time' : self.sim_cost_time, | ||||||
|  |                         'current-arch'  : None} | ||||||
|             }) |             }) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -139,16 +146,14 @@ def main(xargs, nas_bench): | |||||||
|   #logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string())) |   #logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string())) | ||||||
|   workers = [] |   workers = [] | ||||||
|   for i in range(num_workers): |   for i in range(num_workers): | ||||||
|     w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_scale=xargs.time_scale, run_id=hb_run_id, id=i) |     w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_budget=xargs.time_budget, run_id=hb_run_id, id=i) | ||||||
|     w.run(background=True) |     w.run(background=True) | ||||||
|     workers.append(w) |     workers.append(w) | ||||||
|  |  | ||||||
|   simulate_time_budge = xargs.time_budget // xargs.time_scale |  | ||||||
|   start_time = time.time() |   start_time = time.time() | ||||||
|   logger.log('simulate_time_budge : {:} (in seconds).'.format(simulate_time_budge)) |  | ||||||
|   bohb = BOHB(configspace=cs, |   bohb = BOHB(configspace=cs, | ||||||
|             run_id=hb_run_id, |             run_id=hb_run_id, | ||||||
|             eta=3, min_budget=simulate_time_budge//3, max_budget=simulate_time_budge, |             eta=3, min_budget=12, max_budget=200, | ||||||
|             nameserver=ns_host, |             nameserver=ns_host, | ||||||
|             nameserver_port=ns_port, |             nameserver_port=ns_port, | ||||||
|             num_samples=xargs.num_samples, |             num_samples=xargs.num_samples, | ||||||
| @@ -161,11 +166,9 @@ def main(xargs, nas_bench): | |||||||
|   NS.shutdown() |   NS.shutdown() | ||||||
|  |  | ||||||
|   real_cost_time = time.time() - start_time |   real_cost_time = time.time() - start_time | ||||||
|   import pdb; pdb.set_trace() |  | ||||||
|  |  | ||||||
|   id2config = results.get_id2config_mapping() |   id2config = results.get_id2config_mapping() | ||||||
|   incumbent = results.get_incumbent_id() |   incumbent = results.get_incumbent_id() | ||||||
|  |  | ||||||
|   logger.log('Best found configuration: {:}'.format(id2config[incumbent]['config'])) |   logger.log('Best found configuration: {:}'.format(id2config[incumbent]['config'])) | ||||||
|   best_arch = config2structure( id2config[incumbent]['config'] ) |   best_arch = config2structure( id2config[incumbent]['config'] ) | ||||||
|  |  | ||||||
| @@ -174,7 +177,7 @@ def main(xargs, nas_bench): | |||||||
|   else           : logger.log('{:}'.format(info)) |   else           : logger.log('{:}'.format(info)) | ||||||
|   logger.log('-'*100) |   logger.log('-'*100) | ||||||
|  |  | ||||||
|   logger.log('workers : {:}'.format(workers[0].test_time)) |   logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs))) | ||||||
|   logger.close() |   logger.close() | ||||||
|   return logger.log_dir, nas_bench.query_index_by_arch( best_arch ) |   return logger.log_dir, nas_bench.query_index_by_arch( best_arch ) | ||||||
|    |    | ||||||
| @@ -190,14 +193,13 @@ if __name__ == '__main__': | |||||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') |   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||||
|   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') |   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') | ||||||
|   parser.add_argument('--time_budget',        type=int,   help='The total time cost budge for searching (in seconds).') |   parser.add_argument('--time_budget',        type=int,   help='The total time cost budge for searching (in seconds).') | ||||||
|   parser.add_argument('--time_scale' ,        type=int,   help='The time scale to accelerate the time budget.') |  | ||||||
|   # BOHB |   # BOHB | ||||||
|   parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function') |   parser.add_argument('--strategy', default="sampling",  type=str, nargs='?', help='optimization strategy for the acquisition function') | ||||||
|   parser.add_argument('--min_bandwidth',    default=.3, type=float, nargs='?', help='minimum bandwidth for KDE') |   parser.add_argument('--min_bandwidth',    default=.3,  type=float, nargs='?', help='minimum bandwidth for KDE') | ||||||
|   parser.add_argument('--num_samples',      default=64, type=int, nargs='?', help='number of samples for the acquisition function') |   parser.add_argument('--num_samples',      default=64,  type=int, nargs='?', help='number of samples for the acquisition function') | ||||||
|   parser.add_argument('--random_fraction',  default=.33, type=float, nargs='?', help='fraction of random configurations') |   parser.add_argument('--random_fraction',  default=.33, type=float, nargs='?', help='fraction of random configurations') | ||||||
|   parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth') |   parser.add_argument('--bandwidth_factor', default=3,   type=int, nargs='?', help='factor multiplied to the bandwidth') | ||||||
|   parser.add_argument('--n_iters', default=100, type=int, nargs='?', help='number of iterations for optimization method') |   parser.add_argument('--n_iters',          default=100, type=int, nargs='?', help='number of iterations for optimization method') | ||||||
|   # log |   # log | ||||||
|   parser.add_argument('--workers',            type=int,   default=2,    help='number of data loading workers (default: 2)') |   parser.add_argument('--workers',            type=int,   default=2,    help='number of data loading workers (default: 2)') | ||||||
|   parser.add_argument('--save_dir',           type=str,   help='Folder to save checkpoints and log.') |   parser.add_argument('--save_dir',           type=str,   help='Folder to save checkpoints and log.') | ||||||
|   | |||||||
| @@ -82,14 +82,29 @@ def valid_func(xloader, network, criterion): | |||||||
|   return arch_losses.avg, arch_top1.avg, arch_top5.avg |   return arch_losses.avg, arch_top1.avg, arch_top5.avg | ||||||
|  |  | ||||||
|  |  | ||||||
| def search_find_best(valid_loader, network, criterion, select_num): | def search_find_best(xloader, network, n_samples): | ||||||
|   best_arch, best_acc = None, -1 |   with torch.no_grad(): | ||||||
|   for iarch in range(select_num): |     network.eval() | ||||||
|     arch = network.module.random_genotype( True ) |     archs, valid_accs = [], [] | ||||||
|     valid_a_loss, valid_a_top1, valid_a_top5  = valid_func(valid_loader, network, criterion) |     #print ('obtain the top-{:} architectures'.format(n_samples)) | ||||||
|     if best_arch is None or best_acc < valid_a_top1: |     loader_iter = iter(xloader) | ||||||
|       best_arch, best_acc = arch, valid_a_top1 |     for i in range(n_samples): | ||||||
|   return best_arch |       arch = network.module.random_genotype( True ) | ||||||
|  |       try: | ||||||
|  |         inputs, targets = next(loader_iter) | ||||||
|  |       except: | ||||||
|  |         loader_iter = iter(xloader) | ||||||
|  |         inputs, targets = next(loader_iter) | ||||||
|  |  | ||||||
|  |       _, logits = network(inputs) | ||||||
|  |       val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) | ||||||
|  |  | ||||||
|  |       archs.append( arch ) | ||||||
|  |       valid_accs.append( val_top1.item() ) | ||||||
|  |  | ||||||
|  |     best_idx = np.argmax(valid_accs) | ||||||
|  |     best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] | ||||||
|  |     return best_arch, best_valid_acc | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(xargs): | def main(xargs): | ||||||
| @@ -127,7 +142,7 @@ def main(xargs): | |||||||
|   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) |   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||||
|   # data loader |   # data loader | ||||||
|   search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) |   search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) | ||||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) |   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||||
|   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) |   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) | ||||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) |   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||||
|  |  | ||||||
| @@ -177,7 +192,8 @@ def main(xargs): | |||||||
|     logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) |     logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) | ||||||
|     valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) |     valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||||
|     logger.log('[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) |     logger.log('[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) | ||||||
|     cur_arch = search_find_best(valid_loader, network, criterion, xargs.select_num) |     cur_arch, cur_valid_acc = search_find_best(valid_loader, network, xargs.select_num) | ||||||
|  |     logger.log('[{:}] find-the-best : {:}, accuracy@1={:.2f}%'.format(epoch_str, cur_arch, cur_valid_acc)) | ||||||
|     genotypes[epoch] = cur_arch |     genotypes[epoch] = cur_arch | ||||||
|     # check the best accuracy |     # check the best accuracy | ||||||
|     valid_accuracies[epoch] = valid_a_top1 |     valid_accuracies[epoch] = valid_a_top1 | ||||||
| @@ -211,13 +227,7 @@ def main(xargs): | |||||||
|   logger.log('\n' + '-'*200) |   logger.log('\n' + '-'*200) | ||||||
|   logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum)) |   logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum)) | ||||||
|   start_time = time.time() |   start_time = time.time() | ||||||
|   best_arch, best_acc = None, -1 |   best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num) | ||||||
|   for iarch in range(xargs.select_num): |  | ||||||
|     arch = search_model.random_genotype( True ) |  | ||||||
|     valid_a_loss, valid_a_top1, valid_a_top5  = valid_func(valid_loader, network, criterion) |  | ||||||
|     logger.log('final evaluation [{:02d}/{:02d}] : {:} : accuracy={:.2f}%, loss={:.3f}'.format(iarch, xargs.select_num, arch, valid_a_top1, valid_a_loss)) |  | ||||||
|     if best_arch is None or best_acc < valid_a_top1: |  | ||||||
|       best_arch, best_acc = arch, valid_a_top1 |  | ||||||
|   search_time.update(time.time() - start_time) |   search_time.update(time.time() - start_time) | ||||||
|   logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum)) |   logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum)) | ||||||
|   if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) )) |   if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) )) | ||||||
|   | |||||||
| @@ -26,8 +26,6 @@ def get_depth_choices(nDepth, return_num): | |||||||
|   else         : return choices |   else         : return choices | ||||||
|    |    | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def conv_forward(inputs, conv, choices): | def conv_forward(inputs, conv, choices): | ||||||
|   iC = conv.in_channels |   iC = conv.in_channels | ||||||
|   fill_size = list(inputs.size()) |   fill_size = list(inputs.size()) | ||||||
|   | |||||||
| @@ -104,14 +104,19 @@ class NASBench102API(object): | |||||||
|       print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index)) |       print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index)) | ||||||
|       return None |       return None | ||||||
|  |  | ||||||
|   def query_by_index(self, arch_index, dataname, use_12epochs_result=False): |   # query information with the training of 12 epochs or 200 epochs | ||||||
|  |   # if dataname is None, return the ArchResults | ||||||
|  |   # else, return a dict with all trials on that dataset (the key is the seed) | ||||||
|  |   def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False): | ||||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less |     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full |     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||||
|     assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr) |     assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr) | ||||||
|     archInfo = copy.deepcopy( arch2infos[ arch_index ] ) |     archInfo = copy.deepcopy( arch2infos[ arch_index ] ) | ||||||
|     assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname) |     if dataname is None: return archInfo | ||||||
|     info = archInfo.query(dataname) |     else: | ||||||
|     return info |       assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname) | ||||||
|  |       info = archInfo.query(dataname) | ||||||
|  |       return info | ||||||
|  |  | ||||||
|   def query_meta_info_by_index(self, arch_index, use_12epochs_result=False): |   def query_meta_info_by_index(self, arch_index, use_12epochs_result=False): | ||||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less |     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||||
| @@ -266,7 +271,7 @@ class ArchResults(object): | |||||||
|   def query(self, dataset, seed=None): |   def query(self, dataset, seed=None): | ||||||
|     if seed is None: |     if seed is None: | ||||||
|       x_seeds = self.dataset_seed[dataset] |       x_seeds = self.dataset_seed[dataset] | ||||||
|       return [self.all_results[ (dataset, seed) ] for seed in x_seeds] |       return {seed: self.all_results[ (dataset, seed) ] for seed in x_seeds} | ||||||
|     else: |     else: | ||||||
|       return self.all_results[ (dataset, seed) ] |       return self.all_results[ (dataset, seed) ] | ||||||
|  |  | ||||||
|   | |||||||
| @@ -34,6 +34,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \ | |||||||
| 	--dataset ${dataset} --data_path ${data_path} \ | 	--dataset ${dataset} --data_path ${data_path} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | ||||||
| 	--time_budget 12000 --time_scale 200 \ | 	--time_budget 12000  \ | ||||||
| 	--n_iters 64 --num_samples 4 --random_fraction 0 \ | 	--n_iters 28 --num_samples 64 --random_fraction .33 --bandwidth_factor 3 \ | ||||||
| 	--workers 4 --print_freq 200 --rand_seed ${seed} | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user