MeCo/correlation/foresight/dataset.py

175 lines
6.9 KiB
Python

# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
from torchvision.datasets import MNIST, CIFAR10, CIFAR100, SVHN
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision import transforms
import torchvision.datasets as dset
from torch.utils.data import TensorDataset, DataLoader
import torch
from .imagenet16 import *
def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_workers, resize=None, datadir='_dataset'):
# print(dataset)
if 'ImageNet16' in dataset:
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
std = [x / 255 for x in [63.22, 61.26 , 65.09]]
size, pad = 16, 2
elif 'cifar' in dataset:
mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)
size, pad = 32, 4
elif 'svhn' in dataset:
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
size, pad = 32, 0
elif dataset == 'ImageNet1k':
from .h5py_dataset import H5Dataset
size,pad = 224,2
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
#resize = 256
elif dataset == 'aircraft':
mean = (0.4785, 0.5100, 0.5338)
std = (0.1845, 0.1830, 0.2060)
size, pad = 224, 2
elif dataset == 'oxford':
mean = (0.4811, 0.4492, 0.3957)
std = (0.2260, 0.2231, 0.2249)
size, pad = 32, 0
elif 'random' in dataset:
mean = (0.5, 0.5, 0.5)
std = (1, 1, 1)
size, pad = 32, 0
if resize is None:
resize = size
train_transform = transforms.Compose([
transforms.RandomCrop(size, padding=pad),
transforms.Resize(resize),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean,std),
])
test_transform = transforms.Compose([
transforms.Resize(resize),
transforms.ToTensor(),
transforms.Normalize(mean,std),
])
root = '/nfs/data3/hanzhang/MeCo/data'
if dataset == 'cifar10':
train_dataset = CIFAR10(datadir, True, train_transform, download=True)
test_dataset = CIFAR10(datadir, False, test_transform, download=True)
elif dataset == 'cifar100':
train_dataset = CIFAR100(datadir, True, train_transform, download=True)
test_dataset = CIFAR100(datadir, False, test_transform, download=True)
elif dataset == 'aircraft':
lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)]
# if resize != None :
# print(resize)
# lists += [CUTOUT(resize)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
train_data = dset.ImageFolder(os.path.join(root, 'train_sorted_images'), train_transform)
test_data = dset.ImageFolder(os.path.join(root, 'test_sorted_images'), test_transform)
elif dataset == 'oxford':
lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)]
# if resize != None :
# print(resize)
# lists += [CUTOUT(resize)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
train_data = torch.load(os.path.join(root, 'train85.pth'))
test_data = torch.load(os.path.join(root, 'test15.pth'))
train_tensor_data = [(image, label) for image, label in train_data]
test_tensor_data = [(image, label) for image, label in test_data]
sum_data = train_tensor_data + test_tensor_data
train_images = [image for image, label in train_tensor_data]
train_labels = torch.tensor([label for image, label in train_tensor_data])
test_images = [image for image, label in test_tensor_data]
test_labels = torch.tensor([label for image, label in test_tensor_data])
train_tensors = torch.stack([train_transform(image) for image in train_images])
test_tensors = torch.stack([test_transform(image) for image in test_images])
train_dataset = TensorDataset(train_tensors, train_labels)
test_dataset = TensorDataset(test_tensors, test_labels)
elif dataset == 'svhn':
train_dataset = SVHN(datadir, split='train', transform=train_transform, download=True)
test_dataset = SVHN(datadir, split='test', transform=test_transform, download=True)
elif dataset == 'ImageNet16-120':
train_dataset = ImageNet16(os.path.join(datadir, 'ImageNet16'), True , train_transform, 120)
test_dataset = ImageNet16(os.path.join(datadir, 'ImageNet16'), False, test_transform , 120)
elif dataset == 'ImageNet1k':
train_dataset = H5Dataset(os.path.join(datadir, 'imagenet-train-256.h5'), transform=train_transform)
test_dataset = H5Dataset(os.path.join(datadir, 'imagenet-val-256.h5'), transform=test_transform)
else:
raise ValueError('There are no more cifars or imagenets.')
train_loader = DataLoader(
train_dataset,
train_batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True)
test_loader = DataLoader(
test_dataset,
test_batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True)
return train_loader, test_loader
def get_mnist_dataloaders(train_batch_size, val_batch_size, num_workers):
data_transform = Compose([transforms.ToTensor()])
# Normalise? transforms.Normalize((0.1307,), (0.3081,))
train_dataset = MNIST("_dataset", True, data_transform, download=True)
test_dataset = MNIST("_dataset", False, data_transform, download=True)
train_loader = DataLoader(
train_dataset,
train_batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True)
test_loader = DataLoader(
test_dataset,
val_batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True)
return train_loader, test_loader
if __name__ == '__main__':
tr, te = get_cifar_dataloaders(64, 64, 'random', 2, resize=None, datadir='_dataset')
for x, y in tr:
print(x.size(), y.size())
break