clarify restrictions
This commit is contained in:
parent
db44e56fb6
commit
e6ca3628ce
@ -105,28 +105,27 @@ def main(xargs):
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||
#config_path = 'configs/nas-benchmark/algos/DARTS.config'
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
if xargs.dataset == 'cifar10':
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set
|
||||
logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set
|
||||
# To split data
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
valid_data = train_data_v2
|
||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
elif xargs.dataset == 'cifar100':
|
||||
raise ValueError('not support yet : {:}'.format(xargs.dataset))
|
||||
elif xargs.dataset.startswith('ImageNet16'):
|
||||
split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
|
||||
imagenet16_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
raise ValueError('not support yet : {:}'.format(xargs.dataset))
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
config_path = 'configs/nas-benchmark/algos/DARTS.config'
|
||||
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
# To split data
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
valid_data = train_data_v2
|
||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
@ -231,6 +230,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||
# channels and number-of-cells
|
||||
parser.add_argument('--config_path', type=str, help='The config paths.')
|
||||
parser.add_argument('--search_space_name', type=str, help='The search space name.')
|
||||
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
|
@ -184,6 +184,7 @@ def main(xargs):
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
||||
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
|
@ -81,6 +81,7 @@ def main(xargs):
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, _, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
||||
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
|
@ -135,6 +135,7 @@ def main(xargs):
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
||||
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
|
@ -5,7 +5,6 @@ import os, sys, torch
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import torchvision.datasets as dset
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
from .DownsampledImageNet import ImageNet16
|
||||
|
@ -33,6 +33,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V1.py \
|
||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||
--dataset ${dataset} --data_path ${data_path} \
|
||||
--search_space_name ${space} \
|
||||
--config_path configs/nas-benchmark/algos/DARTS.config \
|
||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
||||
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||
|
Loading…
Reference in New Issue
Block a user