update README
This commit is contained in:
		| @@ -39,7 +39,7 @@ At the moment, this project provides the following algorithms and scripts to run | |||||||
|     <tr> <!-- (2-nd row) --> |     <tr> <!-- (2-nd row) --> | ||||||
|     <td align="center" valign="middle"> DARTS </td> |     <td align="center" valign="middle"> DARTS </td> | ||||||
|     <td align="center" valign="middle"> DARTS: Differentiable Architecture Search </td> |     <td align="center" valign="middle"> DARTS: Differentiable Architecture Search </td> | ||||||
|     <td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/master/docs/NAS-Bench-201.md">NAS-Bench-201.md</a> </td> |     <td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/master/docs/ICLR-2019-DARTS.md">ICLR-2019-DARTS.md</a> </td> | ||||||
|     </tr> |     </tr> | ||||||
|     <tr> <!-- (3-nd row) --> |     <tr> <!-- (3-nd row) --> | ||||||
|     <td align="center" valign="middle"> GDAS </td> |     <td align="center" valign="middle"> GDAS </td> | ||||||
|   | |||||||
							
								
								
									
										22
									
								
								docs/ICLR-2019-DARTS.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								docs/ICLR-2019-DARTS.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | |||||||
|  | # DARTS: Differentiable Architecture Search | ||||||
|  |  | ||||||
|  | DARTS: Differentiable Architecture Search is accepted by ICLR 2019. | ||||||
|  | In this paper, Hanxiao proposed a differentiable neural architecture search method, named as DARTS. | ||||||
|  | Recently, DARTS becomes very popular due to its simplicity and performance. | ||||||
|  |  | ||||||
|  | **Run DARTS on the NAS-Bench-201 search space**: | ||||||
|  | ``` | ||||||
|  | CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 1 -1 | ||||||
|  | CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh     cifar10 1 -1 | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | # Citation | ||||||
|  |  | ||||||
|  | ``` | ||||||
|  | @inproceedings{liu2019darts, | ||||||
|  |   title     = {{DARTS}: Differentiable architecture search}, | ||||||
|  |   author    = {Liu, Hanxiao and Simonyan, Karen and Yang, Yiming}, | ||||||
|  |   booktitle = {International Conference on Learning Representations (ICLR)}, | ||||||
|  |   year      = {2019} | ||||||
|  | } | ||||||
|  | ``` | ||||||
| @@ -181,7 +181,7 @@ If researchers can provide better results with different hyper-parameters, we ar | |||||||
| - [4] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh     cifar10 1 -1` | - [4] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh     cifar10 1 -1` | ||||||
| - [5] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/ENAS.sh     cifar10 1 -1` | - [5] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/ENAS.sh     cifar10 1 -1` | ||||||
| - [6] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 1 -1` | - [6] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 1 -1` | ||||||
| - [7] `bash ./scripts-search/algos/R-EA.sh -1` | - [7] `bash ./scripts-search/algos/R-EA.sh cifar10 3 -1` | ||||||
| - [8] `bash ./scripts-search/algos/Random.sh -1` | - [8] `bash ./scripts-search/algos/Random.sh -1` | ||||||
| - [9] `bash ./scripts-search/algos/REINFORCE.sh 0.5 -1` | - [9] `bash ./scripts-search/algos/REINFORCE.sh 0.5 -1` | ||||||
| - [10] `bash ./scripts-search/algos/BOHB.sh -1` | - [10] `bash ./scripts-search/algos/BOHB.sh -1` | ||||||
|   | |||||||
| @@ -517,7 +517,7 @@ def just_show(api): | |||||||
|     print ('[{:10s}-{:10s} ::: index={:5d}, accuracy={:.2f}'.format(dataset, metric_on_set, arch_index, highest_acc)) |     print ('[{:10s}-{:10s} ::: index={:5d}, accuracy={:.2f}'.format(dataset, metric_on_set, arch_index, highest_acc)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_maxs): | def show_nas_sharing_w(api, dataset, subset, vis_save_dir, sufix, file_name, y_lims, x_maxs): | ||||||
|   color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] |   color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] | ||||||
|   dpi, width, height = 300, 3400, 2600 |   dpi, width, height = 300, 3400, 2600 | ||||||
|   LabelSize, LegendFontsize = 28, 28 |   LabelSize, LegendFontsize = 28, 28 | ||||||
| @@ -533,13 +533,14 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_ | |||||||
|   plt.xlabel('The searching epoch', fontsize=LabelSize) |   plt.xlabel('The searching epoch', fontsize=LabelSize) | ||||||
|   plt.ylabel('The accuracy (%)', fontsize=LabelSize) |   plt.ylabel('The accuracy (%)', fontsize=LabelSize) | ||||||
|  |  | ||||||
|   xpaths = {'RSPS'    : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/', |   xpaths = {'RSPS'    : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10-{:}/checkpoint/'.format(sufix),  | ||||||
|             'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/', |             'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10-{:}/checkpoint/'.format(sufix), | ||||||
|             'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/', |             'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10-{:}/checkpoint/'.format(sufix), | ||||||
|             'GDAS'    : 'output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/', |             'GDAS'    : 'output/search-cell-nas-bench-201/GDAS-cifar10-{:}/checkpoint/'.format(sufix), | ||||||
|             'SETN'    : 'output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/', |             'SETN'    : 'output/search-cell-nas-bench-201/SETN-cifar10-{:}/checkpoint/'.format(sufix), | ||||||
|             'ENAS'    : 'output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/', |             'ENAS'    : 'output/search-cell-nas-bench-201/ENAS-cifar10-{:}/checkpoint/'.format(sufix), | ||||||
|            } |            } | ||||||
|  |   """ | ||||||
|   xseeds = {'RSPS'    : [5349, 59613, 5983], |   xseeds = {'RSPS'    : [5349, 59613, 5983], | ||||||
|             'DARTS-V1': [11416, 72873, 81184, 28640], |             'DARTS-V1': [11416, 72873, 81184, 28640], | ||||||
|             'DARTS-V2': [43330, 79405, 79423], |             'DARTS-V2': [43330, 79405, 79423], | ||||||
| @@ -547,6 +548,15 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_ | |||||||
|             'SETN'    : [20518, 61817, 89144], |             'SETN'    : [20518, 61817, 89144], | ||||||
|             'ENAS'    : [3231, 34238, 96929], |             'ENAS'    : [3231, 34238, 96929], | ||||||
|            } |            } | ||||||
|  |   """ | ||||||
|  |   xseeds = {'RSPS'    : [23814, 28015, 95809], | ||||||
|  |             'DARTS-V1': [48349, 80877, 81920], | ||||||
|  |             'DARTS-V2': [61712, 7941 , 87041] , | ||||||
|  |             'GDAS'    : [72818, 72996, 78877], | ||||||
|  |             'SETN'    : [26985, 55206, 95404], | ||||||
|  |             'ENAS'    : [21792, 36605, 45029] | ||||||
|  |            } | ||||||
|  |  | ||||||
|  |  | ||||||
|   def get_accs(xdata): |   def get_accs(xdata): | ||||||
|     epochs, xresults = xdata['epoch'], [] |     epochs, xresults = xdata['epoch'], [] | ||||||
| @@ -579,12 +589,13 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_ | |||||||
|     plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color_set[idx]) |     plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color_set[idx]) | ||||||
|   #plt.legend(loc=4, fontsize=LegendFontsize) |   #plt.legend(loc=4, fontsize=LegendFontsize) | ||||||
|   plt.legend(loc=0, fontsize=LegendFontsize) |   plt.legend(loc=0, fontsize=LegendFontsize) | ||||||
|   save_path = vis_save_dir / '{:}-{:}-{:}-{:}'.format(xox, dataset, subset, file_name) |   save_path = vis_save_dir / '{:}.pdf'.format(file_name) | ||||||
|   print('save figure into {:}\n'.format(save_path)) |   print('save figure into {:}\n'.format(save_path)) | ||||||
|   fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') |   fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') | ||||||
|  |  | ||||||
|  |  | ||||||
| def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, file_name, y_lims, x_maxs): |  | ||||||
|  | def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, sufix, file_name, y_lims, x_maxs): | ||||||
|   color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] |   color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] | ||||||
|   dpi, width, height = 300, 3400, 2600 |   dpi, width, height = 300, 3400, 2600 | ||||||
|   LabelSize, LegendFontsize = 28, 28 |   LabelSize, LegendFontsize = 28, 28 | ||||||
| @@ -600,13 +611,14 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, file_name, | |||||||
|   plt.xlabel('The searching epoch', fontsize=LabelSize) |   plt.xlabel('The searching epoch', fontsize=LabelSize) | ||||||
|   plt.ylabel('The accuracy (%)', fontsize=LabelSize) |   plt.ylabel('The accuracy (%)', fontsize=LabelSize) | ||||||
|  |  | ||||||
|   xpaths = {'RSPS'    : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/', |   xpaths = {'RSPS'    : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10-{:}/checkpoint/'.format(sufix),  | ||||||
|             'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/', |             'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10-{:}/checkpoint/'.format(sufix), | ||||||
|             'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/', |             'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10-{:}/checkpoint/'.format(sufix), | ||||||
|             'GDAS'    : 'output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/', |             'GDAS'    : 'output/search-cell-nas-bench-201/GDAS-cifar10-{:}/checkpoint/'.format(sufix), | ||||||
|             'SETN'    : 'output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/', |             'SETN'    : 'output/search-cell-nas-bench-201/SETN-cifar10-{:}/checkpoint/'.format(sufix), | ||||||
|             'ENAS'    : 'output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/', |             'ENAS'    : 'output/search-cell-nas-bench-201/ENAS-cifar10-{:}/checkpoint/'.format(sufix), | ||||||
|            } |            } | ||||||
|  |   """ | ||||||
|   xseeds = {'RSPS'    : [5349, 59613, 5983], |   xseeds = {'RSPS'    : [5349, 59613, 5983], | ||||||
|             'DARTS-V1': [11416, 72873, 81184, 28640], |             'DARTS-V1': [11416, 72873, 81184, 28640], | ||||||
|             'DARTS-V2': [43330, 79405, 79423], |             'DARTS-V2': [43330, 79405, 79423], | ||||||
| @@ -614,6 +626,15 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, file_name, | |||||||
|             'SETN'    : [20518, 61817, 89144], |             'SETN'    : [20518, 61817, 89144], | ||||||
|             'ENAS'    : [3231, 34238, 96929], |             'ENAS'    : [3231, 34238, 96929], | ||||||
|            } |            } | ||||||
|  |   """ | ||||||
|  |   xseeds = {'RSPS'    : [23814, 28015, 95809], | ||||||
|  |             'DARTS-V1': [48349, 80877, 81920], | ||||||
|  |             'DARTS-V2': [61712, 7941 , 87041] , | ||||||
|  |             'GDAS'    : [72818, 72996, 78877], | ||||||
|  |             'SETN'    : [26985, 55206, 95404], | ||||||
|  |             'ENAS'    : [21792, 36605, 45029] | ||||||
|  |            } | ||||||
|  |  | ||||||
|  |  | ||||||
|   def get_accs(xdata, dataset, subset): |   def get_accs(xdata, dataset, subset): | ||||||
|     epochs, xresults = xdata['epoch'], [] |     epochs, xresults = xdata['epoch'], [] | ||||||
| @@ -643,8 +664,15 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, file_name, | |||||||
|     accyss_B = np.array( [get_accs(xdatas, data_sub_b[0], data_sub_b[1]) for xdatas in all_datas] ) |     accyss_B = np.array( [get_accs(xdatas, data_sub_b[0], data_sub_b[1]) for xdatas in all_datas] ) | ||||||
|     epochs = list(range(accyss_A.shape[1])) |     epochs = list(range(accyss_A.shape[1])) | ||||||
|     for j, accyss in enumerate([accyss_A, accyss_B]): |     for j, accyss in enumerate([accyss_A, accyss_B]): | ||||||
|       plt.plot(epochs, [accyss[:,i].mean() for i in epochs], color=color_set[idx*2+j], linestyle='-' if j==0 else '--', label='{:} ({:})'.format(method, 'VALID' if j == 0 else 'TEST'), lw=2, alpha=0.9) |       if x_maxs == 50: | ||||||
|       plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color_set[idx*2+j]) |         color, line = color_set[idx*2+j], '-' if j==0 else '--' | ||||||
|  |       elif x_maxs == 250: | ||||||
|  |         color, line = color_set[idx], '-' if j==0 else '--' | ||||||
|  |       else: raise ValueError('invalid x-maxs={:}'.format(x_maxs)) | ||||||
|  |       plt.plot(epochs, [accyss[:,i].mean() for i in epochs], color=color, linestyle=line, label='{:} ({:})'.format(method, 'VALID' if j == 0 else 'TEST'), lw=2, alpha=0.9) | ||||||
|  |       plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color) | ||||||
|  |       setname = data_sub_a if j == 0 else data_sub_b | ||||||
|  |       print('{:} -- {:} ---- {:.2f}$\\pm${:.2f}'.format(method, setname, accyss[:,-1].mean(), accyss[:,-1].std())) | ||||||
|   #plt.legend(loc=4, fontsize=LegendFontsize) |   #plt.legend(loc=4, fontsize=LegendFontsize) | ||||||
|   plt.legend(loc=0, fontsize=LegendFontsize) |   plt.legend(loc=0, fontsize=LegendFontsize) | ||||||
|   save_path = vis_save_dir / '{:}-{:}'.format(xox, file_name) |   save_path = vis_save_dir / '{:}-{:}'.format(xox, file_name) | ||||||
| @@ -654,7 +682,7 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, file_name, | |||||||
|  |  | ||||||
| def show_reinforce(api, root, dataset, xset, file_name, y_lims): | def show_reinforce(api, root, dataset, xset, file_name, y_lims): | ||||||
|   print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset)) |   print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset)) | ||||||
|   LRs = ['0.01', '0.02', '0.1', '0.2', '0.5', '1.0', '1.5', '2.0', '2.5', '3.0'] |   LRs = ['0.01', '0.02', '0.1', '0.2', '0.5'] | ||||||
|   checkpoints = ['./output/search-cell-nas-bench-201/REINFORCE-cifar10-{:}/results.pth'.format(x) for x in LRs] |   checkpoints = ['./output/search-cell-nas-bench-201/REINFORCE-cifar10-{:}/results.pth'.format(x) for x in LRs] | ||||||
|   acc_lr_dict, indexes = {}, None |   acc_lr_dict, indexes = {}, None | ||||||
|   for lr, checkpoint in zip(LRs, checkpoints): |   for lr, checkpoint in zip(LRs, checkpoints): | ||||||
| @@ -684,7 +712,8 @@ def show_reinforce(api, root, dataset, xset, file_name, y_lims): | |||||||
|  |  | ||||||
|   for idx, LR in enumerate(LRs): |   for idx, LR in enumerate(LRs): | ||||||
|     legend = 'LR={:.2f}'.format(float(LR)) |     legend = 'LR={:.2f}'.format(float(LR)) | ||||||
|     color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.' |     #color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.' | ||||||
|  |     color, linestyle = color_set[idx], '-' | ||||||
|     plt.plot(indexes, acc_lr_dict[LR], color=color, linestyle=linestyle, label=legend, lw=2, alpha=0.8) |     plt.plot(indexes, acc_lr_dict[LR], color=color, linestyle=linestyle, label=legend, lw=2, alpha=0.8) | ||||||
|     print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(acc_lr_dict[LR]), np.std(acc_lr_dict[LR]), np.mean(acc_lr_dict[LR]), np.std(acc_lr_dict[LR]))) |     print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(acc_lr_dict[LR]), np.std(acc_lr_dict[LR]), np.mean(acc_lr_dict[LR]), np.std(acc_lr_dict[LR]))) | ||||||
|   plt.legend(loc=4, fontsize=LegendFontsize) |   plt.legend(loc=4, fontsize=LegendFontsize) | ||||||
| @@ -693,6 +722,49 @@ def show_reinforce(api, root, dataset, xset, file_name, y_lims): | |||||||
|   fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') |   fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def show_rea(api, root, dataset, xset, file_name, y_lims): | ||||||
|  |   print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset)) | ||||||
|  |   SSs = [3, 5, 10] | ||||||
|  |   checkpoints = ['./output/search-cell-nas-bench-201/R-EA-cifar10-SS{:}/results.pth'.format(x) for x in SSs] | ||||||
|  |   acc_ss_dict, indexes = {}, None | ||||||
|  |   for ss, checkpoint in zip(SSs, checkpoints): | ||||||
|  |     all_indexes, accuracies = torch.load(checkpoint, map_location='cpu'), [] | ||||||
|  |     for x in all_indexes: | ||||||
|  |       info = api.arch2infos_full[ x ] | ||||||
|  |       metrics = info.get_metrics(dataset, xset, None, False) | ||||||
|  |       accuracies.append( metrics['accuracy'] ) | ||||||
|  |     if indexes is None: indexes = list(range(len(accuracies))) | ||||||
|  |     acc_ss_dict[ss] = np.array( sorted(accuracies) ) | ||||||
|  |     print ('Sample-Size={:2d}, mean={:}, std={:}'.format(ss, acc_ss_dict[ss].mean(), acc_ss_dict[ss].std())) | ||||||
|  |    | ||||||
|  |   color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] | ||||||
|  |   dpi, width, height = 300, 3400, 2600 | ||||||
|  |   LabelSize, LegendFontsize = 28, 22 | ||||||
|  |   figsize = width / float(dpi), height / float(dpi) | ||||||
|  |   fig = plt.figure(figsize=figsize) | ||||||
|  |   x_axis = np.arange(0, 600) | ||||||
|  |   plt.xlim(0, max(indexes)) | ||||||
|  |   plt.ylim(y_lims[0], y_lims[1]) | ||||||
|  |   interval_x, interval_y = 100, y_lims[2] | ||||||
|  |   plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize) | ||||||
|  |   plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize) | ||||||
|  |   plt.grid() | ||||||
|  |   plt.xlabel('The index of runs', fontsize=LabelSize) | ||||||
|  |   plt.ylabel('The accuracy (%)', fontsize=LabelSize) | ||||||
|  |  | ||||||
|  |   for idx, ss in enumerate(SSs): | ||||||
|  |     legend = 'sample-size={:2d}'.format(ss) | ||||||
|  |     #color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.' | ||||||
|  |     color, linestyle = color_set[idx], '-' | ||||||
|  |     plt.plot(indexes, acc_ss_dict[ss], color=color, linestyle=linestyle, label=legend, lw=2, alpha=0.8) | ||||||
|  |     print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(acc_ss_dict[ss]), np.std(acc_ss_dict[ss]), np.mean(acc_ss_dict[ss]), np.std(acc_ss_dict[ss]))) | ||||||
|  |   plt.legend(loc=4, fontsize=LegendFontsize) | ||||||
|  |   save_path = root / '{:}-{:}-{:}.pdf'.format(dataset, xset, file_name) | ||||||
|  |   print('save figure into {:}\n'.format(save_path)) | ||||||
|  |   fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|  |  | ||||||
|   parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |   parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||||
| @@ -712,9 +784,25 @@ if __name__ == '__main__': | |||||||
|   #visualize_relative_ranking(vis_save_dir) |   #visualize_relative_ranking(vis_save_dir) | ||||||
|  |  | ||||||
|   api = API(args.api_path) |   api = API(args.api_path) | ||||||
|   show_reinforce(api, vis_save_dir, 'cifar10-valid' , 'x-valid', 'REINFORCE-CIFAR-10', (75, 95, 5)) |   #show_reinforce(api, vis_save_dir, 'cifar10-valid' , 'x-valid', 'REINFORCE-CIFAR-10', (85, 92, 2)) | ||||||
|   import pdb; pdb.set_trace() |   #show_rea      (api, vis_save_dir, 'cifar10-valid' , 'x-valid', 'REA-CIFAR-10', (88, 92, 1)) | ||||||
|  |  | ||||||
|  |   #plot_results_nas_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10'       , 'ori-test'), vis_save_dir, 'nas-com-v2-cifar010.pdf', (85,95, 1)) | ||||||
|  |   #plot_results_nas_v2(api, ('cifar100'      , 'x-valid'), ('cifar100'      , 'x-test'  ), vis_save_dir, 'nas-com-v2-cifar100.pdf', (60,75, 3)) | ||||||
|  |   #plot_results_nas_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test'  ), vis_save_dir, 'nas-com-v2-imagenet.pdf', (35,48, 2)) | ||||||
|  |  | ||||||
|  |   show_nas_sharing_w_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10'       , 'ori-test') , vis_save_dir, 'BN0', 'BN0-DARTS-CIFAR010.pdf', (0, 100,10), 50) | ||||||
|  |   show_nas_sharing_w_v2(api, ('cifar100'      , 'x-valid'), ('cifar100'      , 'x-test'  ) , vis_save_dir, 'BN0', 'BN0-DARTS-CIFAR100.pdf', (0, 100,10), 50) | ||||||
|  |   show_nas_sharing_w_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test'  ) , vis_save_dir, 'BN0', 'BN0-DARTS-ImageNet.pdf', (0, 100,10), 50) | ||||||
|  |  | ||||||
|  |   show_nas_sharing_w_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10'       , 'ori-test') , vis_save_dir, 'BN0', 'BN0-OTHER-CIFAR010.pdf', (0, 100,10), 250) | ||||||
|  |   show_nas_sharing_w_v2(api, ('cifar100'      , 'x-valid'), ('cifar100'      , 'x-test'  ) , vis_save_dir, 'BN0', 'BN0-OTHER-CIFAR100.pdf', (0, 100,10), 250) | ||||||
|  |   show_nas_sharing_w_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test'  ) , vis_save_dir, 'BN0', 'BN0-OTHER-ImageNet.pdf', (0, 100,10), 250) | ||||||
|  |  | ||||||
|  |   show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'BN0', 'BN0-XX-CIFAR010-VALID.pdf', (0, 100,10), 250) | ||||||
|  |   show_nas_sharing_w(api, 'cifar10'       , 'ori-test', vis_save_dir, 'BN0', 'BN0-XX-CIFAR010-TEST.pdf' , (0, 100,10), 250) | ||||||
|  |   import pdb; pdb.set_trace() | ||||||
|  |   """ | ||||||
|   for x_maxs in [50, 250]: |   for x_maxs in [50, 250]: | ||||||
|     show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) |     show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||||
|     show_nas_sharing_w(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) |     show_nas_sharing_w(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||||
| @@ -724,17 +812,11 @@ if __name__ == '__main__': | |||||||
|     show_nas_sharing_w(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) |     show_nas_sharing_w(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||||
|    |    | ||||||
|   show_nas_sharing_w_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10'       , 'ori-test') , vis_save_dir, 'DARTS-CIFAR010.pdf', (0, 100,10), 50) |   show_nas_sharing_w_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10'       , 'ori-test') , vis_save_dir, 'DARTS-CIFAR010.pdf', (0, 100,10), 50) | ||||||
|   show_nas_sharing_w_v2(api, ('cifar100'      , 'x-valid'), ('cifar100'      , 'x-test'  ) , vis_save_dir, 'DARTS-CIFAR100.pdf', (0, 100,10), 50) |   just_show(api) | ||||||
|   show_nas_sharing_w_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test'  ) , vis_save_dir, 'DARTS-ImageNet.pdf', (0, 100,10), 50) |  | ||||||
|   #just_show(api) |  | ||||||
|   """ |  | ||||||
|   plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1)) |   plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1)) | ||||||
|   plot_results_nas(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1)) |   plot_results_nas(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1)) | ||||||
|   plot_results_nas(api, 'cifar100'      , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3)) |   plot_results_nas(api, 'cifar100'      , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3)) | ||||||
|   plot_results_nas(api, 'cifar100'      , 'x-test'  , vis_save_dir, 'nas-com.pdf', (55,75, 3)) |   plot_results_nas(api, 'cifar100'      , 'x-test'  , vis_save_dir, 'nas-com.pdf', (55,75, 3)) | ||||||
|   plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3)) |   plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3)) | ||||||
|   plot_results_nas(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-com.pdf', (35,50, 3)) |   plot_results_nas(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-com.pdf', (35,50, 3)) | ||||||
|   plot_results_nas_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10'       , 'ori-test'), vis_save_dir, 'nas-com-v2-cifar010.pdf', (85,95, 1)) |  | ||||||
|   plot_results_nas_v2(api, ('cifar100'      , 'x-valid'), ('cifar100'      , 'x-test'  ), vis_save_dir, 'nas-com-v2-cifar100.pdf', (60,75, 3)) |  | ||||||
|   plot_results_nas_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test'  ), vis_save_dir, 'nas-com-v2-imagenet.pdf', (35,48, 2)) |  | ||||||
|   """ |   """ | ||||||
|   | |||||||
| @@ -33,13 +33,38 @@ class Model(object): | |||||||
|  |  | ||||||
| # This function is to mimic the training and evaluatinig procedure for a single architecture `arch`. | # This function is to mimic the training and evaluatinig procedure for a single architecture `arch`. | ||||||
| # The time_cost is calculated as the total training time for a few (e.g., 12 epochs) plus the evaluation time for one epoch. | # The time_cost is calculated as the total training time for a few (e.g., 12 epochs) plus the evaluation time for one epoch. | ||||||
| def train_and_eval(arch, nas_bench, extra_info): | # For use_converged_LR = True, the architecture is trained for 12 epochs, with LR being decaded from 0.1 to 0. | ||||||
|   if nas_bench is not None: | #       In this case, the LR schedular is converged. | ||||||
|  | # For use_converged_LR = False, the architecture is planed to be trained for 200 epochs, but we early stop its procedure. | ||||||
|  | #        | ||||||
|  | def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_converged_LR=True): | ||||||
|  |   if use_converged_LR and nas_bench is not None: | ||||||
|     arch_index = nas_bench.query_index_by_arch( arch ) |     arch_index = nas_bench.query_index_by_arch( arch ) | ||||||
|     assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) |     assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) | ||||||
|     info = nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True) |     info = nas_bench.get_more_info(arch_index, dataname, None, True) | ||||||
|     valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] |     valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] | ||||||
|     #_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs |     #_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs | ||||||
|  |   elif not use_converged_LR and nas_bench is not None: | ||||||
|  |     # Please use `use_converged_LR=False` for cifar10 only. | ||||||
|  |     # It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details) | ||||||
|  |     arch_index, nepoch = nas_bench.query_index_by_arch( arch ), 25 | ||||||
|  |     assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) | ||||||
|  |     xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True) | ||||||
|  |     xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', False) | ||||||
|  |     info = nas_bench.get_more_info(arch_index, dataname, nepoch, False, True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready). | ||||||
|  |     cost = nas_bench.get_cost_info(arch_index, dataname, False) | ||||||
|  |     # The following codes are used to estimate the time cost. | ||||||
|  |     # When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record. | ||||||
|  |     # When we create checkpoints for converged_LR, we run all experiments on 1080Ti, and thus the time for each architecture can be fairly compared. | ||||||
|  |     nums = {'ImageNet16-120-train': 151700, 'ImageNet16-120-valid': 3000, | ||||||
|  |             'cifar10-valid-train' : 25000,  'cifar10-valid-valid' : 25000, | ||||||
|  |             'cifar100-train'      : 50000,  'cifar100-valid'      : 5000} | ||||||
|  |     estimated_train_cost = xoinfo['train-per-time'] / nums['cifar10-valid-train'] * nums['{:}-train'.format(dataname)] / xocost['latency'] * cost['latency'] * nepoch | ||||||
|  |     estimated_valid_cost = xoinfo['valid-per-time'] / nums['cifar10-valid-valid'] * nums['{:}-valid'.format(dataname)] / xocost['latency'] * cost['latency'] | ||||||
|  |     try: | ||||||
|  |       valid_acc, time_cost = info['valid-accuracy'], estimated_train_cost + estimated_valid_cost | ||||||
|  |     except: | ||||||
|  |       valid_acc, time_cost = info['est-valid-accuracy'], estimated_train_cost + estimated_valid_cost | ||||||
|   else: |   else: | ||||||
|     # train a model from scratch. |     # train a model from scratch. | ||||||
|     raise ValueError('NOT IMPLEMENT YET') |     raise ValueError('NOT IMPLEMENT YET') | ||||||
| @@ -79,7 +104,7 @@ def mutate_arch_func(op_names): | |||||||
|   return mutate_arch_func |   return mutate_arch_func | ||||||
|  |  | ||||||
|  |  | ||||||
| def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, nas_bench, extra_info): | def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, nas_bench, extra_info, dataname): | ||||||
|   """Algorithm for regularized evolution (i.e. aging evolution). |   """Algorithm for regularized evolution (i.e. aging evolution). | ||||||
|    |    | ||||||
|   Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image |   Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image | ||||||
| @@ -150,6 +175,10 @@ def main(xargs, nas_bench): | |||||||
|   logger = prepare_logger(args) |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' |   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' | ||||||
|  |   if xargs.dataset == 'cifar10': | ||||||
|  |     dataname = 'cifar10-valid' | ||||||
|  |   else: | ||||||
|  |     dataname = xargs.dataset | ||||||
|   if xargs.data_path is not None: |   if xargs.data_path is not None: | ||||||
|     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) |     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||||
| @@ -182,7 +211,7 @@ def main(xargs, nas_bench): | |||||||
|   x_start_time = time.time() |   x_start_time = time.time() | ||||||
|   logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) |   logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) | ||||||
|   logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget)) |   logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget)) | ||||||
|   history, total_cost = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info) |   history, total_cost = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info, dataname) | ||||||
|   logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_cost, time.time()-x_start_time)) |   logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_cost, time.time()-x_start_time)) | ||||||
|   best_arch = max(history, key=lambda i: i.accuracy) |   best_arch = max(history, key=lambda i: i.accuracy) | ||||||
|   best_arch = best_arch.arch |   best_arch = best_arch.arch | ||||||
|   | |||||||
| @@ -162,6 +162,13 @@ class NASBench201API(object): | |||||||
|     archresult = arch2infos[index] |     archresult = arch2infos[index] | ||||||
|     return archresult.get_net_param(dataset, seed) |     return archresult.get_net_param(dataset, seed) | ||||||
|  |  | ||||||
|  |   # obtain the cost metric for the `index`-th architecture on a dataset | ||||||
|  |   def get_cost_info(self, index, dataset, use_12epochs_result=False): | ||||||
|  |     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||||
|  |     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||||
|  |     archresult = arch2infos[index] | ||||||
|  |     return archresult.get_comput_costs(dataset) | ||||||
|  |  | ||||||
|   # obtain the metric for the `index`-th architecture |   # obtain the metric for the `index`-th architecture | ||||||
|   def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True): |   def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True): | ||||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less |     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||||
| @@ -177,6 +184,7 @@ class NASBench201API(object): | |||||||
|       total      = train_info['iepoch'] + 1 |       total      = train_info['iepoch'] + 1 | ||||||
|       xifo = {'train-loss'    : train_info['loss'], |       xifo = {'train-loss'    : train_info['loss'], | ||||||
|               'train-accuracy': train_info['accuracy'], |               'train-accuracy': train_info['accuracy'], | ||||||
|  |               'train-per-time': None if train_info['all_time'] is None else train_info['all_time'] / total, | ||||||
|               'train-all-time': train_info['all_time'], |               'train-all-time': train_info['all_time'], | ||||||
|               'valid-loss'    : valid_info['loss'], |               'valid-loss'    : valid_info['loss'], | ||||||
|               'valid-accuracy': valid_info['accuracy'], |               'valid-accuracy': valid_info['accuracy'], | ||||||
| @@ -188,21 +196,32 @@ class NASBench201API(object): | |||||||
|       return xifo |       return xifo | ||||||
|     else: |     else: | ||||||
|       train_info = archresult.get_metrics(dataset, 'train'   , iepoch=iepoch, is_random=is_random) |       train_info = archresult.get_metrics(dataset, 'train'   , iepoch=iepoch, is_random=is_random) | ||||||
|       if dataset == 'cifar10': |       try: | ||||||
|         test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) |         if dataset == 'cifar10': | ||||||
|       else: |           test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) | ||||||
|         test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) |         else: | ||||||
|  |           test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) | ||||||
|  |       except: | ||||||
|  |         valid_info = None | ||||||
|       try: |       try: | ||||||
|         valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) |         valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) | ||||||
|       except: |       except: | ||||||
|         valid_info = None |         valid_info = None | ||||||
|  |       try: | ||||||
|  |         est_valid_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) | ||||||
|  |       except: | ||||||
|  |         est_valid_info = None | ||||||
|       xifo = {'train-loss'    : train_info['loss'], |       xifo = {'train-loss'    : train_info['loss'], | ||||||
|               'train-accuracy': train_info['accuracy'], |               'train-accuracy': train_info['accuracy']} | ||||||
|               'test-loss'     : test__info['loss'], |       if valid_info is not None: | ||||||
|               'test-accuracy' : test__info['accuracy']} |         xifo['test-loss'] = test__info['loss'], | ||||||
|  |         xifo['test-accuracy'] = test__info['accuracy'] | ||||||
|       if valid_info is not None: |       if valid_info is not None: | ||||||
|         xifo['valid-loss'] = valid_info['loss'] |         xifo['valid-loss'] = valid_info['loss'] | ||||||
|         xifo['valid-accuracy'] = valid_info['accuracy'] |         xifo['valid-accuracy'] = valid_info['accuracy'] | ||||||
|  |       if est_valid_info is not None: | ||||||
|  |         xifo['est-valid-loss'] = est_valid_info['loss'] | ||||||
|  |         xifo['est-valid-accuracy'] = est_valid_info['accuracy'] | ||||||
|       return xifo |       return xifo | ||||||
|  |  | ||||||
|   def show(self, index=-1): |   def show(self, index=-1): | ||||||
|   | |||||||
| @@ -1,7 +1,8 @@ | |||||||
| #!/bin/bash | #!/bin/bash | ||||||
| echo script name: $0 | echo script name: $0 | ||||||
|  |  | ||||||
| lrs="0.01 0.02 0.1 0.2 0.5 1.0 1.5 2.0 2.5 3.0" | #lrs="0.01 0.02 0.1 0.2 0.5 1.0 1.5 2.0 2.5 3.0" | ||||||
|  | lrs="0.01 0.02 0.1 0.2 0.5" | ||||||
|  |  | ||||||
| for lr in ${lrs} | for lr in ${lrs} | ||||||
| do | do | ||||||
|   | |||||||
| @@ -1,11 +1,11 @@ | |||||||
| #!/bin/bash | #!/bin/bash | ||||||
| # Regularized Evolution for Image Classifier Architecture Search, AAAI 2019 | # Regularized Evolution for Image Classifier Architecture Search, AAAI 2019 | ||||||
| # bash ./scripts-search/algos/R-EA.sh -1 | # bash ./scripts-search/algos/R-EA.sh cifar10 3 -1 | ||||||
| echo script name: $0 | echo script name: $0 | ||||||
| echo $# arguments | echo $# arguments | ||||||
| if [ "$#" -ne 1 ] ;then | if [ "$#" -ne 3 ] ;then | ||||||
|   echo "Input illegal number of parameters " $# |   echo "Input illegal number of parameters " $# | ||||||
|   echo "Need 1 parameters for seed" |   echo "Need 3 parameters for the-dataset-name, the-ea-sample-size and the-seed" | ||||||
|   exit 1 |   exit 1 | ||||||
| fi | fi | ||||||
| if [ "$TORCH_HOME" = "" ]; then | if [ "$TORCH_HOME" = "" ]; then | ||||||
| @@ -15,14 +15,16 @@ else | |||||||
|   echo "TORCH_HOME : $TORCH_HOME" |   echo "TORCH_HOME : $TORCH_HOME" | ||||||
| fi | fi | ||||||
|  |  | ||||||
| dataset=cifar10 | #dataset=cifar10 | ||||||
| seed=$1 | dataset=$1 | ||||||
|  | sample_size=$2 | ||||||
|  | seed=$3 | ||||||
| channel=16 | channel=16 | ||||||
| num_cells=5 | num_cells=5 | ||||||
| max_nodes=4 | max_nodes=4 | ||||||
| space=nas-bench-201 | space=nas-bench-201 | ||||||
|  |  | ||||||
| save_dir=./output/search-cell-${space}/R-EA-${dataset} | save_dir=./output/search-cell-${space}/R-EA-${dataset}-SS${sample_size} | ||||||
|  |  | ||||||
| OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \ | OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \ | ||||||
| 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | ||||||
| @@ -30,5 +32,5 @@ OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \ | |||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ | 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ | ||||||
| 	--time_budget 12000 \ | 	--time_budget 12000 \ | ||||||
| 	--ea_cycles 100 --ea_population 10 --ea_sample_size 3 --ea_fast_by_api 1 \ | 	--ea_cycles 200 --ea_population 10 --ea_sample_size ${sample_size} --ea_fast_by_api 1 \ | ||||||
| 	--workers 4 --print_freq 200 --rand_seed ${seed} | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user