Update visualization codes for NATS-Bench
This commit is contained in:
		| @@ -25,6 +25,7 @@ from config_utils import dict2config, load_config | |||||||
| from nats_bench import create | from nats_bench import create | ||||||
| from log_utils import time_string | from log_utils import time_string | ||||||
|  |  | ||||||
|  |  | ||||||
| plt.rcParams.update({ | plt.rcParams.update({ | ||||||
|     "text.usetex": True, |     "text.usetex": True, | ||||||
|     "font.family": "sans-serif", |     "font.family": "sans-serif", | ||||||
| @@ -46,7 +47,7 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): | |||||||
|   if search_space == 'tss': |   if search_space == 'tss': | ||||||
|     hp = '$\mathcal{H}^{1}$' |     hp = '$\mathcal{H}^{1}$' | ||||||
|     if dataset == 'cifar10': |     if dataset == 'cifar10': | ||||||
|       suffixes = ['-T800000', '-T800000-FULL'] |       suffixes = ['-T1200000', '-T1200000-FULL'] | ||||||
|   elif search_space == 'sss': |   elif search_space == 'sss': | ||||||
|     hp = '$\mathcal{H}^{2}$' |     hp = '$\mathcal{H}^{2}$' | ||||||
|     if dataset == 'cifar10': |     if dataset == 'cifar10': | ||||||
| @@ -100,7 +101,7 @@ y_max_s = {('cifar10', 'tss'): 94.5, | |||||||
|            ('ImageNet16-120', 'tss'): 46, |            ('ImageNet16-120', 'tss'): 46, | ||||||
|            ('ImageNet16-120', 'sss'): 46} |            ('ImageNet16-120', 'sss'): 46} | ||||||
|  |  | ||||||
| x_axis_s = {('cifar10', 'tss'): 800000, | x_axis_s = {('cifar10', 'tss'): 1200000, | ||||||
|             ('cifar10', 'sss'): 200000, |             ('cifar10', 'sss'): 200000, | ||||||
|             ('cifar100', 'tss'): 400, |             ('cifar100', 'tss'): 400, | ||||||
|             ('cifar100', 'sss'): 400, |             ('cifar100', 'sss'): 400, | ||||||
| @@ -114,6 +115,16 @@ name2label = {'cifar10': 'CIFAR-10', | |||||||
| spaces2latex = {'tss': r'$\mathcal{S}_{t}$', | spaces2latex = {'tss': r'$\mathcal{S}_{t}$', | ||||||
|                 'sss': r'$\mathcal{S}_{s}$',} |                 'sss': r'$\mathcal{S}_{s}$',} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # FuncFormatter can be used as a decorator | ||||||
|  | @ticker.FuncFormatter | ||||||
|  | def major_formatter(x, pos): | ||||||
|  |   if x == 0: | ||||||
|  |     return '0' | ||||||
|  |   else: | ||||||
|  |     return "{:.2f}e5".format(x/1e5) | ||||||
|  |  | ||||||
|  |  | ||||||
| def visualize_curve(api_dict, vis_save_dir): | def visualize_curve(api_dict, vis_save_dir): | ||||||
|   vis_save_dir = vis_save_dir.resolve() |   vis_save_dir = vis_save_dir.resolve() | ||||||
|   vis_save_dir.mkdir(parents=True, exist_ok=True) |   vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||||
| @@ -136,6 +147,7 @@ def visualize_curve(api_dict, vis_save_dir): | |||||||
|       tick.set_fontsize(LabelSize - 6) |       tick.set_fontsize(LabelSize - 6) | ||||||
|     for tick in ax.get_yticklabels(): |     for tick in ax.get_yticklabels(): | ||||||
|       tick.set_fontsize(LabelSize - 6) |       tick.set_fontsize(LabelSize - 6) | ||||||
|  |     ax.xaxis.set_major_formatter(major_formatter) | ||||||
|     for idx, (alg, xdata) in enumerate(alg2data.items()): |     for idx, (alg, xdata) in enumerate(alg2data.items()): | ||||||
|       accuracies = [] |       accuracies = [] | ||||||
|       for ticket in time_tickets: |       for ticket in time_tickets: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user