Upgrade NAS-Bench-201 to APIv1.3/FILEv1.1
This commit is contained in:
parent
c53a9ce407
commit
fb76814369
@ -18,14 +18,18 @@ You can simply type `pip install nas-bench-201` to install our api. Please see s
|
|||||||
|
|
||||||
### Preparation and Download
|
### Preparation and Download
|
||||||
|
|
||||||
The benchmark file of NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w).
|
[deprecated] The benchmark file of NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w).
|
||||||
|
|
||||||
|
[recommended] The benchmark file of NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1OOfVPpt-lA4u2HJrXbgrRd42IbfvJMyE). The files for model weight are too large (431G) and I need some time to upload it. Please be patient, thanks for your understanding.
|
||||||
|
|
||||||
You can move it to anywhere you want and send its path to our API for initialization.
|
You can move it to anywhere you want and send its path to our API for initialization.
|
||||||
- [2020.02.25] v1.0: `NAS-Bench-201-v1_0-e61699.pth`, where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial.
|
- [2020.02.25] APIv1.0/FILEv1.0: `NAS-Bench-201-v1_0-e61699.pth` (2.2G), where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial.
|
||||||
- [2020.02.25] v1.0: The full data of each architecture can be download from [
|
- [2020.02.25] APIv1.0/FILEv1.0: The full data of each architecture can be download from [
|
||||||
NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights.
|
NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights.
|
||||||
- [2020.02.25] v1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi).
|
- [2020.02.25] APIv1.0/FILEv1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi).
|
||||||
- [2020.03.09] v1.2: More robust API with more functions and descriptions
|
- [2020.03.09] APIv1.2/FILEv1.0: More robust API with more functions and descriptions
|
||||||
- [2020.04.01] v2.0: coming soon (results of two set of hyper-parameters avaliable on all three datasets)
|
- [2020.03.16] APIv1.3/FILEv1.1: `NAS-Bench-201-v1_1-096897.pth` (4.7G), where `096897` is the last six digits for this file. It contains information of more trials compared to `NAS-Bench-201-v1_0-e61699.pth`, especially all models trained by 12 epochs on all datasets are avaliable.
|
||||||
|
- [2020.06.01] APIv2.0/FILEv2.0: coming soon!
|
||||||
|
|
||||||
|
|
||||||
The training and evaluation data used in NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan (code:4fg7)](https://pan.baidu.com/s/1XAzavPKq3zcat1yBA1L2tQ).
|
The training and evaluation data used in NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan (code:4fg7)](https://pan.baidu.com/s/1XAzavPKq3zcat1yBA1L2tQ).
|
||||||
@ -92,7 +96,9 @@ print(network) # show the structure of this architecture
|
|||||||
```
|
```
|
||||||
If you want to load the trained weights of this created network, you need to use `api.get_net_param(123, ...)` to obtain the weights and then load it to the network.
|
If you want to load the trained weights of this created network, you need to use `api.get_net_param(123, ...)` to obtain the weights and then load it to the network.
|
||||||
|
|
||||||
6. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201.
|
6. `api.get_more_info(...)` can return the loss / accuracy / time on training / validation / test sets, which is very helpful. For more details, please look at the comments in the get_more_info function.
|
||||||
|
|
||||||
|
7. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201.
|
||||||
|
|
||||||
|
|
||||||
### Detailed Instruction
|
### Detailed Instruction
|
||||||
@ -213,12 +219,14 @@ If researchers can provide better results with different hyper-parameters, we ar
|
|||||||
- [5] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/ENAS.sh cifar10 1 -1`
|
- [5] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/ENAS.sh cifar10 1 -1`
|
||||||
- [6] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 1 -1`
|
- [6] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 1 -1`
|
||||||
- [7] `bash ./scripts-search/algos/R-EA.sh cifar10 3 -1`
|
- [7] `bash ./scripts-search/algos/R-EA.sh cifar10 3 -1`
|
||||||
- [8] `bash ./scripts-search/algos/Random.sh -1`
|
- [8] `bash ./scripts-search/algos/Random.sh cifar10 -1`
|
||||||
- [9] `bash ./scripts-search/algos/REINFORCE.sh 0.5 -1`
|
- [9] `bash ./scripts-search/algos/REINFORCE.sh cifar10 0.5 -1`
|
||||||
- [10] `bash ./scripts-search/algos/BOHB.sh -1`
|
- [10] `bash ./scripts-search/algos/BOHB.sh cifar10 -1`
|
||||||
|
|
||||||
In commands [1-6], the first args `cifar10` indicates the dataset name, the second args `1` indicates the behavior of BN, and the first args `-1` indicates the random seed.
|
In commands [1-6], the first args `cifar10` indicates the dataset name, the second args `1` indicates the behavior of BN, and the first args `-1` indicates the random seed.
|
||||||
|
|
||||||
|
**Note that** since 2020 March 16, in these scripts, the default NAS-Bench-201 benchmark file has changed from `NAS-Bench-201-v1_0-e61699.pth` to `NAS-Bench-201-v1_1-096897.pth`, and thus the results could be slightly different from the original paper.
|
||||||
|
|
||||||
|
|
||||||
# Citation
|
# Citation
|
||||||
|
|
||||||
|
@ -1,36 +1,84 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||||
########################################################
|
###############################################################################################
|
||||||
# python exps/NAS-Bench-201/test-weights.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
|
# Before run these commands, the files must be properly put.
|
||||||
########################################################
|
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_0-e61699
|
||||||
|
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897
|
||||||
|
###############################################################################################
|
||||||
import os, sys, time, glob, random, argparse
|
import os, sys, time, glob, random, argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||||
from nas_201_api import NASBench201API as API
|
from nas_201_api import NASBench201API as API
|
||||||
|
from log_utils import time_string
|
||||||
|
from models import get_cell_based_tiny_net
|
||||||
from utils import weight_watcher
|
from utils import weight_watcher
|
||||||
|
|
||||||
|
|
||||||
def main(meta_file, weight_dir, save_dir):
|
def get_cor(A, B):
|
||||||
import pdb;
|
return float(np.corrcoef(A, B)[0,1])
|
||||||
pdb.set_trace()
|
|
||||||
|
|
||||||
|
def evaluate(api, weight_dir, data: str, use_12epochs_result: bool, valid_or_test: bool):
|
||||||
|
norms, accs = [], []
|
||||||
|
for idx in tqdm(range(len(api))):
|
||||||
|
info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False)
|
||||||
|
if valid_or_test:
|
||||||
|
accs.append(info['valid-accuracy'])
|
||||||
|
else:
|
||||||
|
accs.append(info['test-accuracy'])
|
||||||
|
config = api.get_net_config(idx, data)
|
||||||
|
net = get_cell_based_tiny_net(config)
|
||||||
|
api.reload(weight_dir, idx)
|
||||||
|
params = api.get_net_param(idx, data, None)
|
||||||
|
cur_norms = []
|
||||||
|
for seed, param in params.items():
|
||||||
|
net.load_state_dict(param)
|
||||||
|
_, summary = weight_watcher.analyze(net, alphas=False)
|
||||||
|
cur_norms.append( summary['lognorm'] )
|
||||||
|
norms.append( float(np.mean(cur_norms)) )
|
||||||
|
api.clear_params(idx, use_12epochs_result)
|
||||||
|
correlation = get_cor(norms, accs)
|
||||||
|
print('For {:} with {:} epochs on {:} : the correlation is {:}'.format(data, 12 if use_12epochs_result else 200, 'valid' if valid_or_test else 'test', correlation))
|
||||||
|
|
||||||
|
|
||||||
|
def main(meta_file: str, weight_dir, save_dir):
|
||||||
|
api = API(meta_file)
|
||||||
|
datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
|
||||||
|
print(time_string() + ' ' + '='*50)
|
||||||
|
for data in datasets:
|
||||||
|
nums = api.statistics(data, True)
|
||||||
|
total = sum([k*v for k, v in nums.items()])
|
||||||
|
print('Using 012 epochs, trained on {:20s} : {:} trials in total ({:}).'.format(data, total, nums))
|
||||||
|
print(time_string() + ' ' + '='*50)
|
||||||
|
for data in datasets:
|
||||||
|
nums = api.statistics(data, False)
|
||||||
|
total = sum([k*v for k, v in nums.items()])
|
||||||
|
print('Using 200 epochs, trained on {:20s} : {:} trials in total ({:}).'.format(data, total, nums))
|
||||||
|
print(time_string() + ' ' + '='*50)
|
||||||
|
|
||||||
|
evaluate(api, weight_dir, 'cifar10-valid', False, True)
|
||||||
|
|
||||||
|
print('{:} finish this test.'.format(time_string()))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
|
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
|
||||||
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.')
|
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.')
|
||||||
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file.')
|
parser.add_argument('--base_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file and weight dir.')
|
||||||
parser.add_argument('--weight_dir', type=str, default=None, help='The directory path to the weights of every NAS-Bench-201 architecture.')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
save_dir = Path(args.save_dir)
|
save_dir = Path(args.save_dir)
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
meta_file = Path(args.api_path)
|
meta_file = Path(args.base_path + '.pth')
|
||||||
weight_dir = Path(args.weight_dir)
|
weight_dir = Path(args.base_path + '-archive')
|
||||||
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
|
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
|
||||||
|
assert weight_dir.exists() and weight_dir.is_dir(), 'invalid path for weight dir : {:}'.format(weight_dir)
|
||||||
|
|
||||||
main(meta_file, weight_dir, save_dir)
|
main(str(meta_file), weight_dir, save_dir)
|
||||||
|
|
||||||
|
@ -50,10 +50,11 @@ def config2structure_func(max_nodes):
|
|||||||
|
|
||||||
class MyWorker(Worker):
|
class MyWorker(Worker):
|
||||||
|
|
||||||
def __init__(self, *args, convert_func=None, nas_bench=None, time_budget=None, **kwargs):
|
def __init__(self, *args, convert_func=None, dataname=None, nas_bench=None, time_budget=None, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.convert_func = convert_func
|
self.convert_func = convert_func
|
||||||
self.nas_bench = nas_bench
|
self._dataname = dataname
|
||||||
|
self._nas_bench = nas_bench
|
||||||
self.time_budget = time_budget
|
self.time_budget = time_budget
|
||||||
self.seen_archs = []
|
self.seen_archs = []
|
||||||
self.sim_cost_time = 0
|
self.sim_cost_time = 0
|
||||||
@ -64,7 +65,7 @@ class MyWorker(Worker):
|
|||||||
assert len(self.seen_archs) > 0
|
assert len(self.seen_archs) > 0
|
||||||
best_index, best_acc = -1, None
|
best_index, best_acc = -1, None
|
||||||
for arch_index in self.seen_archs:
|
for arch_index in self.seen_archs:
|
||||||
info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True)
|
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, True, True)
|
||||||
vacc = info['valid-accuracy']
|
vacc = info['valid-accuracy']
|
||||||
if best_acc is None or best_acc < vacc:
|
if best_acc is None or best_acc < vacc:
|
||||||
best_acc = vacc
|
best_acc = vacc
|
||||||
@ -75,8 +76,8 @@ class MyWorker(Worker):
|
|||||||
def compute(self, config, budget, **kwargs):
|
def compute(self, config, budget, **kwargs):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
structure = self.convert_func( config )
|
structure = self.convert_func( config )
|
||||||
arch_index = self.nas_bench.query_index_by_arch( structure )
|
arch_index = self._nas_bench.query_index_by_arch( structure )
|
||||||
info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True)
|
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, True, True)
|
||||||
cur_time = info['train-all-time'] + info['valid-per-time']
|
cur_time = info['train-all-time'] + info['valid-per-time']
|
||||||
cur_vacc = info['valid-accuracy']
|
cur_vacc = info['valid-accuracy']
|
||||||
self.real_cost_time += (time.time() - start_time)
|
self.real_cost_time += (time.time() - start_time)
|
||||||
@ -106,7 +107,10 @@ def main(xargs, nas_bench):
|
|||||||
prepare_seed(xargs.rand_seed)
|
prepare_seed(xargs.rand_seed)
|
||||||
logger = prepare_logger(args)
|
logger = prepare_logger(args)
|
||||||
|
|
||||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
if xargs.dataset == 'cifar10':
|
||||||
|
dataname = 'cifar10-valid'
|
||||||
|
else:
|
||||||
|
dataname = xargs.dataset
|
||||||
if xargs.data_path is not None:
|
if xargs.data_path is not None:
|
||||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||||
@ -148,7 +152,7 @@ def main(xargs, nas_bench):
|
|||||||
#logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string()))
|
#logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string()))
|
||||||
workers = []
|
workers = []
|
||||||
for i in range(num_workers):
|
for i in range(num_workers):
|
||||||
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_budget=xargs.time_budget, run_id=hb_run_id, id=i)
|
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, dataname=dataname, nas_bench=nas_bench, time_budget=xargs.time_budget, run_id=hb_run_id, id=i)
|
||||||
w.run(background=True)
|
w.run(background=True)
|
||||||
workers.append(w)
|
workers.append(w)
|
||||||
|
|
||||||
|
@ -28,7 +28,10 @@ def main(xargs, nas_bench):
|
|||||||
prepare_seed(xargs.rand_seed)
|
prepare_seed(xargs.rand_seed)
|
||||||
logger = prepare_logger(args)
|
logger = prepare_logger(args)
|
||||||
|
|
||||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
if xargs.dataset == 'cifar10':
|
||||||
|
dataname = 'cifar10-valid'
|
||||||
|
else:
|
||||||
|
dataname = xargs.dataset
|
||||||
if xargs.data_path is not None:
|
if xargs.data_path is not None:
|
||||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||||
@ -62,7 +65,7 @@ def main(xargs, nas_bench):
|
|||||||
#for idx in range(xargs.random_num):
|
#for idx in range(xargs.random_num):
|
||||||
while total_time_cost < xargs.time_budget:
|
while total_time_cost < xargs.time_budget:
|
||||||
arch = random_arch()
|
arch = random_arch()
|
||||||
accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info)
|
accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname)
|
||||||
if total_time_cost + cost_time > xargs.time_budget: break
|
if total_time_cost + cost_time > xargs.time_budget: break
|
||||||
else: total_time_cost += cost_time
|
else: total_time_cost += cost_time
|
||||||
history.append(arch)
|
history.append(arch)
|
||||||
|
@ -33,19 +33,21 @@ class Model(object):
|
|||||||
|
|
||||||
# This function is to mimic the training and evaluatinig procedure for a single architecture `arch`.
|
# This function is to mimic the training and evaluatinig procedure for a single architecture `arch`.
|
||||||
# The time_cost is calculated as the total training time for a few (e.g., 12 epochs) plus the evaluation time for one epoch.
|
# The time_cost is calculated as the total training time for a few (e.g., 12 epochs) plus the evaluation time for one epoch.
|
||||||
# For use_converged_LR = True, the architecture is trained for 12 epochs, with LR being decaded from 0.1 to 0.
|
# For use_012_epoch_training = True, the architecture is trained for 12 epochs, with LR being decaded from 0.1 to 0.
|
||||||
# In this case, the LR schedular is converged.
|
# In this case, the LR schedular is converged.
|
||||||
# For use_converged_LR = False, the architecture is planed to be trained for 200 epochs, but we early stop its procedure.
|
# For use_012_epoch_training = False, the architecture is planed to be trained for 200 epochs, but we early stop its procedure.
|
||||||
#
|
#
|
||||||
def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_converged_LR=True):
|
def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_012_epoch_training=True):
|
||||||
if use_converged_LR and nas_bench is not None:
|
|
||||||
|
if use_012_epoch_training and nas_bench is not None:
|
||||||
arch_index = nas_bench.query_index_by_arch( arch )
|
arch_index = nas_bench.query_index_by_arch( arch )
|
||||||
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
|
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
|
||||||
info = nas_bench.get_more_info(arch_index, dataname, None, True)
|
info = nas_bench.get_more_info(arch_index, dataname, None, True)
|
||||||
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
|
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
|
||||||
#_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs
|
#_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs
|
||||||
elif not use_converged_LR and nas_bench is not None:
|
elif not use_012_epoch_training and nas_bench is not None:
|
||||||
# Please use `use_converged_LR=False` for cifar10 only.
|
# Please contact me if you want to use the following logic, because it has some potential issues.
|
||||||
|
# Please use `use_012_epoch_training=False` for cifar10 only.
|
||||||
# It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details)
|
# It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details)
|
||||||
arch_index, nepoch = nas_bench.query_index_by_arch( arch ), 25
|
arch_index, nepoch = nas_bench.query_index_by_arch( arch ), 25
|
||||||
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
|
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
|
||||||
@ -64,7 +66,7 @@ def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_co
|
|||||||
try:
|
try:
|
||||||
valid_acc, time_cost = info['valid-accuracy'], estimated_train_cost + estimated_valid_cost
|
valid_acc, time_cost = info['valid-accuracy'], estimated_train_cost + estimated_valid_cost
|
||||||
except:
|
except:
|
||||||
valid_acc, time_cost = info['est-valid-accuracy'], estimated_train_cost + estimated_valid_cost
|
valid_acc, time_cost = info['valtest-accuracy'], estimated_train_cost + estimated_valid_cost
|
||||||
else:
|
else:
|
||||||
# train a model from scratch.
|
# train a model from scratch.
|
||||||
raise ValueError('NOT IMPLEMENT YET')
|
raise ValueError('NOT IMPLEMENT YET')
|
||||||
@ -127,7 +129,7 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
|
|||||||
while len(population) < population_size:
|
while len(population) < population_size:
|
||||||
model = Model()
|
model = Model()
|
||||||
model.arch = random_arch()
|
model.arch = random_arch()
|
||||||
model.accuracy, time_cost = train_and_eval(model.arch, nas_bench, extra_info)
|
model.accuracy, time_cost = train_and_eval(model.arch, nas_bench, extra_info, dataname)
|
||||||
population.append(model)
|
population.append(model)
|
||||||
history.append(model)
|
history.append(model)
|
||||||
total_time_cost += time_cost
|
total_time_cost += time_cost
|
||||||
@ -152,7 +154,7 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
|
|||||||
child = Model()
|
child = Model()
|
||||||
child.arch = mutate_arch(parent.arch)
|
child.arch = mutate_arch(parent.arch)
|
||||||
total_time_cost += time.time() - start_time
|
total_time_cost += time.time() - start_time
|
||||||
child.accuracy, time_cost = train_and_eval(child.arch, nas_bench, extra_info)
|
child.accuracy, time_cost = train_and_eval(child.arch, nas_bench, extra_info, dataname)
|
||||||
if total_time_cost + time_cost > time_budget: # return
|
if total_time_cost + time_cost > time_budget: # return
|
||||||
return history, total_time_cost
|
return history, total_time_cost
|
||||||
else:
|
else:
|
||||||
@ -174,7 +176,6 @@ def main(xargs, nas_bench):
|
|||||||
prepare_seed(xargs.rand_seed)
|
prepare_seed(xargs.rand_seed)
|
||||||
logger = prepare_logger(args)
|
logger = prepare_logger(args)
|
||||||
|
|
||||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
|
||||||
if xargs.dataset == 'cifar10':
|
if xargs.dataset == 'cifar10':
|
||||||
dataname = 'cifar10-valid'
|
dataname = 'cifar10-valid'
|
||||||
else:
|
else:
|
||||||
|
@ -98,7 +98,10 @@ def main(xargs, nas_bench):
|
|||||||
prepare_seed(xargs.rand_seed)
|
prepare_seed(xargs.rand_seed)
|
||||||
logger = prepare_logger(args)
|
logger = prepare_logger(args)
|
||||||
|
|
||||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
if xargs.dataset == 'cifar10':
|
||||||
|
dataname = 'cifar10-valid'
|
||||||
|
else:
|
||||||
|
dataname = xargs.dataset
|
||||||
if xargs.data_path is not None:
|
if xargs.data_path is not None:
|
||||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||||
@ -148,7 +151,7 @@ def main(xargs, nas_bench):
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
log_prob, action = select_action( policy )
|
log_prob, action = select_action( policy )
|
||||||
arch = policy.generate_arch( action )
|
arch = policy.generate_arch( action )
|
||||||
reward, cost_time = train_and_eval(arch, nas_bench, extra_info)
|
reward, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname)
|
||||||
trace.append( (reward, arch) )
|
trace.append( (reward, arch) )
|
||||||
# accumulate time
|
# accumulate time
|
||||||
if total_costs + cost_time < xargs.time_budget:
|
if total_costs + cost_time < xargs.time_budget:
|
||||||
|
@ -5,4 +5,5 @@ from .api import NASBench201API
|
|||||||
from .api import ArchResults, ResultsCount
|
from .api import ArchResults, ResultsCount
|
||||||
|
|
||||||
# NAS_BENCH_201_API_VERSION="v1.1" # [2020.02.25]
|
# NAS_BENCH_201_API_VERSION="v1.1" # [2020.02.25]
|
||||||
NAS_BENCH_201_API_VERSION="v1.2" # [2020.03.09]
|
# NAS_BENCH_201_API_VERSION="v1.2" # [2020.03.09]
|
||||||
|
NAS_BENCH_201_API_VERSION="v1.3" # [2020.03.16]
|
||||||
|
@ -3,11 +3,14 @@
|
|||||||
############################################################################################
|
############################################################################################
|
||||||
# NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
|
# NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
|
||||||
############################################################################################
|
############################################################################################
|
||||||
|
# The history of benchmark files:
|
||||||
# [2020.02.25] NAS-Bench-201-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID.
|
# [2020.02.25] NAS-Bench-201-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID.
|
||||||
# [2020.03.08] Next version (coming soon)
|
# [2020.03.16] NAS-Bench-201-v1_1-096897.pth : 2225 architectures are trained once, 5439 archiitectures are trained twice, 7961 architectures are trained three times on all training sets. For the hyper-parameters with the total epochs of 12, each model is trained on CIFAR-10, CIFAR-100, ImageNet16-120 once, and is trained on CIFAR-10-VALID twice.
|
||||||
#
|
#
|
||||||
|
# I'm still actively enhancing this benchmark. Please feel free to contact me if you have any question w.r.t. NAS-Bench-201.
|
||||||
#
|
#
|
||||||
import os, copy, random, torch, numpy as np
|
import os, copy, random, torch, numpy as np
|
||||||
|
from pathlib import Path
|
||||||
from typing import List, Text, Union, Dict
|
from typing import List, Text, Union, Dict
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
|
|
||||||
@ -44,9 +47,12 @@ class NASBench201API(object):
|
|||||||
|
|
||||||
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
|
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
|
||||||
def __init__(self, file_path_or_dict: Union[Text, Dict], verbose: bool=True):
|
def __init__(self, file_path_or_dict: Union[Text, Dict], verbose: bool=True):
|
||||||
if isinstance(file_path_or_dict, str):
|
self.filename = None
|
||||||
|
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
|
||||||
|
file_path_or_dict = str(file_path_or_dict)
|
||||||
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
|
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
|
||||||
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
||||||
|
self.filename = Path(file_path_or_dict).name
|
||||||
file_path_or_dict = torch.load(file_path_or_dict)
|
file_path_or_dict = torch.load(file_path_or_dict)
|
||||||
elif isinstance(file_path_or_dict, dict):
|
elif isinstance(file_path_or_dict, dict):
|
||||||
file_path_or_dict = copy.deepcopy( file_path_or_dict )
|
file_path_or_dict = copy.deepcopy( file_path_or_dict )
|
||||||
@ -76,7 +82,7 @@ class NASBench201API(object):
|
|||||||
return len(self.meta_archs)
|
return len(self.meta_archs)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs)))
|
return ('{name}({num}/{total} architectures, file={filename})'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs), filename=self.filename))
|
||||||
|
|
||||||
def random(self):
|
def random(self):
|
||||||
"""Return a random index of all architectures."""
|
"""Return a random index of all architectures."""
|
||||||
@ -98,9 +104,10 @@ class NASBench201API(object):
|
|||||||
else: arch_index = -1
|
else: arch_index = -1
|
||||||
return arch_index
|
return arch_index
|
||||||
|
|
||||||
# Overwrite all information of the 'index'-th architecture in the search space.
|
|
||||||
# It will load its data from 'archive_root'.
|
|
||||||
def reload(self, archive_root: Text, index: int):
|
def reload(self, archive_root: Text, index: int):
|
||||||
|
"""Overwrite all information of the 'index'-th architecture in the search space.
|
||||||
|
It will load its data from 'archive_root'.
|
||||||
|
"""
|
||||||
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
|
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
|
||||||
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index))
|
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index))
|
||||||
assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index)
|
assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index)
|
||||||
@ -109,6 +116,13 @@ class NASBench201API(object):
|
|||||||
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
|
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
|
||||||
self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] )
|
self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] )
|
||||||
self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] )
|
self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] )
|
||||||
|
|
||||||
|
def clear_params(self, index: int, use_12epochs_result: bool):
|
||||||
|
"""Remove the architecture's weights to save memory."""
|
||||||
|
if use_12epochs_result: arch2infos = self.arch2infos_less
|
||||||
|
else : arch2infos = self.arch2infos_full
|
||||||
|
archresult = arch2infos[index]
|
||||||
|
archresult.clear_params()
|
||||||
|
|
||||||
# This function is used to query the information of a specific archiitecture
|
# This function is used to query the information of a specific archiitecture
|
||||||
# 'arch' can be an architecture index or an architecture string
|
# 'arch' can be an architecture index or an architecture string
|
||||||
@ -162,6 +176,7 @@ class NASBench201API(object):
|
|||||||
return archInfo
|
return archInfo
|
||||||
|
|
||||||
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, use_12epochs_result=False):
|
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, use_12epochs_result=False):
|
||||||
|
"""Find the architecture with the highest accuracy based on some constraints."""
|
||||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||||
best_index, highest_accuracy = -1, None
|
best_index, highest_accuracy = -1, None
|
||||||
@ -255,6 +270,65 @@ class NASBench201API(object):
|
|||||||
# `is_random`
|
# `is_random`
|
||||||
# When is_random=True, the performance of a random architecture will be returned
|
# When is_random=True, the performance of a random architecture will be returned
|
||||||
# When is_random=False, the performanceo of all trials will be averaged.
|
# When is_random=False, the performanceo of all trials will be averaged.
|
||||||
|
def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
|
||||||
|
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||||
|
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||||
|
archresult = arch2infos[index]
|
||||||
|
# if randomly select one trial, select the seed at first
|
||||||
|
if isinstance(is_random, bool) and is_random:
|
||||||
|
seeds = archresult.get_dataset_seeds(dataset)
|
||||||
|
is_random = random.choice(seeds)
|
||||||
|
# collect the training information
|
||||||
|
train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random)
|
||||||
|
total = train_info['iepoch'] + 1
|
||||||
|
xinfo = {'train-loss' : train_info['loss'],
|
||||||
|
'train-accuracy': train_info['accuracy'],
|
||||||
|
'train-per-time': train_info['all_time'] / total,
|
||||||
|
'train-all-time': train_info['all_time']}
|
||||||
|
# collect the evaluation information
|
||||||
|
if dataset == 'cifar10-valid':
|
||||||
|
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||||
|
try:
|
||||||
|
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||||
|
except:
|
||||||
|
test_info = None
|
||||||
|
valtest_info = None
|
||||||
|
else:
|
||||||
|
try: # collect results on the proposed test set
|
||||||
|
if dataset == 'cifar10':
|
||||||
|
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||||
|
else:
|
||||||
|
test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
|
||||||
|
except:
|
||||||
|
test_info = None
|
||||||
|
try: # collect results on the proposed validation set
|
||||||
|
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||||
|
except:
|
||||||
|
valid_info = None
|
||||||
|
try:
|
||||||
|
if dataset != 'cifar10':
|
||||||
|
valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||||
|
else:
|
||||||
|
valtest_info = None
|
||||||
|
except:
|
||||||
|
valtest_info = None
|
||||||
|
if valid_info is not None:
|
||||||
|
xinfo['valid-loss'] = valid_info['loss']
|
||||||
|
xinfo['valid-accuracy'] = valid_info['accuracy']
|
||||||
|
xinfo['valid-per-time'] = valid_info['all_time'] / total
|
||||||
|
xinfo['valid-all-time'] = valid_info['all_time']
|
||||||
|
if test_info is not None:
|
||||||
|
xinfo['test-loss'] = test_info['loss']
|
||||||
|
xinfo['test-accuracy'] = test_info['accuracy']
|
||||||
|
xinfo['test-per-time'] = test_info['all_time'] / total
|
||||||
|
xinfo['test-all-time'] = test_info['all_time']
|
||||||
|
if valtest_info is not None:
|
||||||
|
xinfo['valtest-loss'] = valtest_info['loss']
|
||||||
|
xinfo['valtest-accuracy'] = valtest_info['accuracy']
|
||||||
|
xinfo['valtest-per-time'] = valtest_info['all_time'] / total
|
||||||
|
xinfo['valtest-all-time'] = valtest_info['all_time']
|
||||||
|
return xinfo
|
||||||
|
""" # The following logic is deprecated after March 15 2020, where the benchmark file upgrades from NAS-Bench-201-v1_0-e61699.pth to NAS-Bench-201-v1_1-096897.pth.
|
||||||
def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
|
def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
|
||||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||||
@ -312,6 +386,7 @@ class NASBench201API(object):
|
|||||||
xifo['est-valid-loss'] = est_valid_info['loss']
|
xifo['est-valid-loss'] = est_valid_info['loss']
|
||||||
xifo['est-valid-accuracy'] = est_valid_info['accuracy']
|
xifo['est-valid-accuracy'] = est_valid_info['accuracy']
|
||||||
return xifo
|
return xifo
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def show(self, index: int = -1) -> None:
|
def show(self, index: int = -1) -> None:
|
||||||
@ -349,6 +424,26 @@ class NASBench201API(object):
|
|||||||
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
|
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
|
||||||
|
|
||||||
|
|
||||||
|
def statistics(self, dataset: Text, use_12epochs_result: bool) -> Dict[int, int]:
|
||||||
|
"""
|
||||||
|
This function will count the number of total trials.
|
||||||
|
"""
|
||||||
|
valid_datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
|
||||||
|
if dataset not in valid_datasets:
|
||||||
|
raise ValueError('{:} not in {:}'.format(dataset, valid_datasets))
|
||||||
|
if use_12epochs_result: arch2infos = self.arch2infos_less
|
||||||
|
else : arch2infos = self.arch2infos_full
|
||||||
|
nums = defaultdict(lambda: 0)
|
||||||
|
for index in range(len(self)):
|
||||||
|
archInfo = arch2infos[index]
|
||||||
|
dataset_seed = archInfo.dataset_seed
|
||||||
|
if dataset not in dataset_seed:
|
||||||
|
nums[0] += 1
|
||||||
|
else:
|
||||||
|
nums[len(dataset_seed[dataset])] += 1
|
||||||
|
return dict(nums)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def str2lists(arch_str: Text) -> List[tuple]:
|
def str2lists(arch_str: Text) -> List[tuple]:
|
||||||
"""
|
"""
|
||||||
|
@ -2,9 +2,9 @@
|
|||||||
# bash ./scripts-search/algos/BOHB.sh -1
|
# bash ./scripts-search/algos/BOHB.sh -1
|
||||||
echo script name: $0
|
echo script name: $0
|
||||||
echo $# arguments
|
echo $# arguments
|
||||||
if [ "$#" -ne 1 ] ;then
|
if [ "$#" -ne 2 ] ;then
|
||||||
echo "Input illegal number of parameters " $#
|
echo "Input illegal number of parameters " $#
|
||||||
echo "Need 1 parameters for seed"
|
echo "Need 2 parameters for dataset and seed"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
if [ "$TORCH_HOME" = "" ]; then
|
if [ "$TORCH_HOME" = "" ]; then
|
||||||
@ -14,12 +14,14 @@ else
|
|||||||
echo "TORCH_HOME : $TORCH_HOME"
|
echo "TORCH_HOME : $TORCH_HOME"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
dataset=cifar10
|
dataset=$1
|
||||||
seed=$1
|
seed=$2
|
||||||
channel=16
|
channel=16
|
||||||
num_cells=5
|
num_cells=5
|
||||||
max_nodes=4
|
max_nodes=4
|
||||||
space=nas-bench-201
|
space=nas-bench-201
|
||||||
|
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
|
||||||
|
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/BOHB-${dataset}
|
save_dir=./output/search-cell-${space}/BOHB-${dataset}
|
||||||
|
|
||||||
@ -27,7 +29,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \
|
|||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} \
|
--dataset ${dataset} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
--arch_nas_dataset ${benchmark_file} \
|
||||||
--time_budget 12000 \
|
--time_budget 12000 \
|
||||||
--n_iters 50 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 \
|
--n_iters 50 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 \
|
||||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||||
|
@ -27,6 +27,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
|||||||
else
|
else
|
||||||
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
||||||
fi
|
fi
|
||||||
|
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
|
||||||
|
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/DARTS-V1-${dataset}-BN${BN}
|
save_dir=./output/search-cell-${space}/DARTS-V1-${dataset}-BN${BN}
|
||||||
|
|
||||||
@ -35,7 +37,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V1.py \
|
|||||||
--dataset ${dataset} --data_path ${data_path} \
|
--dataset ${dataset} --data_path ${data_path} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--config_path configs/nas-benchmark/algos/DARTS.config \
|
--config_path configs/nas-benchmark/algos/DARTS.config \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
--arch_nas_dataset ${benchmark_file} \
|
||||||
--track_running_stats ${BN} \
|
--track_running_stats ${BN} \
|
||||||
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
||||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||||
|
@ -27,6 +27,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
|||||||
else
|
else
|
||||||
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
||||||
fi
|
fi
|
||||||
|
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
|
||||||
|
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/DARTS-V2-${dataset}-BN${BN}
|
save_dir=./output/search-cell-${space}/DARTS-V2-${dataset}-BN${BN}
|
||||||
|
|
||||||
@ -35,7 +37,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V2.py \
|
|||||||
--dataset ${dataset} --data_path ${data_path} \
|
--dataset ${dataset} --data_path ${data_path} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--config_path configs/nas-benchmark/algos/DARTS.config \
|
--config_path configs/nas-benchmark/algos/DARTS.config \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
--arch_nas_dataset ${benchmark_file} \
|
||||||
--track_running_stats ${BN} \
|
--track_running_stats ${BN} \
|
||||||
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
||||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||||
|
@ -28,6 +28,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
|||||||
else
|
else
|
||||||
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
||||||
fi
|
fi
|
||||||
|
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
|
||||||
|
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/ENAS-${dataset}-BN${BN}
|
save_dir=./output/search-cell-${space}/ENAS-${dataset}-BN${BN}
|
||||||
|
|
||||||
@ -35,7 +37,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/ENAS.py \
|
|||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} --data_path ${data_path} \
|
--dataset ${dataset} --data_path ${data_path} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
--arch_nas_dataset ${benchmark_file} \
|
||||||
--track_running_stats ${BN} \
|
--track_running_stats ${BN} \
|
||||||
--config_path ./configs/nas-benchmark/algos/ENAS.config \
|
--config_path ./configs/nas-benchmark/algos/ENAS.config \
|
||||||
--controller_entropy_weight 0.0001 \
|
--controller_entropy_weight 0.0001 \
|
||||||
|
@ -27,6 +27,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
|||||||
else
|
else
|
||||||
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
||||||
fi
|
fi
|
||||||
|
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
|
||||||
|
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/GDAS-${dataset}-BN${BN}
|
save_dir=./output/search-cell-${space}/GDAS-${dataset}-BN${BN}
|
||||||
|
|
||||||
@ -34,7 +36,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/GDAS.py \
|
|||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} --data_path ${data_path} \
|
--dataset ${dataset} --data_path ${data_path} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
--arch_nas_dataset ${benchmark_file} \
|
||||||
--config_path configs/nas-benchmark/algos/GDAS.config \
|
--config_path configs/nas-benchmark/algos/GDAS.config \
|
||||||
--tau_max 10 --tau_min 0.1 --track_running_stats ${BN} \
|
--tau_max 10 --tau_min 0.1 --track_running_stats ${BN} \
|
||||||
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
||||||
|
@ -23,6 +23,8 @@ channel=16
|
|||||||
num_cells=5
|
num_cells=5
|
||||||
max_nodes=4
|
max_nodes=4
|
||||||
space=nas-bench-201
|
space=nas-bench-201
|
||||||
|
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
|
||||||
|
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/R-EA-${dataset}-SS${sample_size}
|
save_dir=./output/search-cell-${space}/R-EA-${dataset}-SS${sample_size}
|
||||||
|
|
||||||
@ -30,7 +32,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \
|
|||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} \
|
--dataset ${dataset} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
--arch_nas_dataset ${benchmark_file} \
|
||||||
--time_budget 12000 \
|
--time_budget 12000 \
|
||||||
--ea_cycles 200 --ea_population 10 --ea_sample_size ${sample_size} --ea_fast_by_api 1 \
|
--ea_cycles 200 --ea_population 10 --ea_sample_size ${sample_size} --ea_fast_by_api 1 \
|
||||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||||
|
@ -28,6 +28,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
|||||||
else
|
else
|
||||||
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
||||||
fi
|
fi
|
||||||
|
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
|
||||||
|
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/RANDOM-NAS-${dataset}-BN${BN}
|
save_dir=./output/search-cell-${space}/RANDOM-NAS-${dataset}-BN${BN}
|
||||||
|
|
||||||
@ -36,7 +38,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/RANDOM-NAS.py \
|
|||||||
--dataset ${dataset} --data_path ${data_path} \
|
--dataset ${dataset} --data_path ${data_path} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--track_running_stats ${BN} \
|
--track_running_stats ${BN} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
--arch_nas_dataset ${benchmark_file} \
|
||||||
--config_path ./configs/nas-benchmark/algos/RANDOM.config \
|
--config_path ./configs/nas-benchmark/algos/RANDOM.config \
|
||||||
--select_num 100 \
|
--select_num 100 \
|
||||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||||
|
@ -2,9 +2,9 @@
|
|||||||
# bash ./scripts-search/algos/REINFORCE.sh 0.001 -1
|
# bash ./scripts-search/algos/REINFORCE.sh 0.001 -1
|
||||||
echo script name: $0
|
echo script name: $0
|
||||||
echo $# arguments
|
echo $# arguments
|
||||||
if [ "$#" -ne 2 ] ;then
|
if [ "$#" -ne 3 ] ;then
|
||||||
echo "Input illegal number of parameters " $#
|
echo "Input illegal number of parameters " $#
|
||||||
echo "Need 2 parameters for LR and seed"
|
echo "Need 3 parameters for dataset, LR, and seed"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
if [ "$TORCH_HOME" = "" ]; then
|
if [ "$TORCH_HOME" = "" ]; then
|
||||||
@ -14,13 +14,15 @@ else
|
|||||||
echo "TORCH_HOME : $TORCH_HOME"
|
echo "TORCH_HOME : $TORCH_HOME"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
dataset=cifar10
|
dataset=$1
|
||||||
LR=$1
|
LR=$2
|
||||||
seed=$2
|
seed=$3
|
||||||
channel=16
|
channel=16
|
||||||
num_cells=5
|
num_cells=5
|
||||||
max_nodes=4
|
max_nodes=4
|
||||||
space=nas-bench-201
|
space=nas-bench-201
|
||||||
|
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
|
||||||
|
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/REINFORCE-${dataset}-${LR}
|
save_dir=./output/search-cell-${space}/REINFORCE-${dataset}-${LR}
|
||||||
|
|
||||||
@ -28,7 +30,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/reinforce.py \
|
|||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} \
|
--dataset ${dataset} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
--arch_nas_dataset ${benchmark_file} \
|
||||||
--time_budget 12000 \
|
--time_budget 12000 \
|
||||||
--learning_rate ${LR} --EMA_momentum 0.9 \
|
--learning_rate ${LR} --EMA_momentum 0.9 \
|
||||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||||
|
@ -2,9 +2,9 @@
|
|||||||
# bash ./scripts-search/algos/Random.sh -1
|
# bash ./scripts-search/algos/Random.sh -1
|
||||||
echo script name: $0
|
echo script name: $0
|
||||||
echo $# arguments
|
echo $# arguments
|
||||||
if [ "$#" -ne 1 ] ;then
|
if [ "$#" -ne 2 ] ;then
|
||||||
echo "Input illegal number of parameters " $#
|
echo "Input illegal number of parameters " $#
|
||||||
echo "Need 1 parameters for seed"
|
echo "Need 2 parameters for dataset and seed"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
if [ "$TORCH_HOME" = "" ]; then
|
if [ "$TORCH_HOME" = "" ]; then
|
||||||
@ -14,12 +14,14 @@ else
|
|||||||
echo "TORCH_HOME : $TORCH_HOME"
|
echo "TORCH_HOME : $TORCH_HOME"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
dataset=cifar10
|
dataset=$1
|
||||||
seed=$1
|
seed=$2
|
||||||
channel=16
|
channel=16
|
||||||
num_cells=5
|
num_cells=5
|
||||||
max_nodes=4
|
max_nodes=4
|
||||||
space=nas-bench-201
|
space=nas-bench-201
|
||||||
|
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
|
||||||
|
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/RAND-${dataset}
|
save_dir=./output/search-cell-${space}/RAND-${dataset}
|
||||||
|
|
||||||
@ -27,7 +29,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/RANDOM.py \
|
|||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} \
|
--dataset ${dataset} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
--arch_nas_dataset ${benchmark_file} \
|
||||||
--time_budget 12000 \
|
--time_budget 12000 \
|
||||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||||
# --random_num 100 \
|
|
||||||
|
@ -28,6 +28,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
|||||||
else
|
else
|
||||||
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
||||||
fi
|
fi
|
||||||
|
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
|
||||||
|
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/SETN-${dataset}-BN${BN}
|
save_dir=./output/search-cell-${space}/SETN-${dataset}-BN${BN}
|
||||||
|
|
||||||
@ -35,7 +37,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/SETN.py \
|
|||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} --data_path ${data_path} \
|
--dataset ${dataset} --data_path ${data_path} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
--arch_nas_dataset ${benchmark_file} \
|
||||||
--config_path configs/nas-benchmark/algos/SETN.config \
|
--config_path configs/nas-benchmark/algos/SETN.config \
|
||||||
--track_running_stats ${BN} \
|
--track_running_stats ${BN} \
|
||||||
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
||||||
|
@ -28,6 +28,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
|||||||
else
|
else
|
||||||
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
||||||
fi
|
fi
|
||||||
|
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
|
||||||
|
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/DARTS-V1-${dataset}-BN${BN}-Gradient${gradient_clip}
|
save_dir=./output/search-cell-${space}/DARTS-V1-${dataset}-BN${BN}-Gradient${gradient_clip}
|
||||||
|
|
||||||
@ -36,7 +38,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V1.py \
|
|||||||
--dataset ${dataset} --data_path ${data_path} \
|
--dataset ${dataset} --data_path ${data_path} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--config_path configs/nas-benchmark/algos/DARTS.config \
|
--config_path configs/nas-benchmark/algos/DARTS.config \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
--arch_nas_dataset ${benchmark_file} \
|
||||||
--track_running_stats ${BN} --gradient_clip ${gradient_clip} \
|
--track_running_stats ${BN} --gradient_clip ${gradient_clip} \
|
||||||
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
||||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
Loading…
Reference in New Issue
Block a user