From 168b08d9e66adeb1b5f3460c92020388882971d4 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Tue, 14 Jul 2020 06:10:34 +0000 Subject: [PATCH] Update REA, REINFORCE, and RANDOM --- docs/NAS-Bench-201-PURE.md | 180 ++++++++++++++++++++++ exps/algos-v2/bohb.py | 214 +++++++++++++++++++++++++++ exps/algos-v2/regularized_ea.py | 4 +- exps/algos-v2/run-all.sh | 4 +- exps/algos/BOHB.py | 2 +- exps/experimental/vis-bench-algos.py | 4 +- 6 files changed, 401 insertions(+), 7 deletions(-) create mode 100644 docs/NAS-Bench-201-PURE.md create mode 100644 exps/algos-v2/bohb.py diff --git a/docs/NAS-Bench-201-PURE.md b/docs/NAS-Bench-201-PURE.md new file mode 100644 index 0000000..e9980cb --- /dev/null +++ b/docs/NAS-Bench-201-PURE.md @@ -0,0 +1,180 @@ +# [NAS-BENCH-201: Extending the Scope of Reproducible Neural Architecture Search](https://openreview.net/forum?id=HJxyZkBKDr) + +We propose an algorithm-agnostic NAS benchmark (NAS-Bench-201) with a fixed search space, which provides a unified benchmark for almost any up-to-date NAS algorithms. +The design of our search space is inspired by that used in the most popular cell-based searching algorithms, where a cell is represented as a directed acyclic graph. +Each edge here is associated with an operation selected from a predefined operation set. +For it to be applicable for all NAS algorithms, the search space defined in NAS-Bench-201 includes 4 nodes and 5 associated operation options, which generates 15,625 neural cell candidates in total. + +In this Markdown file, we provide: +- [How to Use NAS-Bench-201](#how-to-use-nas-bench-201) + +For the following two things, please use [AutoDL-Projects](https://github.com/D-X-Y/AutoDL-Projects): +- [Instruction to re-generate NAS-Bench-201](#instruction-to-re-generate-nas-bench-201) +- [10 NAS algorithms evaluated in our paper](#to-reproduce-10-baseline-nas-algorithms-in-nas-bench-201) + +Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`. + +You can simply type `pip install nas-bench-201` to install our api. Please see source codes of `nas-bench-201` module in [this repo](https://github.com/D-X-Y/NAS-Bench-201). + +**If you have any questions or issues, please post it at [here](https://github.com/D-X-Y/AutoDL-Projects/issues) or email me.** + +### Preparation and Download + +[deprecated] The **old** benchmark file of NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/file/d/1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs/view?usp=sharing) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w). + +[recommended] The **latest** benchmark file of NAS-Bench-201 (`NAS-Bench-201-v1_1-096897.pth`) can be downloaded from [Google Drive](https://drive.google.com/file/d/16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_/view?usp=sharing). The files for model weight are too large (431G) and I need some time to upload it. Please be patient, thanks for your understanding. + +You can move it to anywhere you want and send its path to our API for initialization. +- [2020.02.25] APIv1.0/FILEv1.0: [`NAS-Bench-201-v1_0-e61699.pth`](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) (2.2G), where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial. +- [2020.02.25] APIv1.0/FILEv1.0: The full data of each architecture can be download from [ +NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights. +- [2020.02.25] APIv1.0/FILEv1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi). +- [2020.03.09] APIv1.2/FILEv1.0: More robust API with more functions and descriptions +- [2020.03.16] APIv1.3/FILEv1.1: [`NAS-Bench-201-v1_1-096897.pth`](https://drive.google.com/open?id=16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_) (4.7G), where `096897` is the last six digits for this file. It contains information of more trials compared to `NAS-Bench-201-v1_0-e61699.pth`, especially all models trained by 12 epochs on all datasets are avaliable. +- [2020.06.30] APIv2.0: Use abstract class (NASBenchMetaAPI) for APIs of NAS-Bench-x0y. +- [2020.06.30] FILEv2.0: coming soon! + +**We recommend to use `NAS-Bench-201-v1_1-096897.pth`** + + +The training and evaluation data used in NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan (code:4fg7)](https://pan.baidu.com/s/1XAzavPKq3zcat1yBA1L2tQ). +It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default). If you want to generate NAS-Bench-201 or similar NAS datasets or training models by yourself, you need these data. + +## How to Use NAS-Bench-201 + +**More usage can be found in [our test codes](https://github.com/D-X-Y/AutoDL-Projects/blob/master/exps/NAS-Bench-201/test-nas-api.py)**. + +1. Creating an API instance from a file: +``` +from nas_201_api import NASBench201API as API +api = API('$path_to_meta_nas_bench_file') +# Create an API without the verbose log +api = API('NAS-Bench-201-v1_1-096897.pth', verbose=False) +# The default path for benchmark file is '{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_1-096897.pth') +api = API(None) +``` + +2. Show the number of architectures `len(api)` and each architecture `api[i]`: +``` +num = len(api) +for i, arch_str in enumerate(api): + print ('{:5d}/{:5d} : {:}'.format(i, len(api), arch_str)) +``` + +3. Show the results of all trials for a single architecture: +``` +# show all information for a specific architecture +api.show(1) +api.show(2) + +# show the mean loss and accuracy of an architecture +info = api.query_meta_info_by_index(1) # This is an instance of `ArchResults` +res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys +cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency + +# get the detailed information +results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed +print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1])) +print ('Latency : {:}'.format(results[0].get_latency())) +print ('Train Info : {:}'.format(results[0].get_train())) +print ('Valid Info : {:}'.format(results[0].get_eval('x-valid'))) +print ('Test Info : {:}'.format(results[0].get_eval('x-test'))) +# for the metric after a specific epoch +print ('Train Info [10-th epoch] : {:}'.format(results[0].get_train(10))) +``` + +4. Query the index of an architecture by string +``` +index = api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|') +api.show(index) +``` +This string `|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|` means: +``` +node-0: the input tensor +node-1: conv-3x3( node-0 ) +node-2: conv-3x3( node-0 ) + avg-pool-3x3( node-1 ) +node-3: skip-connect( node-0 ) + conv-3x3( node-1 ) + skip-connect( node-2 ) +``` + +5. Create the network from api: +``` +config = api.get_net_config(123, 'cifar10') # obtain the network configuration for the 123-th architecture on the CIFAR-10 dataset +from models import get_cell_based_tiny_net # this module is in AutoDL-Projects/lib/models +network = get_cell_based_tiny_net(config) # create the network from configurration +print(network) # show the structure of this architecture +``` +If you want to load the trained weights of this created network, you need to use `api.get_net_param(123, ...)` to obtain the weights and then load it to the network. + +6. `api.get_more_info(...)` can return the loss / accuracy / time on training / validation / test sets, which is very helpful. For more details, please look at the comments in the get_more_info function. + +7. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201. + + +### Detailed Instruction + +In `nas_201_api`, we define three classes: `NASBench201API`, `ArchResults`, `ResultsCount`. + +`ResultsCount` maintains all information of a specific trial. One can instantiate ResultsCount and get the info via the following codes (`000157-FULL.pth` saves all information of all trials of 157-th architecture): +``` +from nas_201_api import ResultsCount +xdata = torch.load('000157-FULL.pth') +odata = xdata['full']['all_results'][('cifar10-valid', 777)] +result = ResultsCount.create_from_state_dict( odata ) +print(result) # print it +print(result.get_train()) # print the final training loss/accuracy/[optional:time-cost-of-a-training-epoch] +print(result.get_train(11)) # print the training info of the 11-th epoch +print(result.get_eval('x-valid')) # print the final evaluation info on the validation set +print(result.get_eval('x-valid', 11)) # print the info on the validation set of the 11-th epoch +print(result.get_latency()) # print the evaluation latency [in batch] +result.get_net_param() # the trained parameters of this trial +arch_config = result.get_config(CellStructure.str2structure) # create the network with params +net_config = dict2config(arch_config, None) +network = get_cell_based_tiny_net(net_config) +network.load_state_dict(result.get_net_param()) +``` + +`ArchResults` maintains all information of all trials of an architecture. Please see the following usages: +``` +from nas_201_api import ArchResults +xdata = torch.load('000157-FULL.pth') +archRes = ArchResults.create_from_state_dict(xdata['less']) # load trials trained with 12 epochs +archRes = ArchResults.create_from_state_dict(xdata['full']) # load trials trained with 200 epochs + +print(archRes.arch_idx_str()) # print the index of this architecture +print(archRes.get_dataset_names()) # print the supported training data +print(archRes.get_comput_costs('cifar10-valid')) # print all computational info when training on cifar10-valid +print(archRes.get_metrics('cifar10-valid', 'x-valid', None, False)) # print the average loss/accuracy/time on all trials +print(archRes.get_metrics('cifar10-valid', 'x-valid', None, True)) # print loss/accuracy/time of a randomly selected trial +``` + +`NASBench201API` is the topest level api. Please see the following usages: +``` +from nas_201_api import NASBench201API as API +api = API('NAS-Bench-201-v1_1-096897.pth') # This will load all the information of NAS-Bench-201 except the trained weights +api = API('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_1-096897.pth')) # The same as the above line while I usually save NAS-Bench-201-v1_1-096897.pth in ~/.torch/. +api.show(-1) # show info of all architectures +api.reload('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-BENCH-201-4-v1.0-archive'), 3) # This code will reload the information 3-th architecture with the trained weights + +weights = api.get_net_param(3, 'cifar10', None) # Obtaining the weights of all trials for the 3-th architecture on cifar10. It will returns a dict, where the key is the seed and the value is the trained weights. +``` + +To obtain the training and evaluation information (please see the comments [here](https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/nas_201_api/api_201.py#L142)): +``` +api.get_more_info(112, 'cifar10', None, hp='200', is_random=True) +# Query info of last training epoch for 112-th architecture +# using 200-epoch-hyper-parameter and randomly select a trial. +api.get_more_info(112, 'ImageNet16-120', None, hp='200', is_random=True) +``` + +# Citation + +If you find that NAS-Bench-201 helps your research, please consider citing it: +``` +@inproceedings{dong2020nasbench201, + title = {NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search}, + author = {Dong, Xuanyi and Yang, Yi}, + booktitle = {International Conference on Learning Representations (ICLR)}, + url = {https://openreview.net/forum?id=HJxyZkBKDr}, + year = {2020} +} +``` diff --git a/exps/algos-v2/bohb.py b/exps/algos-v2/bohb.py new file mode 100644 index 0000000..b14e903 --- /dev/null +++ b/exps/algos-v2/bohb.py @@ -0,0 +1,214 @@ +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # +################################################################### +# BOHB: Robust and Efficient Hyperparameter Optimization at Scale # +# required to install hpbandster ################################## +# pip install hpbandster ################################## +################################################################### +# python exps/algos-v2/bohb.py --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 +################################################################### +import os, sys, time, random, argparse +from copy import deepcopy +from pathlib import Path +import torch +lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() +if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +from config_utils import load_config +from datasets import get_datasets, SearchDataset +from procedures import prepare_seed, prepare_logger +from log_utils import AverageMeter, time_string, convert_secs2time +from nas_201_api import NASBench201API as API +from models import CellStructure, get_search_spaces +# BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018 +import ConfigSpace +from hpbandster.optimizers.bohb import BOHB +import hpbandster.core.nameserver as hpns +from hpbandster.core.worker import Worker + + +def get_topology_config_space(search_space, max_nodes=4): + cs = ConfigSpace.ConfigurationSpace() + #edge2index = {} + for i in range(1, max_nodes): + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space)) + return cs + + +def get_size_config_space(search_space): + cs = ConfigSpace.ConfigurationSpace() + import pdb; pdb.set_trace() + #edge2index = {} + for i in range(1, max_nodes): + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space)) + return cs + + +def config2topology_func(max_nodes=4): + def config2structure(config): + genotypes = [] + for i in range(1, max_nodes): + xlist = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + op_name = config[node_str] + xlist.append((op_name, j)) + genotypes.append( tuple(xlist) ) + return CellStructure( genotypes ) + return config2structure + + +class MyWorker(Worker): + + def __init__(self, *args, convert_func=None, dataname=None, nas_bench=None, time_budget=None, **kwargs): + super().__init__(*args, **kwargs) + self.convert_func = convert_func + self._dataname = dataname + self._nas_bench = nas_bench + self.time_budget = time_budget + self.seen_archs = [] + self.sim_cost_time = 0 + self.real_cost_time = 0 + self.is_end = False + + def get_the_best(self): + assert len(self.seen_archs) > 0 + best_index, best_acc = -1, None + for arch_index in self.seen_archs: + info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp='200', is_random=True) + vacc = info['valid-accuracy'] + if best_acc is None or best_acc < vacc: + best_acc = vacc + best_index = arch_index + assert best_index != -1 + return best_index + + def compute(self, config, budget, **kwargs): + start_time = time.time() + structure = self.convert_func( config ) + arch_index = self._nas_bench.query_index_by_arch( structure ) + info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp='200', is_random=True) + cur_time = info['train-all-time'] + info['valid-per-time'] + cur_vacc = info['valid-accuracy'] + self.real_cost_time += (time.time() - start_time) + if self.sim_cost_time + cur_time <= self.time_budget and not self.is_end: + self.sim_cost_time += cur_time + self.seen_archs.append( arch_index ) + return ({'loss': 100 - float(cur_vacc), + 'info': {'seen-arch' : len(self.seen_archs), + 'sim-test-time' : self.sim_cost_time, + 'current-arch' : arch_index} + }) + else: + self.is_end = True + return ({'loss': 100, + 'info': {'seen-arch' : len(self.seen_archs), + 'sim-test-time' : self.sim_cost_time, + 'current-arch' : None} + }) + + +def main(xargs, api): + torch.set_num_threads(4) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) + + logger.log('{:} use api : {:}'.format(time_string(), api)) + search_space = get_search_spaces(xargs.search_space, 'nas-bench-301') + if xargs.search_space == 'tss': + cs = get_topology_config_space(xargs.max_nodes, search_space) + config2structure = config2topology_func(xargs.max_nodes) + else: + cs = get_size_config_space(xargs.max_nodes, search_space) + import pdb; pdb.set_trace() + + hb_run_id = '0' + + NS = hpns.NameServer(run_id=hb_run_id, host='localhost', port=0) + ns_host, ns_port = NS.start() + num_workers = 1 + + workers = [] + for i in range(num_workers): + w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, dataname=dataname, nas_bench=nas_bench, time_budget=xargs.time_budget, run_id=hb_run_id, id=i) + w.run(background=True) + workers.append(w) + + start_time = time.time() + bohb = BOHB(configspace=cs, + run_id=hb_run_id, + eta=3, min_budget=12, max_budget=200, + nameserver=ns_host, + nameserver_port=ns_port, + num_samples=xargs.num_samples, + random_fraction=xargs.random_fraction, bandwidth_factor=xargs.bandwidth_factor, + ping_interval=10, min_bandwidth=xargs.min_bandwidth) + + results = bohb.run(xargs.n_iters, min_n_workers=num_workers) + + bohb.shutdown(shutdown_workers=True) + NS.shutdown() + + real_cost_time = time.time() - start_time + + id2config = results.get_id2config_mapping() + incumbent = results.get_incumbent_id() + logger.log('Best found configuration: {:} within {:.3f} s'.format(id2config[incumbent]['config'], real_cost_time)) + best_arch = config2structure( id2config[incumbent]['config'] ) + + info = nas_bench.query_by_arch(best_arch, '200') + if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) + else : logger.log('{:}'.format(info)) + logger.log('-'*100) + + logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs))) + logger.close() + return logger.log_dir, nas_bench.query_index_by_arch( best_arch ), real_cost_time + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("BOHB: Robust and Efficient Hyperparameter Optimization at Scale") + parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') + # general arg + parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.') + parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).') + parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.') + # BOHB + parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function') + parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE') + parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function') + parser.add_argument('--random_fraction', default=.33, type=float, nargs='?', help='fraction of random configurations') + parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth') + parser.add_argument('--n_iters', default=300, type=int, nargs='?', help='number of iterations for optimization method') + # log + parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') + parser.add_argument('--rand_seed', type=int, help='manual seed') + args = parser.parse_args() + + if args.search_space == 'tss': + api = NASBench201API(verbose=False) + elif args.search_space == 'sss': + api = NASBench301API(verbose=False) + else: + raise ValueError('Invalid search space : {:}'.format(args.search_space)) + + args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'BOHB') + print('save-dir : {:}'.format(args.save_dir)) + + if args.rand_seed < 0: + save_dir, all_info = None, collections.OrderedDict() + for i in range(args.loops_if_rand): + print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand)) + args.rand_seed = random.randint(1, 100000) + save_dir, all_archs, all_total_times = main(args, api) + all_info[i] = {'all_archs': all_archs, + 'all_total_times': all_total_times} + save_path = save_dir / 'results.pth' + print('save into {:}'.format(save_path)) + torch.save(all_info, save_path) + else: + main(args, api) diff --git a/exps/algos-v2/regularized_ea.py b/exps/algos-v2/regularized_ea.py index 845bd28..8ebdfe4 100644 --- a/exps/algos-v2/regularized_ea.py +++ b/exps/algos-v2/regularized_ea.py @@ -214,8 +214,7 @@ def main(xargs, api): logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget)) history, current_best_index, total_times = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, api, xargs.dataset) logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_times[-1], time.time()-x_start_time)) - best_arch = max(history, key=lambda i: i.accuracy) - best_arch = best_arch.arch + best_arch = max(history, key=lambda x: x[0])[1] logger.log('{:} best arch is {:}'.format(time_string(), best_arch)) info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90') @@ -249,6 +248,7 @@ if __name__ == '__main__': args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'R-EA-SS{:}'.format(args.ea_sample_size)) print('save-dir : {:}'.format(args.save_dir)) + print('xargs : {:}'.format(args)) if args.rand_seed < 0: save_dir, all_info = None, collections.OrderedDict() diff --git a/exps/algos-v2/run-all.sh b/exps/algos-v2/run-all.sh index 41a907b..f900a67 100644 --- a/exps/algos-v2/run-all.sh +++ b/exps/algos-v2/run-all.sh @@ -11,8 +11,8 @@ for dataset in ${datasets} do for search_space in ${search_spaces} do - python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.001 + # python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.001 python ./exps/algos-v2/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 - python ./exps/algos-v2/random_wo_share.py --dataset ${dataset} --search_space ${search_space} + # python ./exps/algos-v2/random_wo_share.py --dataset ${dataset} --search_space ${search_space} done done diff --git a/exps/algos/BOHB.py b/exps/algos/BOHB.py index f4c6e50..c6c45f2 100644 --- a/exps/algos/BOHB.py +++ b/exps/algos/BOHB.py @@ -192,7 +192,7 @@ def main(xargs, nas_bench): if __name__ == '__main__': - parser = argparse.ArgumentParser("Regularized Evolution Algorithm") + parser = argparse.ArgumentParser("BOHB: Robust and Efficient Hyperparameter Optimization at Scale") parser.add_argument('--data_path', type=str, help='Path to dataset') parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') # channels and number-of-cells diff --git a/exps/experimental/vis-bench-algos.py b/exps/experimental/vis-bench-algos.py index 2cc1f51..13ab151 100644 --- a/exps/experimental/vis-bench-algos.py +++ b/exps/experimental/vis-bench-algos.py @@ -30,10 +30,10 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): alg2name, alg2path = OrderedDict(), OrderedDict() alg2name['REA'] = 'R-EA-SS3' alg2name['REINFORCE'] = 'REINFORCE-0.001' - # alg2name['RANDOM'] = 'RANDOM' + alg2name['RANDOM'] = 'RANDOM' for alg, name in alg2name.items(): alg2path[alg] = os.path.join(ss_dir, dataset, name, 'results.pth') - assert os.path.isfile(alg2path[alg]) + assert os.path.isfile(alg2path[alg]), 'invalid path : {:}'.format(alg2path[alg]) alg2data = OrderedDict() for alg, path in alg2path.items(): data = torch.load(path)