Update visualization codes for NATS-Bench
This commit is contained in:
		| @@ -43,20 +43,14 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): | ||||
|   # alg2name['REINFORCE'] = 'REINFORCE-0.01' | ||||
|   # alg2name['RANDOM'] = 'RANDOM' | ||||
|   # alg2name['BOHB'] = 'BOHB' | ||||
|   if dataset == 'cifar10': | ||||
|     suffixes = ['-T200000', '-T200000-FULL'] | ||||
|   elif dataset == 'cifar100': | ||||
|     suffixes = ['-T40000', '-T40000-FULL'] | ||||
|   elif search_space == 'tss': | ||||
|     suffixes = ['-T120000', '-T120000-FULL'] | ||||
|   elif search_space == 'sss': | ||||
|     suffixes = ['-T60000', '-T60000-FULL'] | ||||
|   else: | ||||
|     raise ValueError('Unkonwn dataset : {:}'.format(dataset)) | ||||
|   if search_space == 'tss': | ||||
|     hp = '$\mathcal{H}^{1}$' | ||||
|     if dataset == 'cifar10': | ||||
|       suffixes = ['-T800000', '-T800000-FULL'] | ||||
|   elif search_space == 'sss': | ||||
|     hp = '$\mathcal{H}^{2}$' | ||||
|     if dataset == 'cifar10': | ||||
|       suffixes = ['-T200000', '-T200000-FULL'] | ||||
|   else: | ||||
|     raise ValueError('Unkonwn search space: {:}'.format(search_space)) | ||||
|  | ||||
| @@ -92,21 +86,21 @@ def query_performance(api, data, dataset, ticket): | ||||
|   return np.mean(results), np.std(results) | ||||
|  | ||||
|  | ||||
| y_min_s = {('cifar10', 'tss'): 90, | ||||
|            ('cifar10', 'sss'): 90, | ||||
| y_min_s = {('cifar10', 'tss'): 91, | ||||
|            ('cifar10', 'sss'): 91, | ||||
|            ('cifar100', 'tss'): 65, | ||||
|            ('cifar100', 'sss'): 65, | ||||
|            ('ImageNet16-120', 'tss'): 36, | ||||
|            ('ImageNet16-120', 'sss'): 40} | ||||
|  | ||||
| y_max_s = {('cifar10', 'tss'): 94.5, | ||||
|            ('cifar10', 'sss'): 94.5, | ||||
|            ('cifar10', 'sss'): 93.5, | ||||
|            ('cifar100', 'tss'): 72.5, | ||||
|            ('cifar100', 'sss'): 70.5, | ||||
|            ('ImageNet16-120', 'tss'): 46, | ||||
|            ('ImageNet16-120', 'sss'): 46} | ||||
|  | ||||
| x_axis_s = {('cifar10', 'tss'): 200000, | ||||
| x_axis_s = {('cifar10', 'tss'): 800000, | ||||
|             ('cifar10', 'sss'): 200000, | ||||
|             ('cifar100', 'tss'): 400, | ||||
|             ('cifar100', 'sss'): 400, | ||||
| @@ -124,9 +118,9 @@ def visualize_curve(api_dict, vis_save_dir): | ||||
|   vis_save_dir = vis_save_dir.resolve() | ||||
|   vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|   dpi, width, height = 250, 4000, 2400 | ||||
|   dpi, width, height = 250, 5000, 2000 | ||||
|   figsize = width / float(dpi), height / float(dpi) | ||||
|   LabelSize, LegendFontsize = 16, 16 | ||||
|   LabelSize, LegendFontsize = 28, 28 | ||||
|  | ||||
|   def sub_plot_fn(ax, search_space, dataset): | ||||
|     max_time = x_axis_s[(dataset, search_space)] | ||||
| @@ -137,6 +131,11 @@ def visualize_curve(api_dict, vis_save_dir): | ||||
|     ax.set_xlim(0, x_axis_s[(dataset, search_space)]) | ||||
|     ax.set_ylim(y_min_s[(dataset, search_space)], | ||||
|                 y_max_s[(dataset, search_space)]) | ||||
|     for tick in ax.get_xticklabels(): | ||||
|       tick.set_rotation(25) | ||||
|       tick.set_fontsize(LabelSize - 6) | ||||
|     for tick in ax.get_yticklabels(): | ||||
|       tick.set_fontsize(LabelSize - 6) | ||||
|     for idx, (alg, xdata) in enumerate(alg2data.items()): | ||||
|       accuracies = [] | ||||
|       for ticket in time_tickets: | ||||
| @@ -150,8 +149,8 @@ def visualize_curve(api_dict, vis_save_dir): | ||||
|       ax.plot(time_tickets, accuracies, c=xdata['color'], linestyle=xdata['linestyle'], label='{:}'.format(alg)) | ||||
|       ax.set_xlabel('Estimated wall-clock time', fontsize=LabelSize) | ||||
|       ax.set_ylabel('Test accuracy', fontsize=LabelSize) | ||||
|       ax.set_title(r'Searching results on {:} for {:}'.format(name2label[dataset], spaces2latex[search_space]), | ||||
|         fontsize=LabelSize+4) | ||||
|       ax.set_title(r'Results on {:} over {:}'.format(name2label[dataset], spaces2latex[search_space]), | ||||
|         fontsize=LabelSize) | ||||
|     ax.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|   fig, axs = plt.subplots(1, 2, figsize=figsize) | ||||
| @@ -165,7 +164,7 @@ def visualize_curve(api_dict, vis_save_dir): | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   parser = argparse.ArgumentParser(description='NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--save_dir',     type=str,   default='output/vis-nas-bench/nas-algos-vs-h', help='Folder to save checkpoints and log.') | ||||
|   parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos-vs-h', help='Folder to save checkpoints and log.') | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   save_dir = Path(args.save_dir) | ||||
|   | ||||
| @@ -11,7 +11,7 @@ import scipy | ||||
| import numpy as np | ||||
| from typing import List, Text, Dict, Any | ||||
| from shutil import copyfile | ||||
| from collections import defaultdict | ||||
| from collections import defaultdict, OrderedDict | ||||
| from copy    import deepcopy | ||||
| from pathlib import Path | ||||
| import matplotlib | ||||
| @@ -28,69 +28,103 @@ from models import get_cell_based_tiny_net | ||||
| from nats_bench import create | ||||
|  | ||||
|  | ||||
| def visualize_relative_info(api, vis_save_dir, indicator): | ||||
| name2label = {'cifar10': 'CIFAR-10', | ||||
|               'cifar100': 'CIFAR-100', | ||||
|               'ImageNet16-120': 'ImageNet-16-120'} | ||||
|  | ||||
|  | ||||
| def visualize_relative_info(vis_save_dir, search_space, indicator, topk): | ||||
|   vis_save_dir = vis_save_dir.resolve() | ||||
|   # print ('{:} start to visualize {:} information'.format(time_string(), api)) | ||||
|   print ('{:} start to visualize {:} with top-{:} information'.format(time_string(), search_space, topk)) | ||||
|   vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|   cache_file_path = vis_save_dir / 'cache-{:}-info.pth'.format(search_space) | ||||
|   datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] | ||||
|   if not cache_file_path.exists(): | ||||
|     api = create(None, search_space, fast_mode=False, verbose=False) | ||||
|     all_infos = OrderedDict() | ||||
|     for index in range(len(api)): | ||||
|       all_info = OrderedDict() | ||||
|       for dataset in datasets: | ||||
|         info_less = api.get_more_info(index, dataset, hp='12', is_random=False) | ||||
|         info_more = api.get_more_info(index, dataset, hp=api.full_train_epochs, is_random=False) | ||||
|         all_info[dataset] = dict(less=info_less['test-accuracy'], | ||||
|                                  more=info_more['test-accuracy']) | ||||
|       all_infos[index] = all_info | ||||
|     torch.save(all_infos, cache_file_path) | ||||
|     print ('{:} save all cache data into {:}'.format(time_string(), cache_file_path)) | ||||
|   else: | ||||
|     api = create(None, search_space, fast_mode=True, verbose=False) | ||||
|     all_infos = torch.load(cache_file_path) | ||||
|  | ||||
|   cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator) | ||||
|   cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator) | ||||
|   imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator) | ||||
|   cifar010_info = torch.load(cifar010_cache_path) | ||||
|   cifar100_info = torch.load(cifar100_cache_path) | ||||
|   imagenet_info = torch.load(imagenet_cache_path) | ||||
|   indexes       = list(range(len(cifar010_info['params']))) | ||||
|  | ||||
|   print ('{:} start to visualize relative ranking'.format(time_string())) | ||||
|  | ||||
|   cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i]) | ||||
|   cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i]) | ||||
|   imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i]) | ||||
|  | ||||
|   cifar100_labels, imagenet_labels = [], [] | ||||
|   for idx in cifar010_ord_indexes: | ||||
|     cifar100_labels.append( cifar100_ord_indexes.index(idx) ) | ||||
|     imagenet_labels.append( imagenet_ord_indexes.index(idx) ) | ||||
|   print ('{:} prepare data done.'.format(time_string())) | ||||
|  | ||||
|   dpi, width, height = 200, 1400,  800 | ||||
|   dpi, width, height = 250, 5000, 1300 | ||||
|   figsize = width / float(dpi), height / float(dpi) | ||||
|   LabelSize, LegendFontsize = 18, 12 | ||||
|   resnet_scale, resnet_alpha = 120, 0.5 | ||||
|   LabelSize, LegendFontsize = 16, 16 | ||||
|  | ||||
|   fig = plt.figure(figsize=figsize) | ||||
|   ax  = fig.add_subplot(111) | ||||
|   plt.xlim(min(indexes), max(indexes)) | ||||
|   plt.ylim(min(indexes), max(indexes)) | ||||
|   # plt.ylabel('y').set_rotation(30) | ||||
|   plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//3), fontsize=LegendFontsize, rotation='vertical') | ||||
|   plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize) | ||||
|   ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8) | ||||
|   ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red'  , alpha=0.8) | ||||
|   ax.scatter(indexes, indexes        , marker='o', s=0.5, c='tab:blue' , alpha=0.8) | ||||
|   ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10') | ||||
|   ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100') | ||||
|   ax.scatter([-1], [-1], marker='*', s=100, c='tab:red'  , label='ImageNet-16-120') | ||||
|   plt.grid(zorder=0) | ||||
|   ax.set_axisbelow(True) | ||||
|   plt.legend(loc=0, fontsize=LegendFontsize) | ||||
|   ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize) | ||||
|   ax.set_ylabel('architecture ranking', fontsize=LabelSize) | ||||
|   save_path = (vis_save_dir / '{:}-relative-rank.pdf'.format(indicator)).resolve() | ||||
|   fig, axs = plt.subplots(1, 3, figsize=figsize) | ||||
|   datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] | ||||
|    | ||||
|   def sub_plot_fn(ax, dataset, indicator): | ||||
|     performances = [] | ||||
|     # pickup top 10% architectures | ||||
|     for _index in range(len(api)): | ||||
|       performances.append((all_infos[_index][dataset][indicator], _index)) | ||||
|     performances = sorted(performances, reverse=True) | ||||
|     performances = performances[: int(len(api) * topk * 0.01)] | ||||
|     selected_indexes = [x[1] for x in performances] | ||||
|     print('{:} plot {:10s} with {:}, {:} architectures'.format(time_string(), dataset, indicator, len(selected_indexes))) | ||||
|     standard_scores = [] | ||||
|     random_scores = [] | ||||
|     for idx in selected_indexes: | ||||
|       standard_scores.append( | ||||
|         api.get_more_info(idx, dataset, hp=api.full_train_epochs if indicator == 'more' else '12', is_random=False)['test-accuracy']) | ||||
|       random_scores.append( | ||||
|         api.get_more_info(idx, dataset, hp=api.full_train_epochs if indicator == 'more' else '12', is_random=True)['test-accuracy']) | ||||
|     indexes = list(range(len(selected_indexes))) | ||||
|     standard_indexes = sorted(indexes, key=lambda i: standard_scores[i]) | ||||
|     random_indexes = sorted(indexes, key=lambda i: random_scores[i]) | ||||
|     random_labels = [] | ||||
|     for idx in standard_indexes: | ||||
|       random_labels.append(random_indexes.index(idx)) | ||||
|     for tick in ax.get_xticklabels(): | ||||
|       tick.set_fontsize(LabelSize - 3) | ||||
|     for tick in ax.get_yticklabels(): | ||||
|       tick.set_rotation(25) | ||||
|       tick.set_fontsize(LabelSize - 3) | ||||
|     ax.set_xlim(0, len(indexes)) | ||||
|     ax.set_ylim(0, len(indexes)) | ||||
|     ax.set_yticks(np.arange(min(indexes), max(indexes), max(indexes)//3)) | ||||
|     ax.set_xticks(np.arange(min(indexes), max(indexes), max(indexes)//5)) | ||||
|     ax.scatter(indexes, random_labels, marker='^', s=0.5, c='tab:green', alpha=0.8) | ||||
|     ax.scatter(indexes, indexes      , marker='o', s=0.5, c='tab:blue' , alpha=0.8) | ||||
|     ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='Average Over Multi-Trials') | ||||
|     ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='Randomly Selected Trial') | ||||
|  | ||||
|     coef, p = scipy.stats.kendalltau(standard_scores, random_scores) | ||||
|     ax.set_xlabel('architecture ranking in {:}'.format(name2label[dataset]), fontsize=LabelSize) | ||||
|     if dataset == 'cifar10': | ||||
|       ax.set_ylabel('architecture ranking', fontsize=LabelSize) | ||||
|     ax.legend(loc=4, fontsize=LegendFontsize) | ||||
|     return coef | ||||
|  | ||||
|   for dataset, ax in zip(datasets, axs): | ||||
|     rank_coef = sub_plot_fn(ax, dataset, indicator) | ||||
|     print('sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.'.format(dataset, search_space, rank_coef)) | ||||
|  | ||||
|   save_path = (vis_save_dir / '{:}-rank-{:}-top{:}.pdf'.format(search_space, indicator, topk)).resolve() | ||||
|   fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') | ||||
|   save_path = (vis_save_dir / '{:}-relative-rank.png'.format(indicator)).resolve() | ||||
|   save_path = (vis_save_dir / '{:}-rank-{:}-top{:}.png'.format(search_space, indicator, topk)).resolve() | ||||
|   fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') | ||||
|   print ('{:} save into {:}'.format(time_string(), save_path)) | ||||
|   print('Save into {:}'.format(save_path)) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   parser = argparse.ArgumentParser(description='NATS-Bench', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--save_dir',    type=str, default='output/vis-nas-bench/rank-stability', help='Folder to save checkpoints and log.') | ||||
|   # use for train the model | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   to_save_dir = Path(args.save_dir) | ||||
|  | ||||
|   # Figure 2 | ||||
|   visualize_relative_info(None, to_save_dir, 'tss') | ||||
|   visualize_relative_info(None, to_save_dir, 'sss') | ||||
|   for topk in [1, 5, 10, 20]: | ||||
|     visualize_relative_info(to_save_dir, 'tss', 'more', topk) | ||||
|     visualize_relative_info(to_save_dir, 'sss', 'less', topk) | ||||
|   print ('{:} : complete running this file : {:}'.format(time_string(), __file__)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user