Update new version of BOHB
This commit is contained in:
		| @@ -6,6 +6,7 @@ | |||||||
| # pip install hpbandster         ################################## | # pip install hpbandster         ################################## | ||||||
| ################################################################### | ################################################################### | ||||||
| # OMP_NUM_THREADS=4 python exps/algos-v2/bohb.py --search_space tss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1 | # OMP_NUM_THREADS=4 python exps/algos-v2/bohb.py --search_space tss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1 | ||||||
|  | # OMP_NUM_THREADS=4 python exps/algos-v2/bohb.py --search_space sss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1 | ||||||
| ################################################################### | ################################################################### | ||||||
| import os, sys, time, random, argparse, collections | import os, sys, time, random, argparse, collections | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| @@ -38,12 +39,9 @@ def get_topology_config_space(search_space, max_nodes=4): | |||||||
|  |  | ||||||
| def get_size_config_space(search_space): | def get_size_config_space(search_space): | ||||||
|   cs = ConfigSpace.ConfigurationSpace() |   cs = ConfigSpace.ConfigurationSpace() | ||||||
|   import pdb; pdb.set_trace() |   for ilayer in range(search_space['numbers']): | ||||||
|   #edge2index   = {} |     node_str = 'layer-{:}'.format(ilayer) | ||||||
|   for i in range(1, max_nodes): |     cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space['candidates'])) | ||||||
|     for j in range(i): |  | ||||||
|       node_str = '{:}<-{:}'.format(i, j) |  | ||||||
|       cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space)) |  | ||||||
|   return cs |   return cs | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -61,6 +59,16 @@ def config2topology_func(max_nodes=4): | |||||||
|   return config2structure |   return config2structure | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def config2size_func(search_space): | ||||||
|  |   def config2structure(config): | ||||||
|  |     channels = [] | ||||||
|  |     for ilayer in range(search_space['numbers']): | ||||||
|  |       node_str = 'layer-{:}'.format(ilayer) | ||||||
|  |       channels.append(str(config[node_str])) | ||||||
|  |     return ':'.join(channels) | ||||||
|  |   return config2structure | ||||||
|  |  | ||||||
|  |  | ||||||
| class MyWorker(Worker): | class MyWorker(Worker): | ||||||
|  |  | ||||||
|   def __init__(self, *args, convert_func=None, dataset=None, api=None, **kwargs): |   def __init__(self, *args, convert_func=None, dataset=None, api=None, **kwargs): | ||||||
| @@ -93,7 +101,7 @@ def main(xargs, api): | |||||||
|     config2structure = config2topology_func() |     config2structure = config2topology_func() | ||||||
|   else: |   else: | ||||||
|     cs = get_size_config_space(search_space) |     cs = get_size_config_space(search_space) | ||||||
|     import pdb; pdb.set_trace() |     config2structure = config2size_func(search_space) | ||||||
|    |    | ||||||
|   hb_run_id = '0' |   hb_run_id = '0' | ||||||
|  |  | ||||||
|   | |||||||
| @@ -17,3 +17,6 @@ do | |||||||
|     python exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 |     python exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 | ||||||
|   done |   done | ||||||
| done | done | ||||||
|  |  | ||||||
|  | python exps/experimental/vis-bench-algos.py --search_space tss | ||||||
|  | python exps/experimental/vis-bench-algos.py --search_space sss | ||||||
|   | |||||||
| @@ -3,7 +3,8 @@ | |||||||
| ############################################################### | ############################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06           # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06           # | ||||||
| ############################################################### | ############################################################### | ||||||
| # Usage: python exps/experimental/vis-bench-algos.py          # | # Usage: python exps/experimental/vis-bench-algos.py --search_space tss | ||||||
|  | # Usage: python exps/experimental/vis-bench-algos.py --search_space sss | ||||||
| ############################################################### | ############################################################### | ||||||
| import os, gc, sys, time, torch, argparse | import os, gc, sys, time, torch, argparse | ||||||
| import numpy as np | import numpy as np | ||||||
| @@ -116,14 +117,16 @@ def visualize_curve(api, vis_save_dir, search_space, max_time): | |||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |   parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||||
|   parser.add_argument('--save_dir',     type=str,   default='output/vis-nas-bench/nas-algos', help='Folder to save checkpoints and log.') |   parser.add_argument('--save_dir',     type=str,   default='output/vis-nas-bench/nas-algos', help='Folder to save checkpoints and log.') | ||||||
|  |   parser.add_argument('--search_space', type=str,   choices=['tss', 'sss'], help='Choose the search space.') | ||||||
|   parser.add_argument('--max_time',     type=float, default=20000, help='The maximum time budget.') |   parser.add_argument('--max_time',     type=float, default=20000, help='The maximum time budget.') | ||||||
|   args = parser.parse_args() |   args = parser.parse_args() | ||||||
|  |  | ||||||
|   save_dir = Path(args.save_dir) |   save_dir = Path(args.save_dir) | ||||||
|  |  | ||||||
|   api201 = NASBench201API(verbose=False) |   if args.search_space == 'tss': | ||||||
|   visualize_curve(api201, save_dir, 'tss', args.max_time) |     api = NASBench201API(verbose=False) | ||||||
|   del api201 |   elif args.search_space == 'sss': | ||||||
|   gc.collect() |     api = NASBench301API(verbose=False) | ||||||
|   api301 = NASBench301API(verbose=False) |   else: | ||||||
|   visualize_curve(api301, save_dir, 'sss', args.max_time) |     raise ValueError('Invalid search space : {:}'.format(args.search_space)) | ||||||
|  |   visualize_curve(api, save_dir, args.search_space, args.max_time) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user