MeCo/correlation/foresight/dataset.py

203 lines
7.7 KiB
Python
Raw Permalink Normal View History

2024-01-23 03:08:45 +01:00
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
from torchvision.datasets import MNIST, CIFAR10, CIFAR100, SVHN
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision import transforms
2024-11-26 11:02:56 +01:00
import torchvision.datasets as dset
2024-01-23 03:08:45 +01:00
from torch.utils.data import TensorDataset, DataLoader
import torch
from .imagenet16 import *
2024-11-27 16:58:12 +01:00
class CUTOUT(object):
def __init__(self, length):
self.length = length
def __repr__(self):
return "{name}(length={length})".format(
name=self.__class__.__name__, **self.__dict__
)
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1:y2, x1:x2] = 0.0
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
2024-01-23 03:08:45 +01:00
def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_workers, resize=None, datadir='_dataset'):
# print(dataset)
if 'ImageNet16' in dataset:
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
std = [x / 255 for x in [63.22, 61.26 , 65.09]]
size, pad = 16, 2
elif 'cifar' in dataset:
mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)
size, pad = 32, 4
elif 'svhn' in dataset:
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
size, pad = 32, 0
elif dataset == 'ImageNet1k':
from .h5py_dataset import H5Dataset
size,pad = 224,2
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
#resize = 256
2024-11-26 11:02:56 +01:00
elif dataset == 'aircraft':
mean = (0.4785, 0.5100, 0.5338)
std = (0.1845, 0.1830, 0.2060)
size, pad = 224, 2
elif dataset == 'oxford':
mean = (0.4811, 0.4492, 0.3957)
std = (0.2260, 0.2231, 0.2249)
size, pad = 32, 0
2024-01-23 03:08:45 +01:00
elif 'random' in dataset:
mean = (0.5, 0.5, 0.5)
std = (1, 1, 1)
size, pad = 32, 0
if resize is None:
resize = size
train_transform = transforms.Compose([
transforms.RandomCrop(size, padding=pad),
transforms.Resize(resize),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean,std),
])
test_transform = transforms.Compose([
transforms.Resize(resize),
transforms.ToTensor(),
transforms.Normalize(mean,std),
])
2024-11-27 16:58:12 +01:00
root = '/home/iicd/MeCo/data'
aircraft_dataset_root = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data'
2024-01-23 03:08:45 +01:00
if dataset == 'cifar10':
train_dataset = CIFAR10(datadir, True, train_transform, download=True)
test_dataset = CIFAR10(datadir, False, test_transform, download=True)
elif dataset == 'cifar100':
train_dataset = CIFAR100(datadir, True, train_transform, download=True)
test_dataset = CIFAR100(datadir, False, test_transform, download=True)
2024-11-26 11:02:56 +01:00
elif dataset == 'aircraft':
lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)]
2024-11-27 16:58:12 +01:00
if resize != None :
print(resize)
lists += [CUTOUT(resize)]
2024-11-26 11:02:56 +01:00
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
2024-11-27 16:58:12 +01:00
train_data = dset.ImageFolder(os.path.join(aircraft_dataset_root, 'train_sorted_images'), train_transform)
test_data = dset.ImageFolder(os.path.join(aircraft_dataset_root, 'test_sorted_images'), test_transform)
2024-11-26 11:02:56 +01:00
elif dataset == 'oxford':
lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)]
2024-11-27 16:58:12 +01:00
if resize != None :
print(resize)
lists += [CUTOUT(resize)]
2024-11-26 11:02:56 +01:00
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
train_data = torch.load(os.path.join(root, 'train85.pth'))
test_data = torch.load(os.path.join(root, 'test15.pth'))
train_tensor_data = [(image, label) for image, label in train_data]
test_tensor_data = [(image, label) for image, label in test_data]
sum_data = train_tensor_data + test_tensor_data
train_images = [image for image, label in train_tensor_data]
train_labels = torch.tensor([label for image, label in train_tensor_data])
test_images = [image for image, label in test_tensor_data]
test_labels = torch.tensor([label for image, label in test_tensor_data])
train_tensors = torch.stack([train_transform(image) for image in train_images])
test_tensors = torch.stack([test_transform(image) for image in test_images])
train_dataset = TensorDataset(train_tensors, train_labels)
test_dataset = TensorDataset(test_tensors, test_labels)
2024-01-23 03:08:45 +01:00
elif dataset == 'svhn':
train_dataset = SVHN(datadir, split='train', transform=train_transform, download=True)
test_dataset = SVHN(datadir, split='test', transform=test_transform, download=True)
elif dataset == 'ImageNet16-120':
train_dataset = ImageNet16(os.path.join(datadir, 'ImageNet16'), True , train_transform, 120)
test_dataset = ImageNet16(os.path.join(datadir, 'ImageNet16'), False, test_transform , 120)
elif dataset == 'ImageNet1k':
train_dataset = H5Dataset(os.path.join(datadir, 'imagenet-train-256.h5'), transform=train_transform)
test_dataset = H5Dataset(os.path.join(datadir, 'imagenet-val-256.h5'), transform=test_transform)
else:
raise ValueError('There are no more cifars or imagenets.')
train_loader = DataLoader(
train_dataset,
train_batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True)
test_loader = DataLoader(
test_dataset,
test_batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True)
return train_loader, test_loader
def get_mnist_dataloaders(train_batch_size, val_batch_size, num_workers):
data_transform = Compose([transforms.ToTensor()])
# Normalise? transforms.Normalize((0.1307,), (0.3081,))
train_dataset = MNIST("_dataset", True, data_transform, download=True)
test_dataset = MNIST("_dataset", False, data_transform, download=True)
train_loader = DataLoader(
train_dataset,
train_batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True)
test_loader = DataLoader(
test_dataset,
val_batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True)
return train_loader, test_loader
if __name__ == '__main__':
tr, te = get_cifar_dataloaders(64, 64, 'random', 2, resize=None, datadir='_dataset')
for x, y in tr:
print(x.size(), y.size())
2024-11-27 16:58:12 +01:00
break