From 3f9b54d99e19addfac876b89e11b3cd1e65095c2 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Fri, 1 Feb 2019 03:23:55 +1100 Subject: [PATCH] update scripts --- README.md | 15 +++- {scripts-cnn => TEMP}/DMS-V-Train.sh | 0 {scripts-cnn => TEMP}/DMS-V-TrainV3.sh | 0 {scripts-cnn => TEMP}/README.md | 0 {scripts-cnn => TEMP}/TRAIN-BASE.sh | 0 {scripts-cnn => TEMP}/batch-base-model.sh | 0 {scripts-cnn => TEMP}/batch-base-search.sh | 0 {scripts-cnn => TEMP}/meta-search.sh | 0 {scripts-cnn => TEMP}/search-acc-simple.sh | 0 {scripts-cnn => TEMP}/search-acc-v2-E150.sh | 0 {scripts-cnn => TEMP}/search-acc-v2-E200.sh | 0 {scripts-cnn => TEMP}/search-acc-v2-E300.sh | 0 {scripts-cnn => TEMP}/search-acc-v2-E50.sh | 0 {scripts-cnn => TEMP}/search-acc-v2.sh | 0 {scripts-cnn => TEMP}/search.sh | 0 {scripts-cnn => TEMP}/vis.sh | 0 exps-cnn/DARTS-Search.py | 1 + exps-cnn/train_base.py | 16 +--- exps-cnn/train_utils.py | 33 +-------- exps-cnn/train_utils_imagenet.py | 30 +------- lib/datasets/__init__.py | 1 + lib/datasets/get_dataset_with_transform.py | 74 +++++++++++++++++++ lib/move.sh | 4 - lib/nas/__init__.py | 4 +- lib/nas/genotypes.py | 14 +++- .../{train-cifar100.sh => train-cifar.sh} | 7 +- scripts-cnn/train-imagenet.sh | 2 +- scripts-cnn/train-model-simple.sh | 25 ------- scripts-cnn/train-model.sh | 26 ------- 29 files changed, 115 insertions(+), 137 deletions(-) rename {scripts-cnn => TEMP}/DMS-V-Train.sh (100%) rename {scripts-cnn => TEMP}/DMS-V-TrainV3.sh (100%) rename {scripts-cnn => TEMP}/README.md (100%) rename {scripts-cnn => TEMP}/TRAIN-BASE.sh (100%) rename {scripts-cnn => TEMP}/batch-base-model.sh (100%) rename {scripts-cnn => TEMP}/batch-base-search.sh (100%) rename {scripts-cnn => TEMP}/meta-search.sh (100%) rename {scripts-cnn => TEMP}/search-acc-simple.sh (100%) rename {scripts-cnn => TEMP}/search-acc-v2-E150.sh (100%) rename {scripts-cnn => TEMP}/search-acc-v2-E200.sh (100%) rename {scripts-cnn => TEMP}/search-acc-v2-E300.sh (100%) rename {scripts-cnn => TEMP}/search-acc-v2-E50.sh (100%) rename {scripts-cnn => TEMP}/search-acc-v2.sh (100%) rename {scripts-cnn => TEMP}/search.sh (100%) rename {scripts-cnn => TEMP}/vis.sh (100%) create mode 100644 lib/datasets/get_dataset_with_transform.py delete mode 100644 lib/move.sh rename scripts-cnn/{train-cifar100.sh => train-cifar.sh} (78%) delete mode 100644 scripts-cnn/train-model-simple.sh delete mode 100644 scripts-cnn/train-model.sh diff --git a/README.md b/README.md index b55172c..fff2d03 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,20 @@ Searching CNNs ``` ``` -Train the Searched RNN +Train the searched CNN on CIFAR +``` +bash ./scripts-cnn/train-imagenet.sh 0 GDAS_F1 52 14 +bash ./scripts-cnn/train-imagenet.sh 0 GDAS_V1 50 14 +``` + +Train the searched CNN on ImageNet +``` +bash ./scripts-cnn/train-imagenet.sh 0 GDAS_F1 52 14 +bash ./scripts-cnn/train-imagenet.sh 0 GDAS_V1 50 14 +``` + + +Train the searched RNN ``` bash ./scripts-rnn/train-PTB.sh 0 DARTS_V1 bash ./scripts-rnn/train-PTB.sh 0 DARTS_V2 diff --git a/scripts-cnn/DMS-V-Train.sh b/TEMP/DMS-V-Train.sh similarity index 100% rename from scripts-cnn/DMS-V-Train.sh rename to TEMP/DMS-V-Train.sh diff --git a/scripts-cnn/DMS-V-TrainV3.sh b/TEMP/DMS-V-TrainV3.sh similarity index 100% rename from scripts-cnn/DMS-V-TrainV3.sh rename to TEMP/DMS-V-TrainV3.sh diff --git a/scripts-cnn/README.md b/TEMP/README.md similarity index 100% rename from scripts-cnn/README.md rename to TEMP/README.md diff --git a/scripts-cnn/TRAIN-BASE.sh b/TEMP/TRAIN-BASE.sh similarity index 100% rename from scripts-cnn/TRAIN-BASE.sh rename to TEMP/TRAIN-BASE.sh diff --git a/scripts-cnn/batch-base-model.sh b/TEMP/batch-base-model.sh similarity index 100% rename from scripts-cnn/batch-base-model.sh rename to TEMP/batch-base-model.sh diff --git a/scripts-cnn/batch-base-search.sh b/TEMP/batch-base-search.sh similarity index 100% rename from scripts-cnn/batch-base-search.sh rename to TEMP/batch-base-search.sh diff --git a/scripts-cnn/meta-search.sh b/TEMP/meta-search.sh similarity index 100% rename from scripts-cnn/meta-search.sh rename to TEMP/meta-search.sh diff --git a/scripts-cnn/search-acc-simple.sh b/TEMP/search-acc-simple.sh similarity index 100% rename from scripts-cnn/search-acc-simple.sh rename to TEMP/search-acc-simple.sh diff --git a/scripts-cnn/search-acc-v2-E150.sh b/TEMP/search-acc-v2-E150.sh similarity index 100% rename from scripts-cnn/search-acc-v2-E150.sh rename to TEMP/search-acc-v2-E150.sh diff --git a/scripts-cnn/search-acc-v2-E200.sh b/TEMP/search-acc-v2-E200.sh similarity index 100% rename from scripts-cnn/search-acc-v2-E200.sh rename to TEMP/search-acc-v2-E200.sh diff --git a/scripts-cnn/search-acc-v2-E300.sh b/TEMP/search-acc-v2-E300.sh similarity index 100% rename from scripts-cnn/search-acc-v2-E300.sh rename to TEMP/search-acc-v2-E300.sh diff --git a/scripts-cnn/search-acc-v2-E50.sh b/TEMP/search-acc-v2-E50.sh similarity index 100% rename from scripts-cnn/search-acc-v2-E50.sh rename to TEMP/search-acc-v2-E50.sh diff --git a/scripts-cnn/search-acc-v2.sh b/TEMP/search-acc-v2.sh similarity index 100% rename from scripts-cnn/search-acc-v2.sh rename to TEMP/search-acc-v2.sh diff --git a/scripts-cnn/search.sh b/TEMP/search.sh similarity index 100% rename from scripts-cnn/search.sh rename to TEMP/search.sh diff --git a/scripts-cnn/vis.sh b/TEMP/vis.sh similarity index 100% rename from scripts-cnn/vis.sh rename to TEMP/vis.sh diff --git a/exps-cnn/DARTS-Search.py b/exps-cnn/DARTS-Search.py index 266b3d1..9f2a88a 100644 --- a/exps-cnn/DARTS-Search.py +++ b/exps-cnn/DARTS-Search.py @@ -1,3 +1,4 @@ +# DARTS First Order, Refer to https://github.com/quark0/darts import os, sys, time, glob, random, argparse import numpy as np from copy import deepcopy diff --git a/exps-cnn/train_base.py b/exps-cnn/train_base.py index 59ef82d..065b2f4 100644 --- a/exps-cnn/train_base.py +++ b/exps-cnn/train_base.py @@ -13,25 +13,11 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from utils import AverageMeter, time_string, convert_secs2time from utils import print_log, obtain_accuracy from utils import Cutout, count_parameters_in_MB -from nas import DARTS_V1, DARTS_V2, NASNet, PNASNet, AmoebaNet, ENASNet -from nas import DMS_V1, DMS_F1, GDAS_CC -from meta_nas import META_V1, META_V2 +from nas import model_types as models from train_utils import main_procedure from train_utils_imagenet import main_procedure_imagenet from scheduler import load_config -models = {'DARTS_V1': DARTS_V1, - 'DARTS_V2': DARTS_V2, - 'NASNet' : NASNet, - 'PNASNet' : PNASNet, - 'ENASNet' : ENASNet, - 'DMS_V1' : DMS_V1, - 'DMS_F1' : DMS_F1, - 'GDAS_CC' : GDAS_CC, - 'META_V1' : META_V1, - 'META_V2' : META_V2, - 'AmoebaNet' : AmoebaNet} - parser = argparse.ArgumentParser("cifar") parser.add_argument('--data_path', type=str, help='Path to dataset') diff --git a/exps-cnn/train_utils.py b/exps-cnn/train_utils.py index f7875f2..d44c69c 100644 --- a/exps-cnn/train_utils.py +++ b/exps-cnn/train_utils.py @@ -10,6 +10,7 @@ from utils import time_string, convert_secs2time from utils import count_parameters_in_MB from utils import Cutout from nas import NetworkCIFAR as Network +from datasets import get_datasets def obtain_best(accuracies): if len(accuracies) == 0: return (0, 0) @@ -17,38 +18,10 @@ def obtain_best(accuracies): s2b = sorted( tops ) return s2b[-1] + def main_procedure(config, dataset, data_path, args, genotype, init_channels, layers, log): - # Mean + Std - if dataset == 'cifar10': - mean = [x / 255 for x in [125.3, 123.0, 113.9]] - std = [x / 255 for x in [63.0, 62.1, 66.7]] - elif dataset == 'cifar100': - mean = [x / 255 for x in [129.3, 124.1, 112.4]] - std = [x / 255 for x in [68.2, 65.4, 70.4]] - else: - raise TypeError("Unknow dataset : {:}".format(dataset)) - # Dataset Transformation - if dataset == 'cifar10' or dataset == 'cifar100': - lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), - transforms.Normalize(mean, std)] - if config.cutout > 0 : lists += [Cutout(config.cutout)] - train_transform = transforms.Compose(lists) - test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) - else: - raise TypeError("Unknow dataset : {:}".format(dataset)) - # Dataset Defination - if dataset == 'cifar10': - train_data = dset.CIFAR10(data_path, train= True, transform=train_transform, download=True) - test_data = dset.CIFAR10(data_path, train=False, transform=test_transform , download=True) - class_num = 10 - elif dataset == 'cifar100': - train_data = dset.CIFAR100(data_path, train= True, transform=train_transform, download=True) - test_data = dset.CIFAR100(data_path, train=False, transform=test_transform , download=True) - class_num = 100 - else: - raise TypeError("Unknow dataset : {:}".format(dataset)) - + train_data, test_data, class_num = get_datasets(dataset, data_path, args.cutout) print_log('-------------------------------------- main-procedure', log) print_log('config : {:}'.format(config), log) diff --git a/exps-cnn/train_utils_imagenet.py b/exps-cnn/train_utils_imagenet.py index c0880e8..4067521 100644 --- a/exps-cnn/train_utils_imagenet.py +++ b/exps-cnn/train_utils_imagenet.py @@ -12,6 +12,7 @@ from utils import count_parameters_in_MB from utils import print_FLOPs from utils import Cutout from nas import NetworkImageNet as Network +from datasets import get_datasets def obtain_best(accuracies): @@ -40,30 +41,7 @@ class CrossEntropyLabelSmooth(nn.Module): def main_procedure_imagenet(config, data_path, args, genotype, init_channels, layers, log): # training data and testing data - traindir = os.path.join(data_path, 'train') - validdir = os.path.join(data_path, 'val') - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - train_data = dset.ImageFolder( - traindir, - transforms.Compose([ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ColorJitter( - brightness=0.4, - contrast=0.4, - saturation=0.4, - hue=0.2), - transforms.ToTensor(), - normalize, - ])) - valid_data = dset.ImageFolder( - validdir, - transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ])) + train_data, valid_data, class_num = get_datasets('imagenet-1k', data_path, -1) train_queue = torch.utils.data.DataLoader( train_data, batch_size=config.batch_size, shuffle= True, pin_memory=True, num_workers=args.workers) @@ -73,7 +51,6 @@ def main_procedure_imagenet(config, data_path, args, genotype, init_channels, la class_num = 1000 - print_log('-------------------------------------- main-procedure', log) print_log('config : {:}'.format(config), log) print_log('genotype : {:}'.format(genotype), log) @@ -98,8 +75,7 @@ def main_procedure_imagenet(config, data_path, args, genotype, init_channels, la criterion_smooth = CrossEntropyLabelSmooth(class_num, config.label_smooth).cuda() - optimizer = torch.optim.SGD(model.parameters(), config.LR, momentum=config.momentum, weight_decay=config.decay) - #optimizer = torch.optim.SGD(model.parameters(), config.LR, momentum=config.momentum, weight_decay=config.decay, nestero=True) + optimizer = torch.optim.SGD(model.parameters(), config.LR, momentum=config.momentum, weight_decay=config.decay, nestero=True) if config.type == 'cosine': scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(config.epochs)) elif config.type == 'steplr': diff --git a/lib/datasets/__init__.py b/lib/datasets/__init__.py index 0698df2..d78e8dc 100644 --- a/lib/datasets/__init__.py +++ b/lib/datasets/__init__.py @@ -1,3 +1,4 @@ from .MetaBatchSampler import MetaBatchSampler from .TieredImageNet import TieredImageNet from .LanguageDataset import Corpus +from .get_dataset_with_transform import get_datasets diff --git a/lib/datasets/get_dataset_with_transform.py b/lib/datasets/get_dataset_with_transform.py new file mode 100644 index 0000000..d7f9464 --- /dev/null +++ b/lib/datasets/get_dataset_with_transform.py @@ -0,0 +1,74 @@ +import os, sys, torch +import os.path as osp +import torchvision.datasets as dset +import torch.backends.cudnn as cudnn +import torchvision.transforms as transforms + +from utils import Cutout +from .TieredImageNet import TieredImageNet + +Dataset2Class = {'cifar10' : 10, + 'cifar100': 100, + 'tiered' : -1, + 'imagnet-1k' : 1000, + 'imagenet-100': 100} + + +def get_datasets(name, root, cutout): + + # Mean + Std + if name == 'cifar10': + mean = [x / 255 for x in [125.3, 123.0, 113.9]] + std = [x / 255 for x in [63.0, 62.1, 66.7]] + elif name == 'cifar100': + mean = [x / 255 for x in [129.3, 124.1, 112.4]] + std = [x / 255 for x in [68.2, 65.4, 70.4]] + elif name == 'tiered': + mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] + elif name == 'imagnet-1k' or name == 'imagenet-100': + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + else: raise TypeError("Unknow dataset : {:}".format(name)) + + + # Data Argumentation + if name == 'cifar10' or name == 'cifar100': + lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), + transforms.Normalize(mean, std)] + if cutout > 0 : lists += [Cutout(cutout)] + train_transform = transforms.Compose(lists) + test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) + elif name == 'tiered': + lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(80, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)] + if cutout > 0 : lists += [Cutout(cutout)] + train_transform = transforms.Compose(lists) + test_transform = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)]) + elif name == 'imagnet-1k' or name == 'imagenet-100': + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + train_transform = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter( + brightness=0.4, + contrast=0.4, + saturation=0.4, + hue=0.2), + transforms.ToTensor(), + normalize, + ]) + test_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize]) + else: raise TypeError("Unknow dataset : {:}".format(name)) + train_data = TieredImageNet(root, 'train-val', train_transform) + test_data = None + if name == 'cifar10': + train_data = dset.CIFAR10(root, train=True, transform=train_transform, download=True) + test_data = dset.CIFAR10(root, train=True, transform=test_transform , download=True) + elif name == 'cifar100': + train_data = dset.CIFAR100(root, train=True, transform=train_transform, download=True) + test_data = dset.CIFAR100(root, train=True, transform=test_transform , download=True) + elif name == 'imagnet-1k' or name == 'imagenet-100': + train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform) + test_data = dset.ImageFolder(osp.join(root, 'val'), train_transform) + else: raise TypeError("Unknow dataset : {:}".format(name)) + + class_num = Dataset2Class[name] + return train_data, test_data, class_num diff --git a/lib/move.sh b/lib/move.sh deleted file mode 100644 index 834b16e..0000000 --- a/lib/move.sh +++ /dev/null @@ -1,4 +0,0 @@ -rm -rf pytorch -git clone https://github.com/pytorch/pytorch.git -cp -r ./pytorch/torch/nn xnn -rm -rf pytorch diff --git a/lib/nas/__init__.py b/lib/nas/__init__.py index ff60761..e092e54 100644 --- a/lib/nas/__init__.py +++ b/lib/nas/__init__.py @@ -11,8 +11,6 @@ from .CifarNet import NetworkCIFAR from .ImageNet import NetworkImageNet # genotypes -from .genotypes import DARTS_V1, DARTS_V2 -from .genotypes import NASNet, PNASNet, AmoebaNet, ENASNet -from .genotypes import DMS_V1, DMS_F1, GDAS_CC +from .genotypes import model_types from .construct_utils import return_alphas_str diff --git a/lib/nas/genotypes.py b/lib/nas/genotypes.py index 6fc4bcf..06d3633 100644 --- a/lib/nas/genotypes.py +++ b/lib/nas/genotypes.py @@ -179,7 +179,7 @@ ENASNet = Genotype( DARTS = DARTS_V2 # Search by normal and reduce -DMS_V1 = Genotype( +GDAS_V1 = Genotype( normal=[('skip_connect', 0, 0.13017432391643524), ('skip_connect', 1, 0.12947972118854523), ('skip_connect', 0, 0.13062666356563568), ('sep_conv_5x5', 2, 0.12980839610099792), ('sep_conv_3x3', 3, 0.12923765182495117), ('skip_connect', 0, 0.12901571393013), ('sep_conv_5x5', 4, 0.12938997149467468), ('sep_conv_3x3', 3, 0.1289220005273819)], normal_concat=range(2, 6), reduce=[('sep_conv_5x5', 0, 0.12862831354141235), ('sep_conv_3x3', 1, 0.12783904373645782), ('sep_conv_5x5', 2, 0.12725995481014252), ('sep_conv_5x5', 1, 0.12705285847187042), ('dil_conv_5x5', 2, 0.12797553837299347), ('sep_conv_3x3', 1, 0.12737272679805756), ('sep_conv_5x5', 0, 0.12833961844444275), ('sep_conv_5x5', 1, 0.12758426368236542)], @@ -187,7 +187,7 @@ DMS_V1 = Genotype( ) # Search by normal and fixing reduction -DMS_F1 = Genotype( +GDAS_F1 = Genotype( normal=[('skip_connect', 0, 0.16), ('skip_connect', 1, 0.13), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.16), ('sep_conv_3x3', 2, 0.15)], normal_concat=[2, 3, 4, 5], reduce=None, @@ -201,3 +201,13 @@ GDAS_CC = Genotype( reduce=None, reduce_concat=range(2, 6) ) + +model_types = {'DARTS_V1': DARTS_V1, + 'DARTS_V2': DARTS_V2, + 'NASNet' : NASNet, + 'PNASNet' : PNASNet, + 'AmoebaNet': AmoebaNet, + 'ENASNet' : ENASNet, + 'GDAS_V1' : GDAS_V1, + 'GDAS_F1' : GDAS_F1, + 'GDAS_CC' : GDAS_CC} diff --git a/scripts-cnn/train-cifar100.sh b/scripts-cnn/train-cifar.sh similarity index 78% rename from scripts-cnn/train-cifar100.sh rename to scripts-cnn/train-cifar.sh index 46eaf6a..35ad6a7 100644 --- a/scripts-cnn/train-cifar100.sh +++ b/scripts-cnn/train-cifar.sh @@ -1,7 +1,8 @@ #!/usr/bin/env sh -if [ "$#" -ne 2 ] ;then +# bash scripts-cnn/train-cifar.sh 0 GDAS cifar10 +if [ "$#" -ne 3 ] ;then echo "Input illegal number of parameters " $# - echo "Need 2 parameters for the GPUs, the architecture" + echo "Need 3 parameters for the GPUs, the architecture, and the dataset-name" exit 1 fi if [ "$TORCH_HOME" = "" ]; then @@ -13,7 +14,7 @@ fi gpus=$1 arch=$2 -dataset=cifar100 +dataset=$3 SAVED=./snapshots/NAS/${arch}-${dataset}-E600 CUDA_VISIBLE_DEVICES=${gpus} python ./exps-nas/train_base.py \ diff --git a/scripts-cnn/train-imagenet.sh b/scripts-cnn/train-imagenet.sh index 4164629..c1061dc 100644 --- a/scripts-cnn/train-imagenet.sh +++ b/scripts-cnn/train-imagenet.sh @@ -18,7 +18,7 @@ channels=$3 layers=$4 SAVED=./snapshots/NAS/${arch}-${dataset}-C${channels}-L${layers}-E250 -CUDA_VISIBLE_DEVICES=${gpus} python ./exps-nas/train_base.py \ +CUDA_VISIBLE_DEVICES=${gpus} python ./exps-cnn/train_base.py \ --data_path $TORCH_HOME/ILSVRC2012 \ --dataset ${dataset} --arch ${arch} \ --save_path ${SAVED} \ diff --git a/scripts-cnn/train-model-simple.sh b/scripts-cnn/train-model-simple.sh deleted file mode 100644 index 4b0a918..0000000 --- a/scripts-cnn/train-model-simple.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/usr/bin/env sh -if [ "$#" -ne 2 ] ;then - echo "Input illegal number of parameters " $# - echo "Need 2 parameters for the GPUs and the architecture" - exit 1 -fi -if [ "$TORCH_HOME" = "" ]; then - echo "Must set TORCH_HOME envoriment variable for data dir saving" - exit 1 -else - echo "TORCH_HOME : $TORCH_HOME" -fi - -gpus=$1 -arch=$2 -dataset=cifar10 -SAVED=./snapshots/NAS/${arch}-${dataset}-E100 - -CUDA_VISIBLE_DEVICES=${gpus} python ./exps-nas/train_base.py \ - --data_path $TORCH_HOME/cifar.python \ - --dataset ${dataset} --arch ${arch} \ - --save_path ${SAVED} \ - --grad_clip 5 \ - --model_config ./configs/nas-cifar-cos-simple.config \ - --print_freq 100 --workers 8 diff --git a/scripts-cnn/train-model.sh b/scripts-cnn/train-model.sh deleted file mode 100644 index 6de2089..0000000 --- a/scripts-cnn/train-model.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env sh -if [ "$#" -ne 2 ] ;then - echo "Input illegal number of parameters " $# - echo "Need 2 parameters for the GPUs, the architecture" - exit 1 -fi -if [ "$TORCH_HOME" = "" ]; then - echo "Must set TORCH_HOME envoriment variable for data dir saving" - exit 1 -else - echo "TORCH_HOME : $TORCH_HOME" -fi - -gpus=$1 -arch=$2 -dataset=cifar10 -SAVED=./snapshots/NAS/${arch}-${dataset}-E600 - -CUDA_VISIBLE_DEVICES=${gpus} python ./exps-nas/train_base.py \ - --data_path $TORCH_HOME/cifar.python \ - --dataset ${dataset} --arch ${arch} \ - --save_path ${SAVED} \ - --grad_clip 5 \ - --init_channels 36 --layers 20 \ - --model_config ./configs/nas-cifar-cos.config \ - --print_freq 100 --workers 8