From 2c86d6aa67adc70b09574d5c66dc3b1e8d4c0c74 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Fri, 28 Aug 2020 08:31:53 +0000 Subject: [PATCH] Update NATS-Bench (tss version 0.8) --- exps/NATS-Bench/main-sss.py | 26 +--- exps/NATS-Bench/main-tss.py | 195 ++++++++++++++------------- lib/utils/__init__.py | 1 + lib/utils/str_utils.py | 18 +++ scripts/NATS-Bench/train-topology.sh | 43 ++++++ 5 files changed, 172 insertions(+), 111 deletions(-) create mode 100644 lib/utils/str_utils.py create mode 100644 scripts/NATS-Bench/train-topology.sh diff --git a/exps/NATS-Bench/main-sss.py b/exps/NATS-Bench/main-sss.py index 70fdbb9..8a2badb 100644 --- a/exps/NATS-Bench/main-sss.py +++ b/exps/NATS-Bench/main-sss.py @@ -27,6 +27,7 @@ from procedures import bench_evaluate_for_seed from procedures import get_machine_info from datasets import get_datasets from log_utils import Logger, AverageMeter, time_string, convert_secs2time +from utils import split_str2indexes def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Text], @@ -107,7 +108,6 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text], logger.log('xargs : seeds = {:}'.format(seeds)) logger.log('xargs : cover_mode = {:}'.format(cover_mode)) logger.log('-' * 100) - logger.log( 'Start evaluating range =: {:06d} - {:06d}'.format(min(to_evaluate_indexes), max(to_evaluate_indexes)) +'({:} in total) / {:06d} with cover-mode={:}'.format(len(to_evaluate_indexes), len(nets), cover_mode)) @@ -115,7 +115,6 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text], logger.log( '--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split)) logger.log('--->>> optimization config : {:}'.format(opt_config)) - #to_evaluate_indexes = list(range(srange[0], srange[1] + 1)) start_time, epoch_time = time.time(), AverageMeter() for i, index in enumerate(to_evaluate_indexes): @@ -136,10 +135,12 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text], logger.log('Find existing file : {:}, skip this evaluation'.format(to_save_name)) has_continue = True continue - results = evaluate_all_datasets(channelstr, datasets, xpaths, splits, opt_config, seed, workers, logger) + results = evaluate_all_datasets(channelstr, + datasets, xpaths, splits, opt_config, seed, + workers, logger) torch.save(results, to_save_name) - logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}'.format(time_string(), i, - len(to_evaluate_indexes), index, len(nets), seeds, to_save_name)) + logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}'.format(time_string(), i, + len(to_evaluate_indexes), index, len(nets), seeds, to_save_name)) # measure elapsed time if not has_continue: epoch_time.update(time.time() - start_time) start_time = time.time() @@ -224,20 +225,7 @@ if __name__ == '__main__': raise ValueError('{:} is not a file.'.format(opt_config)) save_dir = Path(args.save_dir) / 'raw-data-{:}'.format(args.hyper) save_dir.mkdir(parents=True, exist_ok=True) - if not isinstance(args.srange, str): - raise ValueError('Invalid scheme for {:}'.format(args.srange)) - srangestr = "".join(args.srange.split()) - to_evaluate_indexes = set() - for srange in srangestr.split(','): - srange = srange.split('-') - if len(srange) != 2: - raise ValueError('invalid srange : {:}'.format(srange)) - assert len(srange[0]) == len(srange[1]) == 5, 'invalid srange : {:}'.format(srange) - srange = (int(srange[0]), int(srange[1])) - if not (0 <= srange[0] <= srange[1] < args.check_N): - raise ValueError('{:} vs {:} vs {:}'.format(srange[0], srange[1], args.check_N)) - for i in range(srange[0], srange[1]+1): - to_evaluate_indexes.add(i) + to_evaluate_indexes = split_str2indexes(args.srange, args.check_N, 5) if not len(args.seeds): raise ValueError('invalid length of seeds args: {:}'.format(args.seeds)) diff --git a/exps/NATS-Bench/main-tss.py b/exps/NATS-Bench/main-tss.py index 24dfac7..b862006 100644 --- a/exps/NATS-Bench/main-tss.py +++ b/exps/NATS-Bench/main-tss.py @@ -5,13 +5,20 @@ ############################################################################## # This file is used to train (all) architecture candidate in the topology # # search space in NATS-Bench (tss) with different hyper-parameters. # -# When use mode=meta, -### +# When use mode=new, it will automatically detect whether the checkpoint of # +# a trial exists, if so, it will skip this trial. When use mode=cover, it # +# will ignore the (possible) existing checkpoint, run each trial, and save. # ############################################################################## -# 1, generate meta data: # +# Please use the script of scripts/NATS-Bench/train-topology.sh to run. # +# bash scripts/NATS-Bench/train-topology.sh 00000-15624 12 777 # +# bash scripts/NATS-Bench/train-topology.sh 00000-15624 200 '777 888 999' # +# # +################ # +# [Deprecated Function: Generate the meta information] # # python ./exps/NATS-Bench/main-tss.py --mode meta # ############################################################################## import os, sys, time, torch, random, argparse +from typing import List, Text, Dict, Any from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True from copy import deepcopy @@ -19,16 +26,18 @@ from pathlib import Path lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from config_utils import load_config +from config_utils import dict2config, load_config from procedures import bench_evaluate_for_seed from procedures import get_machine_info from datasets import get_datasets from log_utils import Logger, AverageMeter, time_string, convert_secs2time from models import CellStructure, CellArchitectures, get_search_spaces +from utils import split_str2indexes -def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger): - machine_info, arch_config = get_machine_info(), deepcopy(arch_config) +def evaluate_all_datasets(arch: Text, datasets: List[Text], xpaths: List[Text], + splits: List[Text], config_path: Text, seed: int, raw_arch_config, workers, logger): + machine_info, raw_arch_config = get_machine_info(), deepcopy(raw_arch_config) all_infos = {'info': machine_info} all_dataset_keys = [] # look all the datasets @@ -37,19 +46,12 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) # load the configuration if dataset == 'cifar10' or dataset == 'cifar100': - if use_less: config_path = 'configs/nas-benchmark/LESS.config' - else : config_path = 'configs/nas-benchmark/CIFAR.config' split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None) elif dataset.startswith('ImageNet16'): - if use_less: config_path = 'configs/nas-benchmark/LESS.config' - else : config_path = 'configs/nas-benchmark/ImageNet-16.config' split_info = load_config('configs/nas-benchmark/{:}-split.txt'.format(dataset), None, None) else: raise ValueError('invalid dataset : {:}'.format(dataset)) - config = load_config(config_path, \ - {'class_num': class_num, - 'xshape' : xshape}, \ - logger) + config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger) # check whether use splited validation set if bool(split): assert dataset == 'cifar10' @@ -89,6 +91,8 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(dataset_key, config)) for key, value in ValLoaders.items(): logger.log('Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value))) + arch_config = dict2config(dict(name='infer.tiny', C=raw_arch_config['channel'], N=raw_arch_config['num_cells'], + genotype=arch, num_classes=config.class_num), None) results = bench_evaluate_for_seed(arch_config, config, train_loader, ValLoaders, seed, logger) all_infos[dataset_key] = results all_dataset_keys.append( dataset_key ) @@ -96,71 +100,59 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c return all_infos -def main(save_dir, workers, datasets, xpaths, splits, use_less, srange, arch_index, seeds, cover_mode, meta_info, arch_config): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - #torch.backends.cudnn.benchmark = True - torch.backends.cudnn.deterministic = True - torch.set_num_threads( workers ) +def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text], + splits: List[int], seeds: List[int], nets: List[str], opt_config: Dict[Text, Any], + to_evaluate_indexes: tuple, cover_mode: bool, arch_config: Dict[Text, Any]): - assert len(srange) == 2 and 0 <= srange[0] <= srange[1], 'invalid srange : {:}'.format(srange) - - if use_less: - sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}-LESS'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells']) - else: - sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells']) - logger = Logger(str(sub_dir), 0, False) + log_dir = save_dir / 'logs' + log_dir.mkdir(parents=True, exist_ok=True) + logger = Logger(str(log_dir), os.getpid(), False) - all_archs = meta_info['archs'] - assert srange[1] < meta_info['total'], 'invalid range : {:}-{:} vs. {:}'.format(srange[0], srange[1], meta_info['total']) - assert arch_index == -1 or srange[0] <= arch_index <= srange[1], 'invalid range : {:} vs. {:} vs. {:}'.format(srange[0], arch_index, srange[1]) - if arch_index == -1: - to_evaluate_indexes = list(range(srange[0], srange[1]+1)) - else: - to_evaluate_indexes = [arch_index] logger.log('xargs : seeds = {:}'.format(seeds)) - logger.log('xargs : arch_index = {:}'.format(arch_index)) logger.log('xargs : cover_mode = {:}'.format(cover_mode)) - logger.log('-'*100) - - logger.log('Start evaluating range =: {:06d} vs. {:06d} vs. {:06d} / {:06d} with cover-mode={:}'.format(srange[0], arch_index, srange[1], meta_info['total'], cover_mode)) + logger.log('-' * 100) + logger.log( + 'Start evaluating range =: {:06d} - {:06d}'.format(min(to_evaluate_indexes), max(to_evaluate_indexes)) + +'({:} in total) / {:06d} with cover-mode={:}'.format(len(to_evaluate_indexes), len(nets), cover_mode)) for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)): - logger.log('--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split)) - logger.log('--->>> architecture config : {:}'.format(arch_config)) - + logger.log( + '--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split)) + logger.log('--->>> optimization config : {:}'.format(opt_config)) start_time, epoch_time = time.time(), AverageMeter() for i, index in enumerate(to_evaluate_indexes): - arch = all_archs[index] - logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th architecture [seeds={:}] {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seeds, '-'*15)) - #logger.log('{:} {:} {:}'.format('-'*15, arch.tostr(), '-'*15)) - logger.log('{:} {:} {:}'.format('-'*15, arch, '-'*15)) - + arch = nets[index] + logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] {:}'.format(time_string(), i, + len(to_evaluate_indexes), index, len(nets), seeds, '-' * 15)) + logger.log('{:} {:} {:}'.format('-' * 15, arch, '-' * 15)) + # test this arch on different datasets with different seeds has_continue = False for seed in seeds: - to_save_name = sub_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed) + to_save_name = save_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed) if to_save_name.exists(): if cover_mode: logger.log('Find existing file : {:}, remove it before evaluation'.format(to_save_name)) os.remove(str(to_save_name)) - else : + else: logger.log('Find existing file : {:}, skip this evaluation'.format(to_save_name)) has_continue = True continue - results = evaluate_all_datasets(CellStructure.str2structure(arch), \ - datasets, xpaths, splits, use_less, seed, \ - arch_config, workers, logger) + results = evaluate_all_datasets(CellStructure.str2structure(arch), + datasets, xpaths, splits, opt_config, seed, + arch_config, workers, logger) torch.save(results, to_save_name) - logger.log('{:} --evaluate-- {:06d}/{:06d} ({:06d}/{:06d})-th seed={:} done, save into {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seed, to_save_name)) + logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}'.format(time_string(), i, + len(to_evaluate_indexes), index, len(nets), seeds, to_save_name)) # measure elapsed time if not has_continue: epoch_time.update(time.time() - start_time) start_time = time.time() - need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes)-i-1), True) ) - logger.log('This arch costs : {:}'.format( convert_secs2time(epoch_time.val, True) )) - logger.log('{:}'.format('*'*100)) - logger.log('{:} {:74s} {:}'.format('*'*10, '{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}'.format(i, len(to_evaluate_indexes), index, meta_info['total'], need_time), '*'*10)) - logger.log('{:}'.format('*'*100)) + need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes)-i-1), True) ) + logger.log('This arch costs : {:}'.format(convert_secs2time(epoch_time.val, True) )) + logger.log('{:}'.format('*' * 100)) + logger.log('{:} {:74s} {:}'.format('*' * 10, '{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}'.format(i, len( + to_evaluate_indexes), index, len(nets), need_time), '*' * 10)) + logger.log('{:}'.format('*' * 100)) logger.close() @@ -256,28 +248,34 @@ def generate_meta_info(save_dir, max_node, divide=40): torch.save(info, save_name) print ('save the meta file into {:}'.format(save_name)) - """ - script_name_full = save_dir / 'BENCH-201-N{:}.opt-full.script'.format(max_node) - script_name_less = save_dir / 'BENCH-201-N{:}.opt-less.script'.format(max_node) - full_file = open(str(script_name_full), 'w') - less_file = open(str(script_name_less), 'w') - gaps = total_arch // divide - for start in range(0, total_arch, gaps): - xend = min(start+gaps, total_arch) - full_file.write('bash ./scripts-search/NAS-Bench-201/train-models.sh 0 {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1)) - less_file.write('bash ./scripts-search/NAS-Bench-201/train-models.sh 1 {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1)) - print ('save the training script into {:} and {:}'.format(script_name_full, script_name_less)) - full_file.close() - less_file.close() - script_name = save_dir / 'meta-node-{:}.cal-script.txt'.format(max_node) - macro = 'OMP_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=0' - with open(str(script_name), 'w') as cfile: - for start in range(0, total_arch, gaps): - xend = min(start+gaps, total_arch) - cfile.write('{:} python exps/NAS-Bench-201/statistics.py --mode cal --target_dir {:06d}-{:06d}-C16-N5\n'.format(macro, start, xend-1)) - print ('save the post-processing script into {:}'.format(script_name)) - """ +def traverse_net(max_node): + aa_nas_bench_ss = get_search_spaces('cell', 'nats-bench') + archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) + print ('There are {:} archs vs {:}.'.format(len(archs), len(aa_nas_bench_ss) ** ((max_node-1)*max_node/2))) + + random.seed( 88 ) # please do not change this line for reproducibility + random.shuffle( archs ) + assert archs[0 ].tostr() == '|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|', 'please check the 0-th architecture : {:}'.format(archs[0]) + assert archs[9 ].tostr() == '|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|', 'please check the 9-th architecture : {:}'.format(archs[9]) + assert archs[123].tostr() == '|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|', 'please check the 123-th architecture : {:}'.format(archs[123]) + return [x.tostr() for x in archs] + + +def filter_indexes(xlist, mode, save_dir, seeds): + all_indexes = [] + for index in xlist: + if mode == 'cover': + all_indexes.append(index) + else: + for seed in seeds: + temp_path = save_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed) + if not temp_path.exists(): + all_indexes.append(index) + break + print('{:} [FILTER-INDEXES] : there are {:}/{:} architectures in total'.format(time_string(), len(all_indexes), len(xlist))) + return all_indexes + if __name__ == '__main__': # mode_choices = ['meta', 'new', 'cover'] + ['specific-{:}'.format(_) for _ in CellArchitectures.keys()] @@ -291,11 +289,12 @@ if __name__ == '__main__': parser.add_argument('--datasets', type=str, nargs='+', help='The applied datasets.') parser.add_argument('--xpaths', type=str, nargs='+', help='The root path for this dataset.') parser.add_argument('--splits', type=int, nargs='+', help='The root path for this dataset.') - parser.add_argument('--hyper', type=str, default='12', choices=['01', '12', '90'], help='The tag for hyper-parameters.') + parser.add_argument('--hyper', type=str, default='12', choices=['01', '12', '200'], help='The tag for hyper-parameters.') parser.add_argument('--seeds' , type=int, nargs='+', help='The range of models to be evaluated') parser.add_argument('--channel', type=int, default=16, help='The number of channels.') parser.add_argument('--num_cells', type=int, default=5, help='The number of cells in one stage.') + parser.add_argument('--check_N', type=int, default=15625, help='For safety.') args = parser.parse_args() assert args.mode in ['meta', 'new', 'cover'] or args.mode.startswith('specific-'), 'invalid mode : {:}'.format(args.mode) @@ -308,16 +307,28 @@ if __name__ == '__main__': train_single_model(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, args.use_less>0, \ tuple(args.seeds), model_str, {'channel': args.channel, 'num_cells': args.num_cells}) else: - meta_path = Path(args.save_dir) / 'meta-node-{:}.pth'.format(args.max_node) - assert meta_path.exists(), '{:} does not exist.'.format(meta_path) - meta_info = torch.load( meta_path ) - # check whether args is ok - assert len(args.srange) == 2 and args.srange[0] <= args.srange[1], 'invalid length of srange args: {:}'.format(args.srange) - assert len(args.seeds) > 0, 'invalid length of seeds args: {:}'.format(args.seeds) - assert len(args.datasets) == len(args.xpaths) == len(args.splits), 'invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits)) - assert args.workers > 0, 'invalid number of workers : {:}'.format(args.workers) - - main(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, args.hyper, \ - tuple(args.srange), args.arch_index, tuple(args.seeds), \ - args.mode == 'cover', meta_info, \ - {'channel': args.channel, 'num_cells': args.num_cells}) + nets = traverse_net(args.max_node) + if len(nets) != args.check_N: + raise ValueError('Pre-num-check failed : {:} vs {:}'.format(len(nets), args.check_N)) + opt_config = './configs/nas-benchmark/hyper-opts/{:}E.config'.format(args.hyper) + if not os.path.isfile(opt_config): + raise ValueError('{:} is not a file.'.format(opt_config)) + save_dir = Path(args.save_dir) / 'raw-data-{:}'.format(args.hyper) + save_dir.mkdir(parents=True, exist_ok=True) + to_evaluate_indexes = split_str2indexes(args.srange, args.check_N, 5) + if not len(args.seeds): + raise ValueError('invalid length of seeds args: {:}'.format(args.seeds)) + if not (len(args.datasets) == len(args.xpaths) == len(args.splits)): + raise ValueError('invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits))) + if args.workers <= 0: + raise ValueError('invalid number of workers : {:}'.format(args.workers)) + + target_indexes = filter_indexes(to_evaluate_indexes, args.mode, save_dir, args.seeds) + + assert torch.cuda.is_available(), 'CUDA is not available.' + torch.backends.cudnn.enabled = True + torch.backends.cudnn.deterministic = True + torch.set_num_threads(args.workers) + + main(save_dir, args.workers, args.datasets, args.xpaths, args.splits, tuple(args.seeds), nets, opt_config, target_indexes, args.mode == 'cover', \ + {'name': 'infer.tiny', 'channel': args.channel, 'num_cells': args.num_cells}) diff --git a/lib/utils/__init__.py b/lib/utils/__init__.py index 86b4d37..bceb658 100644 --- a/lib/utils/__init__.py +++ b/lib/utils/__init__.py @@ -4,3 +4,4 @@ from .flop_benchmark import get_model_infos, count_parameters_in_MB from .affine_utils import normalize_points, denormalize_points from .affine_utils import identity2affine, solve2theta, affine2image from .hash_utils import get_md5_file +from .str_utils import split_str2indexes diff --git a/lib/utils/str_utils.py b/lib/utils/str_utils.py new file mode 100644 index 0000000..bd58b6a --- /dev/null +++ b/lib/utils/str_utils.py @@ -0,0 +1,18 @@ + +def split_str2indexes(string: str, max_check: int, length_limit=5): + if not isinstance(string, str): + raise ValueError('Invalid scheme for {:}'.format(string)) + srangestr = "".join(string.split()) + indexes = set() + for srange in srangestr.split(','): + srange = srange.split('-') + if len(srange) != 2: + raise ValueError('invalid srange : {:}'.format(srange)) + if length_limit is not None: + assert len(srange[0]) == len(srange[1]) == length_limit, 'invalid srange : {:}'.format(srange) + srange = (int(srange[0]), int(srange[1])) + if not (0 <= srange[0] <= srange[1] < max_check): + raise ValueError('{:} vs {:} vs {:}'.format(srange[0], srange[1], max_check)) + for i in range(srange[0], srange[1]+1): + indexes.add(i) + return indexes diff --git a/scripts/NATS-Bench/train-topology.sh b/scripts/NATS-Bench/train-topology.sh new file mode 100644 index 0000000..01d81cd --- /dev/null +++ b/scripts/NATS-Bench/train-topology.sh @@ -0,0 +1,43 @@ +#!/bin/bash +############################################################################## +# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # +############################################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.01 # +############################################################################## +# CUDA_VISIBLE_DEVICES=0 bash scripts/NATS-Bench/train-topology.sh 00000-05000 12 777 +# bash ./scripts/NATS-Bench/train-topology.sh 05001-10000 12 777 +# bash ./scripts/NATS-Bench/train-topology.sh 10001-14500 12 777 +# bash ./scripts/NATS-Bench/train-topology.sh 14501-15624 12 777 +# +############################################################################## +echo script name: $0 +echo $# arguments +if [ "$#" -ne 3 ] ;then + echo "Input illegal number of parameters " $# + echo "Need 3 parameters for start-and-end, hyper-parameters-opt-file, and seeds" + exit 1 +fi +if [ "$TORCH_HOME" = "" ]; then + echo "Must set TORCH_HOME envoriment variable for data dir saving" + exit 1 +else + echo "TORCH_HOME : $TORCH_HOME" +fi + +srange=$1 +opt=$2 +all_seeds=$3 +cpus=4 + +save_dir=./output/NATS-Bench-topology/ + +OMP_NUM_THREADS=${cpus} python exps/NATS-Bench/main-tss.py \ + --mode new --srange ${srange} --hyper ${opt} --save_dir ${save_dir} \ + --datasets cifar10 cifar10 cifar100 ImageNet16-120 \ + --splits 1 0 0 0 \ + --xpaths $TORCH_HOME/cifar.python \ + $TORCH_HOME/cifar.python \ + $TORCH_HOME/cifar.python \ + $TORCH_HOME/cifar.python/ImageNet16 \ + --workers ${cpus} \ + --seeds ${all_seeds}