update hp of BOHB
This commit is contained in:
		| @@ -148,6 +148,7 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n | |||||||
|     api = meta_file |     api = meta_file | ||||||
|   else: |   else: | ||||||
|     api = API(str(meta_file)) |     api = API(str(meta_file)) | ||||||
|  |   cifar10_currs     = [] | ||||||
|   cifar10_valid     = [] |   cifar10_valid     = [] | ||||||
|   cifar10_test      = [] |   cifar10_test      = [] | ||||||
|   cifar100_valid    = [] |   cifar100_valid    = [] | ||||||
| @@ -156,6 +157,9 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n | |||||||
|   imagenet_valid    = [] |   imagenet_valid    = [] | ||||||
|   for idx, arch in enumerate(api): |   for idx, arch in enumerate(api): | ||||||
|     results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand) |     results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand) | ||||||
|  |     cifar10_currs.append( results['valid-accuracy'] ) | ||||||
|  |     # --->>>>> | ||||||
|  |     results = api.get_more_info(idx, 'cifar10-valid' , None, False, is_rand) | ||||||
|     cifar10_valid.append( results['valid-accuracy'] ) |     cifar10_valid.append( results['valid-accuracy'] ) | ||||||
|     results = api.get_more_info(idx, 'cifar10'       , None, False, is_rand) |     results = api.get_more_info(idx, 'cifar10'       , None, False, is_rand) | ||||||
|     cifar10_test.append( results['test-accuracy'] ) |     cifar10_test.append( results['test-accuracy'] ) | ||||||
| @@ -168,8 +172,8 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n | |||||||
|   def get_cor(A, B): |   def get_cor(A, B): | ||||||
|     return float(np.corrcoef(A, B)[0,1]) |     return float(np.corrcoef(A, B)[0,1]) | ||||||
|   cors = [] |   cors = [] | ||||||
|   for basestr, xlist in zip(['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test]): |   for basestr, xlist in zip(['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [cifar10_valid, cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test]): | ||||||
|     correlation = get_cor(cifar10_valid, xlist) |     correlation = get_cor(cifar10_currs, xlist) | ||||||
|     if need_print: print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, '012' if use_less_or_not else '200', basestr, correlation)) |     if need_print: print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, '012' if use_less_or_not else '200', basestr, correlation)) | ||||||
|     cors.append( correlation ) |     cors.append( correlation ) | ||||||
|     #print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist))) |     #print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist))) | ||||||
| @@ -183,7 +187,8 @@ def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand): | |||||||
|   for i in tqdm(range(100)): |   for i in tqdm(range(100)): | ||||||
|     x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False) |     x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False) | ||||||
|     corrs.append( x ) |     corrs.append( x ) | ||||||
|   xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] |   #xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] | ||||||
|  |   xstrs = ['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] | ||||||
|   correlations = np.array(corrs) |   correlations = np.array(corrs) | ||||||
|   print('------>>>>>>>> {:03d}/{:} >>>>>>>> ------'.format(test_epoch, '012' if use_less_or_not else '200')) |   print('------>>>>>>>> {:03d}/{:} >>>>>>>> ------'.format(test_epoch, '012' if use_less_or_not else '200')) | ||||||
|   for idx, xstr in enumerate(xstrs): |   for idx, xstr in enumerate(xstrs): | ||||||
| @@ -213,5 +218,6 @@ if __name__ == '__main__': | |||||||
|   check_cor_for_bandit_v2(api,  24, False, True) |   check_cor_for_bandit_v2(api,  24, False, True) | ||||||
|   check_cor_for_bandit_v2(api, 100, False, True) |   check_cor_for_bandit_v2(api, 100, False, True) | ||||||
|   check_cor_for_bandit_v2(api, 150, False, True) |   check_cor_for_bandit_v2(api, 150, False, True) | ||||||
|  |   check_cor_for_bandit_v2(api, 175, False, True) | ||||||
|   check_cor_for_bandit_v2(api, 200, False, True) |   check_cor_for_bandit_v2(api, 200, False, True) | ||||||
|   print('----') |   print('----') | ||||||
|   | |||||||
| @@ -5,6 +5,7 @@ | |||||||
| ################################################## | ################################################## | ||||||
| import os, sys, time, argparse, collections | import os, sys, time, argparse, collections | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
|  | from collections import OrderedDict | ||||||
| import numpy as np | import numpy as np | ||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| @@ -420,7 +421,7 @@ def just_show(api): | |||||||
|             'ENAS'    : [14340.2, 13817.3, 14018.9]} |             'ENAS'    : [14340.2, 13817.3, 14018.9]} | ||||||
|   for xkey, xlist in xtimes.items(): |   for xkey, xlist in xtimes.items(): | ||||||
|     xlist = np.array(xlist) |     xlist = np.array(xlist) | ||||||
|     print ('{:4s} : mean-time={:.1f} s'.format(xkey, xlist.mean())) |     print ('{:4s} : mean-time={:.2f} s'.format(xkey, xlist.mean())) | ||||||
|  |  | ||||||
|   xpaths = {'RSPS'    : 'output/search-cell-nas-bench-102/RANDOM-NAS-cifar10/checkpoint/', |   xpaths = {'RSPS'    : 'output/search-cell-nas-bench-102/RANDOM-NAS-cifar10/checkpoint/', | ||||||
|             'DARTS-V1': 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/', |             'DARTS-V1': 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/', | ||||||
| @@ -546,6 +547,7 @@ if __name__ == '__main__': | |||||||
|   #visualize_relative_ranking(vis_save_dir) |   #visualize_relative_ranking(vis_save_dir) | ||||||
|  |  | ||||||
|   api = API(args.api_path) |   api = API(args.api_path) | ||||||
|  |   """ | ||||||
|   for x_maxs in [50, 250]: |   for x_maxs in [50, 250]: | ||||||
|     show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) |     show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||||
|     show_nas_sharing_w(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) |     show_nas_sharing_w(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||||
| @@ -553,12 +555,11 @@ if __name__ == '__main__': | |||||||
|     show_nas_sharing_w(api, 'cifar100'      , 'x-test'  , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) |     show_nas_sharing_w(api, 'cifar100'      , 'x-test'  , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||||
|     show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) |     show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||||
|     show_nas_sharing_w(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) |     show_nas_sharing_w(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||||
|   """ |  | ||||||
|   just_show(api) |   just_show(api) | ||||||
|  |   """ | ||||||
|   plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1)) |   plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1)) | ||||||
|   plot_results_nas(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1)) |   plot_results_nas(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1)) | ||||||
|   plot_results_nas(api, 'cifar100'      , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3)) |   plot_results_nas(api, 'cifar100'      , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3)) | ||||||
|   plot_results_nas(api, 'cifar100'      , 'x-test'  , vis_save_dir, 'nas-com.pdf', (55,75, 3)) |   plot_results_nas(api, 'cifar100'      , 'x-test'  , vis_save_dir, 'nas-com.pdf', (55,75, 3)) | ||||||
|   plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3)) |   plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3)) | ||||||
|   plot_results_nas(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-com.pdf', (35,50, 3)) |   plot_results_nas(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-com.pdf', (35,50, 3)) | ||||||
|   """ |  | ||||||
|   | |||||||
| @@ -184,7 +184,7 @@ def main(xargs, nas_bench): | |||||||
|  |  | ||||||
|   logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs))) |   logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs))) | ||||||
|   logger.close() |   logger.close() | ||||||
|   return logger.log_dir, nas_bench.query_index_by_arch( best_arch ) |   return logger.log_dir, nas_bench.query_index_by_arch( best_arch ), real_cost_time | ||||||
|    |    | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -219,12 +219,14 @@ if __name__ == '__main__': | |||||||
|     print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset)) |     print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset)) | ||||||
|     nas_bench = API(args.arch_nas_dataset) |     nas_bench = API(args.arch_nas_dataset) | ||||||
|   if args.rand_seed < 0: |   if args.rand_seed < 0: | ||||||
|     save_dir, all_indexes, num = None, [], 500 |     save_dir, all_indexes, num, all_times = None, [], 500, [] | ||||||
|     for i in range(num): |     for i in range(num): | ||||||
|       print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num)) |       print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num)) | ||||||
|       args.rand_seed = random.randint(1, 100000) |       args.rand_seed = random.randint(1, 100000) | ||||||
|       save_dir, index = main(args, nas_bench) |       save_dir, index, ctime = main(args, nas_bench) | ||||||
|       all_indexes.append( index )  |       all_indexes.append( index )  | ||||||
|  |       all_times.append( ctime ) | ||||||
|  |     print ('\n average time : {:.3f} s'.format(sum(all_times)/len(all_times))) | ||||||
|     torch.save(all_indexes, save_dir / 'results.pth') |     torch.save(all_indexes, save_dir / 'results.pth') | ||||||
|   else: |   else: | ||||||
|     main(args, nas_bench) |     main(args, nas_bench) | ||||||
|   | |||||||
| @@ -29,5 +29,5 @@ OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \ | |||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | ||||||
| 	--time_budget 12000  \ | 	--time_budget 12000  \ | ||||||
| 	--n_iters 28 --num_samples 64 --random_fraction .33 --bandwidth_factor 3 \ | 	--n_iters 50 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 \ | ||||||
| 	--workers 4 --print_freq 200 --rand_seed ${seed} | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user