From e6ca3628ce5c192e5d38118ea7ac7386a2b46110 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sat, 4 Jan 2020 22:16:27 +1100 Subject: [PATCH] clarify restrictions --- exps/algos/DARTS-V1.py | 34 +++++++++++----------- exps/algos/ENAS.py | 1 + exps/algos/GDAS.py | 1 + exps/algos/SETN.py | 1 + lib/datasets/get_dataset_with_transform.py | 1 - scripts-search/algos/DARTS-V1.sh | 1 + 6 files changed, 21 insertions(+), 18 deletions(-) diff --git a/exps/algos/DARTS-V1.py b/exps/algos/DARTS-V1.py index 6173748..da3769a 100644 --- a/exps/algos/DARTS-V1.py +++ b/exps/algos/DARTS-V1.py @@ -105,28 +105,27 @@ def main(xargs): logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) - if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': + #config_path = 'configs/nas-benchmark/algos/DARTS.config' + config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) + if xargs.dataset == 'cifar10': split_Fpath = 'configs/nas-benchmark/cifar-split.txt' cifar_split = load_config(split_Fpath, None, None) - train_split, valid_split = cifar_split.train, cifar_split.valid - logger.log('Load split file from {:}'.format(split_Fpath)) + train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set + logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set + # To split data + train_data_v2 = deepcopy(train_data) + train_data_v2.transform = valid_data.transform + valid_data = train_data_v2 + search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) + # data loader + search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) + valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) + elif xargs.dataset == 'cifar100': + raise ValueError('not support yet : {:}'.format(xargs.dataset)) elif xargs.dataset.startswith('ImageNet16'): - split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset) - imagenet16_split = load_config(split_Fpath, None, None) - train_split, valid_split = imagenet16_split.train, imagenet16_split.valid - logger.log('Load split file from {:}'.format(split_Fpath)) + raise ValueError('not support yet : {:}'.format(xargs.dataset)) else: raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) - config_path = 'configs/nas-benchmark/algos/DARTS.config' - config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) - # To split data - train_data_v2 = deepcopy(train_data) - train_data_v2.transform = valid_data.transform - valid_data = train_data_v2 - search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) - # data loader - search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) @@ -231,6 +230,7 @@ if __name__ == '__main__': parser.add_argument('--data_path', type=str, help='Path to dataset') parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') # channels and number-of-cells + parser.add_argument('--config_path', type=str, help='The config paths.') parser.add_argument('--search_space_name', type=str, help='The search space name.') parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') parser.add_argument('--channel', type=int, help='The number of channels.') diff --git a/exps/algos/ENAS.py b/exps/algos/ENAS.py index 38b3b11..a20c9cf 100644 --- a/exps/algos/ENAS.py +++ b/exps/algos/ENAS.py @@ -184,6 +184,7 @@ def main(xargs): logger = prepare_logger(args) train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': split_Fpath = 'configs/nas-benchmark/cifar-split.txt' cifar_split = load_config(split_Fpath, None, None) diff --git a/exps/algos/GDAS.py b/exps/algos/GDAS.py index 3714204..84432b8 100644 --- a/exps/algos/GDAS.py +++ b/exps/algos/GDAS.py @@ -81,6 +81,7 @@ def main(xargs): logger = prepare_logger(args) train_data, _, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': split_Fpath = 'configs/nas-benchmark/cifar-split.txt' cifar_split = load_config(split_Fpath, None, None) diff --git a/exps/algos/SETN.py b/exps/algos/SETN.py index e5ebece..8a937f6 100644 --- a/exps/algos/SETN.py +++ b/exps/algos/SETN.py @@ -135,6 +135,7 @@ def main(xargs): logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': split_Fpath = 'configs/nas-benchmark/cifar-split.txt' cifar_split = load_config(split_Fpath, None, None) diff --git a/lib/datasets/get_dataset_with_transform.py b/lib/datasets/get_dataset_with_transform.py index 416bcde..19323cf 100644 --- a/lib/datasets/get_dataset_with_transform.py +++ b/lib/datasets/get_dataset_with_transform.py @@ -5,7 +5,6 @@ import os, sys, torch import os.path as osp import numpy as np import torchvision.datasets as dset -import torch.backends.cudnn as cudnn import torchvision.transforms as transforms from PIL import Image from .DownsampledImageNet import ImageNet16 diff --git a/scripts-search/algos/DARTS-V1.sh b/scripts-search/algos/DARTS-V1.sh index 2104bda..a096b54 100644 --- a/scripts-search/algos/DARTS-V1.sh +++ b/scripts-search/algos/DARTS-V1.sh @@ -33,6 +33,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V1.py \ --save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ --dataset ${dataset} --data_path ${data_path} \ --search_space_name ${space} \ + --config_path configs/nas-benchmark/algos/DARTS.config \ --arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ --arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ --workers 4 --print_freq 200 --rand_seed ${seed}