Update visualization codes
This commit is contained in:
		| @@ -94,11 +94,11 @@ def visualize_sss_info(api, dataset, vis_save_dir): | ||||
|       params.append(info['params']) | ||||
|       flops.append(info['flops']) | ||||
|       # accuracy | ||||
|       info = api.get_more_info(index, dataset, hp='90') | ||||
|       info = api.get_more_info(index, dataset, hp='90', is_random=False) | ||||
|       train_accs.append(info['train-accuracy']) | ||||
|       test_accs.append(info['test-accuracy']) | ||||
|       if dataset == 'cifar10': | ||||
|         info = api.get_more_info(index, 'cifar10-valid', hp='90') | ||||
|         info = api.get_more_info(index, 'cifar10-valid', hp='90', is_random=False) | ||||
|         valid_accs.append(info['valid-accuracy']) | ||||
|       else: | ||||
|         valid_accs.append(info['valid-accuracy']) | ||||
| @@ -182,11 +182,11 @@ def visualize_tss_info(api, dataset, vis_save_dir): | ||||
|       params.append(info['params']) | ||||
|       flops.append(info['flops']) | ||||
|       # accuracy | ||||
|       info = api.get_more_info(index, dataset, hp='200') | ||||
|       info = api.get_more_info(index, dataset, hp='200', is_random=False) | ||||
|       train_accs.append(info['train-accuracy']) | ||||
|       test_accs.append(info['test-accuracy']) | ||||
|       if dataset == 'cifar10': | ||||
|         info = api.get_more_info(index, 'cifar10-valid', hp='200') | ||||
|         info = api.get_more_info(index, 'cifar10-valid', hp='200', is_random=False) | ||||
|         valid_accs.append(info['valid-accuracy']) | ||||
|       else: | ||||
|         valid_accs.append(info['valid-accuracy']) | ||||
| @@ -319,6 +319,68 @@ def visualize_rank_info(api, vis_save_dir, indicator): | ||||
|   plt.close('all') | ||||
|  | ||||
|  | ||||
| def calculate_correlation(*vectors): | ||||
|   matrix = [] | ||||
|   for i, vectori in enumerate(vectors): | ||||
|     x = [] | ||||
|     for j, vectorj in enumerate(vectors): | ||||
|       x.append( np.corrcoef(vectori, vectorj)[0,1] ) | ||||
|     matrix.append( x ) | ||||
|   return np.array(matrix) | ||||
|  | ||||
|  | ||||
| def visualize_all_rank_info(api, vis_save_dir, indicator): | ||||
|   vis_save_dir = vis_save_dir.resolve() | ||||
|   # print ('{:} start to visualize {:} information'.format(time_string(), api)) | ||||
|   vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|   cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator) | ||||
|   cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator) | ||||
|   imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator) | ||||
|   cifar010_info = torch.load(cifar010_cache_path) | ||||
|   cifar100_info = torch.load(cifar100_cache_path) | ||||
|   imagenet_info = torch.load(imagenet_cache_path) | ||||
|   indexes       = list(range(len(cifar010_info['params']))) | ||||
|  | ||||
|   print ('{:} start to visualize relative ranking'.format(time_string())) | ||||
|     | ||||
|  | ||||
|   dpi, width, height = 250, 3200, 1400 | ||||
|   figsize = width / float(dpi), height / float(dpi) | ||||
|   LabelSize, LegendFontsize = 14, 14 | ||||
|  | ||||
|   fig, axs = plt.subplots(1, 2, figsize=figsize) | ||||
|   ax1, ax2 = axs | ||||
|  | ||||
|   sns_size = 15 | ||||
|   CoRelMatrix = calculate_correlation(cifar010_info['valid_accs'], cifar010_info['test_accs'], cifar100_info['valid_accs'], cifar100_info['test_accs'], imagenet_info['valid_accs'], imagenet_info['test_accs']) | ||||
|    | ||||
|   sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5, ax=ax1, | ||||
|               xticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'], | ||||
|               yticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T']) | ||||
|    | ||||
|   selected_indexes, acc_bar = [], 92 | ||||
|   for i, acc in enumerate(cifar010_info['test_accs']): | ||||
|     if acc > acc_bar: selected_indexes.append( i ) | ||||
|   cifar010_valid_accs = np.array(cifar010_info['valid_accs'])[ selected_indexes ] | ||||
|   cifar010_test_accs  = np.array(cifar010_info['test_accs']) [ selected_indexes ] | ||||
|   cifar100_valid_accs = np.array(cifar100_info['valid_accs'])[ selected_indexes ] | ||||
|   cifar100_test_accs  = np.array(cifar100_info['test_accs']) [ selected_indexes ] | ||||
|   imagenet_valid_accs = np.array(imagenet_info['valid_accs'])[ selected_indexes ] | ||||
|   imagenet_test_accs  = np.array(imagenet_info['test_accs']) [ selected_indexes ] | ||||
|   CoRelMatrix = calculate_correlation(cifar010_valid_accs, cifar010_test_accs, cifar100_valid_accs, cifar100_test_accs, imagenet_valid_accs, imagenet_test_accs) | ||||
|    | ||||
|   sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5, ax=ax2, | ||||
|               xticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'], | ||||
|               yticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T']) | ||||
|   ax1.set_title('Correlation coefficient over ALL candidates') | ||||
|   ax2.set_title('Correlation coefficient over candidates with accuracy > {:}%'.format(acc_bar)) | ||||
|   save_path = (vis_save_dir / '{:}-all-relative-rank.png'.format(indicator)).resolve() | ||||
|   fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') | ||||
|   print ('{:} save into {:}'.format(time_string(), save_path)) | ||||
|   plt.close('all') | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--save_dir',    type=str, default='output/NAS-BENCH-202', help='Folder to save checkpoints and log.') | ||||
| @@ -326,20 +388,19 @@ if __name__ == '__main__': | ||||
|   # use for train the model | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   visualize_rank_info(None, Path('output/vis-nas-bench/'), 'tss') | ||||
|   visualize_rank_info(None, Path('output/vis-nas-bench/'), 'sss') | ||||
|  | ||||
|   datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] | ||||
|   api201 = NASBench201API(None, verbose=True) | ||||
|   visualize_tss_info(api201, 'cifar10', Path('output/vis-nas-bench')) | ||||
|   visualize_tss_info(api201, 'cifar100', Path('output/vis-nas-bench')) | ||||
|   visualize_tss_info(api201, 'ImageNet16-120', Path('output/vis-nas-bench')) | ||||
|   for xdata in datasets: | ||||
|     visualize_tss_info(api201, xdata, Path('output/vis-nas-bench')) | ||||
|  | ||||
|   api301 = NASBench301API(None, verbose=True) | ||||
|   visualize_sss_info(api301, 'cifar10', Path('output/vis-nas-bench')) | ||||
|   visualize_sss_info(api301, 'cifar100', Path('output/vis-nas-bench')) | ||||
|   visualize_sss_info(api301, 'ImageNet16-120', Path('output/vis-nas-bench')) | ||||
|   for xdata in datasets: | ||||
|     visualize_sss_info(api301, xdata, Path('output/vis-nas-bench')) | ||||
|  | ||||
|   visualize_info(None, Path('output/vis-nas-bench/'), 'tss') | ||||
|   visualize_info(None, Path('output/vis-nas-bench/'), 'sss') | ||||
|   visualize_rank_info(None, Path('output/vis-nas-bench/'), 'tss') | ||||
|   visualize_rank_info(None, Path('output/vis-nas-bench/'), 'sss') | ||||
|  | ||||
|   visualize_all_rank_info(None, Path('output/vis-nas-bench/'), 'tss') | ||||
|   visualize_all_rank_info(None, Path('output/vis-nas-bench/'), 'sss') | ||||
|   | ||||
		Reference in New Issue
	
	Block a user