update hp of BOHB

This commit is contained in:
D-X-Y 2020-01-02 16:49:16 +11:00
parent dd6cf5a9c5
commit db44e56fb6
4 changed files with 20 additions and 11 deletions

View File

@ -148,6 +148,7 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n
api = meta_file api = meta_file
else: else:
api = API(str(meta_file)) api = API(str(meta_file))
cifar10_currs = []
cifar10_valid = [] cifar10_valid = []
cifar10_test = [] cifar10_test = []
cifar100_valid = [] cifar100_valid = []
@ -156,6 +157,9 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n
imagenet_valid = [] imagenet_valid = []
for idx, arch in enumerate(api): for idx, arch in enumerate(api):
results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand) results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand)
cifar10_currs.append( results['valid-accuracy'] )
# --->>>>>
results = api.get_more_info(idx, 'cifar10-valid' , None, False, is_rand)
cifar10_valid.append( results['valid-accuracy'] ) cifar10_valid.append( results['valid-accuracy'] )
results = api.get_more_info(idx, 'cifar10' , None, False, is_rand) results = api.get_more_info(idx, 'cifar10' , None, False, is_rand)
cifar10_test.append( results['test-accuracy'] ) cifar10_test.append( results['test-accuracy'] )
@ -168,8 +172,8 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n
def get_cor(A, B): def get_cor(A, B):
return float(np.corrcoef(A, B)[0,1]) return float(np.corrcoef(A, B)[0,1])
cors = [] cors = []
for basestr, xlist in zip(['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test]): for basestr, xlist in zip(['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [cifar10_valid, cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test]):
correlation = get_cor(cifar10_valid, xlist) correlation = get_cor(cifar10_currs, xlist)
if need_print: print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, '012' if use_less_or_not else '200', basestr, correlation)) if need_print: print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, '012' if use_less_or_not else '200', basestr, correlation))
cors.append( correlation ) cors.append( correlation )
#print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist))) #print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist)))
@ -183,7 +187,8 @@ def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand):
for i in tqdm(range(100)): for i in tqdm(range(100)):
x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False) x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False)
corrs.append( x ) corrs.append( x )
xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] #xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
xstrs = ['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
correlations = np.array(corrs) correlations = np.array(corrs)
print('------>>>>>>>> {:03d}/{:} >>>>>>>> ------'.format(test_epoch, '012' if use_less_or_not else '200')) print('------>>>>>>>> {:03d}/{:} >>>>>>>> ------'.format(test_epoch, '012' if use_less_or_not else '200'))
for idx, xstr in enumerate(xstrs): for idx, xstr in enumerate(xstrs):
@ -213,5 +218,6 @@ if __name__ == '__main__':
check_cor_for_bandit_v2(api, 24, False, True) check_cor_for_bandit_v2(api, 24, False, True)
check_cor_for_bandit_v2(api, 100, False, True) check_cor_for_bandit_v2(api, 100, False, True)
check_cor_for_bandit_v2(api, 150, False, True) check_cor_for_bandit_v2(api, 150, False, True)
check_cor_for_bandit_v2(api, 175, False, True)
check_cor_for_bandit_v2(api, 200, False, True) check_cor_for_bandit_v2(api, 200, False, True)
print('----') print('----')

View File

@ -5,6 +5,7 @@
################################################## ##################################################
import os, sys, time, argparse, collections import os, sys, time, argparse, collections
from tqdm import tqdm from tqdm import tqdm
from collections import OrderedDict
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -412,7 +413,7 @@ def plot_results_nas(api, dataset, xset, root, file_name, y_lims):
def just_show(api): def just_show(api):
xtimes = {'RSPS': [8082.5, 7794.2, 8144.7], xtimes = {'RSPS' : [8082.5, 7794.2, 8144.7],
'DARTS-V1': [11582.1, 11347.0, 11948.2], 'DARTS-V1': [11582.1, 11347.0, 11948.2],
'DARTS-V2': [35694.7, 36132.7, 35518.0], 'DARTS-V2': [35694.7, 36132.7, 35518.0],
'GDAS' : [31334.1, 31478.6, 32016.7], 'GDAS' : [31334.1, 31478.6, 32016.7],
@ -420,7 +421,7 @@ def just_show(api):
'ENAS' : [14340.2, 13817.3, 14018.9]} 'ENAS' : [14340.2, 13817.3, 14018.9]}
for xkey, xlist in xtimes.items(): for xkey, xlist in xtimes.items():
xlist = np.array(xlist) xlist = np.array(xlist)
print ('{:4s} : mean-time={:.1f} s'.format(xkey, xlist.mean())) print ('{:4s} : mean-time={:.2f} s'.format(xkey, xlist.mean()))
xpaths = {'RSPS' : 'output/search-cell-nas-bench-102/RANDOM-NAS-cifar10/checkpoint/', xpaths = {'RSPS' : 'output/search-cell-nas-bench-102/RANDOM-NAS-cifar10/checkpoint/',
'DARTS-V1': 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/', 'DARTS-V1': 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/',
@ -546,6 +547,7 @@ if __name__ == '__main__':
#visualize_relative_ranking(vis_save_dir) #visualize_relative_ranking(vis_save_dir)
api = API(args.api_path) api = API(args.api_path)
"""
for x_maxs in [50, 250]: for x_maxs in [50, 250]:
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
@ -553,12 +555,11 @@ if __name__ == '__main__':
show_nas_sharing_w(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) show_nas_sharing_w(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
show_nas_sharing_w(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) show_nas_sharing_w(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
"""
just_show(api) just_show(api)
"""
plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1)) plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1))
plot_results_nas(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1)) plot_results_nas(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1))
plot_results_nas(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3)) plot_results_nas(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
plot_results_nas(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-com.pdf', (55,75, 3)) plot_results_nas(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3)) plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
plot_results_nas(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-com.pdf', (35,50, 3)) plot_results_nas(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
"""

View File

@ -184,7 +184,7 @@ def main(xargs, nas_bench):
logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs))) logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs)))
logger.close() logger.close()
return logger.log_dir, nas_bench.query_index_by_arch( best_arch ) return logger.log_dir, nas_bench.query_index_by_arch( best_arch ), real_cost_time
@ -219,12 +219,14 @@ if __name__ == '__main__':
print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset)) print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset))
nas_bench = API(args.arch_nas_dataset) nas_bench = API(args.arch_nas_dataset)
if args.rand_seed < 0: if args.rand_seed < 0:
save_dir, all_indexes, num = None, [], 500 save_dir, all_indexes, num, all_times = None, [], 500, []
for i in range(num): for i in range(num):
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num)) print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num))
args.rand_seed = random.randint(1, 100000) args.rand_seed = random.randint(1, 100000)
save_dir, index = main(args, nas_bench) save_dir, index, ctime = main(args, nas_bench)
all_indexes.append( index ) all_indexes.append( index )
all_times.append( ctime )
print ('\n average time : {:.3f} s'.format(sum(all_times)/len(all_times)))
torch.save(all_indexes, save_dir / 'results.pth') torch.save(all_indexes, save_dir / 'results.pth')
else: else:
main(args, nas_bench) main(args, nas_bench)

View File

@ -29,5 +29,5 @@ OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \
--search_space_name ${space} \ --search_space_name ${space} \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ --arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
--time_budget 12000 \ --time_budget 12000 \
--n_iters 28 --num_samples 64 --random_fraction .33 --bandwidth_factor 3 \ --n_iters 50 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 \
--workers 4 --print_freq 200 --rand_seed ${seed} --workers 4 --print_freq 200 --rand_seed ${seed}