Update visualization codes for NATS-Bench
This commit is contained in:
		| @@ -25,6 +25,7 @@ from config_utils import dict2config, load_config | ||||
| from nats_bench import create | ||||
| from log_utils import time_string | ||||
|  | ||||
|  | ||||
| plt.rcParams.update({ | ||||
|     "text.usetex": True, | ||||
|     "font.family": "sans-serif", | ||||
| @@ -46,7 +47,7 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): | ||||
|   if search_space == 'tss': | ||||
|     hp = '$\mathcal{H}^{1}$' | ||||
|     if dataset == 'cifar10': | ||||
|       suffixes = ['-T800000', '-T800000-FULL'] | ||||
|       suffixes = ['-T1200000', '-T1200000-FULL'] | ||||
|   elif search_space == 'sss': | ||||
|     hp = '$\mathcal{H}^{2}$' | ||||
|     if dataset == 'cifar10': | ||||
| @@ -100,7 +101,7 @@ y_max_s = {('cifar10', 'tss'): 94.5, | ||||
|            ('ImageNet16-120', 'tss'): 46, | ||||
|            ('ImageNet16-120', 'sss'): 46} | ||||
|  | ||||
| x_axis_s = {('cifar10', 'tss'): 800000, | ||||
| x_axis_s = {('cifar10', 'tss'): 1200000, | ||||
|             ('cifar10', 'sss'): 200000, | ||||
|             ('cifar100', 'tss'): 400, | ||||
|             ('cifar100', 'sss'): 400, | ||||
| @@ -114,6 +115,16 @@ name2label = {'cifar10': 'CIFAR-10', | ||||
| spaces2latex = {'tss': r'$\mathcal{S}_{t}$', | ||||
|                 'sss': r'$\mathcal{S}_{s}$',} | ||||
|  | ||||
|  | ||||
| # FuncFormatter can be used as a decorator | ||||
| @ticker.FuncFormatter | ||||
| def major_formatter(x, pos): | ||||
|   if x == 0: | ||||
|     return '0' | ||||
|   else: | ||||
|     return "{:.2f}e5".format(x/1e5) | ||||
|  | ||||
|  | ||||
| def visualize_curve(api_dict, vis_save_dir): | ||||
|   vis_save_dir = vis_save_dir.resolve() | ||||
|   vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
| @@ -136,6 +147,7 @@ def visualize_curve(api_dict, vis_save_dir): | ||||
|       tick.set_fontsize(LabelSize - 6) | ||||
|     for tick in ax.get_yticklabels(): | ||||
|       tick.set_fontsize(LabelSize - 6) | ||||
|     ax.xaxis.set_major_formatter(major_formatter) | ||||
|     for idx, (alg, xdata) in enumerate(alg2data.items()): | ||||
|       accuracies = [] | ||||
|       for ticket in time_tickets: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user