383 lines
14 KiB
Python
383 lines
14 KiB
Python
##################################################
|
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
|
##################################################
|
|
import os, sys, torch
|
|
import os.path as osp
|
|
import numpy as np
|
|
import torchvision.datasets as dset
|
|
import torchvision.transforms as transforms
|
|
from copy import deepcopy
|
|
from PIL import Image
|
|
|
|
from xautodl.config_utils import load_config
|
|
|
|
from .DownsampledImageNet import ImageNet16
|
|
from .SearchDatasetWrap import SearchDataset
|
|
|
|
|
|
Dataset2Class = {
|
|
"cifar10": 10,
|
|
"cifar100": 100,
|
|
"imagenet-1k-s": 1000,
|
|
"imagenet-1k": 1000,
|
|
"ImageNet16": 1000,
|
|
"ImageNet16-150": 150,
|
|
"ImageNet16-120": 120,
|
|
"ImageNet16-200": 200,
|
|
"aircraft": 100,
|
|
"oxford": 102
|
|
}
|
|
|
|
|
|
class CUTOUT(object):
|
|
def __init__(self, length):
|
|
self.length = length
|
|
|
|
def __repr__(self):
|
|
return "{name}(length={length})".format(
|
|
name=self.__class__.__name__, **self.__dict__
|
|
)
|
|
|
|
def __call__(self, img):
|
|
h, w = img.size(1), img.size(2)
|
|
mask = np.ones((h, w), np.float32)
|
|
y = np.random.randint(h)
|
|
x = np.random.randint(w)
|
|
|
|
y1 = np.clip(y - self.length // 2, 0, h)
|
|
y2 = np.clip(y + self.length // 2, 0, h)
|
|
x1 = np.clip(x - self.length // 2, 0, w)
|
|
x2 = np.clip(x + self.length // 2, 0, w)
|
|
|
|
mask[y1:y2, x1:x2] = 0.0
|
|
mask = torch.from_numpy(mask)
|
|
mask = mask.expand_as(img)
|
|
img *= mask
|
|
return img
|
|
|
|
|
|
imagenet_pca = {
|
|
"eigval": np.asarray([0.2175, 0.0188, 0.0045]),
|
|
"eigvec": np.asarray(
|
|
[
|
|
[-0.5675, 0.7192, 0.4009],
|
|
[-0.5808, -0.0045, -0.8140],
|
|
[-0.5836, -0.6948, 0.4203],
|
|
]
|
|
),
|
|
}
|
|
|
|
|
|
class Lighting(object):
|
|
def __init__(
|
|
self, alphastd, eigval=imagenet_pca["eigval"], eigvec=imagenet_pca["eigvec"]
|
|
):
|
|
self.alphastd = alphastd
|
|
assert eigval.shape == (3,)
|
|
assert eigvec.shape == (3, 3)
|
|
self.eigval = eigval
|
|
self.eigvec = eigvec
|
|
|
|
def __call__(self, img):
|
|
if self.alphastd == 0.0:
|
|
return img
|
|
rnd = np.random.randn(3) * self.alphastd
|
|
rnd = rnd.astype("float32")
|
|
v = rnd
|
|
old_dtype = np.asarray(img).dtype
|
|
v = v * self.eigval
|
|
v = v.reshape((3, 1))
|
|
inc = np.dot(self.eigvec, v).reshape((3,))
|
|
img = np.add(img, inc)
|
|
if old_dtype == np.uint8:
|
|
img = np.clip(img, 0, 255)
|
|
img = Image.fromarray(img.astype(old_dtype), "RGB")
|
|
return img
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + "()"
|
|
|
|
|
|
def get_datasets(name, root, cutout):
|
|
|
|
if name == "cifar10":
|
|
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
|
|
std = [x / 255 for x in [63.0, 62.1, 66.7]]
|
|
elif name == "cifar100":
|
|
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
|
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
|
elif name.startswith("imagenet-1k"):
|
|
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
|
elif name.startswith("ImageNet16"):
|
|
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
|
|
std = [x / 255 for x in [63.22, 61.26, 65.09]]
|
|
elif name == 'aircraft':
|
|
mean = [0.4785, 0.5100, 0.5338]
|
|
std = [0.1845, 0.1830, 0.2060]
|
|
elif name == 'oxford':
|
|
mean = [0.4811, 0.4492, 0.3957]
|
|
std = [0.2260, 0.2231, 0.2249]
|
|
else:
|
|
raise TypeError("Unknow dataset : {:}".format(name))
|
|
|
|
# Data Argumentation
|
|
if name == "cifar10" or name == "cifar100":
|
|
lists = [
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.RandomCrop(32, padding=4),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean, std),
|
|
]
|
|
if cutout > 0:
|
|
lists += [CUTOUT(cutout)]
|
|
train_transform = transforms.Compose(lists)
|
|
test_transform = transforms.Compose(
|
|
[transforms.ToTensor(), transforms.Normalize(mean, std)]
|
|
)
|
|
xshape = (1, 3, 32, 32)
|
|
elif name.startswith("aircraft") or name.startswith("oxford"):
|
|
lists = [transforms.RandomCrop(16, padding=0), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
|
if cutout > 0:
|
|
lists += [CUTOUT(cutout)]
|
|
train_transform = transforms.Compose(lists)
|
|
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
|
|
xshape = (1, 3, 16, 16)
|
|
elif name.startswith("ImageNet16"):
|
|
lists = [
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.RandomCrop(16, padding=2),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean, std),
|
|
]
|
|
if cutout > 0:
|
|
lists += [CUTOUT(cutout)]
|
|
train_transform = transforms.Compose(lists)
|
|
test_transform = transforms.Compose(
|
|
[transforms.ToTensor(), transforms.Normalize(mean, std)]
|
|
)
|
|
xshape = (1, 3, 16, 16)
|
|
elif name == "tiered":
|
|
lists = [
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.RandomCrop(80, padding=4),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean, std),
|
|
]
|
|
if cutout > 0:
|
|
lists += [CUTOUT(cutout)]
|
|
train_transform = transforms.Compose(lists)
|
|
test_transform = transforms.Compose(
|
|
[
|
|
transforms.CenterCrop(80),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean, std),
|
|
]
|
|
)
|
|
xshape = (1, 3, 32, 32)
|
|
elif name.startswith("imagenet-1k"):
|
|
normalize = transforms.Normalize(
|
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
|
)
|
|
if name == "imagenet-1k":
|
|
xlists = [transforms.RandomResizedCrop(224)]
|
|
xlists.append(
|
|
transforms.ColorJitter(
|
|
brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2
|
|
)
|
|
)
|
|
xlists.append(Lighting(0.1))
|
|
elif name == "imagenet-1k-s":
|
|
xlists = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))]
|
|
else:
|
|
raise ValueError("invalid name : {:}".format(name))
|
|
xlists.append(transforms.RandomHorizontalFlip(p=0.5))
|
|
xlists.append(transforms.ToTensor())
|
|
xlists.append(normalize)
|
|
train_transform = transforms.Compose(xlists)
|
|
test_transform = transforms.Compose(
|
|
[
|
|
transforms.Resize(256),
|
|
transforms.CenterCrop(224),
|
|
transforms.ToTensor(),
|
|
normalize,
|
|
]
|
|
)
|
|
xshape = (1, 3, 224, 224)
|
|
else:
|
|
raise TypeError("Unknow dataset : {:}".format(name))
|
|
|
|
if name == "cifar10":
|
|
train_data = dset.CIFAR10(
|
|
root, train=True, transform=train_transform, download=True
|
|
)
|
|
test_data = dset.CIFAR10(
|
|
root, train=False, transform=test_transform, download=True
|
|
)
|
|
assert len(train_data) == 50000 and len(test_data) == 10000
|
|
elif name == "cifar100":
|
|
train_data = dset.CIFAR100(
|
|
root, train=True, transform=train_transform, download=True
|
|
)
|
|
test_data = dset.CIFAR100(
|
|
root, train=False, transform=test_transform, download=True
|
|
)
|
|
assert len(train_data) == 50000 and len(test_data) == 10000
|
|
elif name == "aircraft":
|
|
train_data = dset.ImageFolder(root='/lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/fgvc-aircraft-2013b/data/train_sorted_image', transform=train_transform)
|
|
test_data = dset.ImageFolder(root='/lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/fgvc-aircraft-2013b/data/train_sorted_image', transform=test_transform)
|
|
|
|
elif name.startswith("imagenet-1k"):
|
|
train_data = dset.ImageFolder(osp.join(root, "train"), train_transform)
|
|
test_data = dset.ImageFolder(osp.join(root, "val"), test_transform)
|
|
assert (
|
|
len(train_data) == 1281167 and len(test_data) == 50000
|
|
), "invalid number of images : {:} & {:} vs {:} & {:}".format(
|
|
len(train_data), len(test_data), 1281167, 50000
|
|
)
|
|
elif name == "ImageNet16":
|
|
train_data = ImageNet16(root, True, train_transform)
|
|
test_data = ImageNet16(root, False, test_transform)
|
|
assert len(train_data) == 1281167 and len(test_data) == 50000
|
|
elif name == "ImageNet16-120":
|
|
train_data = ImageNet16(root, True, train_transform, 120)
|
|
test_data = ImageNet16(root, False, test_transform, 120)
|
|
assert len(train_data) == 151700 and len(test_data) == 6000
|
|
elif name == "ImageNet16-150":
|
|
train_data = ImageNet16(root, True, train_transform, 150)
|
|
test_data = ImageNet16(root, False, test_transform, 150)
|
|
assert len(train_data) == 190272 and len(test_data) == 7500
|
|
elif name == "ImageNet16-200":
|
|
train_data = ImageNet16(root, True, train_transform, 200)
|
|
test_data = ImageNet16(root, False, test_transform, 200)
|
|
assert len(train_data) == 254775 and len(test_data) == 10000
|
|
else:
|
|
raise TypeError("Unknow dataset : {:}".format(name))
|
|
|
|
class_num = Dataset2Class[name]
|
|
return train_data, test_data, xshape, class_num
|
|
|
|
|
|
def get_nas_search_loaders(
|
|
train_data, valid_data, dataset, config_root, batch_size, workers
|
|
):
|
|
if isinstance(batch_size, (list, tuple)):
|
|
batch, test_batch = batch_size
|
|
else:
|
|
batch, test_batch = batch_size, batch_size
|
|
if dataset == "cifar10":
|
|
# split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
|
cifar_split = load_config("{:}/cifar-split.txt".format(config_root), None, None)
|
|
train_split, valid_split = (
|
|
cifar_split.train,
|
|
cifar_split.valid,
|
|
) # search over the proposed training and validation set
|
|
# logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set
|
|
# To split data
|
|
xvalid_data = deepcopy(train_data)
|
|
if hasattr(xvalid_data, "transforms"): # to avoid a print issue
|
|
xvalid_data.transforms = valid_data.transform
|
|
xvalid_data.transform = deepcopy(valid_data.transform)
|
|
search_data = SearchDataset(dataset, train_data, train_split, valid_split)
|
|
# data loader
|
|
search_loader = torch.utils.data.DataLoader(
|
|
search_data,
|
|
batch_size=batch,
|
|
shuffle=True,
|
|
num_workers=workers,
|
|
pin_memory=True,
|
|
)
|
|
train_loader = torch.utils.data.DataLoader(
|
|
train_data,
|
|
batch_size=batch,
|
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split),
|
|
num_workers=workers,
|
|
pin_memory=True,
|
|
)
|
|
valid_loader = torch.utils.data.DataLoader(
|
|
xvalid_data,
|
|
batch_size=test_batch,
|
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
|
|
num_workers=workers,
|
|
pin_memory=True,
|
|
)
|
|
elif dataset == "cifar100":
|
|
cifar100_test_split = load_config(
|
|
"{:}/cifar100-test-split.txt".format(config_root), None, None
|
|
)
|
|
search_train_data = train_data
|
|
search_valid_data = deepcopy(valid_data)
|
|
search_valid_data.transform = train_data.transform
|
|
search_data = SearchDataset(
|
|
dataset,
|
|
[search_train_data, search_valid_data],
|
|
list(range(len(search_train_data))),
|
|
cifar100_test_split.xvalid,
|
|
)
|
|
search_loader = torch.utils.data.DataLoader(
|
|
search_data,
|
|
batch_size=batch,
|
|
shuffle=True,
|
|
num_workers=workers,
|
|
pin_memory=True,
|
|
)
|
|
train_loader = torch.utils.data.DataLoader(
|
|
train_data,
|
|
batch_size=batch,
|
|
shuffle=True,
|
|
num_workers=workers,
|
|
pin_memory=True,
|
|
)
|
|
valid_loader = torch.utils.data.DataLoader(
|
|
valid_data,
|
|
batch_size=test_batch,
|
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
|
cifar100_test_split.xvalid
|
|
),
|
|
num_workers=workers,
|
|
pin_memory=True,
|
|
)
|
|
elif dataset == "ImageNet16-120":
|
|
imagenet_test_split = load_config(
|
|
"{:}/imagenet-16-120-test-split.txt".format(config_root), None, None
|
|
)
|
|
search_train_data = train_data
|
|
search_valid_data = deepcopy(valid_data)
|
|
search_valid_data.transform = train_data.transform
|
|
search_data = SearchDataset(
|
|
dataset,
|
|
[search_train_data, search_valid_data],
|
|
list(range(len(search_train_data))),
|
|
imagenet_test_split.xvalid,
|
|
)
|
|
search_loader = torch.utils.data.DataLoader(
|
|
search_data,
|
|
batch_size=batch,
|
|
shuffle=True,
|
|
num_workers=workers,
|
|
pin_memory=True,
|
|
)
|
|
train_loader = torch.utils.data.DataLoader(
|
|
train_data,
|
|
batch_size=batch,
|
|
shuffle=True,
|
|
num_workers=workers,
|
|
pin_memory=True,
|
|
)
|
|
valid_loader = torch.utils.data.DataLoader(
|
|
valid_data,
|
|
batch_size=test_batch,
|
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
|
imagenet_test_split.xvalid
|
|
),
|
|
num_workers=workers,
|
|
pin_memory=True,
|
|
)
|
|
else:
|
|
raise ValueError("invalid dataset : {:}".format(dataset))
|
|
return search_loader, train_loader, valid_loader
|
|
|
|
|
|
# if __name__ == '__main__':
|
|
# train_data, test_data, xshape, class_num = dataset = get_datasets('cifar10', '/data02/dongxuanyi/.torch/cifar.python/', -1)
|
|
# import pdb; pdb.set_trace()
|