From 78439408465da6f65aebe2edced10f6eafb0da53 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 14 Nov 2019 13:55:42 +1100 Subject: [PATCH] update baseline NAS algos --- .gitignore | 2 +- AA-NAS-Bench.md | 68 +++++++-- exps/AA-NAS-statistics.py | 8 +- exps/AA-NAS-test-API.py | 6 +- exps/algos/BOHB.py | 177 +++++++++++++++++++++++ exps/algos/RANDOM.py | 95 ++++++++++++ exps/algos/R_EA.py | 230 ++++++++++++++++++++++++++++++ exps/algos/reinforce.py | 187 ++++++++++++++++++++++++ lib/aa_nas_api/api.py | 35 +++-- scripts-search/algos/BOHB.sh | 37 +++++ scripts-search/algos/R-EA.sh | 38 +++++ scripts-search/algos/REINFORCE.sh | 37 +++++ scripts-search/algos/Random.sh | 37 +++++ 13 files changed, 924 insertions(+), 33 deletions(-) create mode 100644 exps/algos/BOHB.py create mode 100644 exps/algos/RANDOM.py create mode 100644 exps/algos/R_EA.py create mode 100644 exps/algos/reinforce.py create mode 100644 scripts-search/algos/BOHB.sh create mode 100644 scripts-search/algos/R-EA.sh create mode 100644 scripts-search/algos/REINFORCE.sh create mode 100644 scripts-search/algos/Random.sh diff --git a/.gitignore b/.gitignore index 562720b..3a2a7fb 100644 --- a/.gitignore +++ b/.gitignore @@ -110,4 +110,4 @@ logs # snapshot a.pth -cal-merge.sh +cal-merge*.sh diff --git a/AA-NAS-Bench.md b/AA-NAS-Bench.md index 5fb88b5..81a3689 100644 --- a/AA-NAS-Bench.md +++ b/AA-NAS-Bench.md @@ -9,6 +9,51 @@ In this Markdown file, we provide: Note: please use `PyTorch >= 1.1.0` and `Python >= 3.6.0`. +## How to Use AA-NAS-Bench + +1. Creating AA-NAS-Bench API from a file: +``` +from aa_nas_api import AANASBenchAPI +api = AANASBenchAPI('$path_to_meta_aa_nas_bench_file') +``` + +2. Show the number of architectures `len(api)` and each architecture `api[i]`: +``` +num = len(api) +for i, arch_str in enumerate(api): + print ('{:5d}/{:5d} : {:}'.format(i, len(api), arch_str)) +``` + +3. Show the results of all trials for a single architecture: +``` +# show all information for a specific architecture +api.show(1) +api.show(2) + +# show the mean loss and accuracy of an architecture +info = api.query_meta_info_by_index(1) +loss, accuracy = info.get_metrics('cifar10', 'train') +flops, params, latency = info.get_comput_costs('cifar100') + +# get the detailed information +results = api.query_by_index(1, 'cifar100') +print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1])) +print ('Latency : {:}'.format(results[0].get_latency())) +print ('Train Info : {:}'.format(results[0].get_train())) +print ('Valid Info : {:}'.format(results[0].get_eval('x-valid'))) +print ('Test Info : {:}'.format(results[0].get_eval('x-test'))) +# for the metric after a specific epoch +print ('Train Info [10-th epoch] : {:}'.format(results[0].get_train(10))) +``` + +4. Query the index of an architecture by string +``` +index = api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|') +api.show(index) +``` + +5. For other usages, please see `lib/aa_nas_api/api.py` + ## Instruction to Generate AA-NAS-Bench 1. generate the meta file for AA-NAS-Bench using the following script, where `AA-NAS-BENCH` indicates the name and `4` indicates the maximum number of nodes in a cell. @@ -46,19 +91,18 @@ CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/AA-NAS-train-net.sh resnet 16 5 CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/AA-NAS-train-net.sh '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|skip_connect~1|skip_connect~2|' 16 5 ``` -[option] load the parameters of a trained network. -``` - -``` - -## To reproduce 10 baseline NAS algorithms in AA-NAS-Bench +## To Reproduce 10 Baseline NAS Algorithms in AA-NAS-Bench We have tried our best to implement each method. However, still, some algorithms might obtain non-optimal results since their hyper-parameters might not fit our AA-NAS-Bench. If researchers can provide better results with different hyper-parameters, we are happy to update results according to the new experimental results. We also welcome more NAS algorithms to test on our dataset and would include them accordingly. -- `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1` -- `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1` -- `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 -1` -- `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 -1` -- `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/ENAS.sh cifar10 -1` -- `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 -1` +-[1] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1` +-[2] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1` +-[3] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 -1` +-[4] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 -1` +-[5] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/ENAS.sh cifar10 -1` +-[6] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 -1` +-[7] `bash ./scripts-search/algos/R-EA.sh -1` +-[8] `bash ./scripts-search/algos/Random.sh -1` +-[9] `bash ./scripts-search/algos/REINFORCE.sh -1` +-[10] `bash ./scripts-search/algos/BOHB.sh -1` diff --git a/exps/AA-NAS-statistics.py b/exps/AA-NAS-statistics.py index 855ea40..a08a9ad 100644 --- a/exps/AA-NAS-statistics.py +++ b/exps/AA-NAS-statistics.py @@ -1,6 +1,3 @@ -################################################## -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # -################################################## import os, sys, time, argparse, collections from copy import deepcopy import torch @@ -167,7 +164,6 @@ def simplify(save_dir, meta_file, basestr, target_dir): arch_time = AverageMeter() for idx, arch_index in enumerate(arch_indexes): checkpoints = list(target_directory.glob('arch-{:}-seed-*.pth'.format(arch_index))) - arch_info = account_one_arch(arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict) try: arch_info = account_one_arch(arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict) num_seeds[ len(checkpoints) ] += 1 @@ -181,7 +177,7 @@ def simplify(save_dir, meta_file, basestr, target_dir): torch.save(arch_info.state_dict(), to_save_allarc / '{:}-FULL.pth'.format(arch_index)) #torch.save(arch_info, to_save_allarc / '{:}-FULL.pth'.format(arch_index)) arch_info.clear_params() - torch.save(arch_info, to_save_allarc / '{:}-SIMPLE.pth'.format(arch_index)) + torch.save(arch_info.state_dict(), to_save_allarc / '{:}-SIMPLE.pth'.format(arch_index)) # measure elapsed time arch_time.update(time.time() - end_time) end_time = time.time() @@ -241,7 +237,7 @@ def merge_all(save_dir, meta_file, basestr): xevalindexs = sub_ckps['evaluated_indexes'] for eval_index in xevalindexs: assert eval_index not in evaluated_indexes and eval_index not in arch2infos - arch2infos[eval_index] = xarch2infos[eval_index] + arch2infos[eval_index] = xarch2infos[eval_index].state_dict() evaluated_indexes.add( eval_index ) print ('{:} [{:03d}/{:03d}] merge data from {:} with {:} models.'.format(time_string(), IDX, len(subdir2archs), ckp_path, len(xevalindexs))) else: diff --git a/exps/AA-NAS-test-API.py b/exps/AA-NAS-test-API.py index 6967ba8..e5cd03f 100644 --- a/exps/AA-NAS-test-API.py +++ b/exps/AA-NAS-test-API.py @@ -58,8 +58,10 @@ def test_aa_nas_api(): arch_result = ArchResults.create_from_state_dict('output/AA-NAS-BENCH-4/simplifies/architectures/000002-FULL.pth') arch_result.show(True) result = arch_result.query('cifar100') - #xfile = '/home/dxy/search-configures/output/TINY-NAS-BENCHMARK-4/simplifies/C16-N5-final-infos.pth' - #api = AANASBenchAPI(xfile) + #xfile = 'output/AA-NAS-BENCH-4/simplifies/000000-000389-C16-N5.pth' + api = AANASBenchAPI('output/AA-NAS-BENCH-4/simplifies/C16-N5-final-infos.pth') + results = api.query_by_index(1, 'cifar100') + print ('There are {:} trials for this architecture [{:}] on cifar10'.format(len(results), api[1])) import pdb; pdb.set_trace() if __name__ == '__main__': diff --git a/exps/algos/BOHB.py b/exps/algos/BOHB.py new file mode 100644 index 0000000..aa0e253 --- /dev/null +++ b/exps/algos/BOHB.py @@ -0,0 +1,177 @@ +################################################## +# required to install hpbandster ################# +################################################## +import os, sys, time, glob, random, argparse +import numpy as np, collections +from copy import deepcopy +from pathlib import Path +import torch +import torch.nn as nn +from torch.distributions import Categorical +lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() +if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +from config_utils import load_config, dict2config, configure2str +from datasets import get_datasets, SearchDataset +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from utils import get_model_infos, obtain_accuracy +from log_utils import AverageMeter, time_string, convert_secs2time +from aa_nas_api import AANASBenchAPI +from models import CellStructure, get_search_spaces +from R_EA import train_and_eval +# BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018 +import ConfigSpace +from hpbandster.optimizers.bohb import BOHB +import hpbandster.core.nameserver as hpns +from hpbandster.core.worker import Worker + + +def get_configuration_space(max_nodes, search_space): + cs = ConfigSpace.ConfigurationSpace() + #edge2index = {} + for i in range(1, max_nodes): + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space)) + return cs + + +def config2structure_func(max_nodes): + def config2structure(config): + genotypes = [] + for i in range(1, max_nodes): + xlist = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + op_name = config[node_str] + xlist.append((op_name, j)) + genotypes.append( tuple(xlist) ) + return CellStructure( genotypes ) + return config2structure + + +class MyWorker(Worker): + + def __init__(self, *args, sleep_interval=0, convert_func=None, nas_bench=None, **kwargs): + super().__init__(*args, **kwargs) + self.sleep_interval = sleep_interval + self.convert_func = convert_func + self.nas_bench = nas_bench + self.test_time = 0 + + def compute(self, config, budget, **kwargs): + structure = self.convert_func( config ) + reward = train_and_eval(structure, self.nas_bench, None) + self.test_time += 1 + return ({ + 'loss': float(100-reward), + 'info': None}) + + +def main(xargs): + assert torch.cuda.is_available(), 'CUDA is not available.' + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.set_num_threads( xargs.workers ) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) + + assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' + train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + split_Fpath = 'configs/nas-benchmark/cifar-split.txt' + cifar_split = load_config(split_Fpath, None, None) + train_split, valid_split = cifar_split.train, cifar_split.valid + logger.log('Load split file from {:}'.format(split_Fpath)) + config_path = 'configs/nas-benchmark/algos/R-EA.config' + config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) + # To split data + train_data_v2 = deepcopy(train_data) + train_data_v2.transform = valid_data.transform + valid_data = train_data_v2 + search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) + # data loader + train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) + valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) + logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) + logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) + extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} + + + # nas dataset load + assert xargs.arch_nas_dataset is not None and os.path.isfile(xargs.arch_nas_dataset) + search_space = get_search_spaces('cell', xargs.search_space_name) + cs = get_configuration_space(xargs.max_nodes, search_space) + + config2structure = config2structure_func(xargs.max_nodes) + hb_run_id = '0' + + NS = hpns.NameServer(run_id=hb_run_id, host='localhost', port=0) + ns_host, ns_port = NS.start() + num_workers = 1 + + nas_bench = AANASBenchAPI(xargs.arch_nas_dataset) + logger.log('{:} Create AA-NAS-BENCH-API DONE'.format(time_string())) + workers = [] + for i in range(num_workers): + w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, run_id=hb_run_id, id=i) + w.run(background=True) + workers.append(w) + + bohb = BOHB(configspace=cs, + run_id=hb_run_id, + eta=3, min_budget=3, max_budget=108, + nameserver=ns_host, + nameserver_port=ns_port, + num_samples=xargs.num_samples, + random_fraction=xargs.random_fraction, bandwidth_factor=xargs.bandwidth_factor, + ping_interval=10, min_bandwidth=xargs.min_bandwidth) + # optimization_strategy=xargs.strategy, num_samples=xargs.num_samples, + + results = bohb.run(xargs.n_iters, min_n_workers=num_workers) + + bohb.shutdown(shutdown_workers=True) + NS.shutdown() + + id2config = results.get_id2config_mapping() + incumbent = results.get_incumbent_id() + + logger.log('Best found configuration: {:}'.format(id2config[incumbent]['config'])) + best_arch = config2structure( id2config[incumbent]['config'] ) + + if nas_bench is not None: + info = nas_bench.query_by_arch( best_arch ) + if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) + else : logger.log('{:}'.format(info)) + logger.log('-'*100) + + logger.log('workers : {:}'.format(workers[0].test_time)) + + logger.close() + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("Regularized Evolution Algorithm") + parser.add_argument('--data_path', type=str, help='Path to dataset') + parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') + # channels and number-of-cells + parser.add_argument('--search_space_name', type=str, help='The search space name.') + parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') + parser.add_argument('--channel', type=int, help='The number of channels.') + parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') + # BOHB + parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function') + parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE') + parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function') + parser.add_argument('--random_fraction', default=.33, type=float, nargs='?', help='fraction of random configurations') + parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth') + parser.add_argument('--n_iters', default=100, type=int, nargs='?', help='number of iterations for optimization method') + # log + parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') + parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') + parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).') + parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') + parser.add_argument('--rand_seed', type=int, help='manual seed') + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) + main(args) diff --git a/exps/algos/RANDOM.py b/exps/algos/RANDOM.py new file mode 100644 index 0000000..f7cb6f7 --- /dev/null +++ b/exps/algos/RANDOM.py @@ -0,0 +1,95 @@ +import os, sys, time, glob, random, argparse +import numpy as np, collections +from copy import deepcopy +import torch +import torch.nn as nn +from pathlib import Path +lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() +if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +from config_utils import load_config, dict2config, configure2str +from datasets import get_datasets, SearchDataset +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from utils import get_model_infos, obtain_accuracy +from log_utils import AverageMeter, time_string, convert_secs2time +from models import get_search_spaces +from aa_nas_api import AANASBenchAPI +from R_EA import train_and_eval, random_architecture_func + + +def main(xargs): + assert torch.cuda.is_available(), 'CUDA is not available.' + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.set_num_threads( xargs.workers ) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) + + assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' + train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + split_Fpath = 'configs/nas-benchmark/cifar-split.txt' + cifar_split = load_config(split_Fpath, None, None) + train_split, valid_split = cifar_split.train, cifar_split.valid + logger.log('Load split file from {:}'.format(split_Fpath)) + config_path = 'configs/nas-benchmark/algos/R-EA.config' + config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) + # To split data + train_data_v2 = deepcopy(train_data) + train_data_v2.transform = valid_data.transform + valid_data = train_data_v2 + search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) + # data loader + train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) + valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) + logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) + logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) + extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} + + search_space = get_search_spaces('cell', xargs.search_space_name) + random_arch = random_architecture_func(xargs.max_nodes, search_space) + #x =random_arch() ; y = mutate_arch(x) + if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): + logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset)) + nas_bench = None + else: + logger.log('{:} build NAS-Benchmark-API from {:}'.format(time_string(), xargs.arch_nas_dataset)) + nas_bench = AANASBenchAPI(xargs.arch_nas_dataset) + logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) + best_arch, best_acc = None, -1 + for idx in range(xargs.random_num): + arch = random_arch() + accuracy = train_and_eval(arch, nas_bench, extra_info) + if best_arch is None or best_acc < accuracy: + best_acc, best_arch = accuracy, arch + logger.log('[{:03d}/{:03d}] : {:} : accuracy = {:.2f}%'.format(idx, xargs.random_num, arch, accuracy)) + logger.log('{:} best arch is {:}, accuracy = {:.2f}%'.format(time_string(), best_arch, best_acc)) + + if nas_bench is not None: + info = nas_bench.query_by_arch( best_arch ) + if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) + else : logger.log('{:}'.format(info)) + logger.log('-'*100) + + logger.close() + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("Regularized Evolution Algorithm") + parser.add_argument('--data_path', type=str, help='Path to dataset') + parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') + # channels and number-of-cells + parser.add_argument('--search_space_name', type=str, help='The search space name.') + parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') + parser.add_argument('--channel', type=int, help='The number of channels.') + parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') + parser.add_argument('--random_num', type=int, help='The number of random selected architectures.') + # log + parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') + parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') + parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).') + parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') + parser.add_argument('--rand_seed', type=int, help='manual seed') + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) + main(args) diff --git a/exps/algos/R_EA.py b/exps/algos/R_EA.py new file mode 100644 index 0000000..c3efecc --- /dev/null +++ b/exps/algos/R_EA.py @@ -0,0 +1,230 @@ +import os, sys, time, glob, random, argparse +import numpy as np, collections +from copy import deepcopy +import torch +import torch.nn as nn +from pathlib import Path +lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() +if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +from config_utils import load_config, dict2config, configure2str +from datasets import get_datasets, SearchDataset +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from utils import get_model_infos, obtain_accuracy +from log_utils import AverageMeter, time_string, convert_secs2time +from aa_nas_api import AANASBenchAPI +from models import CellStructure, get_search_spaces + + +# Regularized Evolution for Image Classifier Architecture Search +class Model(object): + + def __init__(self): + self.arch = None + self.accuracy = None + + def __str__(self): + """Prints a readable version of this bitstring.""" + return '{:}'.format(self.arch) + + +def valid_func(xloader, network, criterion): + data_time, batch_time = AverageMeter(), AverageMeter() + arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() + network.train() + end = time.time() + with torch.no_grad(): + for step, (arch_inputs, arch_targets) in enumerate(xloader): + arch_targets = arch_targets.cuda(non_blocking=True) + # measure data loading time + data_time.update(time.time() - end) + # prediction + _, logits = network(arch_inputs) + arch_loss = criterion(logits, arch_targets) + # record + arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_losses.update(arch_loss.item(), arch_inputs.size(0)) + arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) + arch_top5.update (arch_prec5.item(), arch_inputs.size(0)) + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + return arch_losses.avg, arch_top1.avg, arch_top5.avg + + +def train_and_eval(arch, nas_bench, extra_info): + if nas_bench is not None: + arch_index = nas_bench.query_index_by_arch( arch ) + assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) + info = nas_bench.arch2infos[ arch_index ] + _, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25) # use the validation accuracy after 25 training epochs + else: + # train a model from scratch. + raise ValueError('NOT IMPLEMENT YET') + return valid_acc + + +def random_architecture_func(max_nodes, op_names): + # return a random architecture + def random_architecture(): + genotypes = [] + for i in range(1, max_nodes): + xlist = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + op_name = random.choice( op_names ) + xlist.append((op_name, j)) + genotypes.append( tuple(xlist) ) + return CellStructure( genotypes ) + return random_architecture + + +def mutate_arch_func(op_names): + """Computes the architecture for a child of the given parent architecture. + The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another. + """ + def mutate_arch_func(parent_arch): + child_arch = deepcopy( parent_arch ) + node_id = random.randint(0, len(child_arch.nodes)-1) + node_info = list( child_arch.nodes[node_id] ) + snode_id = random.randint(0, len(node_info)-1) + xop = random.choice( op_names ) + while xop == node_info[snode_id][0]: + xop = random.choice( op_names ) + node_info[snode_id] = (xop, node_info[snode_id][1]) + child_arch.nodes[node_id] = tuple( node_info ) + return child_arch + return mutate_arch_func + + +def regularized_evolution(cycles, population_size, sample_size, random_arch, mutate_arch, nas_bench, extra_info): + """Algorithm for regularized evolution (i.e. aging evolution). + + Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image + Classifier Architecture Search". + + Args: + cycles: the number of cycles the algorithm should run for. + population_size: the number of individuals to keep in the population. + sample_size: the number of individuals that should participate in each tournament. + + Returns: + history: a list of `Model` instances, representing all the models computed + during the evolution experiment. + """ + population = collections.deque() + history = [] # Not used by the algorithm, only used to report results. + + # Initialize the population with random models. + while len(population) < population_size: + model = Model() + model.arch = random_arch() + model.accuracy = train_and_eval(model.arch, nas_bench, extra_info) + population.append(model) + history.append(model) + + # Carry out evolution in cycles. Each cycle produces a model and removes + # another. + while len(history) < cycles: + # Sample randomly chosen models from the current population. + sample = [] + while len(sample) < sample_size: + # Inefficient, but written this way for clarity. In the case of neural + # nets, the efficiency of this line is irrelevant because training neural + # nets is the rate-determining step. + candidate = random.choice(list(population)) + sample.append(candidate) + + # The parent is the best model in the sample. + parent = max(sample, key=lambda i: i.accuracy) + + # Create the child model and store it. + child = Model() + child.arch = mutate_arch(parent.arch) + child.accuracy = train_and_eval(child.arch, nas_bench, extra_info) + population.append(child) + history.append(child) + + # Remove the oldest model. + population.popleft() + return history + + +def main(xargs): + assert torch.cuda.is_available(), 'CUDA is not available.' + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.set_num_threads( xargs.workers ) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) + + assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' + train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + split_Fpath = 'configs/nas-benchmark/cifar-split.txt' + cifar_split = load_config(split_Fpath, None, None) + train_split, valid_split = cifar_split.train, cifar_split.valid + logger.log('Load split file from {:}'.format(split_Fpath)) + config_path = 'configs/nas-benchmark/algos/R-EA.config' + config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) + # To split data + train_data_v2 = deepcopy(train_data) + train_data_v2.transform = valid_data.transform + valid_data = train_data_v2 + search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) + # data loader + train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) + valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) + logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) + logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) + extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} + + search_space = get_search_spaces('cell', xargs.search_space_name) + random_arch = random_architecture_func(xargs.max_nodes, search_space) + mutate_arch = mutate_arch_func(search_space) + #x =random_arch() ; y = mutate_arch(x) + if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): + logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset)) + nas_bench = None + else: + logger.log('{:} build NAS-Benchmark-API from {:}'.format(time_string(), xargs.arch_nas_dataset)) + nas_bench = AANASBenchAPI(xargs.arch_nas_dataset) + logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) + history = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info) + logger.log('{:} regularized_evolution finish with history of {:} arch.'.format(time_string(), len(history))) + best_arch = max(history, key=lambda i: i.accuracy) + best_arch = best_arch.arch + logger.log('{:} best arch is {:}'.format(time_string(), best_arch)) + + if nas_bench is not None: + info = nas_bench.query_by_arch( best_arch ) + if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) + else : logger.log('{:}'.format(info)) + logger.log('-'*100) + + logger.close() + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("Regularized Evolution Algorithm") + parser.add_argument('--data_path', type=str, help='Path to dataset') + parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') + # channels and number-of-cells + parser.add_argument('--search_space_name', type=str, help='The search space name.') + parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') + parser.add_argument('--channel', type=int, help='The number of channels.') + parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') + parser.add_argument('--ea_cycles', type=int, help='The number of cycles in EA.') + parser.add_argument('--ea_population', type=int, help='The population size in EA.') + parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.') + parser.add_argument('--ea_fast_by_api', type=int, help='Use our API to speed up the experiments or not.') + # log + parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') + parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') + parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).') + parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') + parser.add_argument('--rand_seed', type=int, help='manual seed') + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) + args.ea_fast_by_api = args.ea_fast_by_api > 0 + main(args) diff --git a/exps/algos/reinforce.py b/exps/algos/reinforce.py new file mode 100644 index 0000000..94b9735 --- /dev/null +++ b/exps/algos/reinforce.py @@ -0,0 +1,187 @@ +################################################## +# modified from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py +################################################## +import os, sys, time, glob, random, argparse +import numpy as np, collections +from copy import deepcopy +from pathlib import Path +import torch +import torch.nn as nn +from torch.distributions import Categorical +lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() +if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +from config_utils import load_config, dict2config, configure2str +from datasets import get_datasets, SearchDataset +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from utils import get_model_infos, obtain_accuracy +from log_utils import AverageMeter, time_string, convert_secs2time +from aa_nas_api import AANASBenchAPI +from models import CellStructure, get_search_spaces +from R_EA import train_and_eval + + +class Policy(nn.Module): + + def __init__(self, max_nodes, search_space): + super(Policy, self).__init__() + self.max_nodes = max_nodes + self.search_space = deepcopy(search_space) + self.edge2index = {} + for i in range(1, max_nodes): + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + self.edge2index[ node_str ] = len(self.edge2index) + self.arch_parameters = nn.Parameter( 1e-3*torch.randn(len(self.edge2index), len(search_space)) ) + + def generate_arch(self, actions): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + op_name = self.search_space[ actions[ self.edge2index[ node_str ] ] ] + xlist.append((op_name, j)) + genotypes.append( tuple(xlist) ) + return CellStructure( genotypes ) + + def genotype(self): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + with torch.no_grad(): + weights = self.arch_parameters[ self.edge2index[node_str] ] + op_name = self.search_space[ weights.argmax().item() ] + xlist.append((op_name, j)) + genotypes.append( tuple(xlist) ) + return CellStructure( genotypes ) + + def forward(self): + alphas = nn.functional.softmax(self.arch_parameters, dim=-1) + return alphas + + +class ExponentialMovingAverage(object): + """Class that maintains an exponential moving average.""" + + def __init__(self, momentum): + self._numerator = 0 + self._denominator = 0 + self._momentum = momentum + + def update(self, value): + self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value + self._denominator = self._momentum * self._denominator + (1 - self._momentum) + + def value(self): + """Return the current value of the moving average""" + return self._numerator / self._denominator + + +def select_action(policy): + probs = policy() + m = Categorical(probs) + action = m.sample() + #policy.saved_log_probs.append(m.log_prob(action)) + return m.log_prob(action), action.cpu().tolist() + + +def main(xargs): + assert torch.cuda.is_available(), 'CUDA is not available.' + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.set_num_threads( xargs.workers ) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) + + assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' + train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + split_Fpath = 'configs/nas-benchmark/cifar-split.txt' + cifar_split = load_config(split_Fpath, None, None) + train_split, valid_split = cifar_split.train, cifar_split.valid + logger.log('Load split file from {:}'.format(split_Fpath)) + config_path = 'configs/nas-benchmark/algos/R-EA.config' + config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) + # To split data + train_data_v2 = deepcopy(train_data) + train_data_v2.transform = valid_data.transform + valid_data = train_data_v2 + search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) + # data loader + train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) + valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) + logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) + logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) + extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} + + search_space = get_search_spaces('cell', xargs.search_space_name) + policy = Policy(xargs.max_nodes, search_space) + optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate) + eps = np.finfo(np.float32).eps.item() + baseline = ExponentialMovingAverage(xargs.EMA_momentum) + logger.log('policy : {:}'.format(policy)) + logger.log('optimizer : {:}'.format(optimizer)) + logger.log('eps : {:}'.format(eps)) + + # nas dataset load + if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): + logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset)) + nas_bench = None + else: + logger.log('{:} build NAS-Benchmark-API from {:}'.format(time_string(), xargs.arch_nas_dataset)) + nas_bench = AANASBenchAPI(xargs.arch_nas_dataset) + logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) + + # REINFORCE + # attempts = 0 + for istep in range(xargs.RL_steps): + log_prob, action = select_action( policy ) + arch = policy.generate_arch( action ) + reward = train_and_eval(arch, nas_bench, extra_info) + + baseline.update(reward) + # calculate loss + policy_loss = ( -log_prob * (reward - baseline.value()) ).sum() + optimizer.zero_grad() + policy_loss.backward() + optimizer.step() + + logger.log('step [{:3d}/{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}'.format(istep, xargs.RL_steps, baseline.value(), policy_loss.item(), policy.genotype())) + #logger.log('----> {:}'.format(policy.arch_parameters)) + logger.log('') + + best_arch = policy.genotype() + + if nas_bench is not None: + info = nas_bench.query_by_arch( best_arch ) + if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) + else : logger.log('{:}'.format(info)) + logger.log('-'*100) + + logger.close() + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("Regularized Evolution Algorithm") + parser.add_argument('--data_path', type=str, help='Path to dataset') + parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') + # channels and number-of-cells + parser.add_argument('--search_space_name', type=str, help='The search space name.') + parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') + parser.add_argument('--channel', type=int, help='The number of channels.') + parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') + parser.add_argument('--learning_rate', type=float, help='The learning rate for REINFORCE.') + parser.add_argument('--RL_steps', type=int, help='The steps for REINFORCE.') + parser.add_argument('--EMA_momentum', type=float, help='The momentum value for EMA.') + # log + parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') + parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') + parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).') + parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') + parser.add_argument('--rand_seed', type=int, help='manual seed') + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) + main(args) diff --git a/lib/aa_nas_api/api.py b/lib/aa_nas_api/api.py index a7f60b1..01c4b84 100644 --- a/lib/aa_nas_api/api.py +++ b/lib/aa_nas_api/api.py @@ -1,5 +1,5 @@ import os, sys, copy, torch, numpy as np - +from collections import OrderedDict def print_information(information, extra_info=None, show=False): @@ -29,20 +29,26 @@ def print_information(information, extra_info=None, show=False): class AANASBenchAPI(object): - def __init__(self, file_path_or_dict): + def __init__(self, file_path_or_dict, verbose=True): if isinstance(file_path_or_dict, str): + if verbose: print('try to create AA-NAS-Bench api from {:}'.format(file_path_or_dict)) assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict) file_path_or_dict = torch.load(file_path_or_dict) + else: + file_path_or_dict = copy.deepcopy( file_path_or_dict ) assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict)) keys = ('meta_archs', 'arch2infos', 'evaluated_indexes') for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key) self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] ) - self.arch2infos = copy.deepcopy( file_path_or_dict['arch2infos'] ) - self.evaluated_indexes = sorted(list( copy.deepcopy( file_path_or_dict['evaluated_indexes'] ) )) + self.arch2infos = OrderedDict() + for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())): + self.arch2infos[xkey] = ArchResults.create_from_state_dict( file_path_or_dict['arch2infos'][xkey] ) + self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes'])) self.archstr2index = {} for idx, arch in enumerate(self.meta_archs): - assert arch.tostr() not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch.tostr()]) - self.archstr2index[ arch.tostr() ] = idx + #assert arch.tostr() not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch.tostr()]) + assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch]) + self.archstr2index[ arch ] = idx def __getitem__(self, index): return copy.deepcopy( self.meta_archs[index] ) @@ -54,12 +60,12 @@ class AANASBenchAPI(object): return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs))) def query_index_by_arch(self, arch): - if arch.tostr() in self.archstr2index: - arch_index = self.archstr2index[ arch.tostr() ] - #else: - # arch_str = Structure.str2fullstructure( arch.tostr() ).tostr() - # if arch_str in self.archstr2index: - # arch_index = self.archstr2index[ arch_str ] + if isinstance(arch, str): + if arch in self.archstr2index: arch_index = self.archstr2index[ arch ] + else : arch_index = -1 + elif hasattr(arch, 'tostr'): + if arch.tostr() in self.archstr2index: arch_index = self.archstr2index[ arch.tostr() ] + else : arch_index = -1 else: arch_index = -1 return arch_index @@ -80,6 +86,11 @@ class AANASBenchAPI(object): info = archInfo.query(dataname) return info + def query_meta_info_by_index(self, arch_index): + assert arch_index in self.arch2infos, 'arch_index [{:}] does not in arch2info'.format(arch_index) + archInfo = copy.deepcopy( self.arch2infos[ arch_index ] ) + return archInfo + def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None): best_index, highest_accuracy = -1, None for i, idx in enumerate(self.evaluated_indexes): diff --git a/scripts-search/algos/BOHB.sh b/scripts-search/algos/BOHB.sh new file mode 100644 index 0000000..31a1c91 --- /dev/null +++ b/scripts-search/algos/BOHB.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# bash ./scripts-search/algos/BOHB.sh -1 +echo script name: $0 +echo $# arguments +if [ "$#" -ne 1 ] ;then + echo "Input illegal number of parameters " $# + echo "Need 1 parameters for seed" + exit 1 +fi +if [ "$TORCH_HOME" = "" ]; then + echo "Must set TORCH_HOME envoriment variable for data dir saving" + exit 1 +else + echo "TORCH_HOME : $TORCH_HOME" +fi + +dataset=cifar10 +seed=$1 +channel=16 +num_cells=5 +max_nodes=4 + +if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then + data_path="$TORCH_HOME/cifar.python" +else + data_path="$TORCH_HOME/cifar.python/ImageNet16" +fi + +save_dir=./output/cell-search-tiny/BOHB-${dataset} + +OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \ + --save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ + --dataset ${dataset} --data_path ${data_path} \ + --search_space_name aa-nas \ + --arch_nas_dataset ./output/AA-NAS-BENCH-4/simplifies/C16-N5-final-infos.pth \ + --n_iters 6 --num_samples 3 \ + --workers 4 --print_freq 200 --rand_seed ${seed} diff --git a/scripts-search/algos/R-EA.sh b/scripts-search/algos/R-EA.sh new file mode 100644 index 0000000..078d0c7 --- /dev/null +++ b/scripts-search/algos/R-EA.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# Regularized Evolution for Image Classifier Architecture Search, AAAI 2019 +# bash ./scripts-search/algos/R-EA.sh -1 +echo script name: $0 +echo $# arguments +if [ "$#" -ne 1 ] ;then + echo "Input illegal number of parameters " $# + echo "Need 1 parameters for seed" + exit 1 +fi +if [ "$TORCH_HOME" = "" ]; then + echo "Must set TORCH_HOME envoriment variable for data dir saving" + exit 1 +else + echo "TORCH_HOME : $TORCH_HOME" +fi + +dataset=cifar10 +seed=$1 +channel=16 +num_cells=5 +max_nodes=4 + +if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then + data_path="$TORCH_HOME/cifar.python" +else + data_path="$TORCH_HOME/cifar.python/ImageNet16" +fi + +save_dir=./output/cell-search-tiny/R-EA-${dataset} + +OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \ + --save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ + --dataset ${dataset} --data_path ${data_path} \ + --search_space_name aa-nas \ + --arch_nas_dataset ./output/AA-NAS-BENCH-4/simplifies/C16-N5-final-infos.pth \ + --ea_cycles 30 --ea_population 10 --ea_sample_size 3 --ea_fast_by_api 1 \ + --workers 4 --print_freq 200 --rand_seed ${seed} diff --git a/scripts-search/algos/REINFORCE.sh b/scripts-search/algos/REINFORCE.sh new file mode 100644 index 0000000..f77f8ba --- /dev/null +++ b/scripts-search/algos/REINFORCE.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# bash ./scripts-search/algos/REINFORCE.sh -1 +echo script name: $0 +echo $# arguments +if [ "$#" -ne 1 ] ;then + echo "Input illegal number of parameters " $# + echo "Need 1 parameters for seed" + exit 1 +fi +if [ "$TORCH_HOME" = "" ]; then + echo "Must set TORCH_HOME envoriment variable for data dir saving" + exit 1 +else + echo "TORCH_HOME : $TORCH_HOME" +fi + +dataset=cifar10 +seed=$1 +channel=16 +num_cells=5 +max_nodes=4 + +if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then + data_path="$TORCH_HOME/cifar.python" +else + data_path="$TORCH_HOME/cifar.python/ImageNet16" +fi + +save_dir=./output/cell-search-tiny/REINFORCE-${dataset} + +OMP_NUM_THREADS=4 python ./exps/algos/reinforce.py \ + --save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ + --dataset ${dataset} --data_path ${data_path} \ + --search_space_name aa-nas \ + --arch_nas_dataset ./output/AA-NAS-BENCH-4/simplifies/C16-N5-final-infos.pth \ + --learning_rate 0.001 --RL_steps 100 --EMA_momentum 0.9 \ + --workers 4 --print_freq 200 --rand_seed ${seed} diff --git a/scripts-search/algos/Random.sh b/scripts-search/algos/Random.sh new file mode 100644 index 0000000..2945b08 --- /dev/null +++ b/scripts-search/algos/Random.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# bash ./scripts-search/algos/Random.sh -1 +echo script name: $0 +echo $# arguments +if [ "$#" -ne 1 ] ;then + echo "Input illegal number of parameters " $# + echo "Need 1 parameters for seed" + exit 1 +fi +if [ "$TORCH_HOME" = "" ]; then + echo "Must set TORCH_HOME envoriment variable for data dir saving" + exit 1 +else + echo "TORCH_HOME : $TORCH_HOME" +fi + +dataset=cifar10 +seed=$1 +channel=16 +num_cells=5 +max_nodes=4 + +if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then + data_path="$TORCH_HOME/cifar.python" +else + data_path="$TORCH_HOME/cifar.python/ImageNet16" +fi + +save_dir=./output/cell-search-tiny/RAND-${dataset} + +OMP_NUM_THREADS=4 python ./exps/algos/RANDOM.py \ + --save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ + --dataset ${dataset} --data_path ${data_path} \ + --search_space_name aa-nas \ + --arch_nas_dataset ./output/AA-NAS-BENCH-4/simplifies/C16-N5-final-infos.pth \ + --random_num 100 \ + --workers 4 --print_freq 200 --rand_seed ${seed}