can train aircraft now
This commit is contained in:
parent
ef2608bb42
commit
c6d53f08ae
@ -28,16 +28,30 @@ else
|
||||
mode=cover
|
||||
fi
|
||||
|
||||
# OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \
|
||||
# --mode ${mode} --save_dir ${save_dir} --max_node 4 \
|
||||
# --use_less ${use_less} \
|
||||
# --datasets cifar10 cifar10 cifar100 ImageNet16-120 \
|
||||
# --splits 1 0 0 0 \
|
||||
# --xpaths $TORCH_HOME/cifar.python \
|
||||
# $TORCH_HOME/cifar.python \
|
||||
# $TORCH_HOME/cifar.python \
|
||||
# $TORCH_HOME/cifar.python/ImageNet16 \
|
||||
# --channel 16 --num_cells 5 \
|
||||
# --workers 4 \
|
||||
# --srange ${xstart} ${xend} --arch_index ${arch_index} \
|
||||
# --seeds ${all_seeds}
|
||||
|
||||
OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \
|
||||
--mode ${mode} --save_dir ${save_dir} --max_node 4 \
|
||||
--use_less ${use_less} \
|
||||
--datasets cifar10 cifar10 cifar100 ImageNet16-120 \
|
||||
--splits 1 0 0 0 \
|
||||
--xpaths $TORCH_HOME/cifar.python \
|
||||
$TORCH_HOME/cifar.python \
|
||||
$TORCH_HOME/cifar.python \
|
||||
$TORCH_HOME/cifar.python/ImageNet16 \
|
||||
--channel 16 --num_cells 5 \
|
||||
--datasets aircraft \
|
||||
--xpaths /lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/fgvc-aircraft-2013b/data/ \
|
||||
--channel 16 \
|
||||
--splits 1 \
|
||||
--num_cells 5 \
|
||||
--workers 4 \
|
||||
--srange ${xstart} ${xend} --arch_index ${arch_index} \
|
||||
--seeds ${all_seeds}
|
||||
|
||||
|
||||
|
@ -24,6 +24,8 @@ Dataset2Class = {
|
||||
"ImageNet16-150": 150,
|
||||
"ImageNet16-120": 120,
|
||||
"ImageNet16-200": 200,
|
||||
"aircraft": 100,
|
||||
"oxford": 102
|
||||
}
|
||||
|
||||
|
||||
@ -109,6 +111,12 @@ def get_datasets(name, root, cutout):
|
||||
elif name.startswith("ImageNet16"):
|
||||
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
|
||||
std = [x / 255 for x in [63.22, 61.26, 65.09]]
|
||||
elif name == 'aircraft':
|
||||
mean = [0.4785, 0.5100, 0.5338]
|
||||
std = [0.1845, 0.1830, 0.2060]
|
||||
elif name == 'oxford':
|
||||
mean = [0.4811, 0.4492, 0.3957]
|
||||
std = [0.2260, 0.2231, 0.2249]
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
||||
@ -127,6 +135,13 @@ def get_datasets(name, root, cutout):
|
||||
[transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||
)
|
||||
xshape = (1, 3, 32, 32)
|
||||
elif name.startswith("aircraft") or name.startswith("oxford"):
|
||||
lists = [transforms.RandomCrop(16, padding=0), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||
if cutout > 0:
|
||||
lists += [CUTOUT(cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
xshape = (1, 3, 16, 16)
|
||||
elif name.startswith("ImageNet16"):
|
||||
lists = [
|
||||
transforms.RandomHorizontalFlip(),
|
||||
@ -207,6 +222,10 @@ def get_datasets(name, root, cutout):
|
||||
root, train=False, transform=test_transform, download=True
|
||||
)
|
||||
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||
elif name == "aircraft":
|
||||
train_data = dset.ImageFolder(root='/lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/fgvc-aircraft-2013b/data/train_sorted_image', transform=train_transform)
|
||||
test_data = dset.ImageFolder(root='/lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/fgvc-aircraft-2013b/data/train_sorted_image', transform=test_transform)
|
||||
|
||||
elif name.startswith("imagenet-1k"):
|
||||
train_data = dset.ImageFolder(osp.join(root, "train"), train_transform)
|
||||
test_data = dset.ImageFolder(osp.join(root, "val"), test_transform)
|
||||
|
Loading…
Reference in New Issue
Block a user