update README
This commit is contained in:
		| @@ -26,9 +26,10 @@ It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default). | |||||||
|  |  | ||||||
| 1. Creating an API instance from a file: | 1. Creating an API instance from a file: | ||||||
| ``` | ``` | ||||||
| from nas_102_api import NASBench102API | from nas_102_api import NASBench102API as API | ||||||
| api = NASBench102API('$path_to_meta_nas_bench_file') | api = API('$path_to_meta_nas_bench_file') | ||||||
| api = NASBench102API('NAS-Bench-102-v1_0-e61699.pth') | api = API('NAS-Bench-102-v1_0-e61699.pth') | ||||||
|  | api = API('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-102-v1_0-e61699.pth')) | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| 2. Show the number of architectures `len(api)` and each architecture `api[i]`: | 2. Show the number of architectures `len(api)` and each architecture `api[i]`: | ||||||
| @@ -45,12 +46,12 @@ api.show(1) | |||||||
| api.show(2) | api.show(2) | ||||||
|  |  | ||||||
| # show the mean loss and accuracy of an architecture | # show the mean loss and accuracy of an architecture | ||||||
| info = api.query_meta_info_by_index(1) | info = api.query_meta_info_by_index(1)  # This is an instance of `ArchResults` | ||||||
| res_metrics = info.get_metrics('cifar10', 'train') | res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys | ||||||
| cost_metrics = info.get_comput_costs('cifar100') | cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency | ||||||
|  |  | ||||||
| # get the detailed information | # get the detailed information | ||||||
| results = api.query_by_index(1, 'cifar100') | results = api.query_by_index(1, 'cifar100') # a list of all trials on cifar100 | ||||||
| print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1])) | print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1])) | ||||||
| print ('Latency : {:}'.format(results[0].get_latency())) | print ('Latency : {:}'.format(results[0].get_latency())) | ||||||
| print ('Train Info : {:}'.format(results[0].get_train())) | print ('Train Info : {:}'.format(results[0].get_train())) | ||||||
|   | |||||||
| @@ -35,6 +35,8 @@ We build a new benchmark for neural architecture search, please see more details | |||||||
| The benchmark data file (v1.0) is `NAS-Bench-102-v1_0-e61699.pth`, which can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs). | The benchmark data file (v1.0) is `NAS-Bench-102-v1_0-e61699.pth`, which can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs). | ||||||
|  |  | ||||||
| ## [Network Pruning via Transformable Architecture Search](https://arxiv.org/abs/1905.09717) | ## [Network Pruning via Transformable Architecture Search](https://arxiv.org/abs/1905.09717) | ||||||
|  | [](https://paperswithcode.com/sota/network-pruning-on-cifar-100?p=network-pruning-via-transformable) | ||||||
|  |  | ||||||
| In this paper, we proposed a differentiable searching strategy for transformable architectures, i.e., searching for the depth and width of a deep neural network. | In this paper, we proposed a differentiable searching strategy for transformable architectures, i.e., searching for the depth and width of a deep neural network. | ||||||
| You could see the highlight of our Transformable Architecture Search (TAS) at our [project page](https://xuanyidong.com/assets/projects/NeurIPS-2019-TAS.html). | You could see the highlight of our Transformable Architecture Search (TAS) at our [project page](https://xuanyidong.com/assets/projects/NeurIPS-2019-TAS.html). | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ | |||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||||
| ################################################## | ################################################## | ||||||
| # required to install hpbandster ################# | # required to install hpbandster ################# | ||||||
|  | # bash ./scripts-search/algos/BOHB.sh -1         # | ||||||
| ################################################## | ################################################## | ||||||
| import os, sys, time, glob, random, argparse | import os, sys, time, glob, random, argparse | ||||||
| import numpy as np, collections | import numpy as np, collections | ||||||
| @@ -19,7 +20,6 @@ from utils        import get_model_infos, obtain_accuracy | |||||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | from log_utils    import AverageMeter, time_string, convert_secs2time | ||||||
| from nas_102_api  import NASBench102API as API | from nas_102_api  import NASBench102API as API | ||||||
| from models       import CellStructure, get_search_spaces | from models       import CellStructure, get_search_spaces | ||||||
| from R_EA import train_and_eval |  | ||||||
| # BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018 | # BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018 | ||||||
| import ConfigSpace | import ConfigSpace | ||||||
| from hpbandster.optimizers.bohb import BOHB | from hpbandster.optimizers.bohb import BOHB | ||||||
| @@ -53,21 +53,44 @@ def config2structure_func(max_nodes): | |||||||
|  |  | ||||||
| class MyWorker(Worker): | class MyWorker(Worker): | ||||||
|  |  | ||||||
|   def __init__(self, *args, sleep_interval=0, convert_func=None, nas_bench=None, **kwargs): |   def __init__(self, *args, convert_func=None, nas_bench=None, time_scale=None, **kwargs): | ||||||
|     super().__init__(*args, **kwargs) |     super().__init__(*args, **kwargs) | ||||||
|     self.sleep_interval = sleep_interval |  | ||||||
|     self.convert_func   = convert_func |     self.convert_func   = convert_func | ||||||
|     self.nas_bench      = nas_bench |     self.nas_bench      = nas_bench | ||||||
|     self.test_time      = 0 |     self.time_scale     = time_scale | ||||||
|  |     self.seen_arch      = 0 | ||||||
|  |     self.sim_cost_time  = 0 | ||||||
|  |     self.real_cost_time = 0 | ||||||
|  |  | ||||||
|   def compute(self, config, budget, **kwargs): |   def compute(self, config, budget, **kwargs): | ||||||
|  |     start_time = time.time() | ||||||
|     structure  = self.convert_func( config ) |     structure  = self.convert_func( config ) | ||||||
|     reward, time_cost = train_and_eval(structure, self.nas_bench, None) |     arch_index = self.nas_bench.query_index_by_arch( structure ) | ||||||
|  |     iepoch     = 0 | ||||||
|  |     while iepoch < 12: | ||||||
|  |       info     = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch, True) | ||||||
|  |       cur_time = info['train-all-time'] + info['valid-per-time'] | ||||||
|  |       cur_vacc = info['valid-accuracy'] | ||||||
|  |       if time.time() - start_time + cur_time / self.time_scale > budget: | ||||||
|  |         break | ||||||
|  |       else: | ||||||
|  |         iepoch += 1 | ||||||
|  |     self.sim_cost_time += cur_time | ||||||
|  |     self.seen_arch += 1 | ||||||
|  |     remaining_time = cur_time / self.time_scale - (time.time() - start_time) | ||||||
|  |     if remaining_time > 0: | ||||||
|  |       time.sleep(remaining_time) | ||||||
|  |     else: | ||||||
|       import pdb; pdb.set_trace() |       import pdb; pdb.set_trace() | ||||||
|     self.test_time += 1 |     self.real_cost_time += (time.time() - start_time) | ||||||
|     return ({ |     return ({ | ||||||
|             'loss': float(100-reward), |             'loss': 100 - float(cur_vacc), | ||||||
|             'info': time_cost}) |             'info': {'seen-arch'     : self.seen_arch, | ||||||
|  |                      'sim-test-time' : self.sim_cost_time, | ||||||
|  |                      'real-test-time': self.real_cost_time, | ||||||
|  |                      'current-arch'  : arch_index, | ||||||
|  |                      'current-budget': budget} | ||||||
|  |             }) | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(xargs, nas_bench): | def main(xargs, nas_bench): | ||||||
| @@ -116,26 +139,30 @@ def main(xargs, nas_bench): | |||||||
|   #logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string())) |   #logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string())) | ||||||
|   workers = [] |   workers = [] | ||||||
|   for i in range(num_workers): |   for i in range(num_workers): | ||||||
|     w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, run_id=hb_run_id, id=i) |     w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_scale=xargs.time_scale, run_id=hb_run_id, id=i) | ||||||
|     w.run(background=True) |     w.run(background=True) | ||||||
|     workers.append(w) |     workers.append(w) | ||||||
|  |  | ||||||
|  |   simulate_time_budge = xargs.time_budget // xargs.time_scale | ||||||
|  |   start_time = time.time() | ||||||
|  |   logger.log('simulate_time_budge : {:} (in seconds).'.format(simulate_time_budge)) | ||||||
|   bohb = BOHB(configspace=cs, |   bohb = BOHB(configspace=cs, | ||||||
|             run_id=hb_run_id, |             run_id=hb_run_id, | ||||||
|             eta=3, min_budget=3, max_budget=xargs.time_budget, |             eta=3, min_budget=simulate_time_budge//3, max_budget=simulate_time_budge, | ||||||
|             nameserver=ns_host, |             nameserver=ns_host, | ||||||
|             nameserver_port=ns_port, |             nameserver_port=ns_port, | ||||||
|             num_samples=xargs.num_samples, |             num_samples=xargs.num_samples, | ||||||
|             random_fraction=xargs.random_fraction, bandwidth_factor=xargs.bandwidth_factor, |             random_fraction=xargs.random_fraction, bandwidth_factor=xargs.bandwidth_factor, | ||||||
|             ping_interval=10, min_bandwidth=xargs.min_bandwidth) |             ping_interval=10, min_bandwidth=xargs.min_bandwidth) | ||||||
|   #          optimization_strategy=xargs.strategy, num_samples=xargs.num_samples, |  | ||||||
|    |    | ||||||
|   results = bohb.run(xargs.n_iters, min_n_workers=num_workers) |   results = bohb.run(xargs.n_iters, min_n_workers=num_workers) | ||||||
|   import pdb; pdb.set_trace() |  | ||||||
|  |  | ||||||
|   bohb.shutdown(shutdown_workers=True) |   bohb.shutdown(shutdown_workers=True) | ||||||
|   NS.shutdown() |   NS.shutdown() | ||||||
|  |  | ||||||
|  |   real_cost_time = time.time() - start_time | ||||||
|  |   import pdb; pdb.set_trace() | ||||||
|  |  | ||||||
|   id2config = results.get_id2config_mapping() |   id2config = results.get_id2config_mapping() | ||||||
|   incumbent = results.get_incumbent_id() |   incumbent = results.get_incumbent_id() | ||||||
|  |  | ||||||
| @@ -163,6 +190,7 @@ if __name__ == '__main__': | |||||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') |   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||||
|   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') |   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') | ||||||
|   parser.add_argument('--time_budget',        type=int,   help='The total time cost budge for searching (in seconds).') |   parser.add_argument('--time_budget',        type=int,   help='The total time cost budge for searching (in seconds).') | ||||||
|  |   parser.add_argument('--time_scale' ,        type=int,   help='The time scale to accelerate the time budget.') | ||||||
|   # BOHB |   # BOHB | ||||||
|   parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function') |   parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function') | ||||||
|   parser.add_argument('--min_bandwidth',    default=.3, type=float, nargs='?', help='minimum bandwidth for KDE') |   parser.add_argument('--min_bandwidth',    default=.3, type=float, nargs='?', help='minimum bandwidth for KDE') | ||||||
|   | |||||||
| @@ -59,7 +59,7 @@ def train_and_eval(arch, nas_bench, extra_info): | |||||||
|   if nas_bench is not None: |   if nas_bench is not None: | ||||||
|     arch_index = nas_bench.query_index_by_arch( arch ) |     arch_index = nas_bench.query_index_by_arch( arch ) | ||||||
|     assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) |     assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) | ||||||
|     info = nas_bench.get_more_info(arch_index, 'cifar10-valid', True) |     info = nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True) | ||||||
|     valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] |     valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] | ||||||
|     #_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs |     #_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs | ||||||
|   else: |   else: | ||||||
|   | |||||||
| @@ -147,14 +147,14 @@ class NASBench102API(object): | |||||||
|     archresult = arch2infos[index] |     archresult = arch2infos[index] | ||||||
|     return archresult.get_net_param(dataset, seed) |     return archresult.get_net_param(dataset, seed) | ||||||
|  |  | ||||||
|   def get_more_info(self, index, dataset, use_12epochs_result=False): |   def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False): | ||||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less |     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full |     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||||
|     archresult = arch2infos[index] |     archresult = arch2infos[index] | ||||||
|     if dataset == 'cifar10-valid': |     if dataset == 'cifar10-valid': | ||||||
|       train_info = archresult.get_metrics(dataset, 'train', is_random=True) |       train_info = archresult.get_metrics(dataset, 'train'   , iepoch=iepoch, is_random=True) | ||||||
|       valid_info = archresult.get_metrics(dataset, 'x-valid', is_random=True) |       valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=True) | ||||||
|       test__info = archresult.get_metrics(dataset, 'ori-test', is_random=True) |       test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=True) | ||||||
|       total      = train_info['iepoch'] + 1 |       total      = train_info['iepoch'] + 1 | ||||||
|       return {'train-loss'    : train_info['loss'], |       return {'train-loss'    : train_info['loss'], | ||||||
|               'train-accuracy': train_info['accuracy'], |               'train-accuracy': train_info['accuracy'], | ||||||
|   | |||||||
| @@ -34,6 +34,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \ | |||||||
| 	--dataset ${dataset} --data_path ${data_path} \ | 	--dataset ${dataset} --data_path ${data_path} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | ||||||
| 	--time_budget 12000 \ | 	--time_budget 12000 --time_scale 200 \ | ||||||
| 	--n_iters 100 --num_samples 4 --random_fraction 0 \ | 	--n_iters 64 --num_samples 4 --random_fraction 0 \ | ||||||
| 	--workers 4 --print_freq 200 --rand_seed ${seed} | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user