Update visualization codes for NATS-Bench
This commit is contained in:
		| @@ -43,20 +43,14 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): | |||||||
|   # alg2name['REINFORCE'] = 'REINFORCE-0.01' |   # alg2name['REINFORCE'] = 'REINFORCE-0.01' | ||||||
|   # alg2name['RANDOM'] = 'RANDOM' |   # alg2name['RANDOM'] = 'RANDOM' | ||||||
|   # alg2name['BOHB'] = 'BOHB' |   # alg2name['BOHB'] = 'BOHB' | ||||||
|   if dataset == 'cifar10': |  | ||||||
|     suffixes = ['-T200000', '-T200000-FULL'] |  | ||||||
|   elif dataset == 'cifar100': |  | ||||||
|     suffixes = ['-T40000', '-T40000-FULL'] |  | ||||||
|   elif search_space == 'tss': |  | ||||||
|     suffixes = ['-T120000', '-T120000-FULL'] |  | ||||||
|   elif search_space == 'sss': |  | ||||||
|     suffixes = ['-T60000', '-T60000-FULL'] |  | ||||||
|   else: |  | ||||||
|     raise ValueError('Unkonwn dataset : {:}'.format(dataset)) |  | ||||||
|   if search_space == 'tss': |   if search_space == 'tss': | ||||||
|     hp = '$\mathcal{H}^{1}$' |     hp = '$\mathcal{H}^{1}$' | ||||||
|  |     if dataset == 'cifar10': | ||||||
|  |       suffixes = ['-T800000', '-T800000-FULL'] | ||||||
|   elif search_space == 'sss': |   elif search_space == 'sss': | ||||||
|     hp = '$\mathcal{H}^{2}$' |     hp = '$\mathcal{H}^{2}$' | ||||||
|  |     if dataset == 'cifar10': | ||||||
|  |       suffixes = ['-T200000', '-T200000-FULL'] | ||||||
|   else: |   else: | ||||||
|     raise ValueError('Unkonwn search space: {:}'.format(search_space)) |     raise ValueError('Unkonwn search space: {:}'.format(search_space)) | ||||||
|  |  | ||||||
| @@ -92,21 +86,21 @@ def query_performance(api, data, dataset, ticket): | |||||||
|   return np.mean(results), np.std(results) |   return np.mean(results), np.std(results) | ||||||
|  |  | ||||||
|  |  | ||||||
| y_min_s = {('cifar10', 'tss'): 90, | y_min_s = {('cifar10', 'tss'): 91, | ||||||
|            ('cifar10', 'sss'): 90, |            ('cifar10', 'sss'): 91, | ||||||
|            ('cifar100', 'tss'): 65, |            ('cifar100', 'tss'): 65, | ||||||
|            ('cifar100', 'sss'): 65, |            ('cifar100', 'sss'): 65, | ||||||
|            ('ImageNet16-120', 'tss'): 36, |            ('ImageNet16-120', 'tss'): 36, | ||||||
|            ('ImageNet16-120', 'sss'): 40} |            ('ImageNet16-120', 'sss'): 40} | ||||||
|  |  | ||||||
| y_max_s = {('cifar10', 'tss'): 94.5, | y_max_s = {('cifar10', 'tss'): 94.5, | ||||||
|            ('cifar10', 'sss'): 94.5, |            ('cifar10', 'sss'): 93.5, | ||||||
|            ('cifar100', 'tss'): 72.5, |            ('cifar100', 'tss'): 72.5, | ||||||
|            ('cifar100', 'sss'): 70.5, |            ('cifar100', 'sss'): 70.5, | ||||||
|            ('ImageNet16-120', 'tss'): 46, |            ('ImageNet16-120', 'tss'): 46, | ||||||
|            ('ImageNet16-120', 'sss'): 46} |            ('ImageNet16-120', 'sss'): 46} | ||||||
|  |  | ||||||
| x_axis_s = {('cifar10', 'tss'): 200000, | x_axis_s = {('cifar10', 'tss'): 800000, | ||||||
|             ('cifar10', 'sss'): 200000, |             ('cifar10', 'sss'): 200000, | ||||||
|             ('cifar100', 'tss'): 400, |             ('cifar100', 'tss'): 400, | ||||||
|             ('cifar100', 'sss'): 400, |             ('cifar100', 'sss'): 400, | ||||||
| @@ -124,9 +118,9 @@ def visualize_curve(api_dict, vis_save_dir): | |||||||
|   vis_save_dir = vis_save_dir.resolve() |   vis_save_dir = vis_save_dir.resolve() | ||||||
|   vis_save_dir.mkdir(parents=True, exist_ok=True) |   vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |  | ||||||
|   dpi, width, height = 250, 4000, 2400 |   dpi, width, height = 250, 5000, 2000 | ||||||
|   figsize = width / float(dpi), height / float(dpi) |   figsize = width / float(dpi), height / float(dpi) | ||||||
|   LabelSize, LegendFontsize = 16, 16 |   LabelSize, LegendFontsize = 28, 28 | ||||||
|  |  | ||||||
|   def sub_plot_fn(ax, search_space, dataset): |   def sub_plot_fn(ax, search_space, dataset): | ||||||
|     max_time = x_axis_s[(dataset, search_space)] |     max_time = x_axis_s[(dataset, search_space)] | ||||||
| @@ -137,6 +131,11 @@ def visualize_curve(api_dict, vis_save_dir): | |||||||
|     ax.set_xlim(0, x_axis_s[(dataset, search_space)]) |     ax.set_xlim(0, x_axis_s[(dataset, search_space)]) | ||||||
|     ax.set_ylim(y_min_s[(dataset, search_space)], |     ax.set_ylim(y_min_s[(dataset, search_space)], | ||||||
|                 y_max_s[(dataset, search_space)]) |                 y_max_s[(dataset, search_space)]) | ||||||
|  |     for tick in ax.get_xticklabels(): | ||||||
|  |       tick.set_rotation(25) | ||||||
|  |       tick.set_fontsize(LabelSize - 6) | ||||||
|  |     for tick in ax.get_yticklabels(): | ||||||
|  |       tick.set_fontsize(LabelSize - 6) | ||||||
|     for idx, (alg, xdata) in enumerate(alg2data.items()): |     for idx, (alg, xdata) in enumerate(alg2data.items()): | ||||||
|       accuracies = [] |       accuracies = [] | ||||||
|       for ticket in time_tickets: |       for ticket in time_tickets: | ||||||
| @@ -150,8 +149,8 @@ def visualize_curve(api_dict, vis_save_dir): | |||||||
|       ax.plot(time_tickets, accuracies, c=xdata['color'], linestyle=xdata['linestyle'], label='{:}'.format(alg)) |       ax.plot(time_tickets, accuracies, c=xdata['color'], linestyle=xdata['linestyle'], label='{:}'.format(alg)) | ||||||
|       ax.set_xlabel('Estimated wall-clock time', fontsize=LabelSize) |       ax.set_xlabel('Estimated wall-clock time', fontsize=LabelSize) | ||||||
|       ax.set_ylabel('Test accuracy', fontsize=LabelSize) |       ax.set_ylabel('Test accuracy', fontsize=LabelSize) | ||||||
|       ax.set_title(r'Searching results on {:} for {:}'.format(name2label[dataset], spaces2latex[search_space]), |       ax.set_title(r'Results on {:} over {:}'.format(name2label[dataset], spaces2latex[search_space]), | ||||||
|         fontsize=LabelSize+4) |         fontsize=LabelSize) | ||||||
|     ax.legend(loc=4, fontsize=LegendFontsize) |     ax.legend(loc=4, fontsize=LegendFontsize) | ||||||
|  |  | ||||||
|   fig, axs = plt.subplots(1, 2, figsize=figsize) |   fig, axs = plt.subplots(1, 2, figsize=figsize) | ||||||
| @@ -165,7 +164,7 @@ def visualize_curve(api_dict, vis_save_dir): | |||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   parser = argparse.ArgumentParser(description='NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |   parser = argparse.ArgumentParser(description='NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||||
|   parser.add_argument('--save_dir',     type=str,   default='output/vis-nas-bench/nas-algos-vs-h', help='Folder to save checkpoints and log.') |   parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos-vs-h', help='Folder to save checkpoints and log.') | ||||||
|   args = parser.parse_args() |   args = parser.parse_args() | ||||||
|  |  | ||||||
|   save_dir = Path(args.save_dir) |   save_dir = Path(args.save_dir) | ||||||
|   | |||||||
| @@ -11,7 +11,7 @@ import scipy | |||||||
| import numpy as np | import numpy as np | ||||||
| from typing import List, Text, Dict, Any | from typing import List, Text, Dict, Any | ||||||
| from shutil import copyfile | from shutil import copyfile | ||||||
| from collections import defaultdict | from collections import defaultdict, OrderedDict | ||||||
| from copy    import deepcopy | from copy    import deepcopy | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| import matplotlib | import matplotlib | ||||||
| @@ -28,69 +28,103 @@ from models import get_cell_based_tiny_net | |||||||
| from nats_bench import create | from nats_bench import create | ||||||
|  |  | ||||||
|  |  | ||||||
| def visualize_relative_info(api, vis_save_dir, indicator): | name2label = {'cifar10': 'CIFAR-10', | ||||||
|  |               'cifar100': 'CIFAR-100', | ||||||
|  |               'ImageNet16-120': 'ImageNet-16-120'} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def visualize_relative_info(vis_save_dir, search_space, indicator, topk): | ||||||
|   vis_save_dir = vis_save_dir.resolve() |   vis_save_dir = vis_save_dir.resolve() | ||||||
|   # print ('{:} start to visualize {:} information'.format(time_string(), api)) |   print ('{:} start to visualize {:} with top-{:} information'.format(time_string(), search_space, topk)) | ||||||
|   vis_save_dir.mkdir(parents=True, exist_ok=True) |   vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |   cache_file_path = vis_save_dir / 'cache-{:}-info.pth'.format(search_space) | ||||||
|  |   datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] | ||||||
|  |   if not cache_file_path.exists(): | ||||||
|  |     api = create(None, search_space, fast_mode=False, verbose=False) | ||||||
|  |     all_infos = OrderedDict() | ||||||
|  |     for index in range(len(api)): | ||||||
|  |       all_info = OrderedDict() | ||||||
|  |       for dataset in datasets: | ||||||
|  |         info_less = api.get_more_info(index, dataset, hp='12', is_random=False) | ||||||
|  |         info_more = api.get_more_info(index, dataset, hp=api.full_train_epochs, is_random=False) | ||||||
|  |         all_info[dataset] = dict(less=info_less['test-accuracy'], | ||||||
|  |                                  more=info_more['test-accuracy']) | ||||||
|  |       all_infos[index] = all_info | ||||||
|  |     torch.save(all_infos, cache_file_path) | ||||||
|  |     print ('{:} save all cache data into {:}'.format(time_string(), cache_file_path)) | ||||||
|  |   else: | ||||||
|  |     api = create(None, search_space, fast_mode=True, verbose=False) | ||||||
|  |     all_infos = torch.load(cache_file_path) | ||||||
|  |  | ||||||
|   cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator) |  | ||||||
|   cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator) |  | ||||||
|   imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator) |  | ||||||
|   cifar010_info = torch.load(cifar010_cache_path) |  | ||||||
|   cifar100_info = torch.load(cifar100_cache_path) |  | ||||||
|   imagenet_info = torch.load(imagenet_cache_path) |  | ||||||
|   indexes       = list(range(len(cifar010_info['params']))) |  | ||||||
|  |  | ||||||
|   print ('{:} start to visualize relative ranking'.format(time_string())) |   dpi, width, height = 250, 5000, 1300 | ||||||
|  |  | ||||||
|   cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i]) |  | ||||||
|   cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i]) |  | ||||||
|   imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i]) |  | ||||||
|  |  | ||||||
|   cifar100_labels, imagenet_labels = [], [] |  | ||||||
|   for idx in cifar010_ord_indexes: |  | ||||||
|     cifar100_labels.append( cifar100_ord_indexes.index(idx) ) |  | ||||||
|     imagenet_labels.append( imagenet_ord_indexes.index(idx) ) |  | ||||||
|   print ('{:} prepare data done.'.format(time_string())) |  | ||||||
|  |  | ||||||
|   dpi, width, height = 200, 1400,  800 |  | ||||||
|   figsize = width / float(dpi), height / float(dpi) |   figsize = width / float(dpi), height / float(dpi) | ||||||
|   LabelSize, LegendFontsize = 18, 12 |   LabelSize, LegendFontsize = 16, 16 | ||||||
|   resnet_scale, resnet_alpha = 120, 0.5 |  | ||||||
|  |  | ||||||
|   fig = plt.figure(figsize=figsize) |   fig, axs = plt.subplots(1, 3, figsize=figsize) | ||||||
|   ax  = fig.add_subplot(111) |   datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] | ||||||
|   plt.xlim(min(indexes), max(indexes)) |    | ||||||
|   plt.ylim(min(indexes), max(indexes)) |   def sub_plot_fn(ax, dataset, indicator): | ||||||
|   # plt.ylabel('y').set_rotation(30) |     performances = [] | ||||||
|   plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//3), fontsize=LegendFontsize, rotation='vertical') |     # pickup top 10% architectures | ||||||
|   plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize) |     for _index in range(len(api)): | ||||||
|   ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8) |       performances.append((all_infos[_index][dataset][indicator], _index)) | ||||||
|   ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red'  , alpha=0.8) |     performances = sorted(performances, reverse=True) | ||||||
|   ax.scatter(indexes, indexes        , marker='o', s=0.5, c='tab:blue' , alpha=0.8) |     performances = performances[: int(len(api) * topk * 0.01)] | ||||||
|   ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10') |     selected_indexes = [x[1] for x in performances] | ||||||
|   ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100') |     print('{:} plot {:10s} with {:}, {:} architectures'.format(time_string(), dataset, indicator, len(selected_indexes))) | ||||||
|   ax.scatter([-1], [-1], marker='*', s=100, c='tab:red'  , label='ImageNet-16-120') |     standard_scores = [] | ||||||
|   plt.grid(zorder=0) |     random_scores = [] | ||||||
|   ax.set_axisbelow(True) |     for idx in selected_indexes: | ||||||
|   plt.legend(loc=0, fontsize=LegendFontsize) |       standard_scores.append( | ||||||
|   ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize) |         api.get_more_info(idx, dataset, hp=api.full_train_epochs if indicator == 'more' else '12', is_random=False)['test-accuracy']) | ||||||
|   ax.set_ylabel('architecture ranking', fontsize=LabelSize) |       random_scores.append( | ||||||
|   save_path = (vis_save_dir / '{:}-relative-rank.pdf'.format(indicator)).resolve() |         api.get_more_info(idx, dataset, hp=api.full_train_epochs if indicator == 'more' else '12', is_random=True)['test-accuracy']) | ||||||
|  |     indexes = list(range(len(selected_indexes))) | ||||||
|  |     standard_indexes = sorted(indexes, key=lambda i: standard_scores[i]) | ||||||
|  |     random_indexes = sorted(indexes, key=lambda i: random_scores[i]) | ||||||
|  |     random_labels = [] | ||||||
|  |     for idx in standard_indexes: | ||||||
|  |       random_labels.append(random_indexes.index(idx)) | ||||||
|  |     for tick in ax.get_xticklabels(): | ||||||
|  |       tick.set_fontsize(LabelSize - 3) | ||||||
|  |     for tick in ax.get_yticklabels(): | ||||||
|  |       tick.set_rotation(25) | ||||||
|  |       tick.set_fontsize(LabelSize - 3) | ||||||
|  |     ax.set_xlim(0, len(indexes)) | ||||||
|  |     ax.set_ylim(0, len(indexes)) | ||||||
|  |     ax.set_yticks(np.arange(min(indexes), max(indexes), max(indexes)//3)) | ||||||
|  |     ax.set_xticks(np.arange(min(indexes), max(indexes), max(indexes)//5)) | ||||||
|  |     ax.scatter(indexes, random_labels, marker='^', s=0.5, c='tab:green', alpha=0.8) | ||||||
|  |     ax.scatter(indexes, indexes      , marker='o', s=0.5, c='tab:blue' , alpha=0.8) | ||||||
|  |     ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='Average Over Multi-Trials') | ||||||
|  |     ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='Randomly Selected Trial') | ||||||
|  |  | ||||||
|  |     coef, p = scipy.stats.kendalltau(standard_scores, random_scores) | ||||||
|  |     ax.set_xlabel('architecture ranking in {:}'.format(name2label[dataset]), fontsize=LabelSize) | ||||||
|  |     if dataset == 'cifar10': | ||||||
|  |       ax.set_ylabel('architecture ranking', fontsize=LabelSize) | ||||||
|  |     ax.legend(loc=4, fontsize=LegendFontsize) | ||||||
|  |     return coef | ||||||
|  |  | ||||||
|  |   for dataset, ax in zip(datasets, axs): | ||||||
|  |     rank_coef = sub_plot_fn(ax, dataset, indicator) | ||||||
|  |     print('sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.'.format(dataset, search_space, rank_coef)) | ||||||
|  |  | ||||||
|  |   save_path = (vis_save_dir / '{:}-rank-{:}-top{:}.pdf'.format(search_space, indicator, topk)).resolve() | ||||||
|   fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') |   fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') | ||||||
|   save_path = (vis_save_dir / '{:}-relative-rank.png'.format(indicator)).resolve() |   save_path = (vis_save_dir / '{:}-rank-{:}-top{:}.png'.format(search_space, indicator, topk)).resolve() | ||||||
|   fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') |   fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') | ||||||
|   print ('{:} save into {:}'.format(time_string(), save_path)) |   print('Save into {:}'.format(save_path)) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   parser = argparse.ArgumentParser(description='NATS-Bench', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |   parser = argparse.ArgumentParser(description='NATS-Bench', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||||
|   parser.add_argument('--save_dir',    type=str, default='output/vis-nas-bench/rank-stability', help='Folder to save checkpoints and log.') |   parser.add_argument('--save_dir',    type=str, default='output/vis-nas-bench/rank-stability', help='Folder to save checkpoints and log.') | ||||||
|   # use for train the model |  | ||||||
|   args = parser.parse_args() |   args = parser.parse_args() | ||||||
|  |  | ||||||
|   to_save_dir = Path(args.save_dir) |   to_save_dir = Path(args.save_dir) | ||||||
|  |  | ||||||
|   # Figure 2 |   for topk in [1, 5, 10, 20]: | ||||||
|   visualize_relative_info(None, to_save_dir, 'tss') |     visualize_relative_info(to_save_dir, 'tss', 'more', topk) | ||||||
|   visualize_relative_info(None, to_save_dir, 'sss') |     visualize_relative_info(to_save_dir, 'sss', 'less', topk) | ||||||
|  |   print ('{:} : complete running this file : {:}'.format(time_string(), __file__)) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user