update README
This commit is contained in:
		| @@ -2,6 +2,7 @@ | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| # required to install hpbandster ################# | ||||
| # bash ./scripts-search/algos/BOHB.sh -1         # | ||||
| ################################################## | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np, collections | ||||
| @@ -19,7 +20,6 @@ from utils        import get_model_infos, obtain_accuracy | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from nas_102_api  import NASBench102API as API | ||||
| from models       import CellStructure, get_search_spaces | ||||
| from R_EA import train_and_eval | ||||
| # BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018 | ||||
| import ConfigSpace | ||||
| from hpbandster.optimizers.bohb import BOHB | ||||
| @@ -53,21 +53,44 @@ def config2structure_func(max_nodes): | ||||
|  | ||||
| class MyWorker(Worker): | ||||
|  | ||||
|   def __init__(self, *args, sleep_interval=0, convert_func=None, nas_bench=None, **kwargs): | ||||
|   def __init__(self, *args, convert_func=None, nas_bench=None, time_scale=None, **kwargs): | ||||
|     super().__init__(*args, **kwargs) | ||||
|     self.sleep_interval = sleep_interval | ||||
|     self.convert_func   = convert_func | ||||
|     self.nas_bench      = nas_bench | ||||
|     self.test_time      = 0 | ||||
|     self.time_scale     = time_scale | ||||
|     self.seen_arch      = 0 | ||||
|     self.sim_cost_time  = 0 | ||||
|     self.real_cost_time = 0 | ||||
|  | ||||
|   def compute(self, config, budget, **kwargs): | ||||
|     structure = self.convert_func( config ) | ||||
|     reward, time_cost = train_and_eval(structure, self.nas_bench, None) | ||||
|     import pdb; pdb.set_trace() | ||||
|     self.test_time += 1 | ||||
|     start_time = time.time() | ||||
|     structure  = self.convert_func( config ) | ||||
|     arch_index = self.nas_bench.query_index_by_arch( structure ) | ||||
|     iepoch     = 0 | ||||
|     while iepoch < 12: | ||||
|       info     = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch, True) | ||||
|       cur_time = info['train-all-time'] + info['valid-per-time'] | ||||
|       cur_vacc = info['valid-accuracy'] | ||||
|       if time.time() - start_time + cur_time / self.time_scale > budget: | ||||
|         break | ||||
|       else: | ||||
|         iepoch += 1 | ||||
|     self.sim_cost_time += cur_time | ||||
|     self.seen_arch += 1 | ||||
|     remaining_time = cur_time / self.time_scale - (time.time() - start_time) | ||||
|     if remaining_time > 0: | ||||
|       time.sleep(remaining_time) | ||||
|     else: | ||||
|       import pdb; pdb.set_trace() | ||||
|     self.real_cost_time += (time.time() - start_time) | ||||
|     return ({ | ||||
|             'loss': float(100-reward), | ||||
|             'info': time_cost}) | ||||
|             'loss': 100 - float(cur_vacc), | ||||
|             'info': {'seen-arch'     : self.seen_arch, | ||||
|                      'sim-test-time' : self.sim_cost_time, | ||||
|                      'real-test-time': self.real_cost_time, | ||||
|                      'current-arch'  : arch_index, | ||||
|                      'current-budget': budget} | ||||
|             }) | ||||
|  | ||||
|  | ||||
| def main(xargs, nas_bench): | ||||
| @@ -116,26 +139,30 @@ def main(xargs, nas_bench): | ||||
|   #logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string())) | ||||
|   workers = [] | ||||
|   for i in range(num_workers): | ||||
|     w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, run_id=hb_run_id, id=i) | ||||
|     w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_scale=xargs.time_scale, run_id=hb_run_id, id=i) | ||||
|     w.run(background=True) | ||||
|     workers.append(w) | ||||
|  | ||||
|   simulate_time_budge = xargs.time_budget // xargs.time_scale | ||||
|   start_time = time.time() | ||||
|   logger.log('simulate_time_budge : {:} (in seconds).'.format(simulate_time_budge)) | ||||
|   bohb = BOHB(configspace=cs, | ||||
|             run_id=hb_run_id, | ||||
|             eta=3, min_budget=3, max_budget=xargs.time_budget, | ||||
|             eta=3, min_budget=simulate_time_budge//3, max_budget=simulate_time_budge, | ||||
|             nameserver=ns_host, | ||||
|             nameserver_port=ns_port, | ||||
|             num_samples=xargs.num_samples, | ||||
|             random_fraction=xargs.random_fraction, bandwidth_factor=xargs.bandwidth_factor, | ||||
|             ping_interval=10, min_bandwidth=xargs.min_bandwidth) | ||||
|   #          optimization_strategy=xargs.strategy, num_samples=xargs.num_samples, | ||||
|    | ||||
|   results = bohb.run(xargs.n_iters, min_n_workers=num_workers) | ||||
|   import pdb; pdb.set_trace() | ||||
|  | ||||
|   bohb.shutdown(shutdown_workers=True) | ||||
|   NS.shutdown() | ||||
|  | ||||
|   real_cost_time = time.time() - start_time | ||||
|   import pdb; pdb.set_trace() | ||||
|  | ||||
|   id2config = results.get_id2config_mapping() | ||||
|   incumbent = results.get_incumbent_id() | ||||
|  | ||||
| @@ -163,6 +190,7 @@ if __name__ == '__main__': | ||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||
|   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') | ||||
|   parser.add_argument('--time_budget',        type=int,   help='The total time cost budge for searching (in seconds).') | ||||
|   parser.add_argument('--time_scale' ,        type=int,   help='The time scale to accelerate the time budget.') | ||||
|   # BOHB | ||||
|   parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function') | ||||
|   parser.add_argument('--min_bandwidth',    default=.3, type=float, nargs='?', help='minimum bandwidth for KDE') | ||||
|   | ||||
| @@ -59,7 +59,7 @@ def train_and_eval(arch, nas_bench, extra_info): | ||||
|   if nas_bench is not None: | ||||
|     arch_index = nas_bench.query_index_by_arch( arch ) | ||||
|     assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) | ||||
|     info = nas_bench.get_more_info(arch_index, 'cifar10-valid', True) | ||||
|     info = nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True) | ||||
|     valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] | ||||
|     #_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs | ||||
|   else: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user