update hp of BOHB
This commit is contained in:
parent
dd6cf5a9c5
commit
db44e56fb6
@ -148,6 +148,7 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n
|
|||||||
api = meta_file
|
api = meta_file
|
||||||
else:
|
else:
|
||||||
api = API(str(meta_file))
|
api = API(str(meta_file))
|
||||||
|
cifar10_currs = []
|
||||||
cifar10_valid = []
|
cifar10_valid = []
|
||||||
cifar10_test = []
|
cifar10_test = []
|
||||||
cifar100_valid = []
|
cifar100_valid = []
|
||||||
@ -156,6 +157,9 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n
|
|||||||
imagenet_valid = []
|
imagenet_valid = []
|
||||||
for idx, arch in enumerate(api):
|
for idx, arch in enumerate(api):
|
||||||
results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand)
|
results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand)
|
||||||
|
cifar10_currs.append( results['valid-accuracy'] )
|
||||||
|
# --->>>>>
|
||||||
|
results = api.get_more_info(idx, 'cifar10-valid' , None, False, is_rand)
|
||||||
cifar10_valid.append( results['valid-accuracy'] )
|
cifar10_valid.append( results['valid-accuracy'] )
|
||||||
results = api.get_more_info(idx, 'cifar10' , None, False, is_rand)
|
results = api.get_more_info(idx, 'cifar10' , None, False, is_rand)
|
||||||
cifar10_test.append( results['test-accuracy'] )
|
cifar10_test.append( results['test-accuracy'] )
|
||||||
@ -168,8 +172,8 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n
|
|||||||
def get_cor(A, B):
|
def get_cor(A, B):
|
||||||
return float(np.corrcoef(A, B)[0,1])
|
return float(np.corrcoef(A, B)[0,1])
|
||||||
cors = []
|
cors = []
|
||||||
for basestr, xlist in zip(['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test]):
|
for basestr, xlist in zip(['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [cifar10_valid, cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test]):
|
||||||
correlation = get_cor(cifar10_valid, xlist)
|
correlation = get_cor(cifar10_currs, xlist)
|
||||||
if need_print: print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, '012' if use_less_or_not else '200', basestr, correlation))
|
if need_print: print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, '012' if use_less_or_not else '200', basestr, correlation))
|
||||||
cors.append( correlation )
|
cors.append( correlation )
|
||||||
#print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist)))
|
#print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist)))
|
||||||
@ -183,7 +187,8 @@ def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand):
|
|||||||
for i in tqdm(range(100)):
|
for i in tqdm(range(100)):
|
||||||
x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False)
|
x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False)
|
||||||
corrs.append( x )
|
corrs.append( x )
|
||||||
xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
|
#xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
|
||||||
|
xstrs = ['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
|
||||||
correlations = np.array(corrs)
|
correlations = np.array(corrs)
|
||||||
print('------>>>>>>>> {:03d}/{:} >>>>>>>> ------'.format(test_epoch, '012' if use_less_or_not else '200'))
|
print('------>>>>>>>> {:03d}/{:} >>>>>>>> ------'.format(test_epoch, '012' if use_less_or_not else '200'))
|
||||||
for idx, xstr in enumerate(xstrs):
|
for idx, xstr in enumerate(xstrs):
|
||||||
@ -213,5 +218,6 @@ if __name__ == '__main__':
|
|||||||
check_cor_for_bandit_v2(api, 24, False, True)
|
check_cor_for_bandit_v2(api, 24, False, True)
|
||||||
check_cor_for_bandit_v2(api, 100, False, True)
|
check_cor_for_bandit_v2(api, 100, False, True)
|
||||||
check_cor_for_bandit_v2(api, 150, False, True)
|
check_cor_for_bandit_v2(api, 150, False, True)
|
||||||
|
check_cor_for_bandit_v2(api, 175, False, True)
|
||||||
check_cor_for_bandit_v2(api, 200, False, True)
|
check_cor_for_bandit_v2(api, 200, False, True)
|
||||||
print('----')
|
print('----')
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
##################################################
|
##################################################
|
||||||
import os, sys, time, argparse, collections
|
import os, sys, time, argparse, collections
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from collections import OrderedDict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -412,7 +413,7 @@ def plot_results_nas(api, dataset, xset, root, file_name, y_lims):
|
|||||||
|
|
||||||
|
|
||||||
def just_show(api):
|
def just_show(api):
|
||||||
xtimes = {'RSPS': [8082.5, 7794.2, 8144.7],
|
xtimes = {'RSPS' : [8082.5, 7794.2, 8144.7],
|
||||||
'DARTS-V1': [11582.1, 11347.0, 11948.2],
|
'DARTS-V1': [11582.1, 11347.0, 11948.2],
|
||||||
'DARTS-V2': [35694.7, 36132.7, 35518.0],
|
'DARTS-V2': [35694.7, 36132.7, 35518.0],
|
||||||
'GDAS' : [31334.1, 31478.6, 32016.7],
|
'GDAS' : [31334.1, 31478.6, 32016.7],
|
||||||
@ -420,7 +421,7 @@ def just_show(api):
|
|||||||
'ENAS' : [14340.2, 13817.3, 14018.9]}
|
'ENAS' : [14340.2, 13817.3, 14018.9]}
|
||||||
for xkey, xlist in xtimes.items():
|
for xkey, xlist in xtimes.items():
|
||||||
xlist = np.array(xlist)
|
xlist = np.array(xlist)
|
||||||
print ('{:4s} : mean-time={:.1f} s'.format(xkey, xlist.mean()))
|
print ('{:4s} : mean-time={:.2f} s'.format(xkey, xlist.mean()))
|
||||||
|
|
||||||
xpaths = {'RSPS' : 'output/search-cell-nas-bench-102/RANDOM-NAS-cifar10/checkpoint/',
|
xpaths = {'RSPS' : 'output/search-cell-nas-bench-102/RANDOM-NAS-cifar10/checkpoint/',
|
||||||
'DARTS-V1': 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/',
|
'DARTS-V1': 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/',
|
||||||
@ -546,6 +547,7 @@ if __name__ == '__main__':
|
|||||||
#visualize_relative_ranking(vis_save_dir)
|
#visualize_relative_ranking(vis_save_dir)
|
||||||
|
|
||||||
api = API(args.api_path)
|
api = API(args.api_path)
|
||||||
|
"""
|
||||||
for x_maxs in [50, 250]:
|
for x_maxs in [50, 250]:
|
||||||
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||||
show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||||
@ -553,12 +555,11 @@ if __name__ == '__main__':
|
|||||||
show_nas_sharing_w(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
show_nas_sharing_w(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||||
show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||||
show_nas_sharing_w(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
show_nas_sharing_w(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||||
"""
|
|
||||||
just_show(api)
|
just_show(api)
|
||||||
|
"""
|
||||||
plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1))
|
plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1))
|
||||||
plot_results_nas(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1))
|
plot_results_nas(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1))
|
||||||
plot_results_nas(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
|
plot_results_nas(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
|
||||||
plot_results_nas(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
|
plot_results_nas(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
|
||||||
plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
|
plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
|
||||||
plot_results_nas(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
|
plot_results_nas(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
|
||||||
"""
|
|
||||||
|
@ -184,7 +184,7 @@ def main(xargs, nas_bench):
|
|||||||
|
|
||||||
logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs)))
|
logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs)))
|
||||||
logger.close()
|
logger.close()
|
||||||
return logger.log_dir, nas_bench.query_index_by_arch( best_arch )
|
return logger.log_dir, nas_bench.query_index_by_arch( best_arch ), real_cost_time
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -219,12 +219,14 @@ if __name__ == '__main__':
|
|||||||
print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset))
|
print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset))
|
||||||
nas_bench = API(args.arch_nas_dataset)
|
nas_bench = API(args.arch_nas_dataset)
|
||||||
if args.rand_seed < 0:
|
if args.rand_seed < 0:
|
||||||
save_dir, all_indexes, num = None, [], 500
|
save_dir, all_indexes, num, all_times = None, [], 500, []
|
||||||
for i in range(num):
|
for i in range(num):
|
||||||
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num))
|
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num))
|
||||||
args.rand_seed = random.randint(1, 100000)
|
args.rand_seed = random.randint(1, 100000)
|
||||||
save_dir, index = main(args, nas_bench)
|
save_dir, index, ctime = main(args, nas_bench)
|
||||||
all_indexes.append( index )
|
all_indexes.append( index )
|
||||||
|
all_times.append( ctime )
|
||||||
|
print ('\n average time : {:.3f} s'.format(sum(all_times)/len(all_times)))
|
||||||
torch.save(all_indexes, save_dir / 'results.pth')
|
torch.save(all_indexes, save_dir / 'results.pth')
|
||||||
else:
|
else:
|
||||||
main(args, nas_bench)
|
main(args, nas_bench)
|
||||||
|
@ -29,5 +29,5 @@ OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \
|
|||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
||||||
--time_budget 12000 \
|
--time_budget 12000 \
|
||||||
--n_iters 28 --num_samples 64 --random_fraction .33 --bandwidth_factor 3 \
|
--n_iters 50 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 \
|
||||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||||
|
Loading…
Reference in New Issue
Block a user