Upgrade NAS-Bench-201 to APIv1.3/FILEv1.1

This commit is contained in:
D-X-Y 2020-03-15 22:50:17 +11:00
parent c53a9ce407
commit fb76814369
20 changed files with 259 additions and 75 deletions

View File

@ -18,14 +18,18 @@ You can simply type `pip install nas-bench-201` to install our api. Please see s
### Preparation and Download ### Preparation and Download
The benchmark file of NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w). [deprecated] The benchmark file of NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w).
[recommended] The benchmark file of NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1OOfVPpt-lA4u2HJrXbgrRd42IbfvJMyE). The files for model weight are too large (431G) and I need some time to upload it. Please be patient, thanks for your understanding.
You can move it to anywhere you want and send its path to our API for initialization. You can move it to anywhere you want and send its path to our API for initialization.
- [2020.02.25] v1.0: `NAS-Bench-201-v1_0-e61699.pth`, where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial. - [2020.02.25] APIv1.0/FILEv1.0: `NAS-Bench-201-v1_0-e61699.pth` (2.2G), where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial.
- [2020.02.25] v1.0: The full data of each architecture can be download from [ - [2020.02.25] APIv1.0/FILEv1.0: The full data of each architecture can be download from [
NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights. NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights.
- [2020.02.25] v1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi). - [2020.02.25] APIv1.0/FILEv1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi).
- [2020.03.09] v1.2: More robust API with more functions and descriptions - [2020.03.09] APIv1.2/FILEv1.0: More robust API with more functions and descriptions
- [2020.04.01] v2.0: coming soon (results of two set of hyper-parameters avaliable on all three datasets) - [2020.03.16] APIv1.3/FILEv1.1: `NAS-Bench-201-v1_1-096897.pth` (4.7G), where `096897` is the last six digits for this file. It contains information of more trials compared to `NAS-Bench-201-v1_0-e61699.pth`, especially all models trained by 12 epochs on all datasets are avaliable.
- [2020.06.01] APIv2.0/FILEv2.0: coming soon!
The training and evaluation data used in NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan (code:4fg7)](https://pan.baidu.com/s/1XAzavPKq3zcat1yBA1L2tQ). The training and evaluation data used in NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan (code:4fg7)](https://pan.baidu.com/s/1XAzavPKq3zcat1yBA1L2tQ).
@ -92,7 +96,9 @@ print(network) # show the structure of this architecture
``` ```
If you want to load the trained weights of this created network, you need to use `api.get_net_param(123, ...)` to obtain the weights and then load it to the network. If you want to load the trained weights of this created network, you need to use `api.get_net_param(123, ...)` to obtain the weights and then load it to the network.
6. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201. 6. `api.get_more_info(...)` can return the loss / accuracy / time on training / validation / test sets, which is very helpful. For more details, please look at the comments in the get_more_info function.
7. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201.
### Detailed Instruction ### Detailed Instruction
@ -213,12 +219,14 @@ If researchers can provide better results with different hyper-parameters, we ar
- [5] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/ENAS.sh cifar10 1 -1` - [5] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/ENAS.sh cifar10 1 -1`
- [6] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 1 -1` - [6] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 1 -1`
- [7] `bash ./scripts-search/algos/R-EA.sh cifar10 3 -1` - [7] `bash ./scripts-search/algos/R-EA.sh cifar10 3 -1`
- [8] `bash ./scripts-search/algos/Random.sh -1` - [8] `bash ./scripts-search/algos/Random.sh cifar10 -1`
- [9] `bash ./scripts-search/algos/REINFORCE.sh 0.5 -1` - [9] `bash ./scripts-search/algos/REINFORCE.sh cifar10 0.5 -1`
- [10] `bash ./scripts-search/algos/BOHB.sh -1` - [10] `bash ./scripts-search/algos/BOHB.sh cifar10 -1`
In commands [1-6], the first args `cifar10` indicates the dataset name, the second args `1` indicates the behavior of BN, and the first args `-1` indicates the random seed. In commands [1-6], the first args `cifar10` indicates the dataset name, the second args `1` indicates the behavior of BN, and the first args `-1` indicates the random seed.
**Note that** since 2020 March 16, in these scripts, the default NAS-Bench-201 benchmark file has changed from `NAS-Bench-201-v1_0-e61699.pth` to `NAS-Bench-201-v1_1-096897.pth`, and thus the results could be slightly different from the original paper.
# Citation # Citation

View File

@ -1,36 +1,84 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
######################################################## ###############################################################################################
# python exps/NAS-Bench-201/test-weights.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth # Before run these commands, the files must be properly put.
######################################################## # python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_0-e61699
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897
###############################################################################################
import os, sys, time, glob, random, argparse import os, sys, time, glob, random, argparse
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from pathlib import Path from pathlib import Path
from tqdm import tqdm
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from nas_201_api import NASBench201API as API from nas_201_api import NASBench201API as API
from log_utils import time_string
from models import get_cell_based_tiny_net
from utils import weight_watcher from utils import weight_watcher
def main(meta_file, weight_dir, save_dir): def get_cor(A, B):
import pdb; return float(np.corrcoef(A, B)[0,1])
pdb.set_trace()
def evaluate(api, weight_dir, data: str, use_12epochs_result: bool, valid_or_test: bool):
norms, accs = [], []
for idx in tqdm(range(len(api))):
info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False)
if valid_or_test:
accs.append(info['valid-accuracy'])
else:
accs.append(info['test-accuracy'])
config = api.get_net_config(idx, data)
net = get_cell_based_tiny_net(config)
api.reload(weight_dir, idx)
params = api.get_net_param(idx, data, None)
cur_norms = []
for seed, param in params.items():
net.load_state_dict(param)
_, summary = weight_watcher.analyze(net, alphas=False)
cur_norms.append( summary['lognorm'] )
norms.append( float(np.mean(cur_norms)) )
api.clear_params(idx, use_12epochs_result)
correlation = get_cor(norms, accs)
print('For {:} with {:} epochs on {:} : the correlation is {:}'.format(data, 12 if use_12epochs_result else 200, 'valid' if valid_or_test else 'test', correlation))
def main(meta_file: str, weight_dir, save_dir):
api = API(meta_file)
datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
print(time_string() + ' ' + '='*50)
for data in datasets:
nums = api.statistics(data, True)
total = sum([k*v for k, v in nums.items()])
print('Using 012 epochs, trained on {:20s} : {:} trials in total ({:}).'.format(data, total, nums))
print(time_string() + ' ' + '='*50)
for data in datasets:
nums = api.statistics(data, False)
total = sum([k*v for k, v in nums.items()])
print('Using 200 epochs, trained on {:20s} : {:} trials in total ({:}).'.format(data, total, nums))
print(time_string() + ' ' + '='*50)
evaluate(api, weight_dir, 'cifar10-valid', False, True)
print('{:} finish this test.'.format(time_string()))
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.') parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.')
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file.') parser.add_argument('--base_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file and weight dir.')
parser.add_argument('--weight_dir', type=str, default=None, help='The directory path to the weights of every NAS-Bench-201 architecture.')
args = parser.parse_args() args = parser.parse_args()
save_dir = Path(args.save_dir) save_dir = Path(args.save_dir)
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
meta_file = Path(args.api_path) meta_file = Path(args.base_path + '.pth')
weight_dir = Path(args.weight_dir) weight_dir = Path(args.base_path + '-archive')
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
assert weight_dir.exists() and weight_dir.is_dir(), 'invalid path for weight dir : {:}'.format(weight_dir)
main(meta_file, weight_dir, save_dir) main(str(meta_file), weight_dir, save_dir)

View File

@ -50,10 +50,11 @@ def config2structure_func(max_nodes):
class MyWorker(Worker): class MyWorker(Worker):
def __init__(self, *args, convert_func=None, nas_bench=None, time_budget=None, **kwargs): def __init__(self, *args, convert_func=None, dataname=None, nas_bench=None, time_budget=None, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.convert_func = convert_func self.convert_func = convert_func
self.nas_bench = nas_bench self._dataname = dataname
self._nas_bench = nas_bench
self.time_budget = time_budget self.time_budget = time_budget
self.seen_archs = [] self.seen_archs = []
self.sim_cost_time = 0 self.sim_cost_time = 0
@ -64,7 +65,7 @@ class MyWorker(Worker):
assert len(self.seen_archs) > 0 assert len(self.seen_archs) > 0
best_index, best_acc = -1, None best_index, best_acc = -1, None
for arch_index in self.seen_archs: for arch_index in self.seen_archs:
info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True) info = self._nas_bench.get_more_info(arch_index, self._dataname, None, True, True)
vacc = info['valid-accuracy'] vacc = info['valid-accuracy']
if best_acc is None or best_acc < vacc: if best_acc is None or best_acc < vacc:
best_acc = vacc best_acc = vacc
@ -75,8 +76,8 @@ class MyWorker(Worker):
def compute(self, config, budget, **kwargs): def compute(self, config, budget, **kwargs):
start_time = time.time() start_time = time.time()
structure = self.convert_func( config ) structure = self.convert_func( config )
arch_index = self.nas_bench.query_index_by_arch( structure ) arch_index = self._nas_bench.query_index_by_arch( structure )
info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True) info = self._nas_bench.get_more_info(arch_index, self._dataname, None, True, True)
cur_time = info['train-all-time'] + info['valid-per-time'] cur_time = info['train-all-time'] + info['valid-per-time']
cur_vacc = info['valid-accuracy'] cur_vacc = info['valid-accuracy']
self.real_cost_time += (time.time() - start_time) self.real_cost_time += (time.time() - start_time)
@ -106,7 +107,10 @@ def main(xargs, nas_bench):
prepare_seed(xargs.rand_seed) prepare_seed(xargs.rand_seed)
logger = prepare_logger(args) logger = prepare_logger(args)
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' if xargs.dataset == 'cifar10':
dataname = 'cifar10-valid'
else:
dataname = xargs.dataset
if xargs.data_path is not None: if xargs.data_path is not None:
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
split_Fpath = 'configs/nas-benchmark/cifar-split.txt' split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
@ -148,7 +152,7 @@ def main(xargs, nas_bench):
#logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string())) #logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string()))
workers = [] workers = []
for i in range(num_workers): for i in range(num_workers):
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_budget=xargs.time_budget, run_id=hb_run_id, id=i) w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, dataname=dataname, nas_bench=nas_bench, time_budget=xargs.time_budget, run_id=hb_run_id, id=i)
w.run(background=True) w.run(background=True)
workers.append(w) workers.append(w)

View File

@ -28,7 +28,10 @@ def main(xargs, nas_bench):
prepare_seed(xargs.rand_seed) prepare_seed(xargs.rand_seed)
logger = prepare_logger(args) logger = prepare_logger(args)
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' if xargs.dataset == 'cifar10':
dataname = 'cifar10-valid'
else:
dataname = xargs.dataset
if xargs.data_path is not None: if xargs.data_path is not None:
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
split_Fpath = 'configs/nas-benchmark/cifar-split.txt' split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
@ -62,7 +65,7 @@ def main(xargs, nas_bench):
#for idx in range(xargs.random_num): #for idx in range(xargs.random_num):
while total_time_cost < xargs.time_budget: while total_time_cost < xargs.time_budget:
arch = random_arch() arch = random_arch()
accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info) accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname)
if total_time_cost + cost_time > xargs.time_budget: break if total_time_cost + cost_time > xargs.time_budget: break
else: total_time_cost += cost_time else: total_time_cost += cost_time
history.append(arch) history.append(arch)

View File

@ -33,19 +33,21 @@ class Model(object):
# This function is to mimic the training and evaluatinig procedure for a single architecture `arch`. # This function is to mimic the training and evaluatinig procedure for a single architecture `arch`.
# The time_cost is calculated as the total training time for a few (e.g., 12 epochs) plus the evaluation time for one epoch. # The time_cost is calculated as the total training time for a few (e.g., 12 epochs) plus the evaluation time for one epoch.
# For use_converged_LR = True, the architecture is trained for 12 epochs, with LR being decaded from 0.1 to 0. # For use_012_epoch_training = True, the architecture is trained for 12 epochs, with LR being decaded from 0.1 to 0.
# In this case, the LR schedular is converged. # In this case, the LR schedular is converged.
# For use_converged_LR = False, the architecture is planed to be trained for 200 epochs, but we early stop its procedure. # For use_012_epoch_training = False, the architecture is planed to be trained for 200 epochs, but we early stop its procedure.
# #
def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_converged_LR=True): def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_012_epoch_training=True):
if use_converged_LR and nas_bench is not None:
if use_012_epoch_training and nas_bench is not None:
arch_index = nas_bench.query_index_by_arch( arch ) arch_index = nas_bench.query_index_by_arch( arch )
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
info = nas_bench.get_more_info(arch_index, dataname, None, True) info = nas_bench.get_more_info(arch_index, dataname, None, True)
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
#_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs #_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs
elif not use_converged_LR and nas_bench is not None: elif not use_012_epoch_training and nas_bench is not None:
# Please use `use_converged_LR=False` for cifar10 only. # Please contact me if you want to use the following logic, because it has some potential issues.
# Please use `use_012_epoch_training=False` for cifar10 only.
# It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details) # It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details)
arch_index, nepoch = nas_bench.query_index_by_arch( arch ), 25 arch_index, nepoch = nas_bench.query_index_by_arch( arch ), 25
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
@ -64,7 +66,7 @@ def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_co
try: try:
valid_acc, time_cost = info['valid-accuracy'], estimated_train_cost + estimated_valid_cost valid_acc, time_cost = info['valid-accuracy'], estimated_train_cost + estimated_valid_cost
except: except:
valid_acc, time_cost = info['est-valid-accuracy'], estimated_train_cost + estimated_valid_cost valid_acc, time_cost = info['valtest-accuracy'], estimated_train_cost + estimated_valid_cost
else: else:
# train a model from scratch. # train a model from scratch.
raise ValueError('NOT IMPLEMENT YET') raise ValueError('NOT IMPLEMENT YET')
@ -127,7 +129,7 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
while len(population) < population_size: while len(population) < population_size:
model = Model() model = Model()
model.arch = random_arch() model.arch = random_arch()
model.accuracy, time_cost = train_and_eval(model.arch, nas_bench, extra_info) model.accuracy, time_cost = train_and_eval(model.arch, nas_bench, extra_info, dataname)
population.append(model) population.append(model)
history.append(model) history.append(model)
total_time_cost += time_cost total_time_cost += time_cost
@ -152,7 +154,7 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
child = Model() child = Model()
child.arch = mutate_arch(parent.arch) child.arch = mutate_arch(parent.arch)
total_time_cost += time.time() - start_time total_time_cost += time.time() - start_time
child.accuracy, time_cost = train_and_eval(child.arch, nas_bench, extra_info) child.accuracy, time_cost = train_and_eval(child.arch, nas_bench, extra_info, dataname)
if total_time_cost + time_cost > time_budget: # return if total_time_cost + time_cost > time_budget: # return
return history, total_time_cost return history, total_time_cost
else: else:
@ -174,7 +176,6 @@ def main(xargs, nas_bench):
prepare_seed(xargs.rand_seed) prepare_seed(xargs.rand_seed)
logger = prepare_logger(args) logger = prepare_logger(args)
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
if xargs.dataset == 'cifar10': if xargs.dataset == 'cifar10':
dataname = 'cifar10-valid' dataname = 'cifar10-valid'
else: else:

View File

@ -98,7 +98,10 @@ def main(xargs, nas_bench):
prepare_seed(xargs.rand_seed) prepare_seed(xargs.rand_seed)
logger = prepare_logger(args) logger = prepare_logger(args)
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' if xargs.dataset == 'cifar10':
dataname = 'cifar10-valid'
else:
dataname = xargs.dataset
if xargs.data_path is not None: if xargs.data_path is not None:
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
split_Fpath = 'configs/nas-benchmark/cifar-split.txt' split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
@ -148,7 +151,7 @@ def main(xargs, nas_bench):
start_time = time.time() start_time = time.time()
log_prob, action = select_action( policy ) log_prob, action = select_action( policy )
arch = policy.generate_arch( action ) arch = policy.generate_arch( action )
reward, cost_time = train_and_eval(arch, nas_bench, extra_info) reward, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname)
trace.append( (reward, arch) ) trace.append( (reward, arch) )
# accumulate time # accumulate time
if total_costs + cost_time < xargs.time_budget: if total_costs + cost_time < xargs.time_budget:

View File

@ -5,4 +5,5 @@ from .api import NASBench201API
from .api import ArchResults, ResultsCount from .api import ArchResults, ResultsCount
# NAS_BENCH_201_API_VERSION="v1.1" # [2020.02.25] # NAS_BENCH_201_API_VERSION="v1.1" # [2020.02.25]
NAS_BENCH_201_API_VERSION="v1.2" # [2020.03.09] # NAS_BENCH_201_API_VERSION="v1.2" # [2020.03.09]
NAS_BENCH_201_API_VERSION="v1.3" # [2020.03.16]

View File

@ -3,11 +3,14 @@
############################################################################################ ############################################################################################
# NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 # # NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
############################################################################################ ############################################################################################
# The history of benchmark files:
# [2020.02.25] NAS-Bench-201-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID. # [2020.02.25] NAS-Bench-201-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID.
# [2020.03.08] Next version (coming soon) # [2020.03.16] NAS-Bench-201-v1_1-096897.pth : 2225 architectures are trained once, 5439 archiitectures are trained twice, 7961 architectures are trained three times on all training sets. For the hyper-parameters with the total epochs of 12, each model is trained on CIFAR-10, CIFAR-100, ImageNet16-120 once, and is trained on CIFAR-10-VALID twice.
# #
# I'm still actively enhancing this benchmark. Please feel free to contact me if you have any question w.r.t. NAS-Bench-201.
# #
import os, copy, random, torch, numpy as np import os, copy, random, torch, numpy as np
from pathlib import Path
from typing import List, Text, Union, Dict from typing import List, Text, Union, Dict
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
@ -44,9 +47,12 @@ class NASBench201API(object):
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """ """ The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
def __init__(self, file_path_or_dict: Union[Text, Dict], verbose: bool=True): def __init__(self, file_path_or_dict: Union[Text, Dict], verbose: bool=True):
if isinstance(file_path_or_dict, str): self.filename = None
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
file_path_or_dict = str(file_path_or_dict)
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict)) if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict) assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
self.filename = Path(file_path_or_dict).name
file_path_or_dict = torch.load(file_path_or_dict) file_path_or_dict = torch.load(file_path_or_dict)
elif isinstance(file_path_or_dict, dict): elif isinstance(file_path_or_dict, dict):
file_path_or_dict = copy.deepcopy( file_path_or_dict ) file_path_or_dict = copy.deepcopy( file_path_or_dict )
@ -76,7 +82,7 @@ class NASBench201API(object):
return len(self.meta_archs) return len(self.meta_archs)
def __repr__(self): def __repr__(self):
return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs))) return ('{name}({num}/{total} architectures, file={filename})'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs), filename=self.filename))
def random(self): def random(self):
"""Return a random index of all architectures.""" """Return a random index of all architectures."""
@ -98,9 +104,10 @@ class NASBench201API(object):
else: arch_index = -1 else: arch_index = -1
return arch_index return arch_index
# Overwrite all information of the 'index'-th architecture in the search space.
# It will load its data from 'archive_root'.
def reload(self, archive_root: Text, index: int): def reload(self, archive_root: Text, index: int):
"""Overwrite all information of the 'index'-th architecture in the search space.
It will load its data from 'archive_root'.
"""
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root) assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index)) xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index))
assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index) assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index)
@ -109,6 +116,13 @@ class NASBench201API(object):
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path) assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] ) self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] )
self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] ) self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] )
def clear_params(self, index: int, use_12epochs_result: bool):
"""Remove the architecture's weights to save memory."""
if use_12epochs_result: arch2infos = self.arch2infos_less
else : arch2infos = self.arch2infos_full
archresult = arch2infos[index]
archresult.clear_params()
# This function is used to query the information of a specific archiitecture # This function is used to query the information of a specific archiitecture
# 'arch' can be an architecture index or an architecture string # 'arch' can be an architecture index or an architecture string
@ -162,6 +176,7 @@ class NASBench201API(object):
return archInfo return archInfo
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, use_12epochs_result=False): def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, use_12epochs_result=False):
"""Find the architecture with the highest accuracy based on some constraints."""
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full else : basestr, arch2infos = '200epochs', self.arch2infos_full
best_index, highest_accuracy = -1, None best_index, highest_accuracy = -1, None
@ -255,6 +270,65 @@ class NASBench201API(object):
# `is_random` # `is_random`
# When is_random=True, the performance of a random architecture will be returned # When is_random=True, the performance of a random architecture will be returned
# When is_random=False, the performanceo of all trials will be averaged. # When is_random=False, the performanceo of all trials will be averaged.
def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
archresult = arch2infos[index]
# if randomly select one trial, select the seed at first
if isinstance(is_random, bool) and is_random:
seeds = archresult.get_dataset_seeds(dataset)
is_random = random.choice(seeds)
# collect the training information
train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random)
total = train_info['iepoch'] + 1
xinfo = {'train-loss' : train_info['loss'],
'train-accuracy': train_info['accuracy'],
'train-per-time': train_info['all_time'] / total,
'train-all-time': train_info['all_time']}
# collect the evaluation information
if dataset == 'cifar10-valid':
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
try:
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
except:
test_info = None
valtest_info = None
else:
try: # collect results on the proposed test set
if dataset == 'cifar10':
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else:
test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
except:
test_info = None
try: # collect results on the proposed validation set
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
except:
valid_info = None
try:
if dataset != 'cifar10':
valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else:
valtest_info = None
except:
valtest_info = None
if valid_info is not None:
xinfo['valid-loss'] = valid_info['loss']
xinfo['valid-accuracy'] = valid_info['accuracy']
xinfo['valid-per-time'] = valid_info['all_time'] / total
xinfo['valid-all-time'] = valid_info['all_time']
if test_info is not None:
xinfo['test-loss'] = test_info['loss']
xinfo['test-accuracy'] = test_info['accuracy']
xinfo['test-per-time'] = test_info['all_time'] / total
xinfo['test-all-time'] = test_info['all_time']
if valtest_info is not None:
xinfo['valtest-loss'] = valtest_info['loss']
xinfo['valtest-accuracy'] = valtest_info['accuracy']
xinfo['valtest-per-time'] = valtest_info['all_time'] / total
xinfo['valtest-all-time'] = valtest_info['all_time']
return xinfo
""" # The following logic is deprecated after March 15 2020, where the benchmark file upgrades from NAS-Bench-201-v1_0-e61699.pth to NAS-Bench-201-v1_1-096897.pth.
def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True): def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full else : basestr, arch2infos = '200epochs', self.arch2infos_full
@ -312,6 +386,7 @@ class NASBench201API(object):
xifo['est-valid-loss'] = est_valid_info['loss'] xifo['est-valid-loss'] = est_valid_info['loss']
xifo['est-valid-accuracy'] = est_valid_info['accuracy'] xifo['est-valid-accuracy'] = est_valid_info['accuracy']
return xifo return xifo
"""
def show(self, index: int = -1) -> None: def show(self, index: int = -1) -> None:
@ -349,6 +424,26 @@ class NASBench201API(object):
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs))) print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
def statistics(self, dataset: Text, use_12epochs_result: bool) -> Dict[int, int]:
"""
This function will count the number of total trials.
"""
valid_datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
if dataset not in valid_datasets:
raise ValueError('{:} not in {:}'.format(dataset, valid_datasets))
if use_12epochs_result: arch2infos = self.arch2infos_less
else : arch2infos = self.arch2infos_full
nums = defaultdict(lambda: 0)
for index in range(len(self)):
archInfo = arch2infos[index]
dataset_seed = archInfo.dataset_seed
if dataset not in dataset_seed:
nums[0] += 1
else:
nums[len(dataset_seed[dataset])] += 1
return dict(nums)
@staticmethod @staticmethod
def str2lists(arch_str: Text) -> List[tuple]: def str2lists(arch_str: Text) -> List[tuple]:
""" """

View File

@ -2,9 +2,9 @@
# bash ./scripts-search/algos/BOHB.sh -1 # bash ./scripts-search/algos/BOHB.sh -1
echo script name: $0 echo script name: $0
echo $# arguments echo $# arguments
if [ "$#" -ne 1 ] ;then if [ "$#" -ne 2 ] ;then
echo "Input illegal number of parameters " $# echo "Input illegal number of parameters " $#
echo "Need 1 parameters for seed" echo "Need 2 parameters for dataset and seed"
exit 1 exit 1
fi fi
if [ "$TORCH_HOME" = "" ]; then if [ "$TORCH_HOME" = "" ]; then
@ -14,12 +14,14 @@ else
echo "TORCH_HOME : $TORCH_HOME" echo "TORCH_HOME : $TORCH_HOME"
fi fi
dataset=cifar10 dataset=$1
seed=$1 seed=$2
channel=16 channel=16
num_cells=5 num_cells=5
max_nodes=4 max_nodes=4
space=nas-bench-201 space=nas-bench-201
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
save_dir=./output/search-cell-${space}/BOHB-${dataset} save_dir=./output/search-cell-${space}/BOHB-${dataset}
@ -27,7 +29,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ --save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
--dataset ${dataset} \ --dataset ${dataset} \
--search_space_name ${space} \ --search_space_name ${space} \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ --arch_nas_dataset ${benchmark_file} \
--time_budget 12000 \ --time_budget 12000 \
--n_iters 50 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 \ --n_iters 50 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 \
--workers 4 --print_freq 200 --rand_seed ${seed} --workers 4 --print_freq 200 --rand_seed ${seed}

View File

@ -27,6 +27,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
else else
data_path="$TORCH_HOME/cifar.python/ImageNet16" data_path="$TORCH_HOME/cifar.python/ImageNet16"
fi fi
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
save_dir=./output/search-cell-${space}/DARTS-V1-${dataset}-BN${BN} save_dir=./output/search-cell-${space}/DARTS-V1-${dataset}-BN${BN}
@ -35,7 +37,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V1.py \
--dataset ${dataset} --data_path ${data_path} \ --dataset ${dataset} --data_path ${data_path} \
--search_space_name ${space} \ --search_space_name ${space} \
--config_path configs/nas-benchmark/algos/DARTS.config \ --config_path configs/nas-benchmark/algos/DARTS.config \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ --arch_nas_dataset ${benchmark_file} \
--track_running_stats ${BN} \ --track_running_stats ${BN} \
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ --arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
--workers 4 --print_freq 200 --rand_seed ${seed} --workers 4 --print_freq 200 --rand_seed ${seed}

View File

@ -27,6 +27,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
else else
data_path="$TORCH_HOME/cifar.python/ImageNet16" data_path="$TORCH_HOME/cifar.python/ImageNet16"
fi fi
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
save_dir=./output/search-cell-${space}/DARTS-V2-${dataset}-BN${BN} save_dir=./output/search-cell-${space}/DARTS-V2-${dataset}-BN${BN}
@ -35,7 +37,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V2.py \
--dataset ${dataset} --data_path ${data_path} \ --dataset ${dataset} --data_path ${data_path} \
--search_space_name ${space} \ --search_space_name ${space} \
--config_path configs/nas-benchmark/algos/DARTS.config \ --config_path configs/nas-benchmark/algos/DARTS.config \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ --arch_nas_dataset ${benchmark_file} \
--track_running_stats ${BN} \ --track_running_stats ${BN} \
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ --arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
--workers 4 --print_freq 200 --rand_seed ${seed} --workers 4 --print_freq 200 --rand_seed ${seed}

View File

@ -28,6 +28,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
else else
data_path="$TORCH_HOME/cifar.python/ImageNet16" data_path="$TORCH_HOME/cifar.python/ImageNet16"
fi fi
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
save_dir=./output/search-cell-${space}/ENAS-${dataset}-BN${BN} save_dir=./output/search-cell-${space}/ENAS-${dataset}-BN${BN}
@ -35,7 +37,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/ENAS.py \
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ --save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
--dataset ${dataset} --data_path ${data_path} \ --dataset ${dataset} --data_path ${data_path} \
--search_space_name ${space} \ --search_space_name ${space} \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ --arch_nas_dataset ${benchmark_file} \
--track_running_stats ${BN} \ --track_running_stats ${BN} \
--config_path ./configs/nas-benchmark/algos/ENAS.config \ --config_path ./configs/nas-benchmark/algos/ENAS.config \
--controller_entropy_weight 0.0001 \ --controller_entropy_weight 0.0001 \

View File

@ -27,6 +27,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
else else
data_path="$TORCH_HOME/cifar.python/ImageNet16" data_path="$TORCH_HOME/cifar.python/ImageNet16"
fi fi
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
save_dir=./output/search-cell-${space}/GDAS-${dataset}-BN${BN} save_dir=./output/search-cell-${space}/GDAS-${dataset}-BN${BN}
@ -34,7 +36,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/GDAS.py \
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ --save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
--dataset ${dataset} --data_path ${data_path} \ --dataset ${dataset} --data_path ${data_path} \
--search_space_name ${space} \ --search_space_name ${space} \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ --arch_nas_dataset ${benchmark_file} \
--config_path configs/nas-benchmark/algos/GDAS.config \ --config_path configs/nas-benchmark/algos/GDAS.config \
--tau_max 10 --tau_min 0.1 --track_running_stats ${BN} \ --tau_max 10 --tau_min 0.1 --track_running_stats ${BN} \
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ --arch_learning_rate 0.0003 --arch_weight_decay 0.001 \

View File

@ -23,6 +23,8 @@ channel=16
num_cells=5 num_cells=5
max_nodes=4 max_nodes=4
space=nas-bench-201 space=nas-bench-201
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
save_dir=./output/search-cell-${space}/R-EA-${dataset}-SS${sample_size} save_dir=./output/search-cell-${space}/R-EA-${dataset}-SS${sample_size}
@ -30,7 +32,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ --save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
--dataset ${dataset} \ --dataset ${dataset} \
--search_space_name ${space} \ --search_space_name ${space} \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ --arch_nas_dataset ${benchmark_file} \
--time_budget 12000 \ --time_budget 12000 \
--ea_cycles 200 --ea_population 10 --ea_sample_size ${sample_size} --ea_fast_by_api 1 \ --ea_cycles 200 --ea_population 10 --ea_sample_size ${sample_size} --ea_fast_by_api 1 \
--workers 4 --print_freq 200 --rand_seed ${seed} --workers 4 --print_freq 200 --rand_seed ${seed}

View File

@ -28,6 +28,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
else else
data_path="$TORCH_HOME/cifar.python/ImageNet16" data_path="$TORCH_HOME/cifar.python/ImageNet16"
fi fi
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
save_dir=./output/search-cell-${space}/RANDOM-NAS-${dataset}-BN${BN} save_dir=./output/search-cell-${space}/RANDOM-NAS-${dataset}-BN${BN}
@ -36,7 +38,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/RANDOM-NAS.py \
--dataset ${dataset} --data_path ${data_path} \ --dataset ${dataset} --data_path ${data_path} \
--search_space_name ${space} \ --search_space_name ${space} \
--track_running_stats ${BN} \ --track_running_stats ${BN} \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ --arch_nas_dataset ${benchmark_file} \
--config_path ./configs/nas-benchmark/algos/RANDOM.config \ --config_path ./configs/nas-benchmark/algos/RANDOM.config \
--select_num 100 \ --select_num 100 \
--workers 4 --print_freq 200 --rand_seed ${seed} --workers 4 --print_freq 200 --rand_seed ${seed}

View File

@ -2,9 +2,9 @@
# bash ./scripts-search/algos/REINFORCE.sh 0.001 -1 # bash ./scripts-search/algos/REINFORCE.sh 0.001 -1
echo script name: $0 echo script name: $0
echo $# arguments echo $# arguments
if [ "$#" -ne 2 ] ;then if [ "$#" -ne 3 ] ;then
echo "Input illegal number of parameters " $# echo "Input illegal number of parameters " $#
echo "Need 2 parameters for LR and seed" echo "Need 3 parameters for dataset, LR, and seed"
exit 1 exit 1
fi fi
if [ "$TORCH_HOME" = "" ]; then if [ "$TORCH_HOME" = "" ]; then
@ -14,13 +14,15 @@ else
echo "TORCH_HOME : $TORCH_HOME" echo "TORCH_HOME : $TORCH_HOME"
fi fi
dataset=cifar10 dataset=$1
LR=$1 LR=$2
seed=$2 seed=$3
channel=16 channel=16
num_cells=5 num_cells=5
max_nodes=4 max_nodes=4
space=nas-bench-201 space=nas-bench-201
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
save_dir=./output/search-cell-${space}/REINFORCE-${dataset}-${LR} save_dir=./output/search-cell-${space}/REINFORCE-${dataset}-${LR}
@ -28,7 +30,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/reinforce.py \
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ --save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
--dataset ${dataset} \ --dataset ${dataset} \
--search_space_name ${space} \ --search_space_name ${space} \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ --arch_nas_dataset ${benchmark_file} \
--time_budget 12000 \ --time_budget 12000 \
--learning_rate ${LR} --EMA_momentum 0.9 \ --learning_rate ${LR} --EMA_momentum 0.9 \
--workers 4 --print_freq 200 --rand_seed ${seed} --workers 4 --print_freq 200 --rand_seed ${seed}

View File

@ -2,9 +2,9 @@
# bash ./scripts-search/algos/Random.sh -1 # bash ./scripts-search/algos/Random.sh -1
echo script name: $0 echo script name: $0
echo $# arguments echo $# arguments
if [ "$#" -ne 1 ] ;then if [ "$#" -ne 2 ] ;then
echo "Input illegal number of parameters " $# echo "Input illegal number of parameters " $#
echo "Need 1 parameters for seed" echo "Need 2 parameters for dataset and seed"
exit 1 exit 1
fi fi
if [ "$TORCH_HOME" = "" ]; then if [ "$TORCH_HOME" = "" ]; then
@ -14,12 +14,14 @@ else
echo "TORCH_HOME : $TORCH_HOME" echo "TORCH_HOME : $TORCH_HOME"
fi fi
dataset=cifar10 dataset=$1
seed=$1 seed=$2
channel=16 channel=16
num_cells=5 num_cells=5
max_nodes=4 max_nodes=4
space=nas-bench-201 space=nas-bench-201
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
save_dir=./output/search-cell-${space}/RAND-${dataset} save_dir=./output/search-cell-${space}/RAND-${dataset}
@ -27,7 +29,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/RANDOM.py \
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ --save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
--dataset ${dataset} \ --dataset ${dataset} \
--search_space_name ${space} \ --search_space_name ${space} \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ --arch_nas_dataset ${benchmark_file} \
--time_budget 12000 \ --time_budget 12000 \
--workers 4 --print_freq 200 --rand_seed ${seed} --workers 4 --print_freq 200 --rand_seed ${seed}
# --random_num 100 \

View File

@ -28,6 +28,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
else else
data_path="$TORCH_HOME/cifar.python/ImageNet16" data_path="$TORCH_HOME/cifar.python/ImageNet16"
fi fi
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
save_dir=./output/search-cell-${space}/SETN-${dataset}-BN${BN} save_dir=./output/search-cell-${space}/SETN-${dataset}-BN${BN}
@ -35,7 +37,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/SETN.py \
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ --save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
--dataset ${dataset} --data_path ${data_path} \ --dataset ${dataset} --data_path ${data_path} \
--search_space_name ${space} \ --search_space_name ${space} \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ --arch_nas_dataset ${benchmark_file} \
--config_path configs/nas-benchmark/algos/SETN.config \ --config_path configs/nas-benchmark/algos/SETN.config \
--track_running_stats ${BN} \ --track_running_stats ${BN} \
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ --arch_learning_rate 0.0003 --arch_weight_decay 0.001 \

View File

@ -28,6 +28,8 @@ if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
else else
data_path="$TORCH_HOME/cifar.python/ImageNet16" data_path="$TORCH_HOME/cifar.python/ImageNet16"
fi fi
#benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth
benchmark_file=${TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth
save_dir=./output/search-cell-${space}/DARTS-V1-${dataset}-BN${BN}-Gradient${gradient_clip} save_dir=./output/search-cell-${space}/DARTS-V1-${dataset}-BN${BN}-Gradient${gradient_clip}
@ -36,7 +38,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V1.py \
--dataset ${dataset} --data_path ${data_path} \ --dataset ${dataset} --data_path ${data_path} \
--search_space_name ${space} \ --search_space_name ${space} \
--config_path configs/nas-benchmark/algos/DARTS.config \ --config_path configs/nas-benchmark/algos/DARTS.config \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ --arch_nas_dataset ${benchmark_file} \
--track_running_stats ${BN} --gradient_clip ${gradient_clip} \ --track_running_stats ${BN} --gradient_clip ${gradient_clip} \
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ --arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
--workers 4 --print_freq 200 --rand_seed ${seed} --workers 4 --print_freq 200 --rand_seed ${seed}