clarify restrictions

This commit is contained in:
D-X-Y 2020-01-04 22:16:27 +11:00
parent db44e56fb6
commit e6ca3628ce
6 changed files with 21 additions and 18 deletions

View File

@ -105,20 +105,13 @@ def main(xargs):
logger = prepare_logger(args)
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
#config_path = 'configs/nas-benchmark/algos/DARTS.config'
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
if xargs.dataset == 'cifar10':
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log('Load split file from {:}'.format(split_Fpath))
elif xargs.dataset.startswith('ImageNet16'):
split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
imagenet16_split = load_config(split_Fpath, None, None)
train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
logger.log('Load split file from {:}'.format(split_Fpath))
else:
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
config_path = 'configs/nas-benchmark/algos/DARTS.config'
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set
logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
@ -127,6 +120,12 @@ def main(xargs):
# data loader
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
elif xargs.dataset == 'cifar100':
raise ValueError('not support yet : {:}'.format(xargs.dataset))
elif xargs.dataset.startswith('ImageNet16'):
raise ValueError('not support yet : {:}'.format(xargs.dataset))
else:
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
@ -231,6 +230,7 @@ if __name__ == '__main__':
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells
parser.add_argument('--config_path', type=str, help='The config paths.')
parser.add_argument('--search_space_name', type=str, help='The search space name.')
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')

View File

@ -184,6 +184,7 @@ def main(xargs):
logger = prepare_logger(args)
train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
cifar_split = load_config(split_Fpath, None, None)

View File

@ -81,6 +81,7 @@ def main(xargs):
logger = prepare_logger(args)
train_data, _, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
cifar_split = load_config(split_Fpath, None, None)

View File

@ -135,6 +135,7 @@ def main(xargs):
logger = prepare_logger(args)
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
cifar_split = load_config(split_Fpath, None, None)

View File

@ -5,7 +5,6 @@ import os, sys, torch
import os.path as osp
import numpy as np
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from PIL import Image
from .DownsampledImageNet import ImageNet16

View File

@ -33,6 +33,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V1.py \
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
--dataset ${dataset} --data_path ${data_path} \
--search_space_name ${space} \
--config_path configs/nas-benchmark/algos/DARTS.config \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
--workers 4 --print_freq 200 --rand_seed ${seed}