From 4c144b74376bb7bf87912d615c97c5b1f018a96e Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sat, 28 Dec 2019 15:42:36 +1100 Subject: [PATCH] update README --- NAS-Bench-102.md | 15 +++++----- README.md | 2 ++ exps/algos/BOHB.py | 56 +++++++++++++++++++++++++++--------- exps/algos/R_EA.py | 2 +- lib/nas_102_api/api.py | 8 +++--- scripts-search/algos/BOHB.sh | 4 +-- 6 files changed, 59 insertions(+), 28 deletions(-) diff --git a/NAS-Bench-102.md b/NAS-Bench-102.md index 8871234..a3fb46f 100644 --- a/NAS-Bench-102.md +++ b/NAS-Bench-102.md @@ -26,9 +26,10 @@ It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default). 1. Creating an API instance from a file: ``` -from nas_102_api import NASBench102API -api = NASBench102API('$path_to_meta_nas_bench_file') -api = NASBench102API('NAS-Bench-102-v1_0-e61699.pth') +from nas_102_api import NASBench102API as API +api = API('$path_to_meta_nas_bench_file') +api = API('NAS-Bench-102-v1_0-e61699.pth') +api = API('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-102-v1_0-e61699.pth')) ``` 2. Show the number of architectures `len(api)` and each architecture `api[i]`: @@ -45,12 +46,12 @@ api.show(1) api.show(2) # show the mean loss and accuracy of an architecture -info = api.query_meta_info_by_index(1) -res_metrics = info.get_metrics('cifar10', 'train') -cost_metrics = info.get_comput_costs('cifar100') +info = api.query_meta_info_by_index(1) # This is an instance of `ArchResults` +res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys +cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency # get the detailed information -results = api.query_by_index(1, 'cifar100') +results = api.query_by_index(1, 'cifar100') # a list of all trials on cifar100 print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1])) print ('Latency : {:}'.format(results[0].get_latency())) print ('Train Info : {:}'.format(results[0].get_train())) diff --git a/README.md b/README.md index ed94dfc..e8b0843 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,8 @@ We build a new benchmark for neural architecture search, please see more details The benchmark data file (v1.0) is `NAS-Bench-102-v1_0-e61699.pth`, which can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs). ## [Network Pruning via Transformable Architecture Search](https://arxiv.org/abs/1905.09717) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/network-pruning-via-transformable/network-pruning-on-cifar-100)](https://paperswithcode.com/sota/network-pruning-on-cifar-100?p=network-pruning-via-transformable) + In this paper, we proposed a differentiable searching strategy for transformable architectures, i.e., searching for the depth and width of a deep neural network. You could see the highlight of our Transformable Architecture Search (TAS) at our [project page](https://xuanyidong.com/assets/projects/NeurIPS-2019-TAS.html). diff --git a/exps/algos/BOHB.py b/exps/algos/BOHB.py index c846e0b..eea14bc 100644 --- a/exps/algos/BOHB.py +++ b/exps/algos/BOHB.py @@ -2,6 +2,7 @@ # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # ################################################## # required to install hpbandster ################# +# bash ./scripts-search/algos/BOHB.sh -1 # ################################################## import os, sys, time, glob, random, argparse import numpy as np, collections @@ -19,7 +20,6 @@ from utils import get_model_infos, obtain_accuracy from log_utils import AverageMeter, time_string, convert_secs2time from nas_102_api import NASBench102API as API from models import CellStructure, get_search_spaces -from R_EA import train_and_eval # BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018 import ConfigSpace from hpbandster.optimizers.bohb import BOHB @@ -53,21 +53,44 @@ def config2structure_func(max_nodes): class MyWorker(Worker): - def __init__(self, *args, sleep_interval=0, convert_func=None, nas_bench=None, **kwargs): + def __init__(self, *args, convert_func=None, nas_bench=None, time_scale=None, **kwargs): super().__init__(*args, **kwargs) - self.sleep_interval = sleep_interval self.convert_func = convert_func self.nas_bench = nas_bench - self.test_time = 0 + self.time_scale = time_scale + self.seen_arch = 0 + self.sim_cost_time = 0 + self.real_cost_time = 0 def compute(self, config, budget, **kwargs): - structure = self.convert_func( config ) - reward, time_cost = train_and_eval(structure, self.nas_bench, None) - import pdb; pdb.set_trace() - self.test_time += 1 + start_time = time.time() + structure = self.convert_func( config ) + arch_index = self.nas_bench.query_index_by_arch( structure ) + iepoch = 0 + while iepoch < 12: + info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch, True) + cur_time = info['train-all-time'] + info['valid-per-time'] + cur_vacc = info['valid-accuracy'] + if time.time() - start_time + cur_time / self.time_scale > budget: + break + else: + iepoch += 1 + self.sim_cost_time += cur_time + self.seen_arch += 1 + remaining_time = cur_time / self.time_scale - (time.time() - start_time) + if remaining_time > 0: + time.sleep(remaining_time) + else: + import pdb; pdb.set_trace() + self.real_cost_time += (time.time() - start_time) return ({ - 'loss': float(100-reward), - 'info': time_cost}) + 'loss': 100 - float(cur_vacc), + 'info': {'seen-arch' : self.seen_arch, + 'sim-test-time' : self.sim_cost_time, + 'real-test-time': self.real_cost_time, + 'current-arch' : arch_index, + 'current-budget': budget} + }) def main(xargs, nas_bench): @@ -116,26 +139,30 @@ def main(xargs, nas_bench): #logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string())) workers = [] for i in range(num_workers): - w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, run_id=hb_run_id, id=i) + w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_scale=xargs.time_scale, run_id=hb_run_id, id=i) w.run(background=True) workers.append(w) + simulate_time_budge = xargs.time_budget // xargs.time_scale + start_time = time.time() + logger.log('simulate_time_budge : {:} (in seconds).'.format(simulate_time_budge)) bohb = BOHB(configspace=cs, run_id=hb_run_id, - eta=3, min_budget=3, max_budget=xargs.time_budget, + eta=3, min_budget=simulate_time_budge//3, max_budget=simulate_time_budge, nameserver=ns_host, nameserver_port=ns_port, num_samples=xargs.num_samples, random_fraction=xargs.random_fraction, bandwidth_factor=xargs.bandwidth_factor, ping_interval=10, min_bandwidth=xargs.min_bandwidth) - # optimization_strategy=xargs.strategy, num_samples=xargs.num_samples, results = bohb.run(xargs.n_iters, min_n_workers=num_workers) - import pdb; pdb.set_trace() bohb.shutdown(shutdown_workers=True) NS.shutdown() + real_cost_time = time.time() - start_time + import pdb; pdb.set_trace() + id2config = results.get_id2config_mapping() incumbent = results.get_incumbent_id() @@ -163,6 +190,7 @@ if __name__ == '__main__': parser.add_argument('--channel', type=int, help='The number of channels.') parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).') + parser.add_argument('--time_scale' , type=int, help='The time scale to accelerate the time budget.') # BOHB parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function') parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE') diff --git a/exps/algos/R_EA.py b/exps/algos/R_EA.py index bd66c9b..bc3345b 100644 --- a/exps/algos/R_EA.py +++ b/exps/algos/R_EA.py @@ -59,7 +59,7 @@ def train_and_eval(arch, nas_bench, extra_info): if nas_bench is not None: arch_index = nas_bench.query_index_by_arch( arch ) assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) - info = nas_bench.get_more_info(arch_index, 'cifar10-valid', True) + info = nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True) valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] #_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs else: diff --git a/lib/nas_102_api/api.py b/lib/nas_102_api/api.py index cbc9968..f459380 100644 --- a/lib/nas_102_api/api.py +++ b/lib/nas_102_api/api.py @@ -147,14 +147,14 @@ class NASBench102API(object): archresult = arch2infos[index] return archresult.get_net_param(dataset, seed) - def get_more_info(self, index, dataset, use_12epochs_result=False): + def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False): if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less else : basestr, arch2infos = '200epochs', self.arch2infos_full archresult = arch2infos[index] if dataset == 'cifar10-valid': - train_info = archresult.get_metrics(dataset, 'train', is_random=True) - valid_info = archresult.get_metrics(dataset, 'x-valid', is_random=True) - test__info = archresult.get_metrics(dataset, 'ori-test', is_random=True) + train_info = archresult.get_metrics(dataset, 'train' , iepoch=iepoch, is_random=True) + valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=True) + test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=True) total = train_info['iepoch'] + 1 return {'train-loss' : train_info['loss'], 'train-accuracy': train_info['accuracy'], diff --git a/scripts-search/algos/BOHB.sh b/scripts-search/algos/BOHB.sh index 59ea98d..4d07f0a 100644 --- a/scripts-search/algos/BOHB.sh +++ b/scripts-search/algos/BOHB.sh @@ -34,6 +34,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \ --dataset ${dataset} --data_path ${data_path} \ --search_space_name ${space} \ --arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ - --time_budget 12000 \ - --n_iters 100 --num_samples 4 --random_fraction 0 \ + --time_budget 12000 --time_scale 200 \ + --n_iters 64 --num_samples 4 --random_fraction 0 \ --workers 4 --print_freq 200 --rand_seed ${seed}