update baseline NAS algos

This commit is contained in:
D-X-Y 2019-11-14 13:55:42 +11:00
parent 5c73aeb50b
commit 7843940846
13 changed files with 924 additions and 33 deletions

2
.gitignore vendored
View File

@ -110,4 +110,4 @@ logs
# snapshot # snapshot
a.pth a.pth
cal-merge.sh cal-merge*.sh

View File

@ -9,6 +9,51 @@ In this Markdown file, we provide:
Note: please use `PyTorch >= 1.1.0` and `Python >= 3.6.0`. Note: please use `PyTorch >= 1.1.0` and `Python >= 3.6.0`.
## How to Use AA-NAS-Bench
1. Creating AA-NAS-Bench API from a file:
```
from aa_nas_api import AANASBenchAPI
api = AANASBenchAPI('$path_to_meta_aa_nas_bench_file')
```
2. Show the number of architectures `len(api)` and each architecture `api[i]`:
```
num = len(api)
for i, arch_str in enumerate(api):
print ('{:5d}/{:5d} : {:}'.format(i, len(api), arch_str))
```
3. Show the results of all trials for a single architecture:
```
# show all information for a specific architecture
api.show(1)
api.show(2)
# show the mean loss and accuracy of an architecture
info = api.query_meta_info_by_index(1)
loss, accuracy = info.get_metrics('cifar10', 'train')
flops, params, latency = info.get_comput_costs('cifar100')
# get the detailed information
results = api.query_by_index(1, 'cifar100')
print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1]))
print ('Latency : {:}'.format(results[0].get_latency()))
print ('Train Info : {:}'.format(results[0].get_train()))
print ('Valid Info : {:}'.format(results[0].get_eval('x-valid')))
print ('Test Info : {:}'.format(results[0].get_eval('x-test')))
# for the metric after a specific epoch
print ('Train Info [10-th epoch] : {:}'.format(results[0].get_train(10)))
```
4. Query the index of an architecture by string
```
index = api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|')
api.show(index)
```
5. For other usages, please see `lib/aa_nas_api/api.py`
## Instruction to Generate AA-NAS-Bench ## Instruction to Generate AA-NAS-Bench
1. generate the meta file for AA-NAS-Bench using the following script, where `AA-NAS-BENCH` indicates the name and `4` indicates the maximum number of nodes in a cell. 1. generate the meta file for AA-NAS-Bench using the following script, where `AA-NAS-BENCH` indicates the name and `4` indicates the maximum number of nodes in a cell.
@ -46,19 +91,18 @@ CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/AA-NAS-train-net.sh resnet 16 5
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/AA-NAS-train-net.sh '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|skip_connect~1|skip_connect~2|' 16 5 CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/AA-NAS-train-net.sh '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|skip_connect~1|skip_connect~2|' 16 5
``` ```
[option] load the parameters of a trained network. ## To Reproduce 10 Baseline NAS Algorithms in AA-NAS-Bench
```
```
## To reproduce 10 baseline NAS algorithms in AA-NAS-Bench
We have tried our best to implement each method. However, still, some algorithms might obtain non-optimal results since their hyper-parameters might not fit our AA-NAS-Bench. We have tried our best to implement each method. However, still, some algorithms might obtain non-optimal results since their hyper-parameters might not fit our AA-NAS-Bench.
If researchers can provide better results with different hyper-parameters, we are happy to update results according to the new experimental results. We also welcome more NAS algorithms to test on our dataset and would include them accordingly. If researchers can provide better results with different hyper-parameters, we are happy to update results according to the new experimental results. We also welcome more NAS algorithms to test on our dataset and would include them accordingly.
- `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1` -[1] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1`
- `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1` -[2] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1`
- `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 -1` -[3] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 -1`
- `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 -1` -[4] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 -1`
- `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/ENAS.sh cifar10 -1` -[5] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/ENAS.sh cifar10 -1`
- `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 -1` -[6] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 -1`
-[7] `bash ./scripts-search/algos/R-EA.sh -1`
-[8] `bash ./scripts-search/algos/Random.sh -1`
-[9] `bash ./scripts-search/algos/REINFORCE.sh -1`
-[10] `bash ./scripts-search/algos/BOHB.sh -1`

View File

@ -1,6 +1,3 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, argparse, collections import os, sys, time, argparse, collections
from copy import deepcopy from copy import deepcopy
import torch import torch
@ -167,7 +164,6 @@ def simplify(save_dir, meta_file, basestr, target_dir):
arch_time = AverageMeter() arch_time = AverageMeter()
for idx, arch_index in enumerate(arch_indexes): for idx, arch_index in enumerate(arch_indexes):
checkpoints = list(target_directory.glob('arch-{:}-seed-*.pth'.format(arch_index))) checkpoints = list(target_directory.glob('arch-{:}-seed-*.pth'.format(arch_index)))
arch_info = account_one_arch(arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict)
try: try:
arch_info = account_one_arch(arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict) arch_info = account_one_arch(arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict)
num_seeds[ len(checkpoints) ] += 1 num_seeds[ len(checkpoints) ] += 1
@ -181,7 +177,7 @@ def simplify(save_dir, meta_file, basestr, target_dir):
torch.save(arch_info.state_dict(), to_save_allarc / '{:}-FULL.pth'.format(arch_index)) torch.save(arch_info.state_dict(), to_save_allarc / '{:}-FULL.pth'.format(arch_index))
#torch.save(arch_info, to_save_allarc / '{:}-FULL.pth'.format(arch_index)) #torch.save(arch_info, to_save_allarc / '{:}-FULL.pth'.format(arch_index))
arch_info.clear_params() arch_info.clear_params()
torch.save(arch_info, to_save_allarc / '{:}-SIMPLE.pth'.format(arch_index)) torch.save(arch_info.state_dict(), to_save_allarc / '{:}-SIMPLE.pth'.format(arch_index))
# measure elapsed time # measure elapsed time
arch_time.update(time.time() - end_time) arch_time.update(time.time() - end_time)
end_time = time.time() end_time = time.time()
@ -241,7 +237,7 @@ def merge_all(save_dir, meta_file, basestr):
xevalindexs = sub_ckps['evaluated_indexes'] xevalindexs = sub_ckps['evaluated_indexes']
for eval_index in xevalindexs: for eval_index in xevalindexs:
assert eval_index not in evaluated_indexes and eval_index not in arch2infos assert eval_index not in evaluated_indexes and eval_index not in arch2infos
arch2infos[eval_index] = xarch2infos[eval_index] arch2infos[eval_index] = xarch2infos[eval_index].state_dict()
evaluated_indexes.add( eval_index ) evaluated_indexes.add( eval_index )
print ('{:} [{:03d}/{:03d}] merge data from {:} with {:} models.'.format(time_string(), IDX, len(subdir2archs), ckp_path, len(xevalindexs))) print ('{:} [{:03d}/{:03d}] merge data from {:} with {:} models.'.format(time_string(), IDX, len(subdir2archs), ckp_path, len(xevalindexs)))
else: else:

View File

@ -58,8 +58,10 @@ def test_aa_nas_api():
arch_result = ArchResults.create_from_state_dict('output/AA-NAS-BENCH-4/simplifies/architectures/000002-FULL.pth') arch_result = ArchResults.create_from_state_dict('output/AA-NAS-BENCH-4/simplifies/architectures/000002-FULL.pth')
arch_result.show(True) arch_result.show(True)
result = arch_result.query('cifar100') result = arch_result.query('cifar100')
#xfile = '/home/dxy/search-configures/output/TINY-NAS-BENCHMARK-4/simplifies/C16-N5-final-infos.pth' #xfile = 'output/AA-NAS-BENCH-4/simplifies/000000-000389-C16-N5.pth'
#api = AANASBenchAPI(xfile) api = AANASBenchAPI('output/AA-NAS-BENCH-4/simplifies/C16-N5-final-infos.pth')
results = api.query_by_index(1, 'cifar100')
print ('There are {:} trials for this architecture [{:}] on cifar10'.format(len(results), api[1]))
import pdb; pdb.set_trace() import pdb; pdb.set_trace()
if __name__ == '__main__': if __name__ == '__main__':

177
exps/algos/BOHB.py Normal file
View File

@ -0,0 +1,177 @@
##################################################
# required to install hpbandster #################
##################################################
import os, sys, time, glob, random, argparse
import numpy as np, collections
from copy import deepcopy
from pathlib import Path
import torch
import torch.nn as nn
from torch.distributions import Categorical
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from aa_nas_api import AANASBenchAPI
from models import CellStructure, get_search_spaces
from R_EA import train_and_eval
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018
import ConfigSpace
from hpbandster.optimizers.bohb import BOHB
import hpbandster.core.nameserver as hpns
from hpbandster.core.worker import Worker
def get_configuration_space(max_nodes, search_space):
cs = ConfigSpace.ConfigurationSpace()
#edge2index = {}
for i in range(1, max_nodes):
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space))
return cs
def config2structure_func(max_nodes):
def config2structure(config):
genotypes = []
for i in range(1, max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
op_name = config[node_str]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return CellStructure( genotypes )
return config2structure
class MyWorker(Worker):
def __init__(self, *args, sleep_interval=0, convert_func=None, nas_bench=None, **kwargs):
super().__init__(*args, **kwargs)
self.sleep_interval = sleep_interval
self.convert_func = convert_func
self.nas_bench = nas_bench
self.test_time = 0
def compute(self, config, budget, **kwargs):
structure = self.convert_func( config )
reward = train_and_eval(structure, self.nas_bench, None)
self.test_time += 1
return ({
'loss': float(100-reward),
'info': None})
def main(xargs):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads( xargs.workers )
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log('Load split file from {:}'.format(split_Fpath))
config_path = 'configs/nas-benchmark/algos/R-EA.config'
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader}
# nas dataset load
assert xargs.arch_nas_dataset is not None and os.path.isfile(xargs.arch_nas_dataset)
search_space = get_search_spaces('cell', xargs.search_space_name)
cs = get_configuration_space(xargs.max_nodes, search_space)
config2structure = config2structure_func(xargs.max_nodes)
hb_run_id = '0'
NS = hpns.NameServer(run_id=hb_run_id, host='localhost', port=0)
ns_host, ns_port = NS.start()
num_workers = 1
nas_bench = AANASBenchAPI(xargs.arch_nas_dataset)
logger.log('{:} Create AA-NAS-BENCH-API DONE'.format(time_string()))
workers = []
for i in range(num_workers):
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, run_id=hb_run_id, id=i)
w.run(background=True)
workers.append(w)
bohb = BOHB(configspace=cs,
run_id=hb_run_id,
eta=3, min_budget=3, max_budget=108,
nameserver=ns_host,
nameserver_port=ns_port,
num_samples=xargs.num_samples,
random_fraction=xargs.random_fraction, bandwidth_factor=xargs.bandwidth_factor,
ping_interval=10, min_bandwidth=xargs.min_bandwidth)
# optimization_strategy=xargs.strategy, num_samples=xargs.num_samples,
results = bohb.run(xargs.n_iters, min_n_workers=num_workers)
bohb.shutdown(shutdown_workers=True)
NS.shutdown()
id2config = results.get_id2config_mapping()
incumbent = results.get_incumbent_id()
logger.log('Best found configuration: {:}'.format(id2config[incumbent]['config']))
best_arch = config2structure( id2config[incumbent]['config'] )
if nas_bench is not None:
info = nas_bench.query_by_arch( best_arch )
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
else : logger.log('{:}'.format(info))
logger.log('-'*100)
logger.log('workers : {:}'.format(workers[0].test_time))
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser("Regularized Evolution Algorithm")
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells
parser.add_argument('--search_space_name', type=str, help='The search space name.')
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
# BOHB
parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function')
parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function')
parser.add_argument('--random_fraction', default=.33, type=float, nargs='?', help='fraction of random configurations')
parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth')
parser.add_argument('--n_iters', default=100, type=int, nargs='?', help='number of iterations for optimization method')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, help='manual seed')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
main(args)

95
exps/algos/RANDOM.py Normal file
View File

@ -0,0 +1,95 @@
import os, sys, time, glob, random, argparse
import numpy as np, collections
from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_search_spaces
from aa_nas_api import AANASBenchAPI
from R_EA import train_and_eval, random_architecture_func
def main(xargs):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads( xargs.workers )
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log('Load split file from {:}'.format(split_Fpath))
config_path = 'configs/nas-benchmark/algos/R-EA.config'
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader}
search_space = get_search_spaces('cell', xargs.search_space_name)
random_arch = random_architecture_func(xargs.max_nodes, search_space)
#x =random_arch() ; y = mutate_arch(x)
if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
nas_bench = None
else:
logger.log('{:} build NAS-Benchmark-API from {:}'.format(time_string(), xargs.arch_nas_dataset))
nas_bench = AANASBenchAPI(xargs.arch_nas_dataset)
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
best_arch, best_acc = None, -1
for idx in range(xargs.random_num):
arch = random_arch()
accuracy = train_and_eval(arch, nas_bench, extra_info)
if best_arch is None or best_acc < accuracy:
best_acc, best_arch = accuracy, arch
logger.log('[{:03d}/{:03d}] : {:} : accuracy = {:.2f}%'.format(idx, xargs.random_num, arch, accuracy))
logger.log('{:} best arch is {:}, accuracy = {:.2f}%'.format(time_string(), best_arch, best_acc))
if nas_bench is not None:
info = nas_bench.query_by_arch( best_arch )
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
else : logger.log('{:}'.format(info))
logger.log('-'*100)
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser("Regularized Evolution Algorithm")
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells
parser.add_argument('--search_space_name', type=str, help='The search space name.')
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--random_num', type=int, help='The number of random selected architectures.')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, help='manual seed')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
main(args)

230
exps/algos/R_EA.py Normal file
View File

@ -0,0 +1,230 @@
import os, sys, time, glob, random, argparse
import numpy as np, collections
from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from aa_nas_api import AANASBenchAPI
from models import CellStructure, get_search_spaces
# Regularized Evolution for Image Classifier Architecture Search
class Model(object):
def __init__(self):
self.arch = None
self.accuracy = None
def __str__(self):
"""Prints a readable version of this bitstring."""
return '{:}'.format(self.arch)
def valid_func(xloader, network, criterion):
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.train()
end = time.time()
with torch.no_grad():
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
def train_and_eval(arch, nas_bench, extra_info):
if nas_bench is not None:
arch_index = nas_bench.query_index_by_arch( arch )
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
info = nas_bench.arch2infos[ arch_index ]
_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25) # use the validation accuracy after 25 training epochs
else:
# train a model from scratch.
raise ValueError('NOT IMPLEMENT YET')
return valid_acc
def random_architecture_func(max_nodes, op_names):
# return a random architecture
def random_architecture():
genotypes = []
for i in range(1, max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
op_name = random.choice( op_names )
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return CellStructure( genotypes )
return random_architecture
def mutate_arch_func(op_names):
"""Computes the architecture for a child of the given parent architecture.
The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another.
"""
def mutate_arch_func(parent_arch):
child_arch = deepcopy( parent_arch )
node_id = random.randint(0, len(child_arch.nodes)-1)
node_info = list( child_arch.nodes[node_id] )
snode_id = random.randint(0, len(node_info)-1)
xop = random.choice( op_names )
while xop == node_info[snode_id][0]:
xop = random.choice( op_names )
node_info[snode_id] = (xop, node_info[snode_id][1])
child_arch.nodes[node_id] = tuple( node_info )
return child_arch
return mutate_arch_func
def regularized_evolution(cycles, population_size, sample_size, random_arch, mutate_arch, nas_bench, extra_info):
"""Algorithm for regularized evolution (i.e. aging evolution).
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image
Classifier Architecture Search".
Args:
cycles: the number of cycles the algorithm should run for.
population_size: the number of individuals to keep in the population.
sample_size: the number of individuals that should participate in each tournament.
Returns:
history: a list of `Model` instances, representing all the models computed
during the evolution experiment.
"""
population = collections.deque()
history = [] # Not used by the algorithm, only used to report results.
# Initialize the population with random models.
while len(population) < population_size:
model = Model()
model.arch = random_arch()
model.accuracy = train_and_eval(model.arch, nas_bench, extra_info)
population.append(model)
history.append(model)
# Carry out evolution in cycles. Each cycle produces a model and removes
# another.
while len(history) < cycles:
# Sample randomly chosen models from the current population.
sample = []
while len(sample) < sample_size:
# Inefficient, but written this way for clarity. In the case of neural
# nets, the efficiency of this line is irrelevant because training neural
# nets is the rate-determining step.
candidate = random.choice(list(population))
sample.append(candidate)
# The parent is the best model in the sample.
parent = max(sample, key=lambda i: i.accuracy)
# Create the child model and store it.
child = Model()
child.arch = mutate_arch(parent.arch)
child.accuracy = train_and_eval(child.arch, nas_bench, extra_info)
population.append(child)
history.append(child)
# Remove the oldest model.
population.popleft()
return history
def main(xargs):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads( xargs.workers )
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log('Load split file from {:}'.format(split_Fpath))
config_path = 'configs/nas-benchmark/algos/R-EA.config'
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader}
search_space = get_search_spaces('cell', xargs.search_space_name)
random_arch = random_architecture_func(xargs.max_nodes, search_space)
mutate_arch = mutate_arch_func(search_space)
#x =random_arch() ; y = mutate_arch(x)
if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
nas_bench = None
else:
logger.log('{:} build NAS-Benchmark-API from {:}'.format(time_string(), xargs.arch_nas_dataset))
nas_bench = AANASBenchAPI(xargs.arch_nas_dataset)
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
history = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info)
logger.log('{:} regularized_evolution finish with history of {:} arch.'.format(time_string(), len(history)))
best_arch = max(history, key=lambda i: i.accuracy)
best_arch = best_arch.arch
logger.log('{:} best arch is {:}'.format(time_string(), best_arch))
if nas_bench is not None:
info = nas_bench.query_by_arch( best_arch )
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
else : logger.log('{:}'.format(info))
logger.log('-'*100)
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser("Regularized Evolution Algorithm")
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells
parser.add_argument('--search_space_name', type=str, help='The search space name.')
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--ea_cycles', type=int, help='The number of cycles in EA.')
parser.add_argument('--ea_population', type=int, help='The population size in EA.')
parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.')
parser.add_argument('--ea_fast_by_api', type=int, help='Use our API to speed up the experiments or not.')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, help='manual seed')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
args.ea_fast_by_api = args.ea_fast_by_api > 0
main(args)

187
exps/algos/reinforce.py Normal file
View File

@ -0,0 +1,187 @@
##################################################
# modified from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py
##################################################
import os, sys, time, glob, random, argparse
import numpy as np, collections
from copy import deepcopy
from pathlib import Path
import torch
import torch.nn as nn
from torch.distributions import Categorical
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from aa_nas_api import AANASBenchAPI
from models import CellStructure, get_search_spaces
from R_EA import train_and_eval
class Policy(nn.Module):
def __init__(self, max_nodes, search_space):
super(Policy, self).__init__()
self.max_nodes = max_nodes
self.search_space = deepcopy(search_space)
self.edge2index = {}
for i in range(1, max_nodes):
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
self.edge2index[ node_str ] = len(self.edge2index)
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(len(self.edge2index), len(search_space)) )
def generate_arch(self, actions):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
op_name = self.search_space[ actions[ self.edge2index[ node_str ] ] ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return CellStructure( genotypes )
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
with torch.no_grad():
weights = self.arch_parameters[ self.edge2index[node_str] ]
op_name = self.search_space[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return CellStructure( genotypes )
def forward(self):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
return alphas
class ExponentialMovingAverage(object):
"""Class that maintains an exponential moving average."""
def __init__(self, momentum):
self._numerator = 0
self._denominator = 0
self._momentum = momentum
def update(self, value):
self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value
self._denominator = self._momentum * self._denominator + (1 - self._momentum)
def value(self):
"""Return the current value of the moving average"""
return self._numerator / self._denominator
def select_action(policy):
probs = policy()
m = Categorical(probs)
action = m.sample()
#policy.saved_log_probs.append(m.log_prob(action))
return m.log_prob(action), action.cpu().tolist()
def main(xargs):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads( xargs.workers )
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log('Load split file from {:}'.format(split_Fpath))
config_path = 'configs/nas-benchmark/algos/R-EA.config'
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader}
search_space = get_search_spaces('cell', xargs.search_space_name)
policy = Policy(xargs.max_nodes, search_space)
optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate)
eps = np.finfo(np.float32).eps.item()
baseline = ExponentialMovingAverage(xargs.EMA_momentum)
logger.log('policy : {:}'.format(policy))
logger.log('optimizer : {:}'.format(optimizer))
logger.log('eps : {:}'.format(eps))
# nas dataset load
if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
nas_bench = None
else:
logger.log('{:} build NAS-Benchmark-API from {:}'.format(time_string(), xargs.arch_nas_dataset))
nas_bench = AANASBenchAPI(xargs.arch_nas_dataset)
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
# REINFORCE
# attempts = 0
for istep in range(xargs.RL_steps):
log_prob, action = select_action( policy )
arch = policy.generate_arch( action )
reward = train_and_eval(arch, nas_bench, extra_info)
baseline.update(reward)
# calculate loss
policy_loss = ( -log_prob * (reward - baseline.value()) ).sum()
optimizer.zero_grad()
policy_loss.backward()
optimizer.step()
logger.log('step [{:3d}/{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}'.format(istep, xargs.RL_steps, baseline.value(), policy_loss.item(), policy.genotype()))
#logger.log('----> {:}'.format(policy.arch_parameters))
logger.log('')
best_arch = policy.genotype()
if nas_bench is not None:
info = nas_bench.query_by_arch( best_arch )
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
else : logger.log('{:}'.format(info))
logger.log('-'*100)
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser("Regularized Evolution Algorithm")
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells
parser.add_argument('--search_space_name', type=str, help='The search space name.')
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--learning_rate', type=float, help='The learning rate for REINFORCE.')
parser.add_argument('--RL_steps', type=int, help='The steps for REINFORCE.')
parser.add_argument('--EMA_momentum', type=float, help='The momentum value for EMA.')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, help='manual seed')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
main(args)

View File

@ -1,5 +1,5 @@
import os, sys, copy, torch, numpy as np import os, sys, copy, torch, numpy as np
from collections import OrderedDict
def print_information(information, extra_info=None, show=False): def print_information(information, extra_info=None, show=False):
@ -29,20 +29,26 @@ def print_information(information, extra_info=None, show=False):
class AANASBenchAPI(object): class AANASBenchAPI(object):
def __init__(self, file_path_or_dict): def __init__(self, file_path_or_dict, verbose=True):
if isinstance(file_path_or_dict, str): if isinstance(file_path_or_dict, str):
if verbose: print('try to create AA-NAS-Bench api from {:}'.format(file_path_or_dict))
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict) assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
file_path_or_dict = torch.load(file_path_or_dict) file_path_or_dict = torch.load(file_path_or_dict)
else:
file_path_or_dict = copy.deepcopy( file_path_or_dict )
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict)) assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes') keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key) for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] ) self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
self.arch2infos = copy.deepcopy( file_path_or_dict['arch2infos'] ) self.arch2infos = OrderedDict()
self.evaluated_indexes = sorted(list( copy.deepcopy( file_path_or_dict['evaluated_indexes'] ) )) for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
self.arch2infos[xkey] = ArchResults.create_from_state_dict( file_path_or_dict['arch2infos'][xkey] )
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
self.archstr2index = {} self.archstr2index = {}
for idx, arch in enumerate(self.meta_archs): for idx, arch in enumerate(self.meta_archs):
assert arch.tostr() not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch.tostr()]) #assert arch.tostr() not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch.tostr()])
self.archstr2index[ arch.tostr() ] = idx assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
self.archstr2index[ arch ] = idx
def __getitem__(self, index): def __getitem__(self, index):
return copy.deepcopy( self.meta_archs[index] ) return copy.deepcopy( self.meta_archs[index] )
@ -54,12 +60,12 @@ class AANASBenchAPI(object):
return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs))) return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs)))
def query_index_by_arch(self, arch): def query_index_by_arch(self, arch):
if arch.tostr() in self.archstr2index: if isinstance(arch, str):
arch_index = self.archstr2index[ arch.tostr() ] if arch in self.archstr2index: arch_index = self.archstr2index[ arch ]
#else: else : arch_index = -1
# arch_str = Structure.str2fullstructure( arch.tostr() ).tostr() elif hasattr(arch, 'tostr'):
# if arch_str in self.archstr2index: if arch.tostr() in self.archstr2index: arch_index = self.archstr2index[ arch.tostr() ]
# arch_index = self.archstr2index[ arch_str ] else : arch_index = -1
else: arch_index = -1 else: arch_index = -1
return arch_index return arch_index
@ -80,6 +86,11 @@ class AANASBenchAPI(object):
info = archInfo.query(dataname) info = archInfo.query(dataname)
return info return info
def query_meta_info_by_index(self, arch_index):
assert arch_index in self.arch2infos, 'arch_index [{:}] does not in arch2info'.format(arch_index)
archInfo = copy.deepcopy( self.arch2infos[ arch_index ] )
return archInfo
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None): def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None):
best_index, highest_accuracy = -1, None best_index, highest_accuracy = -1, None
for i, idx in enumerate(self.evaluated_indexes): for i, idx in enumerate(self.evaluated_indexes):

View File

@ -0,0 +1,37 @@
#!/bin/bash
# bash ./scripts-search/algos/BOHB.sh -1
echo script name: $0
echo $# arguments
if [ "$#" -ne 1 ] ;then
echo "Input illegal number of parameters " $#
echo "Need 1 parameters for seed"
exit 1
fi
if [ "$TORCH_HOME" = "" ]; then
echo "Must set TORCH_HOME envoriment variable for data dir saving"
exit 1
else
echo "TORCH_HOME : $TORCH_HOME"
fi
dataset=cifar10
seed=$1
channel=16
num_cells=5
max_nodes=4
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
data_path="$TORCH_HOME/cifar.python"
else
data_path="$TORCH_HOME/cifar.python/ImageNet16"
fi
save_dir=./output/cell-search-tiny/BOHB-${dataset}
OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
--dataset ${dataset} --data_path ${data_path} \
--search_space_name aa-nas \
--arch_nas_dataset ./output/AA-NAS-BENCH-4/simplifies/C16-N5-final-infos.pth \
--n_iters 6 --num_samples 3 \
--workers 4 --print_freq 200 --rand_seed ${seed}

View File

@ -0,0 +1,38 @@
#!/bin/bash
# Regularized Evolution for Image Classifier Architecture Search, AAAI 2019
# bash ./scripts-search/algos/R-EA.sh -1
echo script name: $0
echo $# arguments
if [ "$#" -ne 1 ] ;then
echo "Input illegal number of parameters " $#
echo "Need 1 parameters for seed"
exit 1
fi
if [ "$TORCH_HOME" = "" ]; then
echo "Must set TORCH_HOME envoriment variable for data dir saving"
exit 1
else
echo "TORCH_HOME : $TORCH_HOME"
fi
dataset=cifar10
seed=$1
channel=16
num_cells=5
max_nodes=4
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
data_path="$TORCH_HOME/cifar.python"
else
data_path="$TORCH_HOME/cifar.python/ImageNet16"
fi
save_dir=./output/cell-search-tiny/R-EA-${dataset}
OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
--dataset ${dataset} --data_path ${data_path} \
--search_space_name aa-nas \
--arch_nas_dataset ./output/AA-NAS-BENCH-4/simplifies/C16-N5-final-infos.pth \
--ea_cycles 30 --ea_population 10 --ea_sample_size 3 --ea_fast_by_api 1 \
--workers 4 --print_freq 200 --rand_seed ${seed}

View File

@ -0,0 +1,37 @@
#!/bin/bash
# bash ./scripts-search/algos/REINFORCE.sh -1
echo script name: $0
echo $# arguments
if [ "$#" -ne 1 ] ;then
echo "Input illegal number of parameters " $#
echo "Need 1 parameters for seed"
exit 1
fi
if [ "$TORCH_HOME" = "" ]; then
echo "Must set TORCH_HOME envoriment variable for data dir saving"
exit 1
else
echo "TORCH_HOME : $TORCH_HOME"
fi
dataset=cifar10
seed=$1
channel=16
num_cells=5
max_nodes=4
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
data_path="$TORCH_HOME/cifar.python"
else
data_path="$TORCH_HOME/cifar.python/ImageNet16"
fi
save_dir=./output/cell-search-tiny/REINFORCE-${dataset}
OMP_NUM_THREADS=4 python ./exps/algos/reinforce.py \
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
--dataset ${dataset} --data_path ${data_path} \
--search_space_name aa-nas \
--arch_nas_dataset ./output/AA-NAS-BENCH-4/simplifies/C16-N5-final-infos.pth \
--learning_rate 0.001 --RL_steps 100 --EMA_momentum 0.9 \
--workers 4 --print_freq 200 --rand_seed ${seed}

View File

@ -0,0 +1,37 @@
#!/bin/bash
# bash ./scripts-search/algos/Random.sh -1
echo script name: $0
echo $# arguments
if [ "$#" -ne 1 ] ;then
echo "Input illegal number of parameters " $#
echo "Need 1 parameters for seed"
exit 1
fi
if [ "$TORCH_HOME" = "" ]; then
echo "Must set TORCH_HOME envoriment variable for data dir saving"
exit 1
else
echo "TORCH_HOME : $TORCH_HOME"
fi
dataset=cifar10
seed=$1
channel=16
num_cells=5
max_nodes=4
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
data_path="$TORCH_HOME/cifar.python"
else
data_path="$TORCH_HOME/cifar.python/ImageNet16"
fi
save_dir=./output/cell-search-tiny/RAND-${dataset}
OMP_NUM_THREADS=4 python ./exps/algos/RANDOM.py \
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
--dataset ${dataset} --data_path ${data_path} \
--search_space_name aa-nas \
--arch_nas_dataset ./output/AA-NAS-BENCH-4/simplifies/C16-N5-final-infos.pth \
--random_num 100 \
--workers 4 --print_freq 200 --rand_seed ${seed}