update hp of BOHB
This commit is contained in:
		| @@ -148,6 +148,7 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n | ||||
|     api = meta_file | ||||
|   else: | ||||
|     api = API(str(meta_file)) | ||||
|   cifar10_currs     = [] | ||||
|   cifar10_valid     = [] | ||||
|   cifar10_test      = [] | ||||
|   cifar100_valid    = [] | ||||
| @@ -156,6 +157,9 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n | ||||
|   imagenet_valid    = [] | ||||
|   for idx, arch in enumerate(api): | ||||
|     results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand) | ||||
|     cifar10_currs.append( results['valid-accuracy'] ) | ||||
|     # --->>>>> | ||||
|     results = api.get_more_info(idx, 'cifar10-valid' , None, False, is_rand) | ||||
|     cifar10_valid.append( results['valid-accuracy'] ) | ||||
|     results = api.get_more_info(idx, 'cifar10'       , None, False, is_rand) | ||||
|     cifar10_test.append( results['test-accuracy'] ) | ||||
| @@ -168,8 +172,8 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n | ||||
|   def get_cor(A, B): | ||||
|     return float(np.corrcoef(A, B)[0,1]) | ||||
|   cors = [] | ||||
|   for basestr, xlist in zip(['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test]): | ||||
|     correlation = get_cor(cifar10_valid, xlist) | ||||
|   for basestr, xlist in zip(['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [cifar10_valid, cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test]): | ||||
|     correlation = get_cor(cifar10_currs, xlist) | ||||
|     if need_print: print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, '012' if use_less_or_not else '200', basestr, correlation)) | ||||
|     cors.append( correlation ) | ||||
|     #print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist))) | ||||
| @@ -183,7 +187,8 @@ def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand): | ||||
|   for i in tqdm(range(100)): | ||||
|     x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False) | ||||
|     corrs.append( x ) | ||||
|   xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] | ||||
|   #xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] | ||||
|   xstrs = ['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] | ||||
|   correlations = np.array(corrs) | ||||
|   print('------>>>>>>>> {:03d}/{:} >>>>>>>> ------'.format(test_epoch, '012' if use_less_or_not else '200')) | ||||
|   for idx, xstr in enumerate(xstrs): | ||||
| @@ -213,5 +218,6 @@ if __name__ == '__main__': | ||||
|   check_cor_for_bandit_v2(api,  24, False, True) | ||||
|   check_cor_for_bandit_v2(api, 100, False, True) | ||||
|   check_cor_for_bandit_v2(api, 150, False, True) | ||||
|   check_cor_for_bandit_v2(api, 175, False, True) | ||||
|   check_cor_for_bandit_v2(api, 200, False, True) | ||||
|   print('----') | ||||
|   | ||||
| @@ -5,6 +5,7 @@ | ||||
| ################################################## | ||||
| import os, sys, time, argparse, collections | ||||
| from tqdm import tqdm | ||||
| from collections import OrderedDict | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| @@ -412,7 +413,7 @@ def plot_results_nas(api, dataset, xset, root, file_name, y_lims): | ||||
|  | ||||
|  | ||||
| def just_show(api): | ||||
|   xtimes = {'RSPS': [8082.5, 7794.2, 8144.7], | ||||
|   xtimes = {'RSPS'    : [8082.5, 7794.2, 8144.7], | ||||
|             'DARTS-V1': [11582.1, 11347.0, 11948.2], | ||||
|             'DARTS-V2': [35694.7, 36132.7, 35518.0], | ||||
|             'GDAS'    : [31334.1, 31478.6, 32016.7], | ||||
| @@ -420,7 +421,7 @@ def just_show(api): | ||||
|             'ENAS'    : [14340.2, 13817.3, 14018.9]} | ||||
|   for xkey, xlist in xtimes.items(): | ||||
|     xlist = np.array(xlist) | ||||
|     print ('{:4s} : mean-time={:.1f} s'.format(xkey, xlist.mean())) | ||||
|     print ('{:4s} : mean-time={:.2f} s'.format(xkey, xlist.mean())) | ||||
|  | ||||
|   xpaths = {'RSPS'    : 'output/search-cell-nas-bench-102/RANDOM-NAS-cifar10/checkpoint/', | ||||
|             'DARTS-V1': 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/', | ||||
| @@ -546,6 +547,7 @@ if __name__ == '__main__': | ||||
|   #visualize_relative_ranking(vis_save_dir) | ||||
|  | ||||
|   api = API(args.api_path) | ||||
|   """ | ||||
|   for x_maxs in [50, 250]: | ||||
|     show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||
|     show_nas_sharing_w(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||
| @@ -553,12 +555,11 @@ if __name__ == '__main__': | ||||
|     show_nas_sharing_w(api, 'cifar100'      , 'x-test'  , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||
|     show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||
|     show_nas_sharing_w(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||
|   """ | ||||
|   just_show(api) | ||||
|   """ | ||||
|   plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1)) | ||||
|   plot_results_nas(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1)) | ||||
|   plot_results_nas(api, 'cifar100'      , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3)) | ||||
|   plot_results_nas(api, 'cifar100'      , 'x-test'  , vis_save_dir, 'nas-com.pdf', (55,75, 3)) | ||||
|   plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3)) | ||||
|   plot_results_nas(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-com.pdf', (35,50, 3)) | ||||
|   """ | ||||
|   | ||||
		Reference in New Issue
	
	Block a user