From a45808b8e6b8b7ded266c3818264bfc43446dfb1 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Wed, 1 Jul 2020 12:29:46 +0000 Subject: [PATCH] Update the test codes for NAS-Bench-API --- docs/NAS-Bench-201.md | 2 + exps/NAS-Bench-201/test-nas-api-vis.py | 254 ++++++++++++++++++++++++- exps/NAS-Bench-201/test-nas-api.py | 213 ++------------------- lib/nas_201_api/api_201.py | 6 +- lib/nas_201_api/api_utils.py | 24 ++- 5 files changed, 287 insertions(+), 212 deletions(-) diff --git a/docs/NAS-Bench-201.md b/docs/NAS-Bench-201.md index 2520e03..13800ba 100644 --- a/docs/NAS-Bench-201.md +++ b/docs/NAS-Bench-201.md @@ -40,6 +40,8 @@ It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default). ## How to Use NAS-Bench-201 +**More usage can be found in [our test codes](https://github.com/D-X-Y/AutoDL-Projects/blob/master/exps/NAS-Bench-201/test-nas-api.py)**. + 1. Creating an API instance from a file: ``` from nas_201_api import NASBench201API as API diff --git a/exps/NAS-Bench-201/test-nas-api-vis.py b/exps/NAS-Bench-201/test-nas-api-vis.py index 34bd18b..f08ec5f 100644 --- a/exps/NAS-Bench-201/test-nas-api-vis.py +++ b/exps/NAS-Bench-201/test-nas-api-vis.py @@ -81,6 +81,244 @@ def visualize_info(api, vis_save_dir, indicator): print ('{:} save into {:}'.format(time_string(), save_path)) +def visualize_sss_info(api, dataset, vis_save_dir): + vis_save_dir = vis_save_dir.resolve() + print ('{:} start to visualize {:} information'.format(time_string(), dataset)) + vis_save_dir.mkdir(parents=True, exist_ok=True) + cache_file_path = vis_save_dir / '{:}-cache-sss-info.pth'.format(dataset) + if not cache_file_path.exists(): + print ('Do not find cache file : {:}'.format(cache_file_path)) + params, flops, train_accs, valid_accs, test_accs = [], [], [], [], [] + for index in range(len(api)): + info = api.get_cost_info(index, dataset) + params.append(info['params']) + flops.append(info['flops']) + # accuracy + info = api.get_more_info(index, dataset, hp='90') + train_accs.append(info['train-accuracy']) + test_accs.append(info['test-accuracy']) + if dataset == 'cifar10': + info = api.get_more_info(index, 'cifar10-valid', hp='90') + valid_accs.append(info['valid-accuracy']) + else: + valid_accs.append(info['valid-accuracy']) + info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs} + torch.save(info, cache_file_path) + else: + print ('Find cache file : {:}'.format(cache_file_path)) + info = torch.load(cache_file_path) + params, flops, train_accs, valid_accs, test_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'] + print ('{:} collect data done.'.format(time_string())) + + pyramid = ['8:16:32:48:64', '8:8:16:32:48', '8:8:16:16:32', '8:8:16:16:48', '8:8:16:16:64', '16:16:32:32:64', '32:32:64:64:64'] + pyramid_indexes = [api.query_index_by_arch(x) for x in pyramid] + largest_indexes = [api.query_index_by_arch('64:64:64:64:64')] + + indexes = list(range(len(params))) + dpi, width, height = 250, 8500, 1300 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 24, 24 + # resnet_scale, resnet_alpha = 120, 0.5 + xscale, xalpha = 120, 0.8 + + fig, axs = plt.subplots(1, 4, figsize=figsize) + # ax1, ax2, ax3, ax4, ax5 = axs + for ax in axs: + for tick in ax.xaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize) + ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.0f')) + for tick in ax.yaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize) + ax2, ax3, ax4, ax5 = axs + # ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5)) + # ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue') + # ax1.set_xlabel('architecture ID', fontsize=LabelSize) + # ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize) + + ax2.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue') + ax2.scatter([params[x] for x in pyramid_indexes], [train_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha) + ax2.scatter([params[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) + ax2.set_xlabel('#parameters (MB)', fontsize=LabelSize) + ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize) + ax2.legend(loc=4, fontsize=LegendFontsize) + + ax3.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue') + ax3.scatter([params[x] for x in pyramid_indexes], [test_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha) + ax3.scatter([params[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) + ax3.set_xlabel('#parameters (MB)', fontsize=LabelSize) + ax3.set_ylabel('test accuracy (%)', fontsize=LabelSize) + ax3.legend(loc=4, fontsize=LegendFontsize) + + ax4.scatter(flops, train_accs, marker='o', s=0.5, c='tab:blue') + ax4.scatter([flops[x] for x in pyramid_indexes], [train_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha) + ax4.scatter([flops[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) + ax4.set_xlabel('#FLOPs (M)', fontsize=LabelSize) + ax4.set_ylabel('train accuracy (%)', fontsize=LabelSize) + ax4.legend(loc=4, fontsize=LegendFontsize) + + ax5.scatter(flops, test_accs, marker='o', s=0.5, c='tab:blue') + ax5.scatter([flops[x] for x in pyramid_indexes], [test_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha) + ax5.scatter([flops[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) + ax5.set_xlabel('#FLOPs (M)', fontsize=LabelSize) + ax5.set_ylabel('test accuracy (%)', fontsize=LabelSize) + ax5.legend(loc=4, fontsize=LegendFontsize) + + save_path = vis_save_dir / 'sss-{:}.png'.format(dataset) + fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') + print ('{:} save into {:}'.format(time_string(), save_path)) + plt.close('all') + + +def visualize_tss_info(api, dataset, vis_save_dir): + vis_save_dir = vis_save_dir.resolve() + print ('{:} start to visualize {:} information'.format(time_string(), dataset)) + vis_save_dir.mkdir(parents=True, exist_ok=True) + cache_file_path = vis_save_dir / '{:}-cache-tss-info.pth'.format(dataset) + if not cache_file_path.exists(): + print ('Do not find cache file : {:}'.format(cache_file_path)) + params, flops, train_accs, valid_accs, test_accs = [], [], [], [], [] + for index in range(len(api)): + info = api.get_cost_info(index, dataset) + params.append(info['params']) + flops.append(info['flops']) + # accuracy + info = api.get_more_info(index, dataset, hp='200') + train_accs.append(info['train-accuracy']) + test_accs.append(info['test-accuracy']) + if dataset == 'cifar10': + info = api.get_more_info(index, 'cifar10-valid', hp='200') + valid_accs.append(info['valid-accuracy']) + else: + valid_accs.append(info['valid-accuracy']) + info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs} + torch.save(info, cache_file_path) + else: + print ('Find cache file : {:}'.format(cache_file_path)) + info = torch.load(cache_file_path) + params, flops, train_accs, valid_accs, test_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'] + print ('{:} collect data done.'.format(time_string())) + + resnet = ['|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|'] + resnet_indexes = [api.query_index_by_arch(x) for x in resnet] + largest_indexes = [api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|')] + + indexes = list(range(len(params))) + dpi, width, height = 250, 8500, 1300 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 24, 24 + # resnet_scale, resnet_alpha = 120, 0.5 + xscale, xalpha = 120, 0.8 + + fig, axs = plt.subplots(1, 4, figsize=figsize) + # ax1, ax2, ax3, ax4, ax5 = axs + for ax in axs: + for tick in ax.xaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize) + ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.0f')) + for tick in ax.yaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize) + ax2, ax3, ax4, ax5 = axs + # ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5)) + # ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue') + # ax1.set_xlabel('architecture ID', fontsize=LabelSize) + # ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize) + + ax2.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue') + ax2.scatter([params[x] for x in resnet_indexes] , [train_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha) + ax2.scatter([params[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) + ax2.set_xlabel('#parameters (MB)', fontsize=LabelSize) + ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize) + ax2.legend(loc=4, fontsize=LegendFontsize) + + ax3.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue') + ax3.scatter([params[x] for x in resnet_indexes] , [test_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha) + ax3.scatter([params[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) + ax3.set_xlabel('#parameters (MB)', fontsize=LabelSize) + ax3.set_ylabel('test accuracy (%)', fontsize=LabelSize) + ax3.legend(loc=4, fontsize=LegendFontsize) + + ax4.scatter(flops, train_accs, marker='o', s=0.5, c='tab:blue') + ax4.scatter([flops[x] for x in resnet_indexes], [train_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha) + ax4.scatter([flops[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) + ax4.set_xlabel('#FLOPs (M)', fontsize=LabelSize) + ax4.set_ylabel('train accuracy (%)', fontsize=LabelSize) + ax4.legend(loc=4, fontsize=LegendFontsize) + + ax5.scatter(flops, test_accs, marker='o', s=0.5, c='tab:blue') + ax5.scatter([flops[x] for x in resnet_indexes], [test_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha) + ax5.scatter([flops[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) + ax5.set_xlabel('#FLOPs (M)', fontsize=LabelSize) + ax5.set_ylabel('test accuracy (%)', fontsize=LabelSize) + ax5.legend(loc=4, fontsize=LegendFontsize) + + save_path = vis_save_dir / 'tss-{:}.png'.format(dataset) + fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') + print ('{:} save into {:}'.format(time_string(), save_path)) + plt.close('all') + + +def visualize_rank_info(api, vis_save_dir, indicator): + vis_save_dir = vis_save_dir.resolve() + # print ('{:} start to visualize {:} information'.format(time_string(), api)) + vis_save_dir.mkdir(parents=True, exist_ok=True) + + cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator) + cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator) + imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator) + cifar010_info = torch.load(cifar010_cache_path) + cifar100_info = torch.load(cifar100_cache_path) + imagenet_info = torch.load(imagenet_cache_path) + indexes = list(range(len(cifar010_info['params']))) + + print ('{:} start to visualize relative ranking'.format(time_string())) + + dpi, width, height = 250, 3800, 1200 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 14, 14 + + fig, axs = plt.subplots(1, 3, figsize=figsize) + ax1, ax2, ax3 = axs + + def get_labels(info): + ord_test_indexes = sorted(indexes, key=lambda i: info['test_accs'][i]) + ord_valid_indexes = sorted(indexes, key=lambda i: info['valid_accs'][i]) + labels = [] + for idx in ord_test_indexes: + labels.append(ord_valid_indexes.index(idx)) + return labels + + def plot_ax(labels, ax, name): + for tick in ax.xaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize) + for tick in ax.yaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize) + tick.label.set_rotation(90) + ax.set_xlim(min(indexes), max(indexes)) + ax.set_ylim(min(indexes), max(indexes)) + ax.yaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes)//3)) + ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes)//5)) + ax.scatter(indexes, labels , marker='^', s=0.5, c='tab:green', alpha=0.8) + ax.scatter(indexes, indexes, marker='o', s=0.5, c='tab:blue' , alpha=0.8) + ax.scatter([-1], [-1], marker='^', s=100, c='tab:green' , label='{:} test'.format(name)) + ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='{:} validation'.format(name)) + ax.legend(loc=4, fontsize=LegendFontsize) + ax.set_xlabel('ranking on the {:} validation'.format(name), fontsize=LabelSize) + ax.set_ylabel('architecture ranking', fontsize=LabelSize) + labels = get_labels(cifar010_info) + plot_ax(labels, ax1, 'CIFAR-10') + labels = get_labels(cifar100_info) + plot_ax(labels, ax2, 'CIFAR-100') + labels = get_labels(imagenet_info) + plot_ax(labels, ax3, 'ImageNet-16-120') + + save_path = (vis_save_dir / '{:}-same-relative-rank.pdf'.format(indicator)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') + save_path = (vis_save_dir / '{:}-same-relative-rank.png'.format(indicator)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') + print ('{:} save into {:}'.format(time_string(), save_path)) + plt.close('all') + + if __name__ == '__main__': parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--save_dir', type=str, default='output/NAS-BENCH-202', help='Folder to save checkpoints and log.') @@ -88,6 +326,20 @@ if __name__ == '__main__': # use for train the model args = parser.parse_args() - visualize_info(None, Path('output/vis-nas-bench/'), 'tss') + visualize_rank_info(None, Path('output/vis-nas-bench/'), 'tss') + visualize_rank_info(None, Path('output/vis-nas-bench/'), 'sss') + api201 = NASBench201API(None, verbose=True) + visualize_tss_info(api201, 'cifar10', Path('output/vis-nas-bench')) + visualize_tss_info(api201, 'cifar100', Path('output/vis-nas-bench')) + visualize_tss_info(api201, 'ImageNet16-120', Path('output/vis-nas-bench')) + + api301 = NASBench301API(None, verbose=True) + visualize_sss_info(api301, 'cifar10', Path('output/vis-nas-bench')) + visualize_sss_info(api301, 'cifar100', Path('output/vis-nas-bench')) + visualize_sss_info(api301, 'ImageNet16-120', Path('output/vis-nas-bench')) + + visualize_info(None, Path('output/vis-nas-bench/'), 'tss') visualize_info(None, Path('output/vis-nas-bench/'), 'sss') + visualize_rank_info(None, Path('output/vis-nas-bench/'), 'tss') + visualize_rank_info(None, Path('output/vis-nas-bench/'), 'sss') diff --git a/exps/NAS-Bench-201/test-nas-api.py b/exps/NAS-Bench-201/test-nas-api.py index 1e9ff42..d0f69dc 100644 --- a/exps/NAS-Bench-201/test-nas-api.py +++ b/exps/NAS-Bench-201/test-nas-api.py @@ -48,236 +48,51 @@ def test_api(api, is_301=True): print('') params = api.get_net_param(12, 'cifar10', None) - # obtain the config and create the network + # Obtain the config and create the network config = api.get_net_config(12, 'cifar10') print('{:}\n'.format(config)) network = get_cell_based_tiny_net(config) network.load_state_dict(next(iter(params.values()))) - # obtain the cost information + # Obtain the cost information info = api.get_cost_info(12, 'cifar10') print('{:}\n'.format(info)) info = api.get_latency(12, 'cifar10') print('{:}\n'.format(info)) - # count the number of architectures + # Count the number of architectures info = api.statistics('cifar100', '12') print('{:}\n'.format(info)) - # show the information of the 123-th architecture + # Show the information of the 123-th architecture api.show(123) - # obtain both cost and performance information + # Obtain both cost and performance information info = api.get_more_info(1234, 'cifar10') print('{:}\n'.format(info)) print('{:} finish testing the api : {:}'.format(time_string(), api)) -def visualize_sss_info(api, dataset, vis_save_dir): - vis_save_dir = vis_save_dir.resolve() - print ('{:} start to visualize {:} information'.format(time_string(), dataset)) - vis_save_dir.mkdir(parents=True, exist_ok=True) - cache_file_path = vis_save_dir / '{:}-cache-sss-info.pth'.format(dataset) - if not cache_file_path.exists(): - print ('Do not find cache file : {:}'.format(cache_file_path)) - params, flops, train_accs, valid_accs, test_accs = [], [], [], [], [] - for index in range(len(api)): - info = api.get_cost_info(index, dataset) - params.append(info['params']) - flops.append(info['flops']) - # accuracy - info = api.get_more_info(index, dataset, hp='90') - train_accs.append(info['train-accuracy']) - test_accs.append(info['test-accuracy']) - if dataset == 'cifar10': - info = api.get_more_info(index, 'cifar10-valid', hp='90') - valid_accs.append(info['valid-accuracy']) - else: - valid_accs.append(info['valid-accuracy']) - info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs} - torch.save(info, cache_file_path) - else: - print ('Find cache file : {:}'.format(cache_file_path)) - info = torch.load(cache_file_path) - params, flops, train_accs, valid_accs, test_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'] - print ('{:} collect data done.'.format(time_string())) - - pyramid = ['8:16:32:48:64', '8:8:16:32:48', '8:8:16:16:32', '8:8:16:16:48', '8:8:16:16:64', '16:16:32:32:64', '32:32:64:64:64'] - pyramid_indexes = [api.query_index_by_arch(x) for x in pyramid] - largest_indexes = [api.query_index_by_arch('64:64:64:64:64')] - - indexes = list(range(len(params))) - dpi, width, height = 250, 8500, 1300 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 24, 24 - # resnet_scale, resnet_alpha = 120, 0.5 - xscale, xalpha = 120, 0.8 - - fig, axs = plt.subplots(1, 4, figsize=figsize) - # ax1, ax2, ax3, ax4, ax5 = axs - for ax in axs: - for tick in ax.xaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize) - ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.0f')) - for tick in ax.yaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize) - ax2, ax3, ax4, ax5 = axs - # ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5)) - # ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue') - # ax1.set_xlabel('architecture ID', fontsize=LabelSize) - # ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize) - - ax2.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue') - ax2.scatter([params[x] for x in pyramid_indexes], [train_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha) - ax2.scatter([params[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax2.set_xlabel('#parameters (MB)', fontsize=LabelSize) - ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize) - ax2.legend(loc=4, fontsize=LegendFontsize) - - ax3.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue') - ax3.scatter([params[x] for x in pyramid_indexes], [test_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha) - ax3.scatter([params[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax3.set_xlabel('#parameters (MB)', fontsize=LabelSize) - ax3.set_ylabel('test accuracy (%)', fontsize=LabelSize) - ax3.legend(loc=4, fontsize=LegendFontsize) - - ax4.scatter(flops, train_accs, marker='o', s=0.5, c='tab:blue') - ax4.scatter([flops[x] for x in pyramid_indexes], [train_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha) - ax4.scatter([flops[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax4.set_xlabel('#FLOPs (M)', fontsize=LabelSize) - ax4.set_ylabel('train accuracy (%)', fontsize=LabelSize) - ax4.legend(loc=4, fontsize=LegendFontsize) - - ax5.scatter(flops, test_accs, marker='o', s=0.5, c='tab:blue') - ax5.scatter([flops[x] for x in pyramid_indexes], [test_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha) - ax5.scatter([flops[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax5.set_xlabel('#FLOPs (M)', fontsize=LabelSize) - ax5.set_ylabel('test accuracy (%)', fontsize=LabelSize) - ax5.legend(loc=4, fontsize=LegendFontsize) - - save_path = vis_save_dir / 'sss-{:}.png'.format(dataset) - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - plt.close('all') - - -def visualize_tss_info(api, dataset, vis_save_dir): - vis_save_dir = vis_save_dir.resolve() - print ('{:} start to visualize {:} information'.format(time_string(), dataset)) - vis_save_dir.mkdir(parents=True, exist_ok=True) - cache_file_path = vis_save_dir / '{:}-cache-tss-info.pth'.format(dataset) - if not cache_file_path.exists(): - print ('Do not find cache file : {:}'.format(cache_file_path)) - params, flops, train_accs, valid_accs, test_accs = [], [], [], [], [] - for index in range(len(api)): - info = api.get_cost_info(index, dataset) - params.append(info['params']) - flops.append(info['flops']) - # accuracy - info = api.get_more_info(index, dataset, hp='200') - train_accs.append(info['train-accuracy']) - test_accs.append(info['test-accuracy']) - if dataset == 'cifar10': - info = api.get_more_info(index, 'cifar10-valid', hp='200') - valid_accs.append(info['valid-accuracy']) - else: - valid_accs.append(info['valid-accuracy']) - info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs} - torch.save(info, cache_file_path) - else: - print ('Find cache file : {:}'.format(cache_file_path)) - info = torch.load(cache_file_path) - params, flops, train_accs, valid_accs, test_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'] - print ('{:} collect data done.'.format(time_string())) - - resnet = ['|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|'] - resnet_indexes = [api.query_index_by_arch(x) for x in resnet] - largest_indexes = [api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|')] - - indexes = list(range(len(params))) - dpi, width, height = 250, 8500, 1300 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 24, 24 - # resnet_scale, resnet_alpha = 120, 0.5 - xscale, xalpha = 120, 0.8 - - fig, axs = plt.subplots(1, 4, figsize=figsize) - # ax1, ax2, ax3, ax4, ax5 = axs - for ax in axs: - for tick in ax.xaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize) - ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.0f')) - for tick in ax.yaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize) - ax2, ax3, ax4, ax5 = axs - # ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5)) - # ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue') - # ax1.set_xlabel('architecture ID', fontsize=LabelSize) - # ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize) - - ax2.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue') - ax2.scatter([params[x] for x in resnet_indexes] , [train_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha) - ax2.scatter([params[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax2.set_xlabel('#parameters (MB)', fontsize=LabelSize) - ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize) - ax2.legend(loc=4, fontsize=LegendFontsize) - - ax3.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue') - ax3.scatter([params[x] for x in resnet_indexes] , [test_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha) - ax3.scatter([params[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax3.set_xlabel('#parameters (MB)', fontsize=LabelSize) - ax3.set_ylabel('test accuracy (%)', fontsize=LabelSize) - ax3.legend(loc=4, fontsize=LegendFontsize) - - ax4.scatter(flops, train_accs, marker='o', s=0.5, c='tab:blue') - ax4.scatter([flops[x] for x in resnet_indexes], [train_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha) - ax4.scatter([flops[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax4.set_xlabel('#FLOPs (M)', fontsize=LabelSize) - ax4.set_ylabel('train accuracy (%)', fontsize=LabelSize) - ax4.legend(loc=4, fontsize=LegendFontsize) - - ax5.scatter(flops, test_accs, marker='o', s=0.5, c='tab:blue') - ax5.scatter([flops[x] for x in resnet_indexes], [test_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha) - ax5.scatter([flops[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax5.set_xlabel('#FLOPs (M)', fontsize=LabelSize) - ax5.set_ylabel('test accuracy (%)', fontsize=LabelSize) - ax5.legend(loc=4, fontsize=LegendFontsize) - - save_path = vis_save_dir / 'tss-{:}.png'.format(dataset) - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - plt.close('all') - - def test_issue_81_82(api): - results = api.query_by_index(0, 'cifar10') + results = api.query_by_index(0, 'cifar10-valid', hp='12') results = api.query_by_index(0, 'cifar10-valid', hp='200') - print(results.keys()) + print(list(results.keys())) + print(results[888].get_eval('valid')) print(results[888].get_eval('x-valid')) result_dict = api.get_more_info(index=0, dataset='cifar10-valid', iepoch=11, hp='200', is_random=False) if __name__ == '__main__': - parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--save_dir', type=str, default='output/NAS-BENCH-202', help='Folder to save checkpoints and log.') - parser.add_argument('--check_N', type=int, default=32768, help='For safety.') - # use for train the model - args = parser.parse_args() api201 = NASBench201API(os.path.join(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_0-e61699.pth'), verbose=True) test_issue_81_82(api201) - test_api(api201, False) + # test_api(api201, False) + print ('Test {:} done'.format(api201)) + api201 = NASBench201API(None, verbose=True) test_issue_81_82(api201) - visualize_tss_info(api201, 'cifar10', Path('output/vis-nas-bench')) - visualize_tss_info(api201, 'cifar100', Path('output/vis-nas-bench')) - visualize_tss_info(api201, 'ImageNet16-120', Path('output/vis-nas-bench')) test_api(api201, False) + print ('Test {:} done'.format(api201)) - api301 = NASBench301API(None, verbose=True) - visualize_sss_info(api301, 'cifar10', Path('output/vis-nas-bench')) - visualize_sss_info(api301, 'cifar100', Path('output/vis-nas-bench')) - visualize_sss_info(api301, 'ImageNet16-120', Path('output/vis-nas-bench')) - test_api(api301, True) - - # save_dir = '{:}/visual'.format(args.save_dir) + # api301 = NASBench301API(None, verbose=True) + # test_api(api301, True) diff --git a/lib/nas_201_api/api_201.py b/lib/nas_201_api/api_201.py index f5accd0..257fc30 100644 --- a/lib/nas_201_api/api_201.py +++ b/lib/nas_201_api/api_201.py @@ -184,17 +184,17 @@ class NASBench201API(NASBenchMetaAPI): if valid_info is not None: xinfo['valid-loss'] = valid_info['loss'] xinfo['valid-accuracy'] = valid_info['accuracy'] - xinfo['valid-per-time'] = valid_info['all_time'] / total + xinfo['valid-per-time'] = valid_info['all_time'] / total if valid_info['all_time'] is not None else None xinfo['valid-all-time'] = valid_info['all_time'] if test_info is not None: xinfo['test-loss'] = test_info['loss'] xinfo['test-accuracy'] = test_info['accuracy'] - xinfo['test-per-time'] = test_info['all_time'] / total + xinfo['test-per-time'] = test_info['all_time'] / total if test_info['all_time'] is not None else None xinfo['test-all-time'] = test_info['all_time'] if valtest_info is not None: xinfo['valtest-loss'] = valtest_info['loss'] xinfo['valtest-accuracy'] = valtest_info['accuracy'] - xinfo['valtest-per-time'] = valtest_info['all_time'] / total + xinfo['valtest-per-time'] = valtest_info['all_time'] / total if valtest_info['all_time'] is not None else None xinfo['valtest-all-time'] = valtest_info['all_time'] return xinfo diff --git a/lib/nas_201_api/api_utils.py b/lib/nas_201_api/api_utils.py index 40fa03e..e9213f4 100644 --- a/lib/nas_201_api/api_utils.py +++ b/lib/nas_201_api/api_utils.py @@ -660,15 +660,21 @@ class ResultsCount(object): """Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument).""" if iepoch is None: iepoch = self.epochs-1 assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs) - if isinstance(self.eval_times,dict) and len(self.eval_times) > 0: - xtime = self.eval_times['{:}@{:}'.format(name,iepoch)] - atime = sum([self.eval_times['{:}@{:}'.format(name,i)] for i in range(iepoch+1)]) - else: xtime, atime = None, None - return {'iepoch' : iepoch, - 'loss' : self.eval_losses['{:}@{:}'.format(name,iepoch)], - 'accuracy': self.eval_acc1es['{:}@{:}'.format(name,iepoch)], - 'cur_time': xtime, - 'all_time': atime} + def _internal_query(xname): + if isinstance(self.eval_times,dict) and len(self.eval_times) > 0: + xtime = self.eval_times['{:}@{:}'.format(xname, iepoch)] + atime = sum([self.eval_times['{:}@{:}'.format(xname, i)] for i in range(iepoch+1)]) + else: + xtime, atime = None, None + return {'iepoch' : iepoch, + 'loss' : self.eval_losses['{:}@{:}'.format(xname, iepoch)], + 'accuracy': self.eval_acc1es['{:}@{:}'.format(xname, iepoch)], + 'cur_time': xtime, + 'all_time': atime} + if name == 'valid': + return _internal_query('x-valid') + else: + return _internal_query(name) def get_net_param(self, clone=False): if clone: return copy.deepcopy(self.net_state_dict)