Upgrade NAS-Bench-201 to APIv1.3/FILEv1.1
This commit is contained in:
		| @@ -18,14 +18,18 @@ You can simply type `pip install nas-bench-201` to install our api. Please see s | |||||||
|  |  | ||||||
| ### Preparation and Download | ### Preparation and Download | ||||||
|  |  | ||||||
| The benchmark file of NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w). | [deprecated] The benchmark file of NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w). | ||||||
|  |  | ||||||
|  | [recommended] The benchmark file of NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1OOfVPpt-lA4u2HJrXbgrRd42IbfvJMyE). The files for model weight are too large (431G) and I need some time to upload it. Please be patient, thanks for your understanding. | ||||||
|  |  | ||||||
| You can move it to anywhere you want and send its path to our API for initialization. | You can move it to anywhere you want and send its path to our API for initialization. | ||||||
| - [2020.02.25] v1.0: `NAS-Bench-201-v1_0-e61699.pth`, where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial. | - [2020.02.25] APIv1.0/FILEv1.0: `NAS-Bench-201-v1_0-e61699.pth` (2.2G), where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial. | ||||||
| - [2020.02.25] v1.0: The full data of each architecture can be download from [ | - [2020.02.25] APIv1.0/FILEv1.0: The full data of each architecture can be download from [ | ||||||
| NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights. | NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights. | ||||||
| - [2020.02.25] v1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi). | - [2020.02.25] APIv1.0/FILEv1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi). | ||||||
| - [2020.03.09] v1.2: More robust API with more functions and descriptions | - [2020.03.09] APIv1.2/FILEv1.0: More robust API with more functions and descriptions | ||||||
| - [2020.04.01] v2.0: coming soon (results of two set of hyper-parameters avaliable on all three datasets) | - [2020.03.16] APIv1.3/FILEv1.1: `NAS-Bench-201-v1_1-096897.pth` (4.7G), where `096897` is the last six digits for this file. It contains information of more trials compared to `NAS-Bench-201-v1_0-e61699.pth`, especially all models trained by 12 epochs on all datasets are avaliable. | ||||||
|  | - [2020.06.01] APIv2.0/FILEv2.0: coming soon! | ||||||
|  |  | ||||||
|  |  | ||||||
| The training and evaluation data used in NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan (code:4fg7)](https://pan.baidu.com/s/1XAzavPKq3zcat1yBA1L2tQ). | The training and evaluation data used in NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan (code:4fg7)](https://pan.baidu.com/s/1XAzavPKq3zcat1yBA1L2tQ). | ||||||
| @@ -92,7 +96,9 @@ print(network) # show the structure of this architecture | |||||||
| ``` | ``` | ||||||
| If you want to load the trained weights of this created network, you need to use `api.get_net_param(123, ...)` to obtain the weights and then load it to the network. | If you want to load the trained weights of this created network, you need to use `api.get_net_param(123, ...)` to obtain the weights and then load it to the network. | ||||||
|  |  | ||||||
| 6. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201. | 6. `api.get_more_info(...)` can return the loss / accuracy / time on training / validation / test sets, which is very helpful. For more details, please look at the comments in the get_more_info function. | ||||||
|  |  | ||||||
|  | 7. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201. | ||||||
|  |  | ||||||
|  |  | ||||||
| ### Detailed Instruction | ### Detailed Instruction | ||||||
| @@ -213,12 +219,14 @@ If researchers can provide better results with different hyper-parameters, we ar | |||||||
| - [5] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/ENAS.sh     cifar10 1 -1` | - [5] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/ENAS.sh     cifar10 1 -1` | ||||||
| - [6] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 1 -1` | - [6] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 1 -1` | ||||||
| - [7] `bash ./scripts-search/algos/R-EA.sh cifar10 3 -1` | - [7] `bash ./scripts-search/algos/R-EA.sh cifar10 3 -1` | ||||||
| - [8] `bash ./scripts-search/algos/Random.sh -1` | - [8] `bash ./scripts-search/algos/Random.sh cifar10 -1` | ||||||
| - [9] `bash ./scripts-search/algos/REINFORCE.sh 0.5 -1` | - [9] `bash ./scripts-search/algos/REINFORCE.sh cifar10 0.5 -1` | ||||||
| - [10] `bash ./scripts-search/algos/BOHB.sh -1` | - [10] `bash ./scripts-search/algos/BOHB.sh cifar10 -1` | ||||||
|  |  | ||||||
| In commands [1-6], the first args `cifar10` indicates the dataset name, the second args `1` indicates the behavior of BN, and the first args `-1` indicates the random seed. | In commands [1-6], the first args `cifar10` indicates the dataset name, the second args `1` indicates the behavior of BN, and the first args `-1` indicates the random seed. | ||||||
|  |  | ||||||
|  | **Note that** since 2020 March 16, in these scripts, the default NAS-Bench-201 benchmark file has changed from `NAS-Bench-201-v1_0-e61699.pth` to `NAS-Bench-201-v1_1-096897.pth`, and thus the results could be slightly different from the original paper. | ||||||
|  |  | ||||||
|  |  | ||||||
| # Citation | # Citation | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,36 +1,84 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||||
| ######################################################## | ############################################################################################### | ||||||
| # python exps/NAS-Bench-201/test-weights.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth | # Before run these commands, the files must be properly put. | ||||||
| ######################################################## | # python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_0-e61699 | ||||||
|  | # python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 | ||||||
|  | ############################################################################################### | ||||||
| import os, sys, time, glob, random, argparse | import os, sys, time, glob, random, argparse | ||||||
| import numpy as np | import numpy as np | ||||||
| import torch | import torch | ||||||
|  | import torch.nn as nn | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  | from tqdm import tqdm | ||||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
| from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||||
| from nas_201_api import NASBench201API as API | from nas_201_api import NASBench201API as API | ||||||
|  | from log_utils import time_string | ||||||
|  | from models import get_cell_based_tiny_net | ||||||
| from utils import weight_watcher | from utils import weight_watcher | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(meta_file, weight_dir, save_dir): | def get_cor(A, B): | ||||||
|   import pdb; |   return float(np.corrcoef(A, B)[0,1]) | ||||||
|   pdb.set_trace() |  | ||||||
|  |  | ||||||
|  | def evaluate(api, weight_dir, data: str, use_12epochs_result: bool, valid_or_test: bool): | ||||||
|  |   norms, accs = [], [] | ||||||
|  |   for idx in tqdm(range(len(api))): | ||||||
|  |     info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False) | ||||||
|  |     if valid_or_test: | ||||||
|  |       accs.append(info['valid-accuracy']) | ||||||
|  |     else: | ||||||
|  |       accs.append(info['test-accuracy']) | ||||||
|  |     config = api.get_net_config(idx, data) | ||||||
|  |     net = get_cell_based_tiny_net(config) | ||||||
|  |     api.reload(weight_dir, idx) | ||||||
|  |     params = api.get_net_param(idx, data, None) | ||||||
|  |     cur_norms = [] | ||||||
|  |     for seed, param in params.items(): | ||||||
|  |       net.load_state_dict(param) | ||||||
|  |       _, summary = weight_watcher.analyze(net, alphas=False) | ||||||
|  |       cur_norms.append( summary['lognorm'] ) | ||||||
|  |     norms.append( float(np.mean(cur_norms)) ) | ||||||
|  |     api.clear_params(idx, use_12epochs_result) | ||||||
|  |   correlation = get_cor(norms, accs) | ||||||
|  |   print('For {:} with {:} epochs on {:} : the correlation is {:}'.format(data, 12 if use_12epochs_result else 200, 'valid' if valid_or_test else 'test', correlation)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(meta_file: str, weight_dir, save_dir): | ||||||
|  |   api = API(meta_file) | ||||||
|  |   datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'] | ||||||
|  |   print(time_string() + ' ' + '='*50) | ||||||
|  |   for data in datasets: | ||||||
|  |     nums = api.statistics(data, True) | ||||||
|  |     total = sum([k*v for k, v in nums.items()]) | ||||||
|  |     print('Using 012 epochs, trained on {:20s} : {:} trials in total ({:}).'.format(data, total, nums)) | ||||||
|  |   print(time_string() + ' ' + '='*50) | ||||||
|  |   for data in datasets: | ||||||
|  |     nums = api.statistics(data, False) | ||||||
|  |     total = sum([k*v for k, v in nums.items()]) | ||||||
|  |     print('Using 200 epochs, trained on {:20s} : {:} trials in total ({:}).'.format(data, total, nums)) | ||||||
|  |   print(time_string() + ' ' + '='*50) | ||||||
|  |  | ||||||
|  |   evaluate(api, weight_dir, 'cifar10-valid', False, True) | ||||||
|  |    | ||||||
|  |   print('{:} finish this test.'.format(time_string())) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") |   parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") | ||||||
|   parser.add_argument('--save_dir',   type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.') |   parser.add_argument('--save_dir',   type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.') | ||||||
|   parser.add_argument('--api_path',   type=str, default=None, help='The path to the NAS-Bench-201 benchmark file.') |   parser.add_argument('--base_path',  type=str, default=None, help='The path to the NAS-Bench-201 benchmark file and weight dir.') | ||||||
|   parser.add_argument('--weight_dir', type=str, default=None, help='The directory path to the weights of every NAS-Bench-201 architecture.') |  | ||||||
|   args = parser.parse_args() |   args = parser.parse_args() | ||||||
|  |  | ||||||
|   save_dir = Path(args.save_dir) |   save_dir = Path(args.save_dir) | ||||||
|   save_dir.mkdir(parents=True, exist_ok=True) |   save_dir.mkdir(parents=True, exist_ok=True) | ||||||
|   meta_file = Path(args.api_path) |   meta_file = Path(args.base_path + '.pth') | ||||||
|   weight_dir = Path(args.weight_dir) |   weight_dir = Path(args.base_path + '-archive') | ||||||
|   assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) |   assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) | ||||||
|  |   assert weight_dir.exists() and weight_dir.is_dir(), 'invalid path for weight dir : {:}'.format(weight_dir) | ||||||
|  |  | ||||||
|   main(meta_file, weight_dir, save_dir) |   main(str(meta_file), weight_dir, save_dir) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -50,10 +50,11 @@ def config2structure_func(max_nodes): | |||||||
|  |  | ||||||
| class MyWorker(Worker): | class MyWorker(Worker): | ||||||
|  |  | ||||||
|   def __init__(self, *args, convert_func=None, nas_bench=None, time_budget=None, **kwargs): |   def __init__(self, *args, convert_func=None, dataname=None, nas_bench=None, time_budget=None, **kwargs): | ||||||
|     super().__init__(*args, **kwargs) |     super().__init__(*args, **kwargs) | ||||||
|     self.convert_func   = convert_func |     self.convert_func   = convert_func | ||||||
|     self.nas_bench      = nas_bench |     self._dataname      = dataname | ||||||
|  |     self._nas_bench     = nas_bench | ||||||
|     self.time_budget    = time_budget |     self.time_budget    = time_budget | ||||||
|     self.seen_archs     = [] |     self.seen_archs     = [] | ||||||
|     self.sim_cost_time  = 0 |     self.sim_cost_time  = 0 | ||||||
| @@ -64,7 +65,7 @@ class MyWorker(Worker): | |||||||
|     assert len(self.seen_archs) > 0 |     assert len(self.seen_archs) > 0 | ||||||
|     best_index, best_acc = -1, None |     best_index, best_acc = -1, None | ||||||
|     for arch_index in self.seen_archs: |     for arch_index in self.seen_archs: | ||||||
|       info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True) |       info = self._nas_bench.get_more_info(arch_index, self._dataname, None, True, True) | ||||||
|       vacc = info['valid-accuracy'] |       vacc = info['valid-accuracy'] | ||||||
|       if best_acc is None or best_acc < vacc: |       if best_acc is None or best_acc < vacc: | ||||||
|         best_acc = vacc |         best_acc = vacc | ||||||
| @@ -75,8 +76,8 @@ class MyWorker(Worker): | |||||||
|   def compute(self, config, budget, **kwargs): |   def compute(self, config, budget, **kwargs): | ||||||
|     start_time = time.time() |     start_time = time.time() | ||||||
|     structure  = self.convert_func( config ) |     structure  = self.convert_func( config ) | ||||||
|     arch_index = self.nas_bench.query_index_by_arch( structure ) |     arch_index = self._nas_bench.query_index_by_arch( structure ) | ||||||
|     info       = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True) |     info       = self._nas_bench.get_more_info(arch_index, self._dataname, None, True, True) | ||||||
|     cur_time   = info['train-all-time'] + info['valid-per-time'] |     cur_time   = info['train-all-time'] + info['valid-per-time'] | ||||||
|     cur_vacc   = info['valid-accuracy'] |     cur_vacc   = info['valid-accuracy'] | ||||||
|     self.real_cost_time += (time.time() - start_time) |     self.real_cost_time += (time.time() - start_time) | ||||||
| @@ -106,7 +107,10 @@ def main(xargs, nas_bench): | |||||||
|   prepare_seed(xargs.rand_seed) |   prepare_seed(xargs.rand_seed) | ||||||
|   logger = prepare_logger(args) |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' |   if xargs.dataset == 'cifar10': | ||||||
|  |     dataname = 'cifar10-valid' | ||||||
|  |   else: | ||||||
|  |     dataname = xargs.dataset | ||||||
|   if xargs.data_path is not None: |   if xargs.data_path is not None: | ||||||
|     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) |     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||||
| @@ -148,7 +152,7 @@ def main(xargs, nas_bench): | |||||||
|   #logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string())) |   #logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string())) | ||||||
|   workers = [] |   workers = [] | ||||||
|   for i in range(num_workers): |   for i in range(num_workers): | ||||||
|     w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_budget=xargs.time_budget, run_id=hb_run_id, id=i) |     w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, dataname=dataname, nas_bench=nas_bench, time_budget=xargs.time_budget, run_id=hb_run_id, id=i) | ||||||
|     w.run(background=True) |     w.run(background=True) | ||||||
|     workers.append(w) |     workers.append(w) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -28,7 +28,10 @@ def main(xargs, nas_bench): | |||||||
|   prepare_seed(xargs.rand_seed) |   prepare_seed(xargs.rand_seed) | ||||||
|   logger = prepare_logger(args) |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' |   if xargs.dataset == 'cifar10': | ||||||
|  |     dataname = 'cifar10-valid' | ||||||
|  |   else: | ||||||
|  |     dataname = xargs.dataset | ||||||
|   if xargs.data_path is not None: |   if xargs.data_path is not None: | ||||||
|     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) |     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||||
| @@ -62,7 +65,7 @@ def main(xargs, nas_bench): | |||||||
|   #for idx in range(xargs.random_num): |   #for idx in range(xargs.random_num): | ||||||
|   while total_time_cost < xargs.time_budget: |   while total_time_cost < xargs.time_budget: | ||||||
|     arch = random_arch() |     arch = random_arch() | ||||||
|     accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info) |     accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname) | ||||||
|     if total_time_cost + cost_time > xargs.time_budget: break |     if total_time_cost + cost_time > xargs.time_budget: break | ||||||
|     else: total_time_cost += cost_time |     else: total_time_cost += cost_time | ||||||
|     history.append(arch) |     history.append(arch) | ||||||
|   | |||||||
| @@ -33,19 +33,21 @@ class Model(object): | |||||||
|  |  | ||||||
| # This function is to mimic the training and evaluatinig procedure for a single architecture `arch`. | # This function is to mimic the training and evaluatinig procedure for a single architecture `arch`. | ||||||
| # The time_cost is calculated as the total training time for a few (e.g., 12 epochs) plus the evaluation time for one epoch. | # The time_cost is calculated as the total training time for a few (e.g., 12 epochs) plus the evaluation time for one epoch. | ||||||
| # For use_converged_LR = True, the architecture is trained for 12 epochs, with LR being decaded from 0.1 to 0. | # For use_012_epoch_training = True, the architecture is trained for 12 epochs, with LR being decaded from 0.1 to 0. | ||||||
| #       In this case, the LR schedular is converged. | #       In this case, the LR schedular is converged. | ||||||
| # For use_converged_LR = False, the architecture is planed to be trained for 200 epochs, but we early stop its procedure. | # For use_012_epoch_training = False, the architecture is planed to be trained for 200 epochs, but we early stop its procedure. | ||||||
| #        | #        | ||||||
| def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_converged_LR=True): | def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_012_epoch_training=True): | ||||||
|   if use_converged_LR and nas_bench is not None: |  | ||||||
|  |   if use_012_epoch_training and nas_bench is not None: | ||||||
|     arch_index = nas_bench.query_index_by_arch( arch ) |     arch_index = nas_bench.query_index_by_arch( arch ) | ||||||
|     assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) |     assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) | ||||||
|     info = nas_bench.get_more_info(arch_index, dataname, None, True) |     info = nas_bench.get_more_info(arch_index, dataname, None, True) | ||||||
|     valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] |     valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] | ||||||
|     #_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs |     #_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs | ||||||
|   elif not use_converged_LR and nas_bench is not None: |   elif not use_012_epoch_training and nas_bench is not None: | ||||||
|     # Please use `use_converged_LR=False` for cifar10 only. |     # Please contact me if you want to use the following logic, because it has some potential issues. | ||||||
|  |     # Please use `use_012_epoch_training=False` for cifar10 only. | ||||||
|     # It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details) |     # It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details) | ||||||
|     arch_index, nepoch = nas_bench.query_index_by_arch( arch ), 25 |     arch_index, nepoch = nas_bench.query_index_by_arch( arch ), 25 | ||||||
|     assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) |     assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) | ||||||
| @@ -64,7 +66,7 @@ def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_co | |||||||
|     try: |     try: | ||||||
|       valid_acc, time_cost = info['valid-accuracy'], estimated_train_cost + estimated_valid_cost |       valid_acc, time_cost = info['valid-accuracy'], estimated_train_cost + estimated_valid_cost | ||||||
|     except: |     except: | ||||||
|       valid_acc, time_cost = info['est-valid-accuracy'], estimated_train_cost + estimated_valid_cost |       valid_acc, time_cost = info['valtest-accuracy'], estimated_train_cost + estimated_valid_cost | ||||||
|   else: |   else: | ||||||
|     # train a model from scratch. |     # train a model from scratch. | ||||||
|     raise ValueError('NOT IMPLEMENT YET') |     raise ValueError('NOT IMPLEMENT YET') | ||||||
| @@ -127,7 +129,7 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran | |||||||
|   while len(population) < population_size: |   while len(population) < population_size: | ||||||
|     model = Model() |     model = Model() | ||||||
|     model.arch = random_arch() |     model.arch = random_arch() | ||||||
|     model.accuracy, time_cost = train_and_eval(model.arch, nas_bench, extra_info) |     model.accuracy, time_cost = train_and_eval(model.arch, nas_bench, extra_info, dataname) | ||||||
|     population.append(model) |     population.append(model) | ||||||
|     history.append(model) |     history.append(model) | ||||||
|     total_time_cost += time_cost |     total_time_cost += time_cost | ||||||
| @@ -152,7 +154,7 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran | |||||||
|     child = Model() |     child = Model() | ||||||
|     child.arch = mutate_arch(parent.arch) |     child.arch = mutate_arch(parent.arch) | ||||||
|     total_time_cost += time.time() - start_time |     total_time_cost += time.time() - start_time | ||||||
|     child.accuracy, time_cost = train_and_eval(child.arch, nas_bench, extra_info) |     child.accuracy, time_cost = train_and_eval(child.arch, nas_bench, extra_info, dataname) | ||||||
|     if total_time_cost + time_cost > time_budget: # return |     if total_time_cost + time_cost > time_budget: # return | ||||||
|       return history, total_time_cost |       return history, total_time_cost | ||||||
|     else: |     else: | ||||||
| @@ -174,7 +176,6 @@ def main(xargs, nas_bench): | |||||||
|   prepare_seed(xargs.rand_seed) |   prepare_seed(xargs.rand_seed) | ||||||
|   logger = prepare_logger(args) |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' |  | ||||||
|   if xargs.dataset == 'cifar10': |   if xargs.dataset == 'cifar10': | ||||||
|     dataname = 'cifar10-valid' |     dataname = 'cifar10-valid' | ||||||
|   else: |   else: | ||||||
|   | |||||||
| @@ -98,7 +98,10 @@ def main(xargs, nas_bench): | |||||||
|   prepare_seed(xargs.rand_seed) |   prepare_seed(xargs.rand_seed) | ||||||
|   logger = prepare_logger(args) |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' |   if xargs.dataset == 'cifar10': | ||||||
|  |     dataname = 'cifar10-valid' | ||||||
|  |   else: | ||||||
|  |     dataname = xargs.dataset | ||||||
|   if xargs.data_path is not None: |   if xargs.data_path is not None: | ||||||
|     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) |     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||||
| @@ -148,7 +151,7 @@ def main(xargs, nas_bench): | |||||||
|     start_time = time.time() |     start_time = time.time() | ||||||
|     log_prob, action = select_action( policy ) |     log_prob, action = select_action( policy ) | ||||||
|     arch   = policy.generate_arch( action ) |     arch   = policy.generate_arch( action ) | ||||||
|     reward, cost_time = train_and_eval(arch, nas_bench, extra_info) |     reward, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname) | ||||||
|     trace.append( (reward, arch) ) |     trace.append( (reward, arch) ) | ||||||
|     # accumulate time |     # accumulate time | ||||||
|     if total_costs + cost_time < xargs.time_budget: |     if total_costs + cost_time < xargs.time_budget: | ||||||
|   | |||||||
| @@ -5,4 +5,5 @@ from .api import NASBench201API | |||||||
| from .api import ArchResults, ResultsCount | from .api import ArchResults, ResultsCount | ||||||
|  |  | ||||||
| # NAS_BENCH_201_API_VERSION="v1.1"  # [2020.02.25] | # NAS_BENCH_201_API_VERSION="v1.1"  # [2020.02.25] | ||||||
| NAS_BENCH_201_API_VERSION="v1.2"  # [2020.03.09] | # NAS_BENCH_201_API_VERSION="v1.2"  # [2020.03.09] | ||||||
|  | NAS_BENCH_201_API_VERSION="v1.3"  # [2020.03.16] | ||||||
|   | |||||||
| @@ -3,11 +3,14 @@ | |||||||
| ############################################################################################ | ############################################################################################ | ||||||
| # NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 # | # NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 # | ||||||
| ############################################################################################ | ############################################################################################ | ||||||
|  | # The history of benchmark files: | ||||||
| # [2020.02.25] NAS-Bench-201-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID. | # [2020.02.25] NAS-Bench-201-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID. | ||||||
| # [2020.03.08] Next version (coming soon) | # [2020.03.16] NAS-Bench-201-v1_1-096897.pth : 2225 architectures are trained once, 5439 archiitectures are trained twice, 7961 architectures are trained three times on all training sets. For the hyper-parameters with the total epochs of 12, each model is trained on CIFAR-10, CIFAR-100, ImageNet16-120 once, and is trained on CIFAR-10-VALID twice. | ||||||
| # | # | ||||||
|  | # I'm still actively enhancing this benchmark. Please feel free to contact me if you have any question w.r.t. NAS-Bench-201. | ||||||
| # | # | ||||||
| import os, copy, random, torch, numpy as np | import os, copy, random, torch, numpy as np | ||||||
|  | from pathlib import Path | ||||||
| from typing import List, Text, Union, Dict | from typing import List, Text, Union, Dict | ||||||
| from collections import OrderedDict, defaultdict | from collections import OrderedDict, defaultdict | ||||||
|  |  | ||||||
| @@ -44,9 +47,12 @@ class NASBench201API(object): | |||||||
|  |  | ||||||
|   """ The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """ |   """ The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """ | ||||||
|   def __init__(self, file_path_or_dict: Union[Text, Dict], verbose: bool=True): |   def __init__(self, file_path_or_dict: Union[Text, Dict], verbose: bool=True): | ||||||
|     if isinstance(file_path_or_dict, str): |     self.filename = None | ||||||
|  |     if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path): | ||||||
|  |       file_path_or_dict = str(file_path_or_dict) | ||||||
|       if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict)) |       if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict)) | ||||||
|       assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict) |       assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict) | ||||||
|  |       self.filename = Path(file_path_or_dict).name | ||||||
|       file_path_or_dict = torch.load(file_path_or_dict) |       file_path_or_dict = torch.load(file_path_or_dict) | ||||||
|     elif isinstance(file_path_or_dict, dict): |     elif isinstance(file_path_or_dict, dict): | ||||||
|       file_path_or_dict = copy.deepcopy( file_path_or_dict ) |       file_path_or_dict = copy.deepcopy( file_path_or_dict ) | ||||||
| @@ -76,7 +82,7 @@ class NASBench201API(object): | |||||||
|     return len(self.meta_archs) |     return len(self.meta_archs) | ||||||
|  |  | ||||||
|   def __repr__(self): |   def __repr__(self): | ||||||
|     return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs))) |     return ('{name}({num}/{total} architectures, file={filename})'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs), filename=self.filename)) | ||||||
|  |  | ||||||
|   def random(self): |   def random(self): | ||||||
|     """Return a random index of all architectures.""" |     """Return a random index of all architectures.""" | ||||||
| @@ -98,9 +104,10 @@ class NASBench201API(object): | |||||||
|     else: arch_index = -1 |     else: arch_index = -1 | ||||||
|     return arch_index |     return arch_index | ||||||
|  |  | ||||||
|   # Overwrite all information of the 'index'-th architecture in the search space. |  | ||||||
|   # It will load its data from 'archive_root'. |  | ||||||
|   def reload(self, archive_root: Text, index: int): |   def reload(self, archive_root: Text, index: int): | ||||||
|  |     """Overwrite all information of the 'index'-th architecture in the search space. | ||||||
|  |          It will load its data from 'archive_root'. | ||||||
|  |     """ | ||||||
|     assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root) |     assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root) | ||||||
|     xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index)) |     xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index)) | ||||||
|     assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index) |     assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index) | ||||||
| @@ -110,6 +117,13 @@ class NASBench201API(object): | |||||||
|     self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] ) |     self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] ) | ||||||
|     self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] ) |     self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] ) | ||||||
|  |  | ||||||
|  |   def clear_params(self, index: int, use_12epochs_result: bool): | ||||||
|  |     """Remove the architecture's weights to save memory.""" | ||||||
|  |     if use_12epochs_result: arch2infos = self.arch2infos_less | ||||||
|  |     else                  : arch2infos = self.arch2infos_full | ||||||
|  |     archresult = arch2infos[index] | ||||||
|  |     archresult.clear_params() | ||||||
|  |    | ||||||
|   # This function is used to query the information of a specific archiitecture |   # This function is used to query the information of a specific archiitecture | ||||||
|   # 'arch' can be an architecture index or an architecture string |   # 'arch' can be an architecture index or an architecture string | ||||||
|   # When use_12epochs_result=True, the hyper-parameters used to train a model are in 'configs/nas-benchmark/CIFAR.config' |   # When use_12epochs_result=True, the hyper-parameters used to train a model are in 'configs/nas-benchmark/CIFAR.config' | ||||||
| @@ -162,6 +176,7 @@ class NASBench201API(object): | |||||||
|     return archInfo |     return archInfo | ||||||
|  |  | ||||||
|   def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, use_12epochs_result=False): |   def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, use_12epochs_result=False): | ||||||
|  |     """Find the architecture with the highest accuracy based on some constraints.""" | ||||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less |     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full |     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||||
|     best_index, highest_accuracy = -1, None |     best_index, highest_accuracy = -1, None | ||||||
| @@ -255,6 +270,65 @@ class NASBench201API(object): | |||||||
|   # `is_random` |   # `is_random` | ||||||
|   #   When is_random=True, the performance of a random architecture will be returned |   #   When is_random=True, the performance of a random architecture will be returned | ||||||
|   #   When is_random=False, the performanceo of all trials will be averaged. |   #   When is_random=False, the performanceo of all trials will be averaged. | ||||||
|  |   def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True): | ||||||
|  |     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||||
|  |     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||||
|  |     archresult = arch2infos[index] | ||||||
|  |     # if randomly select one trial, select the seed at first | ||||||
|  |     if isinstance(is_random, bool) and is_random: | ||||||
|  |       seeds = archresult.get_dataset_seeds(dataset) | ||||||
|  |       is_random = random.choice(seeds) | ||||||
|  |     # collect the training information | ||||||
|  |     train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random) | ||||||
|  |     total = train_info['iepoch'] + 1 | ||||||
|  |     xinfo = {'train-loss'    : train_info['loss'], | ||||||
|  |              'train-accuracy': train_info['accuracy'], | ||||||
|  |              'train-per-time': train_info['all_time'] / total, | ||||||
|  |              'train-all-time': train_info['all_time']} | ||||||
|  |     # collect the evaluation information | ||||||
|  |     if dataset == 'cifar10-valid': | ||||||
|  |       valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) | ||||||
|  |       try: | ||||||
|  |         test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) | ||||||
|  |       except: | ||||||
|  |         test_info = None | ||||||
|  |       valtest_info = None | ||||||
|  |     else: | ||||||
|  |       try: # collect results on the proposed test set | ||||||
|  |         if dataset == 'cifar10': | ||||||
|  |           test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) | ||||||
|  |         else: | ||||||
|  |           test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) | ||||||
|  |       except: | ||||||
|  |         test_info = None | ||||||
|  |       try: # collect results on the proposed validation set | ||||||
|  |         valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) | ||||||
|  |       except: | ||||||
|  |         valid_info = None | ||||||
|  |       try: | ||||||
|  |         if dataset != 'cifar10': | ||||||
|  |           valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) | ||||||
|  |         else: | ||||||
|  |           valtest_info = None | ||||||
|  |       except: | ||||||
|  |         valtest_info = None | ||||||
|  |     if valid_info is not None: | ||||||
|  |       xinfo['valid-loss'] = valid_info['loss'] | ||||||
|  |       xinfo['valid-accuracy'] = valid_info['accuracy'] | ||||||
|  |       xinfo['valid-per-time'] = valid_info['all_time'] / total | ||||||
|  |       xinfo['valid-all-time'] = valid_info['all_time'] | ||||||
|  |     if test_info is not None: | ||||||
|  |       xinfo['test-loss'] = test_info['loss'] | ||||||
|  |       xinfo['test-accuracy'] = test_info['accuracy'] | ||||||
|  |       xinfo['test-per-time'] = test_info['all_time'] / total | ||||||
|  |       xinfo['test-all-time'] = test_info['all_time'] | ||||||
|  |     if valtest_info is not None: | ||||||
|  |       xinfo['valtest-loss'] = valtest_info['loss'] | ||||||
|  |       xinfo['valtest-accuracy'] = valtest_info['accuracy'] | ||||||
|  |       xinfo['valtest-per-time'] = valtest_info['all_time'] / total | ||||||
|  |       xinfo['valtest-all-time'] = valtest_info['all_time'] | ||||||
|  |     return xinfo | ||||||
|  |   """ # The following logic is deprecated after March 15 2020, where the benchmark file upgrades from NAS-Bench-201-v1_0-e61699.pth to NAS-Bench-201-v1_1-096897.pth. | ||||||
|   def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True): |   def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True): | ||||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less |     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full |     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||||
| @@ -312,6 +386,7 @@ class NASBench201API(object): | |||||||
|         xifo['est-valid-loss'] = est_valid_info['loss'] |         xifo['est-valid-loss'] = est_valid_info['loss'] | ||||||
|         xifo['est-valid-accuracy'] = est_valid_info['accuracy'] |         xifo['est-valid-accuracy'] = est_valid_info['accuracy'] | ||||||
|       return xifo |       return xifo | ||||||
|  |   """ | ||||||
|  |  | ||||||
|  |  | ||||||
|   def show(self, index: int = -1) -> None: |   def show(self, index: int = -1) -> None: | ||||||
| @@ -349,6 +424,26 @@ class NASBench201API(object): | |||||||
|         print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs))) |         print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs))) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |   def statistics(self, dataset: Text, use_12epochs_result: bool) -> Dict[int, int]: | ||||||
|  |     """ | ||||||
|  |     This function will count the number of total trials. | ||||||
|  |     """ | ||||||
|  |     valid_datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'] | ||||||
|  |     if dataset not in valid_datasets: | ||||||
|  |       raise ValueError('{:} not in {:}'.format(dataset, valid_datasets)) | ||||||
|  |     if use_12epochs_result: arch2infos = self.arch2infos_less | ||||||
|  |     else                  : arch2infos = self.arch2infos_full | ||||||
|  |     nums = defaultdict(lambda: 0) | ||||||
|  |     for index in range(len(self)): | ||||||
|  |       archInfo = arch2infos[index] | ||||||
|  |       dataset_seed = archInfo.dataset_seed | ||||||
|  |       if dataset not in dataset_seed: | ||||||
|  |         nums[0] += 1 | ||||||
|  |       else: | ||||||
|  |         nums[len(dataset_seed[dataset])] += 1 | ||||||
|  |     return dict(nums) | ||||||
|  |  | ||||||
|  |  | ||||||
|   @staticmethod |   @staticmethod | ||||||
|   def str2lists(arch_str: Text) -> List[tuple]: |   def str2lists(arch_str: Text) -> List[tuple]: | ||||||
|     """ |     """ | ||||||
|   | |||||||
| @@ -2,9 +2,9 @@ | |||||||
| # bash ./scripts-search/algos/BOHB.sh -1 | # bash ./scripts-search/algos/BOHB.sh -1 | ||||||
| echo script name: $0 | echo script name: $0 | ||||||
| echo $# arguments | echo $# arguments | ||||||
| if [ "$#" -ne 1 ] ;then | if [ "$#" -ne 2 ] ;then | ||||||
|   echo "Input illegal number of parameters " $# |   echo "Input illegal number of parameters " $# | ||||||
|   echo "Need 1 parameters for seed" |   echo "Need 2 parameters for dataset and seed" | ||||||
|   exit 1 |   exit 1 | ||||||
| fi | fi | ||||||
| if [ "$TORCH_HOME" = "" ]; then | if [ "$TORCH_HOME" = "" ]; then | ||||||
| @@ -14,12 +14,14 @@ else | |||||||
|   echo "TORCH_HOME : $TORCH_HOME" |   echo "TORCH_HOME : $TORCH_HOME" | ||||||
| fi | fi | ||||||
|  |  | ||||||
| dataset=cifar10 | dataset=$1 | ||||||
| seed=$1 | seed=$2 | ||||||
| channel=16 | channel=16 | ||||||
| num_cells=5 | num_cells=5 | ||||||
| max_nodes=4 | max_nodes=4 | ||||||
| space=nas-bench-201 | space=nas-bench-201 | ||||||
|  | #benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth | ||||||
|  | benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth | ||||||
|  |  | ||||||
| save_dir=./output/search-cell-${space}/BOHB-${dataset} | save_dir=./output/search-cell-${space}/BOHB-${dataset} | ||||||
|  |  | ||||||
| @@ -27,7 +29,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \ | |||||||
| 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | ||||||
| 	--dataset ${dataset} \ | 	--dataset ${dataset} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ | 	--arch_nas_dataset ${benchmark_file} \ | ||||||
| 	--time_budget 12000  \ | 	--time_budget 12000  \ | ||||||
| 	--n_iters 50 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 \ | 	--n_iters 50 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 \ | ||||||
| 	--workers 4 --print_freq 200 --rand_seed ${seed} | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
|   | |||||||
| @@ -27,6 +27,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then | |||||||
| else | else | ||||||
|   data_path="$TORCH_HOME/cifar.python/ImageNet16" |   data_path="$TORCH_HOME/cifar.python/ImageNet16" | ||||||
| fi | fi | ||||||
|  | #benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth | ||||||
|  | benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth | ||||||
|  |  | ||||||
| save_dir=./output/search-cell-${space}/DARTS-V1-${dataset}-BN${BN} | save_dir=./output/search-cell-${space}/DARTS-V1-${dataset}-BN${BN} | ||||||
|  |  | ||||||
| @@ -35,7 +37,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V1.py \ | |||||||
| 	--dataset ${dataset} --data_path ${data_path} \ | 	--dataset ${dataset} --data_path ${data_path} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--config_path configs/nas-benchmark/algos/DARTS.config \ | 	--config_path configs/nas-benchmark/algos/DARTS.config \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ | 	--arch_nas_dataset ${benchmark_file} \ | ||||||
| 	--track_running_stats ${BN} \ | 	--track_running_stats ${BN} \ | ||||||
| 	--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ | 	--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ | ||||||
| 	--workers 4 --print_freq 200 --rand_seed ${seed} | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
|   | |||||||
| @@ -27,6 +27,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then | |||||||
| else | else | ||||||
|   data_path="$TORCH_HOME/cifar.python/ImageNet16" |   data_path="$TORCH_HOME/cifar.python/ImageNet16" | ||||||
| fi | fi | ||||||
|  | #benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth | ||||||
|  | benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth | ||||||
|  |  | ||||||
| save_dir=./output/search-cell-${space}/DARTS-V2-${dataset}-BN${BN} | save_dir=./output/search-cell-${space}/DARTS-V2-${dataset}-BN${BN} | ||||||
|  |  | ||||||
| @@ -35,7 +37,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V2.py \ | |||||||
| 	--dataset ${dataset} --data_path ${data_path} \ | 	--dataset ${dataset} --data_path ${data_path} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--config_path configs/nas-benchmark/algos/DARTS.config \ | 	--config_path configs/nas-benchmark/algos/DARTS.config \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ | 	--arch_nas_dataset ${benchmark_file} \ | ||||||
| 	--track_running_stats ${BN} \ | 	--track_running_stats ${BN} \ | ||||||
| 	--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ | 	--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ | ||||||
| 	--workers 4 --print_freq 200 --rand_seed ${seed} | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
|   | |||||||
| @@ -28,6 +28,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then | |||||||
| else | else | ||||||
|   data_path="$TORCH_HOME/cifar.python/ImageNet16" |   data_path="$TORCH_HOME/cifar.python/ImageNet16" | ||||||
| fi | fi | ||||||
|  | #benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth | ||||||
|  | benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth | ||||||
|  |  | ||||||
| save_dir=./output/search-cell-${space}/ENAS-${dataset}-BN${BN} | save_dir=./output/search-cell-${space}/ENAS-${dataset}-BN${BN} | ||||||
|  |  | ||||||
| @@ -35,7 +37,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/ENAS.py \ | |||||||
| 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | ||||||
| 	--dataset ${dataset} --data_path ${data_path} \ | 	--dataset ${dataset} --data_path ${data_path} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ | 	--arch_nas_dataset ${benchmark_file} \ | ||||||
| 	--track_running_stats ${BN} \ | 	--track_running_stats ${BN} \ | ||||||
| 	--config_path ./configs/nas-benchmark/algos/ENAS.config \ | 	--config_path ./configs/nas-benchmark/algos/ENAS.config \ | ||||||
| 	--controller_entropy_weight 0.0001 \ | 	--controller_entropy_weight 0.0001 \ | ||||||
|   | |||||||
| @@ -27,6 +27,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then | |||||||
| else | else | ||||||
|   data_path="$TORCH_HOME/cifar.python/ImageNet16" |   data_path="$TORCH_HOME/cifar.python/ImageNet16" | ||||||
| fi | fi | ||||||
|  | #benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth | ||||||
|  | benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth | ||||||
|  |  | ||||||
| save_dir=./output/search-cell-${space}/GDAS-${dataset}-BN${BN} | save_dir=./output/search-cell-${space}/GDAS-${dataset}-BN${BN} | ||||||
|  |  | ||||||
| @@ -34,7 +36,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/GDAS.py \ | |||||||
| 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | ||||||
| 	--dataset ${dataset} --data_path ${data_path} \ | 	--dataset ${dataset} --data_path ${data_path} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ | 	--arch_nas_dataset ${benchmark_file} \ | ||||||
| 	--config_path configs/nas-benchmark/algos/GDAS.config \ | 	--config_path configs/nas-benchmark/algos/GDAS.config \ | ||||||
| 	--tau_max 10 --tau_min 0.1 --track_running_stats ${BN} \ | 	--tau_max 10 --tau_min 0.1 --track_running_stats ${BN} \ | ||||||
| 	--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ | 	--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ | ||||||
|   | |||||||
| @@ -23,6 +23,8 @@ channel=16 | |||||||
| num_cells=5 | num_cells=5 | ||||||
| max_nodes=4 | max_nodes=4 | ||||||
| space=nas-bench-201 | space=nas-bench-201 | ||||||
|  | #benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth | ||||||
|  | benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth | ||||||
|  |  | ||||||
| save_dir=./output/search-cell-${space}/R-EA-${dataset}-SS${sample_size} | save_dir=./output/search-cell-${space}/R-EA-${dataset}-SS${sample_size} | ||||||
|  |  | ||||||
| @@ -30,7 +32,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \ | |||||||
| 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | ||||||
| 	--dataset ${dataset} \ | 	--dataset ${dataset} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ | 	--arch_nas_dataset ${benchmark_file} \ | ||||||
| 	--time_budget 12000 \ | 	--time_budget 12000 \ | ||||||
| 	--ea_cycles 200 --ea_population 10 --ea_sample_size ${sample_size} --ea_fast_by_api 1 \ | 	--ea_cycles 200 --ea_population 10 --ea_sample_size ${sample_size} --ea_fast_by_api 1 \ | ||||||
| 	--workers 4 --print_freq 200 --rand_seed ${seed} | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
|   | |||||||
| @@ -28,6 +28,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then | |||||||
| else | else | ||||||
|   data_path="$TORCH_HOME/cifar.python/ImageNet16" |   data_path="$TORCH_HOME/cifar.python/ImageNet16" | ||||||
| fi | fi | ||||||
|  | #benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth | ||||||
|  | benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth | ||||||
|  |  | ||||||
| save_dir=./output/search-cell-${space}/RANDOM-NAS-${dataset}-BN${BN} | save_dir=./output/search-cell-${space}/RANDOM-NAS-${dataset}-BN${BN} | ||||||
|  |  | ||||||
| @@ -36,7 +38,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/RANDOM-NAS.py \ | |||||||
| 	--dataset ${dataset} --data_path ${data_path} \ | 	--dataset ${dataset} --data_path ${data_path} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--track_running_stats ${BN} \ | 	--track_running_stats ${BN} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ | 	--arch_nas_dataset ${benchmark_file} \ | ||||||
| 	--config_path ./configs/nas-benchmark/algos/RANDOM.config \ | 	--config_path ./configs/nas-benchmark/algos/RANDOM.config \ | ||||||
| 	--select_num 100 \ | 	--select_num 100 \ | ||||||
| 	--workers 4 --print_freq 200 --rand_seed ${seed} | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
|   | |||||||
| @@ -2,9 +2,9 @@ | |||||||
| # bash ./scripts-search/algos/REINFORCE.sh 0.001 -1 | # bash ./scripts-search/algos/REINFORCE.sh 0.001 -1 | ||||||
| echo script name: $0 | echo script name: $0 | ||||||
| echo $# arguments | echo $# arguments | ||||||
| if [ "$#" -ne 2 ] ;then | if [ "$#" -ne 3 ] ;then | ||||||
|   echo "Input illegal number of parameters " $# |   echo "Input illegal number of parameters " $# | ||||||
|   echo "Need 2 parameters for LR and seed" |   echo "Need 3 parameters for dataset, LR, and seed" | ||||||
|   exit 1 |   exit 1 | ||||||
| fi | fi | ||||||
| if [ "$TORCH_HOME" = "" ]; then | if [ "$TORCH_HOME" = "" ]; then | ||||||
| @@ -14,13 +14,15 @@ else | |||||||
|   echo "TORCH_HOME : $TORCH_HOME" |   echo "TORCH_HOME : $TORCH_HOME" | ||||||
| fi | fi | ||||||
|  |  | ||||||
| dataset=cifar10 | dataset=$1 | ||||||
| LR=$1 | LR=$2 | ||||||
| seed=$2 | seed=$3 | ||||||
| channel=16 | channel=16 | ||||||
| num_cells=5 | num_cells=5 | ||||||
| max_nodes=4 | max_nodes=4 | ||||||
| space=nas-bench-201 | space=nas-bench-201 | ||||||
|  | #benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth | ||||||
|  | benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth | ||||||
|  |  | ||||||
| save_dir=./output/search-cell-${space}/REINFORCE-${dataset}-${LR} | save_dir=./output/search-cell-${space}/REINFORCE-${dataset}-${LR} | ||||||
|  |  | ||||||
| @@ -28,7 +30,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/reinforce.py \ | |||||||
| 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | ||||||
| 	--dataset ${dataset} \ | 	--dataset ${dataset} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ | 	--arch_nas_dataset ${benchmark_file} \ | ||||||
| 	--time_budget 12000 \ | 	--time_budget 12000 \ | ||||||
| 	--learning_rate ${LR} --EMA_momentum 0.9 \ | 	--learning_rate ${LR} --EMA_momentum 0.9 \ | ||||||
| 	--workers 4 --print_freq 200 --rand_seed ${seed} | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
|   | |||||||
| @@ -2,9 +2,9 @@ | |||||||
| # bash ./scripts-search/algos/Random.sh -1 | # bash ./scripts-search/algos/Random.sh -1 | ||||||
| echo script name: $0 | echo script name: $0 | ||||||
| echo $# arguments | echo $# arguments | ||||||
| if [ "$#" -ne 1 ] ;then | if [ "$#" -ne 2 ] ;then | ||||||
|   echo "Input illegal number of parameters " $# |   echo "Input illegal number of parameters " $# | ||||||
|   echo "Need 1 parameters for seed" |   echo "Need 2 parameters for dataset and seed" | ||||||
|   exit 1 |   exit 1 | ||||||
| fi | fi | ||||||
| if [ "$TORCH_HOME" = "" ]; then | if [ "$TORCH_HOME" = "" ]; then | ||||||
| @@ -14,12 +14,14 @@ else | |||||||
|   echo "TORCH_HOME : $TORCH_HOME" |   echo "TORCH_HOME : $TORCH_HOME" | ||||||
| fi | fi | ||||||
|  |  | ||||||
| dataset=cifar10 | dataset=$1 | ||||||
| seed=$1 | seed=$2 | ||||||
| channel=16 | channel=16 | ||||||
| num_cells=5 | num_cells=5 | ||||||
| max_nodes=4 | max_nodes=4 | ||||||
| space=nas-bench-201 | space=nas-bench-201 | ||||||
|  | #benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth | ||||||
|  | benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth | ||||||
|  |  | ||||||
| save_dir=./output/search-cell-${space}/RAND-${dataset} | save_dir=./output/search-cell-${space}/RAND-${dataset} | ||||||
|  |  | ||||||
| @@ -27,7 +29,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/RANDOM.py \ | |||||||
| 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | ||||||
| 	--dataset ${dataset} \ | 	--dataset ${dataset} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ | 	--arch_nas_dataset ${benchmark_file} \ | ||||||
| 	--time_budget 12000 \ | 	--time_budget 12000 \ | ||||||
| 	--workers 4 --print_freq 200 --rand_seed ${seed} | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
| #	--random_num 100 \ |  | ||||||
|   | |||||||
| @@ -28,6 +28,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then | |||||||
| else | else | ||||||
|   data_path="$TORCH_HOME/cifar.python/ImageNet16" |   data_path="$TORCH_HOME/cifar.python/ImageNet16" | ||||||
| fi | fi | ||||||
|  | #benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth | ||||||
|  | benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth | ||||||
|  |  | ||||||
| save_dir=./output/search-cell-${space}/SETN-${dataset}-BN${BN} | save_dir=./output/search-cell-${space}/SETN-${dataset}-BN${BN} | ||||||
|  |  | ||||||
| @@ -35,7 +37,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/SETN.py \ | |||||||
| 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | ||||||
| 	--dataset ${dataset} --data_path ${data_path} \ | 	--dataset ${dataset} --data_path ${data_path} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ | 	--arch_nas_dataset ${benchmark_file} \ | ||||||
| 	--config_path configs/nas-benchmark/algos/SETN.config \ | 	--config_path configs/nas-benchmark/algos/SETN.config \ | ||||||
| 	--track_running_stats ${BN} \ | 	--track_running_stats ${BN} \ | ||||||
| 	--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ | 	--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ | ||||||
|   | |||||||
| @@ -28,6 +28,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then | |||||||
| else | else | ||||||
|   data_path="$TORCH_HOME/cifar.python/ImageNet16" |   data_path="$TORCH_HOME/cifar.python/ImageNet16" | ||||||
| fi | fi | ||||||
|  | #benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth | ||||||
|  | benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth | ||||||
| 
 | 
 | ||||||
| save_dir=./output/search-cell-${space}/DARTS-V1-${dataset}-BN${BN}-Gradient${gradient_clip} | save_dir=./output/search-cell-${space}/DARTS-V1-${dataset}-BN${BN}-Gradient${gradient_clip} | ||||||
| 
 | 
 | ||||||
| @@ -36,7 +38,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V1.py \ | |||||||
| 	--dataset ${dataset} --data_path ${data_path} \ | 	--dataset ${dataset} --data_path ${data_path} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--config_path configs/nas-benchmark/algos/DARTS.config \ | 	--config_path configs/nas-benchmark/algos/DARTS.config \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ | 	--arch_nas_dataset ${benchmark_file} \ | ||||||
| 	--track_running_stats ${BN} --gradient_clip ${gradient_clip} \ | 	--track_running_stats ${BN} --gradient_clip ${gradient_clip} \ | ||||||
| 	--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ | 	--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ | ||||||
| 	--workers 4 --print_freq 200 --rand_seed ${seed} | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
		Reference in New Issue
	
	Block a user