update README
This commit is contained in:
parent
4be2a0000c
commit
b299945b23
@ -39,7 +39,7 @@ At the moment, this project provides the following algorithms and scripts to run
|
||||
<tr> <!-- (2-nd row) -->
|
||||
<td align="center" valign="middle"> DARTS </td>
|
||||
<td align="center" valign="middle"> DARTS: Differentiable Architecture Search </td>
|
||||
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/master/docs/NAS-Bench-201.md">NAS-Bench-201.md</a> </td>
|
||||
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/master/docs/ICLR-2019-DARTS.md">ICLR-2019-DARTS.md</a> </td>
|
||||
</tr>
|
||||
<tr> <!-- (3-nd row) -->
|
||||
<td align="center" valign="middle"> GDAS </td>
|
||||
|
22
docs/ICLR-2019-DARTS.md
Normal file
22
docs/ICLR-2019-DARTS.md
Normal file
@ -0,0 +1,22 @@
|
||||
# DARTS: Differentiable Architecture Search
|
||||
|
||||
DARTS: Differentiable Architecture Search is accepted by ICLR 2019.
|
||||
In this paper, Hanxiao proposed a differentiable neural architecture search method, named as DARTS.
|
||||
Recently, DARTS becomes very popular due to its simplicity and performance.
|
||||
|
||||
**Run DARTS on the NAS-Bench-201 search space**:
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 1 -1
|
||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 1 -1
|
||||
```
|
||||
|
||||
# Citation
|
||||
|
||||
```
|
||||
@inproceedings{liu2019darts,
|
||||
title = {{DARTS}: Differentiable architecture search},
|
||||
author = {Liu, Hanxiao and Simonyan, Karen and Yang, Yiming},
|
||||
booktitle = {International Conference on Learning Representations (ICLR)},
|
||||
year = {2019}
|
||||
}
|
||||
```
|
@ -181,7 +181,7 @@ If researchers can provide better results with different hyper-parameters, we ar
|
||||
- [4] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 1 -1`
|
||||
- [5] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/ENAS.sh cifar10 1 -1`
|
||||
- [6] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 1 -1`
|
||||
- [7] `bash ./scripts-search/algos/R-EA.sh -1`
|
||||
- [7] `bash ./scripts-search/algos/R-EA.sh cifar10 3 -1`
|
||||
- [8] `bash ./scripts-search/algos/Random.sh -1`
|
||||
- [9] `bash ./scripts-search/algos/REINFORCE.sh 0.5 -1`
|
||||
- [10] `bash ./scripts-search/algos/BOHB.sh -1`
|
||||
|
@ -517,7 +517,7 @@ def just_show(api):
|
||||
print ('[{:10s}-{:10s} ::: index={:5d}, accuracy={:.2f}'.format(dataset, metric_on_set, arch_index, highest_acc))
|
||||
|
||||
|
||||
def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_maxs):
|
||||
def show_nas_sharing_w(api, dataset, subset, vis_save_dir, sufix, file_name, y_lims, x_maxs):
|
||||
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
|
||||
dpi, width, height = 300, 3400, 2600
|
||||
LabelSize, LegendFontsize = 28, 28
|
||||
@ -533,13 +533,14 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_
|
||||
plt.xlabel('The searching epoch', fontsize=LabelSize)
|
||||
plt.ylabel('The accuracy (%)', fontsize=LabelSize)
|
||||
|
||||
xpaths = {'RSPS' : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/',
|
||||
'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/',
|
||||
'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/',
|
||||
'GDAS' : 'output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/',
|
||||
'SETN' : 'output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/',
|
||||
'ENAS' : 'output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/',
|
||||
xpaths = {'RSPS' : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10-{:}/checkpoint/'.format(sufix),
|
||||
'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10-{:}/checkpoint/'.format(sufix),
|
||||
'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10-{:}/checkpoint/'.format(sufix),
|
||||
'GDAS' : 'output/search-cell-nas-bench-201/GDAS-cifar10-{:}/checkpoint/'.format(sufix),
|
||||
'SETN' : 'output/search-cell-nas-bench-201/SETN-cifar10-{:}/checkpoint/'.format(sufix),
|
||||
'ENAS' : 'output/search-cell-nas-bench-201/ENAS-cifar10-{:}/checkpoint/'.format(sufix),
|
||||
}
|
||||
"""
|
||||
xseeds = {'RSPS' : [5349, 59613, 5983],
|
||||
'DARTS-V1': [11416, 72873, 81184, 28640],
|
||||
'DARTS-V2': [43330, 79405, 79423],
|
||||
@ -547,6 +548,15 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_
|
||||
'SETN' : [20518, 61817, 89144],
|
||||
'ENAS' : [3231, 34238, 96929],
|
||||
}
|
||||
"""
|
||||
xseeds = {'RSPS' : [23814, 28015, 95809],
|
||||
'DARTS-V1': [48349, 80877, 81920],
|
||||
'DARTS-V2': [61712, 7941 , 87041] ,
|
||||
'GDAS' : [72818, 72996, 78877],
|
||||
'SETN' : [26985, 55206, 95404],
|
||||
'ENAS' : [21792, 36605, 45029]
|
||||
}
|
||||
|
||||
|
||||
def get_accs(xdata):
|
||||
epochs, xresults = xdata['epoch'], []
|
||||
@ -579,12 +589,13 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_
|
||||
plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color_set[idx])
|
||||
#plt.legend(loc=4, fontsize=LegendFontsize)
|
||||
plt.legend(loc=0, fontsize=LegendFontsize)
|
||||
save_path = vis_save_dir / '{:}-{:}-{:}-{:}'.format(xox, dataset, subset, file_name)
|
||||
save_path = vis_save_dir / '{:}.pdf'.format(file_name)
|
||||
print('save figure into {:}\n'.format(save_path))
|
||||
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
|
||||
|
||||
def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, file_name, y_lims, x_maxs):
|
||||
|
||||
def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, sufix, file_name, y_lims, x_maxs):
|
||||
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
|
||||
dpi, width, height = 300, 3400, 2600
|
||||
LabelSize, LegendFontsize = 28, 28
|
||||
@ -600,13 +611,14 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, file_name,
|
||||
plt.xlabel('The searching epoch', fontsize=LabelSize)
|
||||
plt.ylabel('The accuracy (%)', fontsize=LabelSize)
|
||||
|
||||
xpaths = {'RSPS' : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/',
|
||||
'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/',
|
||||
'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/',
|
||||
'GDAS' : 'output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/',
|
||||
'SETN' : 'output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/',
|
||||
'ENAS' : 'output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/',
|
||||
xpaths = {'RSPS' : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10-{:}/checkpoint/'.format(sufix),
|
||||
'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10-{:}/checkpoint/'.format(sufix),
|
||||
'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10-{:}/checkpoint/'.format(sufix),
|
||||
'GDAS' : 'output/search-cell-nas-bench-201/GDAS-cifar10-{:}/checkpoint/'.format(sufix),
|
||||
'SETN' : 'output/search-cell-nas-bench-201/SETN-cifar10-{:}/checkpoint/'.format(sufix),
|
||||
'ENAS' : 'output/search-cell-nas-bench-201/ENAS-cifar10-{:}/checkpoint/'.format(sufix),
|
||||
}
|
||||
"""
|
||||
xseeds = {'RSPS' : [5349, 59613, 5983],
|
||||
'DARTS-V1': [11416, 72873, 81184, 28640],
|
||||
'DARTS-V2': [43330, 79405, 79423],
|
||||
@ -614,6 +626,15 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, file_name,
|
||||
'SETN' : [20518, 61817, 89144],
|
||||
'ENAS' : [3231, 34238, 96929],
|
||||
}
|
||||
"""
|
||||
xseeds = {'RSPS' : [23814, 28015, 95809],
|
||||
'DARTS-V1': [48349, 80877, 81920],
|
||||
'DARTS-V2': [61712, 7941 , 87041] ,
|
||||
'GDAS' : [72818, 72996, 78877],
|
||||
'SETN' : [26985, 55206, 95404],
|
||||
'ENAS' : [21792, 36605, 45029]
|
||||
}
|
||||
|
||||
|
||||
def get_accs(xdata, dataset, subset):
|
||||
epochs, xresults = xdata['epoch'], []
|
||||
@ -643,8 +664,15 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, file_name,
|
||||
accyss_B = np.array( [get_accs(xdatas, data_sub_b[0], data_sub_b[1]) for xdatas in all_datas] )
|
||||
epochs = list(range(accyss_A.shape[1]))
|
||||
for j, accyss in enumerate([accyss_A, accyss_B]):
|
||||
plt.plot(epochs, [accyss[:,i].mean() for i in epochs], color=color_set[idx*2+j], linestyle='-' if j==0 else '--', label='{:} ({:})'.format(method, 'VALID' if j == 0 else 'TEST'), lw=2, alpha=0.9)
|
||||
plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color_set[idx*2+j])
|
||||
if x_maxs == 50:
|
||||
color, line = color_set[idx*2+j], '-' if j==0 else '--'
|
||||
elif x_maxs == 250:
|
||||
color, line = color_set[idx], '-' if j==0 else '--'
|
||||
else: raise ValueError('invalid x-maxs={:}'.format(x_maxs))
|
||||
plt.plot(epochs, [accyss[:,i].mean() for i in epochs], color=color, linestyle=line, label='{:} ({:})'.format(method, 'VALID' if j == 0 else 'TEST'), lw=2, alpha=0.9)
|
||||
plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color)
|
||||
setname = data_sub_a if j == 0 else data_sub_b
|
||||
print('{:} -- {:} ---- {:.2f}$\\pm${:.2f}'.format(method, setname, accyss[:,-1].mean(), accyss[:,-1].std()))
|
||||
#plt.legend(loc=4, fontsize=LegendFontsize)
|
||||
plt.legend(loc=0, fontsize=LegendFontsize)
|
||||
save_path = vis_save_dir / '{:}-{:}'.format(xox, file_name)
|
||||
@ -654,7 +682,7 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, file_name,
|
||||
|
||||
def show_reinforce(api, root, dataset, xset, file_name, y_lims):
|
||||
print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset))
|
||||
LRs = ['0.01', '0.02', '0.1', '0.2', '0.5', '1.0', '1.5', '2.0', '2.5', '3.0']
|
||||
LRs = ['0.01', '0.02', '0.1', '0.2', '0.5']
|
||||
checkpoints = ['./output/search-cell-nas-bench-201/REINFORCE-cifar10-{:}/results.pth'.format(x) for x in LRs]
|
||||
acc_lr_dict, indexes = {}, None
|
||||
for lr, checkpoint in zip(LRs, checkpoints):
|
||||
@ -684,7 +712,8 @@ def show_reinforce(api, root, dataset, xset, file_name, y_lims):
|
||||
|
||||
for idx, LR in enumerate(LRs):
|
||||
legend = 'LR={:.2f}'.format(float(LR))
|
||||
color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.'
|
||||
#color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.'
|
||||
color, linestyle = color_set[idx], '-'
|
||||
plt.plot(indexes, acc_lr_dict[LR], color=color, linestyle=linestyle, label=legend, lw=2, alpha=0.8)
|
||||
print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(acc_lr_dict[LR]), np.std(acc_lr_dict[LR]), np.mean(acc_lr_dict[LR]), np.std(acc_lr_dict[LR])))
|
||||
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||
@ -693,6 +722,49 @@ def show_reinforce(api, root, dataset, xset, file_name, y_lims):
|
||||
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
|
||||
|
||||
|
||||
def show_rea(api, root, dataset, xset, file_name, y_lims):
|
||||
print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset))
|
||||
SSs = [3, 5, 10]
|
||||
checkpoints = ['./output/search-cell-nas-bench-201/R-EA-cifar10-SS{:}/results.pth'.format(x) for x in SSs]
|
||||
acc_ss_dict, indexes = {}, None
|
||||
for ss, checkpoint in zip(SSs, checkpoints):
|
||||
all_indexes, accuracies = torch.load(checkpoint, map_location='cpu'), []
|
||||
for x in all_indexes:
|
||||
info = api.arch2infos_full[ x ]
|
||||
metrics = info.get_metrics(dataset, xset, None, False)
|
||||
accuracies.append( metrics['accuracy'] )
|
||||
if indexes is None: indexes = list(range(len(accuracies)))
|
||||
acc_ss_dict[ss] = np.array( sorted(accuracies) )
|
||||
print ('Sample-Size={:2d}, mean={:}, std={:}'.format(ss, acc_ss_dict[ss].mean(), acc_ss_dict[ss].std()))
|
||||
|
||||
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
|
||||
dpi, width, height = 300, 3400, 2600
|
||||
LabelSize, LegendFontsize = 28, 22
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
fig = plt.figure(figsize=figsize)
|
||||
x_axis = np.arange(0, 600)
|
||||
plt.xlim(0, max(indexes))
|
||||
plt.ylim(y_lims[0], y_lims[1])
|
||||
interval_x, interval_y = 100, y_lims[2]
|
||||
plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize)
|
||||
plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize)
|
||||
plt.grid()
|
||||
plt.xlabel('The index of runs', fontsize=LabelSize)
|
||||
plt.ylabel('The accuracy (%)', fontsize=LabelSize)
|
||||
|
||||
for idx, ss in enumerate(SSs):
|
||||
legend = 'sample-size={:2d}'.format(ss)
|
||||
#color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.'
|
||||
color, linestyle = color_set[idx], '-'
|
||||
plt.plot(indexes, acc_ss_dict[ss], color=color, linestyle=linestyle, label=legend, lw=2, alpha=0.8)
|
||||
print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(acc_ss_dict[ss]), np.std(acc_ss_dict[ss]), np.mean(acc_ss_dict[ss]), np.std(acc_ss_dict[ss])))
|
||||
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||
save_path = root / '{:}-{:}-{:}.pdf'.format(dataset, xset, file_name)
|
||||
print('save figure into {:}\n'.format(save_path))
|
||||
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
@ -712,9 +784,25 @@ if __name__ == '__main__':
|
||||
#visualize_relative_ranking(vis_save_dir)
|
||||
|
||||
api = API(args.api_path)
|
||||
show_reinforce(api, vis_save_dir, 'cifar10-valid' , 'x-valid', 'REINFORCE-CIFAR-10', (75, 95, 5))
|
||||
import pdb; pdb.set_trace()
|
||||
#show_reinforce(api, vis_save_dir, 'cifar10-valid' , 'x-valid', 'REINFORCE-CIFAR-10', (85, 92, 2))
|
||||
#show_rea (api, vis_save_dir, 'cifar10-valid' , 'x-valid', 'REA-CIFAR-10', (88, 92, 1))
|
||||
|
||||
#plot_results_nas_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test'), vis_save_dir, 'nas-com-v2-cifar010.pdf', (85,95, 1))
|
||||
#plot_results_nas_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ), vis_save_dir, 'nas-com-v2-cifar100.pdf', (60,75, 3))
|
||||
#plot_results_nas_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ), vis_save_dir, 'nas-com-v2-imagenet.pdf', (35,48, 2))
|
||||
|
||||
show_nas_sharing_w_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test') , vis_save_dir, 'BN0', 'BN0-DARTS-CIFAR010.pdf', (0, 100,10), 50)
|
||||
show_nas_sharing_w_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ) , vis_save_dir, 'BN0', 'BN0-DARTS-CIFAR100.pdf', (0, 100,10), 50)
|
||||
show_nas_sharing_w_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ) , vis_save_dir, 'BN0', 'BN0-DARTS-ImageNet.pdf', (0, 100,10), 50)
|
||||
|
||||
show_nas_sharing_w_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test') , vis_save_dir, 'BN0', 'BN0-OTHER-CIFAR010.pdf', (0, 100,10), 250)
|
||||
show_nas_sharing_w_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ) , vis_save_dir, 'BN0', 'BN0-OTHER-CIFAR100.pdf', (0, 100,10), 250)
|
||||
show_nas_sharing_w_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ) , vis_save_dir, 'BN0', 'BN0-OTHER-ImageNet.pdf', (0, 100,10), 250)
|
||||
|
||||
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'BN0', 'BN0-XX-CIFAR010-VALID.pdf', (0, 100,10), 250)
|
||||
show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'BN0', 'BN0-XX-CIFAR010-TEST.pdf' , (0, 100,10), 250)
|
||||
import pdb; pdb.set_trace()
|
||||
"""
|
||||
for x_maxs in [50, 250]:
|
||||
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
@ -724,17 +812,11 @@ if __name__ == '__main__':
|
||||
show_nas_sharing_w(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
|
||||
show_nas_sharing_w_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test') , vis_save_dir, 'DARTS-CIFAR010.pdf', (0, 100,10), 50)
|
||||
show_nas_sharing_w_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ) , vis_save_dir, 'DARTS-CIFAR100.pdf', (0, 100,10), 50)
|
||||
show_nas_sharing_w_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ) , vis_save_dir, 'DARTS-ImageNet.pdf', (0, 100,10), 50)
|
||||
#just_show(api)
|
||||
"""
|
||||
just_show(api)
|
||||
plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1))
|
||||
plot_results_nas(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1))
|
||||
plot_results_nas(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
|
||||
plot_results_nas(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
|
||||
plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
|
||||
plot_results_nas(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
|
||||
plot_results_nas_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test'), vis_save_dir, 'nas-com-v2-cifar010.pdf', (85,95, 1))
|
||||
plot_results_nas_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ), vis_save_dir, 'nas-com-v2-cifar100.pdf', (60,75, 3))
|
||||
plot_results_nas_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ), vis_save_dir, 'nas-com-v2-imagenet.pdf', (35,48, 2))
|
||||
"""
|
||||
|
@ -33,13 +33,38 @@ class Model(object):
|
||||
|
||||
# This function is to mimic the training and evaluatinig procedure for a single architecture `arch`.
|
||||
# The time_cost is calculated as the total training time for a few (e.g., 12 epochs) plus the evaluation time for one epoch.
|
||||
def train_and_eval(arch, nas_bench, extra_info):
|
||||
if nas_bench is not None:
|
||||
# For use_converged_LR = True, the architecture is trained for 12 epochs, with LR being decaded from 0.1 to 0.
|
||||
# In this case, the LR schedular is converged.
|
||||
# For use_converged_LR = False, the architecture is planed to be trained for 200 epochs, but we early stop its procedure.
|
||||
#
|
||||
def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_converged_LR=True):
|
||||
if use_converged_LR and nas_bench is not None:
|
||||
arch_index = nas_bench.query_index_by_arch( arch )
|
||||
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
|
||||
info = nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True)
|
||||
info = nas_bench.get_more_info(arch_index, dataname, None, True)
|
||||
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
|
||||
#_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs
|
||||
elif not use_converged_LR and nas_bench is not None:
|
||||
# Please use `use_converged_LR=False` for cifar10 only.
|
||||
# It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details)
|
||||
arch_index, nepoch = nas_bench.query_index_by_arch( arch ), 25
|
||||
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
|
||||
xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True)
|
||||
xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', False)
|
||||
info = nas_bench.get_more_info(arch_index, dataname, nepoch, False, True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready).
|
||||
cost = nas_bench.get_cost_info(arch_index, dataname, False)
|
||||
# The following codes are used to estimate the time cost.
|
||||
# When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record.
|
||||
# When we create checkpoints for converged_LR, we run all experiments on 1080Ti, and thus the time for each architecture can be fairly compared.
|
||||
nums = {'ImageNet16-120-train': 151700, 'ImageNet16-120-valid': 3000,
|
||||
'cifar10-valid-train' : 25000, 'cifar10-valid-valid' : 25000,
|
||||
'cifar100-train' : 50000, 'cifar100-valid' : 5000}
|
||||
estimated_train_cost = xoinfo['train-per-time'] / nums['cifar10-valid-train'] * nums['{:}-train'.format(dataname)] / xocost['latency'] * cost['latency'] * nepoch
|
||||
estimated_valid_cost = xoinfo['valid-per-time'] / nums['cifar10-valid-valid'] * nums['{:}-valid'.format(dataname)] / xocost['latency'] * cost['latency']
|
||||
try:
|
||||
valid_acc, time_cost = info['valid-accuracy'], estimated_train_cost + estimated_valid_cost
|
||||
except:
|
||||
valid_acc, time_cost = info['est-valid-accuracy'], estimated_train_cost + estimated_valid_cost
|
||||
else:
|
||||
# train a model from scratch.
|
||||
raise ValueError('NOT IMPLEMENT YET')
|
||||
@ -79,7 +104,7 @@ def mutate_arch_func(op_names):
|
||||
return mutate_arch_func
|
||||
|
||||
|
||||
def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, nas_bench, extra_info):
|
||||
def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, nas_bench, extra_info, dataname):
|
||||
"""Algorithm for regularized evolution (i.e. aging evolution).
|
||||
|
||||
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image
|
||||
@ -150,6 +175,10 @@ def main(xargs, nas_bench):
|
||||
logger = prepare_logger(args)
|
||||
|
||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
||||
if xargs.dataset == 'cifar10':
|
||||
dataname = 'cifar10-valid'
|
||||
else:
|
||||
dataname = xargs.dataset
|
||||
if xargs.data_path is not None:
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
@ -182,7 +211,7 @@ def main(xargs, nas_bench):
|
||||
x_start_time = time.time()
|
||||
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
|
||||
logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget))
|
||||
history, total_cost = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info)
|
||||
history, total_cost = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info, dataname)
|
||||
logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_cost, time.time()-x_start_time))
|
||||
best_arch = max(history, key=lambda i: i.accuracy)
|
||||
best_arch = best_arch.arch
|
||||
|
@ -162,6 +162,13 @@ class NASBench201API(object):
|
||||
archresult = arch2infos[index]
|
||||
return archresult.get_net_param(dataset, seed)
|
||||
|
||||
# obtain the cost metric for the `index`-th architecture on a dataset
|
||||
def get_cost_info(self, index, dataset, use_12epochs_result=False):
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||
archresult = arch2infos[index]
|
||||
return archresult.get_comput_costs(dataset)
|
||||
|
||||
# obtain the metric for the `index`-th architecture
|
||||
def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
@ -177,6 +184,7 @@ class NASBench201API(object):
|
||||
total = train_info['iepoch'] + 1
|
||||
xifo = {'train-loss' : train_info['loss'],
|
||||
'train-accuracy': train_info['accuracy'],
|
||||
'train-per-time': None if train_info['all_time'] is None else train_info['all_time'] / total,
|
||||
'train-all-time': train_info['all_time'],
|
||||
'valid-loss' : valid_info['loss'],
|
||||
'valid-accuracy': valid_info['accuracy'],
|
||||
@ -188,21 +196,32 @@ class NASBench201API(object):
|
||||
return xifo
|
||||
else:
|
||||
train_info = archresult.get_metrics(dataset, 'train' , iepoch=iepoch, is_random=is_random)
|
||||
try:
|
||||
if dataset == 'cifar10':
|
||||
test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
else:
|
||||
test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
valid_info = None
|
||||
try:
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
valid_info = None
|
||||
try:
|
||||
est_valid_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
est_valid_info = None
|
||||
xifo = {'train-loss' : train_info['loss'],
|
||||
'train-accuracy': train_info['accuracy'],
|
||||
'test-loss' : test__info['loss'],
|
||||
'test-accuracy' : test__info['accuracy']}
|
||||
'train-accuracy': train_info['accuracy']}
|
||||
if valid_info is not None:
|
||||
xifo['test-loss'] = test__info['loss'],
|
||||
xifo['test-accuracy'] = test__info['accuracy']
|
||||
if valid_info is not None:
|
||||
xifo['valid-loss'] = valid_info['loss']
|
||||
xifo['valid-accuracy'] = valid_info['accuracy']
|
||||
if est_valid_info is not None:
|
||||
xifo['est-valid-loss'] = est_valid_info['loss']
|
||||
xifo['est-valid-accuracy'] = est_valid_info['accuracy']
|
||||
return xifo
|
||||
|
||||
def show(self, index=-1):
|
||||
|
@ -1,7 +1,8 @@
|
||||
#!/bin/bash
|
||||
echo script name: $0
|
||||
|
||||
lrs="0.01 0.02 0.1 0.2 0.5 1.0 1.5 2.0 2.5 3.0"
|
||||
#lrs="0.01 0.02 0.1 0.2 0.5 1.0 1.5 2.0 2.5 3.0"
|
||||
lrs="0.01 0.02 0.1 0.2 0.5"
|
||||
|
||||
for lr in ${lrs}
|
||||
do
|
||||
|
@ -1,11 +1,11 @@
|
||||
#!/bin/bash
|
||||
# Regularized Evolution for Image Classifier Architecture Search, AAAI 2019
|
||||
# bash ./scripts-search/algos/R-EA.sh -1
|
||||
# bash ./scripts-search/algos/R-EA.sh cifar10 3 -1
|
||||
echo script name: $0
|
||||
echo $# arguments
|
||||
if [ "$#" -ne 1 ] ;then
|
||||
if [ "$#" -ne 3 ] ;then
|
||||
echo "Input illegal number of parameters " $#
|
||||
echo "Need 1 parameters for seed"
|
||||
echo "Need 3 parameters for the-dataset-name, the-ea-sample-size and the-seed"
|
||||
exit 1
|
||||
fi
|
||||
if [ "$TORCH_HOME" = "" ]; then
|
||||
@ -15,14 +15,16 @@ else
|
||||
echo "TORCH_HOME : $TORCH_HOME"
|
||||
fi
|
||||
|
||||
dataset=cifar10
|
||||
seed=$1
|
||||
#dataset=cifar10
|
||||
dataset=$1
|
||||
sample_size=$2
|
||||
seed=$3
|
||||
channel=16
|
||||
num_cells=5
|
||||
max_nodes=4
|
||||
space=nas-bench-201
|
||||
|
||||
save_dir=./output/search-cell-${space}/R-EA-${dataset}
|
||||
save_dir=./output/search-cell-${space}/R-EA-${dataset}-SS${sample_size}
|
||||
|
||||
OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \
|
||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||
@ -30,5 +32,5 @@ OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \
|
||||
--search_space_name ${space} \
|
||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
||||
--time_budget 12000 \
|
||||
--ea_cycles 100 --ea_population 10 --ea_sample_size 3 --ea_fast_by_api 1 \
|
||||
--ea_cycles 200 --ea_population 10 --ea_sample_size ${sample_size} --ea_fast_by_api 1 \
|
||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||
|
Loading…
Reference in New Issue
Block a user