Update REA, REINFORCE, RANDOM, and BOHB
This commit is contained in:
		
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -123,3 +123,5 @@ scripts-search/l2s-algos | ||||
| TEMP-L.sh | ||||
|  | ||||
| .nfs00* | ||||
| *.swo | ||||
| */*.swo | ||||
|   | ||||
| @@ -5,9 +5,9 @@ | ||||
| # required to install hpbandster ################################## | ||||
| # pip install hpbandster         ################################## | ||||
| ################################################################### | ||||
| # python exps/algos-v2/bohb.py --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 | ||||
| # OMP_NUM_THREADS=4 python exps/algos-v2/bohb.py --search_space tss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1 | ||||
| ################################################################### | ||||
| import os, sys, time, random, argparse | ||||
| import os, sys, time, random, argparse, collections | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
| import torch | ||||
| @@ -17,7 +17,7 @@ from config_utils import load_config | ||||
| from datasets     import get_datasets, SearchDataset | ||||
| from procedures   import prepare_seed, prepare_logger | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from nas_201_api  import NASBench201API as API | ||||
| from nas_201_api  import NASBench201API, NASBench301API | ||||
| from models       import CellStructure, get_search_spaces | ||||
| # BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018 | ||||
| import ConfigSpace | ||||
| @@ -63,52 +63,21 @@ def config2topology_func(max_nodes=4): | ||||
|  | ||||
| class MyWorker(Worker): | ||||
|  | ||||
|   def __init__(self, *args, convert_func=None, dataname=None, nas_bench=None, time_budget=None, **kwargs): | ||||
|   def __init__(self, *args, convert_func=None, dataset=None, api=None, **kwargs): | ||||
|     super().__init__(*args, **kwargs) | ||||
|     self.convert_func   = convert_func | ||||
|     self._dataname      = dataname | ||||
|     self._nas_bench     = nas_bench | ||||
|     self.time_budget    = time_budget | ||||
|     self.seen_archs     = [] | ||||
|     self.sim_cost_time  = 0 | ||||
|     self.real_cost_time = 0 | ||||
|     self.is_end         = False | ||||
|  | ||||
|   def get_the_best(self): | ||||
|     assert len(self.seen_archs) > 0 | ||||
|     best_index, best_acc = -1, None | ||||
|     for arch_index in self.seen_archs: | ||||
|       info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp='200', is_random=True) | ||||
|       vacc = info['valid-accuracy'] | ||||
|       if best_acc is None or best_acc < vacc: | ||||
|         best_acc = vacc | ||||
|         best_index = arch_index | ||||
|     assert best_index != -1 | ||||
|     return best_index | ||||
|     self._dataset       = dataset | ||||
|     self._api           = api | ||||
|     self.total_times    = [] | ||||
|     self.trajectory     = [] | ||||
|  | ||||
|   def compute(self, config, budget, **kwargs): | ||||
|     start_time = time.time() | ||||
|     structure  = self.convert_func( config ) | ||||
|     arch_index = self._nas_bench.query_index_by_arch( structure ) | ||||
|     info       = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp='200', is_random=True) | ||||
|     cur_time   = info['train-all-time'] + info['valid-per-time'] | ||||
|     cur_vacc   = info['valid-accuracy'] | ||||
|     self.real_cost_time += (time.time() - start_time) | ||||
|     if self.sim_cost_time + cur_time <= self.time_budget and not self.is_end: | ||||
|       self.sim_cost_time += cur_time | ||||
|       self.seen_archs.append( arch_index ) | ||||
|       return ({'loss': 100 - float(cur_vacc), | ||||
|                'info': {'seen-arch'     : len(self.seen_archs), | ||||
|                         'sim-test-time' : self.sim_cost_time, | ||||
|                         'current-arch'  : arch_index} | ||||
|             }) | ||||
|     else: | ||||
|       self.is_end = True | ||||
|       return ({'loss': 100, | ||||
|                'info': {'seen-arch'     : len(self.seen_archs), | ||||
|                         'sim-test-time' : self.sim_cost_time, | ||||
|                         'current-arch'  : None} | ||||
|             }) | ||||
|     arch  = self.convert_func( config ) | ||||
|     accuracy, latency, time_cost, total_time = self._api.simulate_train_eval(arch, self._dataset, iepoch=int(budget)-1, hp='12') | ||||
|     self.trajectory.append((accuracy, arch)) | ||||
|     self.total_times.append(total_time) | ||||
|     return ({'loss': 100 - accuracy, | ||||
|              'info': self._api.query_index_by_arch(arch)}) | ||||
|  | ||||
|  | ||||
| def main(xargs, api): | ||||
| @@ -117,12 +86,13 @@ def main(xargs, api): | ||||
|   logger = prepare_logger(args) | ||||
|  | ||||
|   logger.log('{:} use api : {:}'.format(time_string(), api)) | ||||
|   api.reset_time() | ||||
|   search_space = get_search_spaces(xargs.search_space, 'nas-bench-301') | ||||
|   if xargs.search_space == 'tss': | ||||
|   	cs = get_topology_config_space(xargs.max_nodes, search_space) | ||||
|   	config2structure = config2topology_func(xargs.max_nodes) | ||||
|   	cs = get_topology_config_space(search_space) | ||||
|   	config2structure = config2topology_func() | ||||
|   else: | ||||
|   	cs = get_size_config_space(xargs.max_nodes, search_space) | ||||
|     cs = get_size_config_space(search_space) | ||||
|     import pdb; pdb.set_trace() | ||||
|    | ||||
|   hb_run_id = '0' | ||||
| @@ -133,14 +103,13 @@ def main(xargs, api): | ||||
|  | ||||
|   workers = [] | ||||
|   for i in range(num_workers): | ||||
|     w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, dataname=dataname, nas_bench=nas_bench, time_budget=xargs.time_budget, run_id=hb_run_id, id=i) | ||||
|     w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, dataset=xargs.dataset, api=api, run_id=hb_run_id, id=i) | ||||
|     w.run(background=True) | ||||
|     workers.append(w) | ||||
|  | ||||
|   start_time = time.time() | ||||
|   bohb = BOHB(configspace=cs, | ||||
|             run_id=hb_run_id, | ||||
|             eta=3, min_budget=12, max_budget=200, | ||||
|   bohb = BOHB(configspace=cs, run_id=hb_run_id, | ||||
|       eta=3, min_budget=1, max_budget=12, | ||||
|       nameserver=ns_host, | ||||
|       nameserver_port=ns_port, | ||||
|       num_samples=xargs.num_samples, | ||||
| @@ -152,22 +121,23 @@ def main(xargs, api): | ||||
|   bohb.shutdown(shutdown_workers=True) | ||||
|   NS.shutdown() | ||||
|  | ||||
|   real_cost_time = time.time() - start_time | ||||
|   # print('There are {:} runs.'.format(len(results.get_all_runs()))) | ||||
|   # workers[0].total_times | ||||
|   # workers[0].trajectory | ||||
|   current_best_index = [] | ||||
|   for idx in range(len(workers[0].trajectory)): | ||||
|     trajectory = workers[0].trajectory[:idx+1] | ||||
|     arch = max(trajectory, key=lambda x: x[0])[1] | ||||
|     current_best_index.append(api.query_index_by_arch(arch)) | ||||
|    | ||||
|   id2config = results.get_id2config_mapping() | ||||
|   incumbent = results.get_incumbent_id() | ||||
|   logger.log('Best found configuration: {:} within {:.3f} s'.format(id2config[incumbent]['config'], real_cost_time)) | ||||
|   best_arch = config2structure( id2config[incumbent]['config'] ) | ||||
|  | ||||
|   info = nas_bench.query_by_arch(best_arch, '200') | ||||
|   if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) | ||||
|   else           : logger.log('{:}'.format(info)) | ||||
|   best_arch = max(workers[0].trajectory, key=lambda x: x[0])[1] | ||||
|   logger.log('Best found configuration: {:} within {:.3f} s'.format(best_arch, workers[0].total_times[-1])) | ||||
|   info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90') | ||||
|   logger.log('{:}'.format(info)) | ||||
|   logger.log('-'*100) | ||||
|  | ||||
|   logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs))) | ||||
|   logger.close() | ||||
|   return logger.log_dir, nas_bench.query_index_by_arch( best_arch ), real_cost_time | ||||
|  | ||||
|   return logger.log_dir, current_best_index, workers[0].total_times | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
| @@ -185,8 +155,8 @@ if __name__ == '__main__': | ||||
|   parser.add_argument('--bandwidth_factor', default=3,   type=int, nargs='?', help='factor multiplied to the bandwidth') | ||||
|   parser.add_argument('--n_iters',          default=300, type=int, nargs='?', help='number of iterations for optimization method') | ||||
|   # log | ||||
|   parser.add_argument('--save_dir',           type=str,   help='Folder to save checkpoints and log.') | ||||
|   parser.add_argument('--rand_seed',          type=int,   help='manual seed') | ||||
|   parser.add_argument('--save_dir',           type=str,  default='./output/search', help='Folder to save checkpoints and log.') | ||||
|   parser.add_argument('--rand_seed',          type=int,  default=-1, help='manual seed') | ||||
|   args = parser.parse_args() | ||||
|    | ||||
|   if args.search_space == 'tss': | ||||
|   | ||||
| @@ -43,7 +43,7 @@ def main(xargs, api): | ||||
|   current_best_index = [] | ||||
|   while len(total_time_cost) == 0 or total_time_cost[-1] < xargs.time_budget: | ||||
|     arch = random_arch() | ||||
|     accuracy, _, _, total_cost = api.simulate_train_eval(arch, xargs.dataset, '12') | ||||
|     accuracy, _, _, total_cost = api.simulate_train_eval(arch, xargs.dataset, hp='12') | ||||
|     total_time_cost.append(total_cost) | ||||
|     history.append(arch) | ||||
|     if best_arch is None or best_acc < accuracy: | ||||
|   | ||||
| @@ -160,7 +160,7 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran | ||||
|   while len(population) < population_size: | ||||
|     model = Model() | ||||
|     model.arch = random_arch() | ||||
|     model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, '12') | ||||
|     model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp='12') | ||||
|     # Append the info | ||||
|     population.append(model) | ||||
|     history.append((model.accuracy, model.arch)) | ||||
| @@ -184,7 +184,7 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran | ||||
|     # Create the child model and store it. | ||||
|     child = Model() | ||||
|     child.arch = mutate_arch(parent.arch) | ||||
|     child.accuracy, _, _, total_cost = api.simulate_train_eval(child.arch, dataset, '12') | ||||
|     child.accuracy, _, _, total_cost = api.simulate_train_eval(child.arch, dataset, hp='12') | ||||
|     # Append the info | ||||
|     population.append(child) | ||||
|     history.append((child.accuracy, child.arch)) | ||||
|   | ||||
| @@ -150,7 +150,7 @@ def main(xargs, api): | ||||
|     start_time = time.time() | ||||
|     log_prob, action = select_action( policy ) | ||||
|     arch   = policy.generate_arch( action ) | ||||
|     reward, _, _, current_total_cost = api.simulate_train_eval(arch, xargs.dataset, '12') | ||||
|     reward, _, _, current_total_cost = api.simulate_train_eval(arch, xargs.dataset, hp='12') | ||||
|     trace.append((reward, arch)) | ||||
|     total_costs.append(current_total_cost) | ||||
|  | ||||
|   | ||||
| @@ -1,18 +1,19 @@ | ||||
| #!/bin/bash | ||||
| # bash ./exps/algos-v2/run-all.sh | ||||
| set -e | ||||
| echo script name: $0 | ||||
| echo $# arguments | ||||
|  | ||||
| datasets="cifar10 cifar100 ImageNet16-120" | ||||
| search_spaces="tss sss" | ||||
|  | ||||
|  | ||||
| for dataset in ${datasets} | ||||
| do | ||||
|   for search_space in ${search_spaces} | ||||
|   do | ||||
|     # python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.001 | ||||
|     python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.001 | ||||
|     python ./exps/algos-v2/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 | ||||
|     # python ./exps/algos-v2/random_wo_share.py --dataset ${dataset} --search_space ${search_space} | ||||
|     python ./exps/algos-v2/random_wo_share.py --dataset ${dataset} --search_space ${search_space} | ||||
|     python exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 | ||||
|   done | ||||
| done | ||||
|   | ||||
| @@ -5,7 +5,7 @@ | ||||
| ############################################################### | ||||
| # Usage: python exps/experimental/vis-bench-algos.py          # | ||||
| ############################################################### | ||||
| import os, sys, time, torch, argparse | ||||
| import os, gc, sys, time, torch, argparse | ||||
| import numpy as np | ||||
| from typing import List, Text, Dict, Any | ||||
| from shutil import copyfile | ||||
| @@ -31,6 +31,7 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): | ||||
|   alg2name['REA'] = 'R-EA-SS3' | ||||
|   alg2name['REINFORCE'] = 'REINFORCE-0.001' | ||||
|   alg2name['RANDOM'] = 'RANDOM' | ||||
|   alg2name['BOHB'] = 'BOHB' | ||||
|   for alg, name in alg2name.items(): | ||||
|     alg2path[alg] = os.path.join(ss_dir, dataset, name, 'results.pth') | ||||
|     assert os.path.isfile(alg2path[alg]), 'invalid path : {:}'.format(alg2path[alg]) | ||||
| @@ -58,14 +59,27 @@ def query_performance(api, data, dataset, ticket): | ||||
|     results.append(interplate) | ||||
|   return sum(results) / len(results) | ||||
|  | ||||
| y_min_s = {('cifar10', 'tss'): 90, | ||||
|            ('cifar10', 'sss'): 92, | ||||
|            ('cifar100', 'tss'): 65, | ||||
|            ('cifar100', 'sss'): 65, | ||||
|            ('ImageNet16-120', 'tss'): 36, | ||||
|            ('ImageNet16-120', 'sss'): 40} | ||||
|  | ||||
| y_max_s = {('cifar10', 'tss'): 94.5, | ||||
|            ('cifar10', 'sss'): 93.3, | ||||
|            ('cifar100', 'tss'): 72, | ||||
|            ('cifar100', 'sss'): 70, | ||||
|            ('ImageNet16-120', 'tss'): 44, | ||||
|            ('ImageNet16-120', 'sss'): 46} | ||||
|  | ||||
| def visualize_curve(api, vis_save_dir, search_space, max_time): | ||||
|   vis_save_dir = vis_save_dir.resolve() | ||||
|   vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|   dpi, width, height = 250, 5100, 1500 | ||||
|   dpi, width, height = 250, 5200, 1400 | ||||
|   figsize = width / float(dpi), height / float(dpi) | ||||
|   LabelSize, LegendFontsize = 14, 14 | ||||
|   LabelSize, LegendFontsize = 16, 16 | ||||
|  | ||||
|   def sub_plot_fn(ax, dataset): | ||||
|     alg2data = fetch_data(search_space=search_space, dataset=dataset) | ||||
| @@ -73,6 +87,8 @@ def visualize_curve(api, vis_save_dir, search_space, max_time): | ||||
|     total_tickets = 150 | ||||
|     time_tickets = [float(i) / total_tickets * max_time for i in range(total_tickets)] | ||||
|     colors = ['b', 'g', 'c', 'm', 'y'] | ||||
|     ax.set_xlim(0, 200) | ||||
|     ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)]) | ||||
|     for idx, (alg, data) in enumerate(alg2data.items()): | ||||
|       print('plot alg : {:}'.format(alg)) | ||||
|       accuracies = [] | ||||
| @@ -107,5 +123,7 @@ if __name__ == '__main__': | ||||
|  | ||||
|   api201 = NASBench201API(verbose=False) | ||||
|   visualize_curve(api201, save_dir, 'tss', args.max_time) | ||||
|   del api201 | ||||
|   gc.collect() | ||||
|   api301 = NASBench301API(verbose=False) | ||||
|   visualize_curve(api301, save_dir, 'sss', args.max_time) | ||||
|   | ||||
| @@ -68,14 +68,14 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): | ||||
|   def reset_time(self): | ||||
|     self._used_time = 0 | ||||
|  | ||||
|   def simulate_train_eval(self, arch, dataset, hp='12', account_time=True): | ||||
|   def simulate_train_eval(self, arch, dataset, iepoch=None, hp='12', account_time=True): | ||||
|     index = self.query_index_by_arch(arch) | ||||
|     all_names = ('cifar10', 'cifar100', 'ImageNet16-120') | ||||
|     assert dataset in all_names, 'Invalid dataset name : {:} vs {:}'.format(dataset, all_names) | ||||
|     if dataset == 'cifar10': | ||||
|       info = self.get_more_info(index, 'cifar10-valid', iepoch=None, hp=hp, is_random=True) | ||||
|       info = self.get_more_info(index, 'cifar10-valid', iepoch=iepoch, hp=hp, is_random=True) | ||||
|     else: | ||||
|       info = self.get_more_info(index, dataset, iepoch=None, hp=hp, is_random=True) | ||||
|       info = self.get_more_info(index, dataset, iepoch=iepoch, hp=hp, is_random=True) | ||||
|     valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] | ||||
|     latency = self.get_latency(index, dataset) | ||||
|     if account_time: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user