175 lines
6.9 KiB
Python
175 lines
6.9 KiB
Python
|
|
# Copyright 2021 Samsung Electronics Co., Ltd.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# =============================================================================
|
|
|
|
|
|
from torchvision.datasets import MNIST, CIFAR10, CIFAR100, SVHN
|
|
from torchvision.transforms import Compose, ToTensor, Normalize
|
|
from torchvision import transforms
|
|
import torchvision.datasets as dset
|
|
from torch.utils.data import TensorDataset, DataLoader
|
|
import torch
|
|
|
|
from .imagenet16 import *
|
|
|
|
|
|
def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_workers, resize=None, datadir='_dataset'):
|
|
# print(dataset)
|
|
if 'ImageNet16' in dataset:
|
|
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
|
|
std = [x / 255 for x in [63.22, 61.26 , 65.09]]
|
|
size, pad = 16, 2
|
|
elif 'cifar' in dataset:
|
|
mean = (0.4914, 0.4822, 0.4465)
|
|
std = (0.2023, 0.1994, 0.2010)
|
|
size, pad = 32, 4
|
|
elif 'svhn' in dataset:
|
|
mean = (0.5, 0.5, 0.5)
|
|
std = (0.5, 0.5, 0.5)
|
|
size, pad = 32, 0
|
|
elif dataset == 'ImageNet1k':
|
|
from .h5py_dataset import H5Dataset
|
|
size,pad = 224,2
|
|
mean = (0.485, 0.456, 0.406)
|
|
std = (0.229, 0.224, 0.225)
|
|
#resize = 256
|
|
elif dataset == 'aircraft':
|
|
mean = (0.4785, 0.5100, 0.5338)
|
|
std = (0.1845, 0.1830, 0.2060)
|
|
size, pad = 224, 2
|
|
elif dataset == 'oxford':
|
|
mean = (0.4811, 0.4492, 0.3957)
|
|
std = (0.2260, 0.2231, 0.2249)
|
|
size, pad = 32, 0
|
|
elif 'random' in dataset:
|
|
mean = (0.5, 0.5, 0.5)
|
|
std = (1, 1, 1)
|
|
size, pad = 32, 0
|
|
|
|
if resize is None:
|
|
resize = size
|
|
|
|
train_transform = transforms.Compose([
|
|
transforms.RandomCrop(size, padding=pad),
|
|
transforms.Resize(resize),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean,std),
|
|
])
|
|
|
|
test_transform = transforms.Compose([
|
|
transforms.Resize(resize),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean,std),
|
|
])
|
|
root = '/nfs/data3/hanzhang/MeCo/data'
|
|
|
|
if dataset == 'cifar10':
|
|
train_dataset = CIFAR10(datadir, True, train_transform, download=True)
|
|
test_dataset = CIFAR10(datadir, False, test_transform, download=True)
|
|
elif dataset == 'cifar100':
|
|
train_dataset = CIFAR100(datadir, True, train_transform, download=True)
|
|
test_dataset = CIFAR100(datadir, False, test_transform, download=True)
|
|
elif dataset == 'aircraft':
|
|
lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
|
# if resize != None :
|
|
# print(resize)
|
|
# lists += [CUTOUT(resize)]
|
|
train_transform = transforms.Compose(lists)
|
|
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
|
|
train_data = dset.ImageFolder(os.path.join(root, 'train_sorted_images'), train_transform)
|
|
test_data = dset.ImageFolder(os.path.join(root, 'test_sorted_images'), test_transform)
|
|
elif dataset == 'oxford':
|
|
lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
|
# if resize != None :
|
|
# print(resize)
|
|
# lists += [CUTOUT(resize)]
|
|
train_transform = transforms.Compose(lists)
|
|
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
|
|
|
|
train_data = torch.load(os.path.join(root, 'train85.pth'))
|
|
test_data = torch.load(os.path.join(root, 'test15.pth'))
|
|
|
|
train_tensor_data = [(image, label) for image, label in train_data]
|
|
test_tensor_data = [(image, label) for image, label in test_data]
|
|
sum_data = train_tensor_data + test_tensor_data
|
|
|
|
train_images = [image for image, label in train_tensor_data]
|
|
train_labels = torch.tensor([label for image, label in train_tensor_data])
|
|
test_images = [image for image, label in test_tensor_data]
|
|
test_labels = torch.tensor([label for image, label in test_tensor_data])
|
|
|
|
train_tensors = torch.stack([train_transform(image) for image in train_images])
|
|
test_tensors = torch.stack([test_transform(image) for image in test_images])
|
|
|
|
train_dataset = TensorDataset(train_tensors, train_labels)
|
|
test_dataset = TensorDataset(test_tensors, test_labels)
|
|
elif dataset == 'svhn':
|
|
train_dataset = SVHN(datadir, split='train', transform=train_transform, download=True)
|
|
test_dataset = SVHN(datadir, split='test', transform=test_transform, download=True)
|
|
elif dataset == 'ImageNet16-120':
|
|
train_dataset = ImageNet16(os.path.join(datadir, 'ImageNet16'), True , train_transform, 120)
|
|
test_dataset = ImageNet16(os.path.join(datadir, 'ImageNet16'), False, test_transform , 120)
|
|
elif dataset == 'ImageNet1k':
|
|
train_dataset = H5Dataset(os.path.join(datadir, 'imagenet-train-256.h5'), transform=train_transform)
|
|
test_dataset = H5Dataset(os.path.join(datadir, 'imagenet-val-256.h5'), transform=test_transform)
|
|
|
|
else:
|
|
raise ValueError('There are no more cifars or imagenets.')
|
|
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
train_batch_size,
|
|
shuffle=True,
|
|
num_workers=num_workers,
|
|
pin_memory=True)
|
|
test_loader = DataLoader(
|
|
test_dataset,
|
|
test_batch_size,
|
|
shuffle=False,
|
|
num_workers=num_workers,
|
|
pin_memory=True)
|
|
return train_loader, test_loader
|
|
|
|
|
|
def get_mnist_dataloaders(train_batch_size, val_batch_size, num_workers):
|
|
|
|
data_transform = Compose([transforms.ToTensor()])
|
|
|
|
# Normalise? transforms.Normalize((0.1307,), (0.3081,))
|
|
|
|
train_dataset = MNIST("_dataset", True, data_transform, download=True)
|
|
test_dataset = MNIST("_dataset", False, data_transform, download=True)
|
|
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
train_batch_size,
|
|
shuffle=True,
|
|
num_workers=num_workers,
|
|
pin_memory=True)
|
|
test_loader = DataLoader(
|
|
test_dataset,
|
|
val_batch_size,
|
|
shuffle=False,
|
|
num_workers=num_workers,
|
|
pin_memory=True)
|
|
|
|
return train_loader, test_loader
|
|
|
|
if __name__ == '__main__':
|
|
tr, te = get_cifar_dataloaders(64, 64, 'random', 2, resize=None, datadir='_dataset')
|
|
for x, y in tr:
|
|
print(x.size(), y.size())
|
|
break |