update README

This commit is contained in:
D-X-Y 2020-01-16 01:43:07 +11:00
parent 4be2a0000c
commit b299945b23
8 changed files with 205 additions and 50 deletions

View File

@ -39,7 +39,7 @@ At the moment, this project provides the following algorithms and scripts to run
<tr> <!-- (2-nd row) --> <tr> <!-- (2-nd row) -->
<td align="center" valign="middle"> DARTS </td> <td align="center" valign="middle"> DARTS </td>
<td align="center" valign="middle"> DARTS: Differentiable Architecture Search </td> <td align="center" valign="middle"> DARTS: Differentiable Architecture Search </td>
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/master/docs/NAS-Bench-201.md">NAS-Bench-201.md</a> </td> <td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/master/docs/ICLR-2019-DARTS.md">ICLR-2019-DARTS.md</a> </td>
</tr> </tr>
<tr> <!-- (3-nd row) --> <tr> <!-- (3-nd row) -->
<td align="center" valign="middle"> GDAS </td> <td align="center" valign="middle"> GDAS </td>

22
docs/ICLR-2019-DARTS.md Normal file
View File

@ -0,0 +1,22 @@
# DARTS: Differentiable Architecture Search
DARTS: Differentiable Architecture Search is accepted by ICLR 2019.
In this paper, Hanxiao proposed a differentiable neural architecture search method, named as DARTS.
Recently, DARTS becomes very popular due to its simplicity and performance.
**Run DARTS on the NAS-Bench-201 search space**:
```
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 1 -1
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 1 -1
```
# Citation
```
@inproceedings{liu2019darts,
title = {{DARTS}: Differentiable architecture search},
author = {Liu, Hanxiao and Simonyan, Karen and Yang, Yiming},
booktitle = {International Conference on Learning Representations (ICLR)},
year = {2019}
}
```

View File

@ -181,7 +181,7 @@ If researchers can provide better results with different hyper-parameters, we ar
- [4] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 1 -1` - [4] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 1 -1`
- [5] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/ENAS.sh cifar10 1 -1` - [5] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/ENAS.sh cifar10 1 -1`
- [6] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 1 -1` - [6] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 1 -1`
- [7] `bash ./scripts-search/algos/R-EA.sh -1` - [7] `bash ./scripts-search/algos/R-EA.sh cifar10 3 -1`
- [8] `bash ./scripts-search/algos/Random.sh -1` - [8] `bash ./scripts-search/algos/Random.sh -1`
- [9] `bash ./scripts-search/algos/REINFORCE.sh 0.5 -1` - [9] `bash ./scripts-search/algos/REINFORCE.sh 0.5 -1`
- [10] `bash ./scripts-search/algos/BOHB.sh -1` - [10] `bash ./scripts-search/algos/BOHB.sh -1`

View File

@ -517,7 +517,7 @@ def just_show(api):
print ('[{:10s}-{:10s} ::: index={:5d}, accuracy={:.2f}'.format(dataset, metric_on_set, arch_index, highest_acc)) print ('[{:10s}-{:10s} ::: index={:5d}, accuracy={:.2f}'.format(dataset, metric_on_set, arch_index, highest_acc))
def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_maxs): def show_nas_sharing_w(api, dataset, subset, vis_save_dir, sufix, file_name, y_lims, x_maxs):
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
dpi, width, height = 300, 3400, 2600 dpi, width, height = 300, 3400, 2600
LabelSize, LegendFontsize = 28, 28 LabelSize, LegendFontsize = 28, 28
@ -533,13 +533,14 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_
plt.xlabel('The searching epoch', fontsize=LabelSize) plt.xlabel('The searching epoch', fontsize=LabelSize)
plt.ylabel('The accuracy (%)', fontsize=LabelSize) plt.ylabel('The accuracy (%)', fontsize=LabelSize)
xpaths = {'RSPS' : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/', xpaths = {'RSPS' : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10-{:}/checkpoint/'.format(sufix),
'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/', 'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10-{:}/checkpoint/'.format(sufix),
'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/', 'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10-{:}/checkpoint/'.format(sufix),
'GDAS' : 'output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/', 'GDAS' : 'output/search-cell-nas-bench-201/GDAS-cifar10-{:}/checkpoint/'.format(sufix),
'SETN' : 'output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/', 'SETN' : 'output/search-cell-nas-bench-201/SETN-cifar10-{:}/checkpoint/'.format(sufix),
'ENAS' : 'output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/', 'ENAS' : 'output/search-cell-nas-bench-201/ENAS-cifar10-{:}/checkpoint/'.format(sufix),
} }
"""
xseeds = {'RSPS' : [5349, 59613, 5983], xseeds = {'RSPS' : [5349, 59613, 5983],
'DARTS-V1': [11416, 72873, 81184, 28640], 'DARTS-V1': [11416, 72873, 81184, 28640],
'DARTS-V2': [43330, 79405, 79423], 'DARTS-V2': [43330, 79405, 79423],
@ -547,6 +548,15 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_
'SETN' : [20518, 61817, 89144], 'SETN' : [20518, 61817, 89144],
'ENAS' : [3231, 34238, 96929], 'ENAS' : [3231, 34238, 96929],
} }
"""
xseeds = {'RSPS' : [23814, 28015, 95809],
'DARTS-V1': [48349, 80877, 81920],
'DARTS-V2': [61712, 7941 , 87041] ,
'GDAS' : [72818, 72996, 78877],
'SETN' : [26985, 55206, 95404],
'ENAS' : [21792, 36605, 45029]
}
def get_accs(xdata): def get_accs(xdata):
epochs, xresults = xdata['epoch'], [] epochs, xresults = xdata['epoch'], []
@ -579,12 +589,13 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_
plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color_set[idx]) plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color_set[idx])
#plt.legend(loc=4, fontsize=LegendFontsize) #plt.legend(loc=4, fontsize=LegendFontsize)
plt.legend(loc=0, fontsize=LegendFontsize) plt.legend(loc=0, fontsize=LegendFontsize)
save_path = vis_save_dir / '{:}-{:}-{:}-{:}'.format(xox, dataset, subset, file_name) save_path = vis_save_dir / '{:}.pdf'.format(file_name)
print('save figure into {:}\n'.format(save_path)) print('save figure into {:}\n'.format(save_path))
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, file_name, y_lims, x_maxs):
def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, sufix, file_name, y_lims, x_maxs):
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
dpi, width, height = 300, 3400, 2600 dpi, width, height = 300, 3400, 2600
LabelSize, LegendFontsize = 28, 28 LabelSize, LegendFontsize = 28, 28
@ -600,13 +611,14 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, file_name,
plt.xlabel('The searching epoch', fontsize=LabelSize) plt.xlabel('The searching epoch', fontsize=LabelSize)
plt.ylabel('The accuracy (%)', fontsize=LabelSize) plt.ylabel('The accuracy (%)', fontsize=LabelSize)
xpaths = {'RSPS' : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/', xpaths = {'RSPS' : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10-{:}/checkpoint/'.format(sufix),
'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/', 'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10-{:}/checkpoint/'.format(sufix),
'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/', 'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10-{:}/checkpoint/'.format(sufix),
'GDAS' : 'output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/', 'GDAS' : 'output/search-cell-nas-bench-201/GDAS-cifar10-{:}/checkpoint/'.format(sufix),
'SETN' : 'output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/', 'SETN' : 'output/search-cell-nas-bench-201/SETN-cifar10-{:}/checkpoint/'.format(sufix),
'ENAS' : 'output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/', 'ENAS' : 'output/search-cell-nas-bench-201/ENAS-cifar10-{:}/checkpoint/'.format(sufix),
} }
"""
xseeds = {'RSPS' : [5349, 59613, 5983], xseeds = {'RSPS' : [5349, 59613, 5983],
'DARTS-V1': [11416, 72873, 81184, 28640], 'DARTS-V1': [11416, 72873, 81184, 28640],
'DARTS-V2': [43330, 79405, 79423], 'DARTS-V2': [43330, 79405, 79423],
@ -614,6 +626,15 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, file_name,
'SETN' : [20518, 61817, 89144], 'SETN' : [20518, 61817, 89144],
'ENAS' : [3231, 34238, 96929], 'ENAS' : [3231, 34238, 96929],
} }
"""
xseeds = {'RSPS' : [23814, 28015, 95809],
'DARTS-V1': [48349, 80877, 81920],
'DARTS-V2': [61712, 7941 , 87041] ,
'GDAS' : [72818, 72996, 78877],
'SETN' : [26985, 55206, 95404],
'ENAS' : [21792, 36605, 45029]
}
def get_accs(xdata, dataset, subset): def get_accs(xdata, dataset, subset):
epochs, xresults = xdata['epoch'], [] epochs, xresults = xdata['epoch'], []
@ -643,8 +664,15 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, file_name,
accyss_B = np.array( [get_accs(xdatas, data_sub_b[0], data_sub_b[1]) for xdatas in all_datas] ) accyss_B = np.array( [get_accs(xdatas, data_sub_b[0], data_sub_b[1]) for xdatas in all_datas] )
epochs = list(range(accyss_A.shape[1])) epochs = list(range(accyss_A.shape[1]))
for j, accyss in enumerate([accyss_A, accyss_B]): for j, accyss in enumerate([accyss_A, accyss_B]):
plt.plot(epochs, [accyss[:,i].mean() for i in epochs], color=color_set[idx*2+j], linestyle='-' if j==0 else '--', label='{:} ({:})'.format(method, 'VALID' if j == 0 else 'TEST'), lw=2, alpha=0.9) if x_maxs == 50:
plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color_set[idx*2+j]) color, line = color_set[idx*2+j], '-' if j==0 else '--'
elif x_maxs == 250:
color, line = color_set[idx], '-' if j==0 else '--'
else: raise ValueError('invalid x-maxs={:}'.format(x_maxs))
plt.plot(epochs, [accyss[:,i].mean() for i in epochs], color=color, linestyle=line, label='{:} ({:})'.format(method, 'VALID' if j == 0 else 'TEST'), lw=2, alpha=0.9)
plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color)
setname = data_sub_a if j == 0 else data_sub_b
print('{:} -- {:} ---- {:.2f}$\\pm${:.2f}'.format(method, setname, accyss[:,-1].mean(), accyss[:,-1].std()))
#plt.legend(loc=4, fontsize=LegendFontsize) #plt.legend(loc=4, fontsize=LegendFontsize)
plt.legend(loc=0, fontsize=LegendFontsize) plt.legend(loc=0, fontsize=LegendFontsize)
save_path = vis_save_dir / '{:}-{:}'.format(xox, file_name) save_path = vis_save_dir / '{:}-{:}'.format(xox, file_name)
@ -654,7 +682,7 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, file_name,
def show_reinforce(api, root, dataset, xset, file_name, y_lims): def show_reinforce(api, root, dataset, xset, file_name, y_lims):
print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset)) print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset))
LRs = ['0.01', '0.02', '0.1', '0.2', '0.5', '1.0', '1.5', '2.0', '2.5', '3.0'] LRs = ['0.01', '0.02', '0.1', '0.2', '0.5']
checkpoints = ['./output/search-cell-nas-bench-201/REINFORCE-cifar10-{:}/results.pth'.format(x) for x in LRs] checkpoints = ['./output/search-cell-nas-bench-201/REINFORCE-cifar10-{:}/results.pth'.format(x) for x in LRs]
acc_lr_dict, indexes = {}, None acc_lr_dict, indexes = {}, None
for lr, checkpoint in zip(LRs, checkpoints): for lr, checkpoint in zip(LRs, checkpoints):
@ -684,7 +712,8 @@ def show_reinforce(api, root, dataset, xset, file_name, y_lims):
for idx, LR in enumerate(LRs): for idx, LR in enumerate(LRs):
legend = 'LR={:.2f}'.format(float(LR)) legend = 'LR={:.2f}'.format(float(LR))
color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.' #color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.'
color, linestyle = color_set[idx], '-'
plt.plot(indexes, acc_lr_dict[LR], color=color, linestyle=linestyle, label=legend, lw=2, alpha=0.8) plt.plot(indexes, acc_lr_dict[LR], color=color, linestyle=linestyle, label=legend, lw=2, alpha=0.8)
print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(acc_lr_dict[LR]), np.std(acc_lr_dict[LR]), np.mean(acc_lr_dict[LR]), np.std(acc_lr_dict[LR]))) print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(acc_lr_dict[LR]), np.std(acc_lr_dict[LR]), np.mean(acc_lr_dict[LR]), np.std(acc_lr_dict[LR])))
plt.legend(loc=4, fontsize=LegendFontsize) plt.legend(loc=4, fontsize=LegendFontsize)
@ -693,6 +722,49 @@ def show_reinforce(api, root, dataset, xset, file_name, y_lims):
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
def show_rea(api, root, dataset, xset, file_name, y_lims):
print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset))
SSs = [3, 5, 10]
checkpoints = ['./output/search-cell-nas-bench-201/R-EA-cifar10-SS{:}/results.pth'.format(x) for x in SSs]
acc_ss_dict, indexes = {}, None
for ss, checkpoint in zip(SSs, checkpoints):
all_indexes, accuracies = torch.load(checkpoint, map_location='cpu'), []
for x in all_indexes:
info = api.arch2infos_full[ x ]
metrics = info.get_metrics(dataset, xset, None, False)
accuracies.append( metrics['accuracy'] )
if indexes is None: indexes = list(range(len(accuracies)))
acc_ss_dict[ss] = np.array( sorted(accuracies) )
print ('Sample-Size={:2d}, mean={:}, std={:}'.format(ss, acc_ss_dict[ss].mean(), acc_ss_dict[ss].std()))
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
dpi, width, height = 300, 3400, 2600
LabelSize, LegendFontsize = 28, 22
figsize = width / float(dpi), height / float(dpi)
fig = plt.figure(figsize=figsize)
x_axis = np.arange(0, 600)
plt.xlim(0, max(indexes))
plt.ylim(y_lims[0], y_lims[1])
interval_x, interval_y = 100, y_lims[2]
plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize)
plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize)
plt.grid()
plt.xlabel('The index of runs', fontsize=LabelSize)
plt.ylabel('The accuracy (%)', fontsize=LabelSize)
for idx, ss in enumerate(SSs):
legend = 'sample-size={:2d}'.format(ss)
#color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.'
color, linestyle = color_set[idx], '-'
plt.plot(indexes, acc_ss_dict[ss], color=color, linestyle=linestyle, label=legend, lw=2, alpha=0.8)
print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(acc_ss_dict[ss]), np.std(acc_ss_dict[ss]), np.mean(acc_ss_dict[ss]), np.std(acc_ss_dict[ss])))
plt.legend(loc=4, fontsize=LegendFontsize)
save_path = root / '{:}-{:}-{:}.pdf'.format(dataset, xset, file_name)
print('save figure into {:}\n'.format(save_path))
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
@ -712,9 +784,25 @@ if __name__ == '__main__':
#visualize_relative_ranking(vis_save_dir) #visualize_relative_ranking(vis_save_dir)
api = API(args.api_path) api = API(args.api_path)
show_reinforce(api, vis_save_dir, 'cifar10-valid' , 'x-valid', 'REINFORCE-CIFAR-10', (75, 95, 5)) #show_reinforce(api, vis_save_dir, 'cifar10-valid' , 'x-valid', 'REINFORCE-CIFAR-10', (85, 92, 2))
import pdb; pdb.set_trace() #show_rea (api, vis_save_dir, 'cifar10-valid' , 'x-valid', 'REA-CIFAR-10', (88, 92, 1))
#plot_results_nas_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test'), vis_save_dir, 'nas-com-v2-cifar010.pdf', (85,95, 1))
#plot_results_nas_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ), vis_save_dir, 'nas-com-v2-cifar100.pdf', (60,75, 3))
#plot_results_nas_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ), vis_save_dir, 'nas-com-v2-imagenet.pdf', (35,48, 2))
show_nas_sharing_w_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test') , vis_save_dir, 'BN0', 'BN0-DARTS-CIFAR010.pdf', (0, 100,10), 50)
show_nas_sharing_w_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ) , vis_save_dir, 'BN0', 'BN0-DARTS-CIFAR100.pdf', (0, 100,10), 50)
show_nas_sharing_w_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ) , vis_save_dir, 'BN0', 'BN0-DARTS-ImageNet.pdf', (0, 100,10), 50)
show_nas_sharing_w_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test') , vis_save_dir, 'BN0', 'BN0-OTHER-CIFAR010.pdf', (0, 100,10), 250)
show_nas_sharing_w_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ) , vis_save_dir, 'BN0', 'BN0-OTHER-CIFAR100.pdf', (0, 100,10), 250)
show_nas_sharing_w_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ) , vis_save_dir, 'BN0', 'BN0-OTHER-ImageNet.pdf', (0, 100,10), 250)
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'BN0', 'BN0-XX-CIFAR010-VALID.pdf', (0, 100,10), 250)
show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'BN0', 'BN0-XX-CIFAR010-TEST.pdf' , (0, 100,10), 250)
import pdb; pdb.set_trace()
"""
for x_maxs in [50, 250]: for x_maxs in [50, 250]:
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
@ -724,17 +812,11 @@ if __name__ == '__main__':
show_nas_sharing_w(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) show_nas_sharing_w(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
show_nas_sharing_w_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test') , vis_save_dir, 'DARTS-CIFAR010.pdf', (0, 100,10), 50) show_nas_sharing_w_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test') , vis_save_dir, 'DARTS-CIFAR010.pdf', (0, 100,10), 50)
show_nas_sharing_w_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ) , vis_save_dir, 'DARTS-CIFAR100.pdf', (0, 100,10), 50) just_show(api)
show_nas_sharing_w_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ) , vis_save_dir, 'DARTS-ImageNet.pdf', (0, 100,10), 50)
#just_show(api)
"""
plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1)) plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1))
plot_results_nas(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1)) plot_results_nas(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1))
plot_results_nas(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3)) plot_results_nas(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
plot_results_nas(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-com.pdf', (55,75, 3)) plot_results_nas(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3)) plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
plot_results_nas(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-com.pdf', (35,50, 3)) plot_results_nas(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
plot_results_nas_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test'), vis_save_dir, 'nas-com-v2-cifar010.pdf', (85,95, 1))
plot_results_nas_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ), vis_save_dir, 'nas-com-v2-cifar100.pdf', (60,75, 3))
plot_results_nas_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ), vis_save_dir, 'nas-com-v2-imagenet.pdf', (35,48, 2))
""" """

View File

@ -33,13 +33,38 @@ class Model(object):
# This function is to mimic the training and evaluatinig procedure for a single architecture `arch`. # This function is to mimic the training and evaluatinig procedure for a single architecture `arch`.
# The time_cost is calculated as the total training time for a few (e.g., 12 epochs) plus the evaluation time for one epoch. # The time_cost is calculated as the total training time for a few (e.g., 12 epochs) plus the evaluation time for one epoch.
def train_and_eval(arch, nas_bench, extra_info): # For use_converged_LR = True, the architecture is trained for 12 epochs, with LR being decaded from 0.1 to 0.
if nas_bench is not None: # In this case, the LR schedular is converged.
# For use_converged_LR = False, the architecture is planed to be trained for 200 epochs, but we early stop its procedure.
#
def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_converged_LR=True):
if use_converged_LR and nas_bench is not None:
arch_index = nas_bench.query_index_by_arch( arch ) arch_index = nas_bench.query_index_by_arch( arch )
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
info = nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True) info = nas_bench.get_more_info(arch_index, dataname, None, True)
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
#_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs #_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs
elif not use_converged_LR and nas_bench is not None:
# Please use `use_converged_LR=False` for cifar10 only.
# It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details)
arch_index, nepoch = nas_bench.query_index_by_arch( arch ), 25
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True)
xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', False)
info = nas_bench.get_more_info(arch_index, dataname, nepoch, False, True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready).
cost = nas_bench.get_cost_info(arch_index, dataname, False)
# The following codes are used to estimate the time cost.
# When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record.
# When we create checkpoints for converged_LR, we run all experiments on 1080Ti, and thus the time for each architecture can be fairly compared.
nums = {'ImageNet16-120-train': 151700, 'ImageNet16-120-valid': 3000,
'cifar10-valid-train' : 25000, 'cifar10-valid-valid' : 25000,
'cifar100-train' : 50000, 'cifar100-valid' : 5000}
estimated_train_cost = xoinfo['train-per-time'] / nums['cifar10-valid-train'] * nums['{:}-train'.format(dataname)] / xocost['latency'] * cost['latency'] * nepoch
estimated_valid_cost = xoinfo['valid-per-time'] / nums['cifar10-valid-valid'] * nums['{:}-valid'.format(dataname)] / xocost['latency'] * cost['latency']
try:
valid_acc, time_cost = info['valid-accuracy'], estimated_train_cost + estimated_valid_cost
except:
valid_acc, time_cost = info['est-valid-accuracy'], estimated_train_cost + estimated_valid_cost
else: else:
# train a model from scratch. # train a model from scratch.
raise ValueError('NOT IMPLEMENT YET') raise ValueError('NOT IMPLEMENT YET')
@ -79,7 +104,7 @@ def mutate_arch_func(op_names):
return mutate_arch_func return mutate_arch_func
def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, nas_bench, extra_info): def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, nas_bench, extra_info, dataname):
"""Algorithm for regularized evolution (i.e. aging evolution). """Algorithm for regularized evolution (i.e. aging evolution).
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image
@ -150,6 +175,10 @@ def main(xargs, nas_bench):
logger = prepare_logger(args) logger = prepare_logger(args)
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
if xargs.dataset == 'cifar10':
dataname = 'cifar10-valid'
else:
dataname = xargs.dataset
if xargs.data_path is not None: if xargs.data_path is not None:
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
split_Fpath = 'configs/nas-benchmark/cifar-split.txt' split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
@ -182,7 +211,7 @@ def main(xargs, nas_bench):
x_start_time = time.time() x_start_time = time.time()
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget)) logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget))
history, total_cost = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info) history, total_cost = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info, dataname)
logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_cost, time.time()-x_start_time)) logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_cost, time.time()-x_start_time))
best_arch = max(history, key=lambda i: i.accuracy) best_arch = max(history, key=lambda i: i.accuracy)
best_arch = best_arch.arch best_arch = best_arch.arch

View File

@ -162,6 +162,13 @@ class NASBench201API(object):
archresult = arch2infos[index] archresult = arch2infos[index]
return archresult.get_net_param(dataset, seed) return archresult.get_net_param(dataset, seed)
# obtain the cost metric for the `index`-th architecture on a dataset
def get_cost_info(self, index, dataset, use_12epochs_result=False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
archresult = arch2infos[index]
return archresult.get_comput_costs(dataset)
# obtain the metric for the `index`-th architecture # obtain the metric for the `index`-th architecture
def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True): def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
@ -177,6 +184,7 @@ class NASBench201API(object):
total = train_info['iepoch'] + 1 total = train_info['iepoch'] + 1
xifo = {'train-loss' : train_info['loss'], xifo = {'train-loss' : train_info['loss'],
'train-accuracy': train_info['accuracy'], 'train-accuracy': train_info['accuracy'],
'train-per-time': None if train_info['all_time'] is None else train_info['all_time'] / total,
'train-all-time': train_info['all_time'], 'train-all-time': train_info['all_time'],
'valid-loss' : valid_info['loss'], 'valid-loss' : valid_info['loss'],
'valid-accuracy': valid_info['accuracy'], 'valid-accuracy': valid_info['accuracy'],
@ -188,21 +196,32 @@ class NASBench201API(object):
return xifo return xifo
else: else:
train_info = archresult.get_metrics(dataset, 'train' , iepoch=iepoch, is_random=is_random) train_info = archresult.get_metrics(dataset, 'train' , iepoch=iepoch, is_random=is_random)
if dataset == 'cifar10': try:
test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) if dataset == 'cifar10':
else: test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) else:
test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
except:
valid_info = None
try: try:
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
except: except:
valid_info = None valid_info = None
try:
est_valid_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
except:
est_valid_info = None
xifo = {'train-loss' : train_info['loss'], xifo = {'train-loss' : train_info['loss'],
'train-accuracy': train_info['accuracy'], 'train-accuracy': train_info['accuracy']}
'test-loss' : test__info['loss'], if valid_info is not None:
'test-accuracy' : test__info['accuracy']} xifo['test-loss'] = test__info['loss'],
xifo['test-accuracy'] = test__info['accuracy']
if valid_info is not None: if valid_info is not None:
xifo['valid-loss'] = valid_info['loss'] xifo['valid-loss'] = valid_info['loss']
xifo['valid-accuracy'] = valid_info['accuracy'] xifo['valid-accuracy'] = valid_info['accuracy']
if est_valid_info is not None:
xifo['est-valid-loss'] = est_valid_info['loss']
xifo['est-valid-accuracy'] = est_valid_info['accuracy']
return xifo return xifo
def show(self, index=-1): def show(self, index=-1):

View File

@ -1,7 +1,8 @@
#!/bin/bash #!/bin/bash
echo script name: $0 echo script name: $0
lrs="0.01 0.02 0.1 0.2 0.5 1.0 1.5 2.0 2.5 3.0" #lrs="0.01 0.02 0.1 0.2 0.5 1.0 1.5 2.0 2.5 3.0"
lrs="0.01 0.02 0.1 0.2 0.5"
for lr in ${lrs} for lr in ${lrs}
do do

View File

@ -1,11 +1,11 @@
#!/bin/bash #!/bin/bash
# Regularized Evolution for Image Classifier Architecture Search, AAAI 2019 # Regularized Evolution for Image Classifier Architecture Search, AAAI 2019
# bash ./scripts-search/algos/R-EA.sh -1 # bash ./scripts-search/algos/R-EA.sh cifar10 3 -1
echo script name: $0 echo script name: $0
echo $# arguments echo $# arguments
if [ "$#" -ne 1 ] ;then if [ "$#" -ne 3 ] ;then
echo "Input illegal number of parameters " $# echo "Input illegal number of parameters " $#
echo "Need 1 parameters for seed" echo "Need 3 parameters for the-dataset-name, the-ea-sample-size and the-seed"
exit 1 exit 1
fi fi
if [ "$TORCH_HOME" = "" ]; then if [ "$TORCH_HOME" = "" ]; then
@ -15,14 +15,16 @@ else
echo "TORCH_HOME : $TORCH_HOME" echo "TORCH_HOME : $TORCH_HOME"
fi fi
dataset=cifar10 #dataset=cifar10
seed=$1 dataset=$1
sample_size=$2
seed=$3
channel=16 channel=16
num_cells=5 num_cells=5
max_nodes=4 max_nodes=4
space=nas-bench-201 space=nas-bench-201
save_dir=./output/search-cell-${space}/R-EA-${dataset} save_dir=./output/search-cell-${space}/R-EA-${dataset}-SS${sample_size}
OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \ OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ --save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
@ -30,5 +32,5 @@ OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \
--search_space_name ${space} \ --search_space_name ${space} \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ --arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
--time_budget 12000 \ --time_budget 12000 \
--ea_cycles 100 --ea_population 10 --ea_sample_size 3 --ea_fast_by_api 1 \ --ea_cycles 200 --ea_population 10 --ea_sample_size ${sample_size} --ea_fast_by_api 1 \
--workers 4 --print_freq 200 --rand_seed ${seed} --workers 4 --print_freq 200 --rand_seed ${seed}