From db44e56fb6b756e03b3d04228fccb6b0fd1060c2 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 2 Jan 2020 16:49:16 +1100 Subject: [PATCH] update hp of BOHB --- exps/NAS-Bench-102/test-correlation.py | 12 +++++++++--- exps/NAS-Bench-102/visualize.py | 9 +++++---- exps/algos/BOHB.py | 8 +++++--- scripts-search/algos/BOHB.sh | 2 +- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/exps/NAS-Bench-102/test-correlation.py b/exps/NAS-Bench-102/test-correlation.py index 8a9cbe9..2cb6261 100644 --- a/exps/NAS-Bench-102/test-correlation.py +++ b/exps/NAS-Bench-102/test-correlation.py @@ -148,6 +148,7 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n api = meta_file else: api = API(str(meta_file)) + cifar10_currs = [] cifar10_valid = [] cifar10_test = [] cifar100_valid = [] @@ -156,6 +157,9 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n imagenet_valid = [] for idx, arch in enumerate(api): results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand) + cifar10_currs.append( results['valid-accuracy'] ) + # --->>>>> + results = api.get_more_info(idx, 'cifar10-valid' , None, False, is_rand) cifar10_valid.append( results['valid-accuracy'] ) results = api.get_more_info(idx, 'cifar10' , None, False, is_rand) cifar10_test.append( results['test-accuracy'] ) @@ -168,8 +172,8 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n def get_cor(A, B): return float(np.corrcoef(A, B)[0,1]) cors = [] - for basestr, xlist in zip(['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test]): - correlation = get_cor(cifar10_valid, xlist) + for basestr, xlist in zip(['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [cifar10_valid, cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test]): + correlation = get_cor(cifar10_currs, xlist) if need_print: print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, '012' if use_less_or_not else '200', basestr, correlation)) cors.append( correlation ) #print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist))) @@ -183,7 +187,8 @@ def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand): for i in tqdm(range(100)): x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False) corrs.append( x ) - xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] + #xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] + xstrs = ['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] correlations = np.array(corrs) print('------>>>>>>>> {:03d}/{:} >>>>>>>> ------'.format(test_epoch, '012' if use_less_or_not else '200')) for idx, xstr in enumerate(xstrs): @@ -213,5 +218,6 @@ if __name__ == '__main__': check_cor_for_bandit_v2(api, 24, False, True) check_cor_for_bandit_v2(api, 100, False, True) check_cor_for_bandit_v2(api, 150, False, True) + check_cor_for_bandit_v2(api, 175, False, True) check_cor_for_bandit_v2(api, 200, False, True) print('----') diff --git a/exps/NAS-Bench-102/visualize.py b/exps/NAS-Bench-102/visualize.py index e08d474..a41b53c 100644 --- a/exps/NAS-Bench-102/visualize.py +++ b/exps/NAS-Bench-102/visualize.py @@ -5,6 +5,7 @@ ################################################## import os, sys, time, argparse, collections from tqdm import tqdm +from collections import OrderedDict import numpy as np import torch import torch.nn as nn @@ -412,7 +413,7 @@ def plot_results_nas(api, dataset, xset, root, file_name, y_lims): def just_show(api): - xtimes = {'RSPS': [8082.5, 7794.2, 8144.7], + xtimes = {'RSPS' : [8082.5, 7794.2, 8144.7], 'DARTS-V1': [11582.1, 11347.0, 11948.2], 'DARTS-V2': [35694.7, 36132.7, 35518.0], 'GDAS' : [31334.1, 31478.6, 32016.7], @@ -420,7 +421,7 @@ def just_show(api): 'ENAS' : [14340.2, 13817.3, 14018.9]} for xkey, xlist in xtimes.items(): xlist = np.array(xlist) - print ('{:4s} : mean-time={:.1f} s'.format(xkey, xlist.mean())) + print ('{:4s} : mean-time={:.2f} s'.format(xkey, xlist.mean())) xpaths = {'RSPS' : 'output/search-cell-nas-bench-102/RANDOM-NAS-cifar10/checkpoint/', 'DARTS-V1': 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/', @@ -546,6 +547,7 @@ if __name__ == '__main__': #visualize_relative_ranking(vis_save_dir) api = API(args.api_path) + """ for x_maxs in [50, 250]: show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) @@ -553,12 +555,11 @@ if __name__ == '__main__': show_nas_sharing_w(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) show_nas_sharing_w(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) - """ just_show(api) + """ plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1)) plot_results_nas(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1)) plot_results_nas(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3)) plot_results_nas(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-com.pdf', (55,75, 3)) plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3)) plot_results_nas(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-com.pdf', (35,50, 3)) - """ diff --git a/exps/algos/BOHB.py b/exps/algos/BOHB.py index e39a2e8..244556e 100644 --- a/exps/algos/BOHB.py +++ b/exps/algos/BOHB.py @@ -184,7 +184,7 @@ def main(xargs, nas_bench): logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs))) logger.close() - return logger.log_dir, nas_bench.query_index_by_arch( best_arch ) + return logger.log_dir, nas_bench.query_index_by_arch( best_arch ), real_cost_time @@ -219,12 +219,14 @@ if __name__ == '__main__': print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset)) nas_bench = API(args.arch_nas_dataset) if args.rand_seed < 0: - save_dir, all_indexes, num = None, [], 500 + save_dir, all_indexes, num, all_times = None, [], 500, [] for i in range(num): print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num)) args.rand_seed = random.randint(1, 100000) - save_dir, index = main(args, nas_bench) + save_dir, index, ctime = main(args, nas_bench) all_indexes.append( index ) + all_times.append( ctime ) + print ('\n average time : {:.3f} s'.format(sum(all_times)/len(all_times))) torch.save(all_indexes, save_dir / 'results.pth') else: main(args, nas_bench) diff --git a/scripts-search/algos/BOHB.sh b/scripts-search/algos/BOHB.sh index dac73b3..124558f 100644 --- a/scripts-search/algos/BOHB.sh +++ b/scripts-search/algos/BOHB.sh @@ -29,5 +29,5 @@ OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \ --search_space_name ${space} \ --arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ --time_budget 12000 \ - --n_iters 28 --num_samples 64 --random_fraction .33 --bandwidth_factor 3 \ + --n_iters 50 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 \ --workers 4 --print_freq 200 --rand_seed ${seed}