# Copyright 2021 Samsung Electronics Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= from torchvision.datasets import MNIST, CIFAR10, CIFAR100, SVHN from torchvision.transforms import Compose, ToTensor, Normalize from torchvision import transforms import torchvision.datasets as dset from torch.utils.data import TensorDataset, DataLoader import torch from .imagenet16 import * def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_workers, resize=None, datadir='_dataset'): # print(dataset) if 'ImageNet16' in dataset: mean = [x / 255 for x in [122.68, 116.66, 104.01]] std = [x / 255 for x in [63.22, 61.26 , 65.09]] size, pad = 16, 2 elif 'cifar' in dataset: mean = (0.4914, 0.4822, 0.4465) std = (0.2023, 0.1994, 0.2010) size, pad = 32, 4 elif 'svhn' in dataset: mean = (0.5, 0.5, 0.5) std = (0.5, 0.5, 0.5) size, pad = 32, 0 elif dataset == 'ImageNet1k': from .h5py_dataset import H5Dataset size,pad = 224,2 mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) #resize = 256 elif dataset == 'aircraft': mean = (0.4785, 0.5100, 0.5338) std = (0.1845, 0.1830, 0.2060) size, pad = 224, 2 elif dataset == 'oxford': mean = (0.4811, 0.4492, 0.3957) std = (0.2260, 0.2231, 0.2249) size, pad = 32, 0 elif 'random' in dataset: mean = (0.5, 0.5, 0.5) std = (1, 1, 1) size, pad = 32, 0 if resize is None: resize = size train_transform = transforms.Compose([ transforms.RandomCrop(size, padding=pad), transforms.Resize(resize), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean,std), ]) test_transform = transforms.Compose([ transforms.Resize(resize), transforms.ToTensor(), transforms.Normalize(mean,std), ]) root = '/nfs/data3/hanzhang/MeCo/data' if dataset == 'cifar10': train_dataset = CIFAR10(datadir, True, train_transform, download=True) test_dataset = CIFAR10(datadir, False, test_transform, download=True) elif dataset == 'cifar100': train_dataset = CIFAR100(datadir, True, train_transform, download=True) test_dataset = CIFAR100(datadir, False, test_transform, download=True) elif dataset == 'aircraft': lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)] # if resize != None : # print(resize) # lists += [CUTOUT(resize)] train_transform = transforms.Compose(lists) test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)]) train_data = dset.ImageFolder(os.path.join(root, 'train_sorted_images'), train_transform) test_data = dset.ImageFolder(os.path.join(root, 'test_sorted_images'), test_transform) elif dataset == 'oxford': lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)] # if resize != None : # print(resize) # lists += [CUTOUT(resize)] train_transform = transforms.Compose(lists) test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)]) train_data = torch.load(os.path.join(root, 'train85.pth')) test_data = torch.load(os.path.join(root, 'test15.pth')) train_tensor_data = [(image, label) for image, label in train_data] test_tensor_data = [(image, label) for image, label in test_data] sum_data = train_tensor_data + test_tensor_data train_images = [image for image, label in train_tensor_data] train_labels = torch.tensor([label for image, label in train_tensor_data]) test_images = [image for image, label in test_tensor_data] test_labels = torch.tensor([label for image, label in test_tensor_data]) train_tensors = torch.stack([train_transform(image) for image in train_images]) test_tensors = torch.stack([test_transform(image) for image in test_images]) train_dataset = TensorDataset(train_tensors, train_labels) test_dataset = TensorDataset(test_tensors, test_labels) elif dataset == 'svhn': train_dataset = SVHN(datadir, split='train', transform=train_transform, download=True) test_dataset = SVHN(datadir, split='test', transform=test_transform, download=True) elif dataset == 'ImageNet16-120': train_dataset = ImageNet16(os.path.join(datadir, 'ImageNet16'), True , train_transform, 120) test_dataset = ImageNet16(os.path.join(datadir, 'ImageNet16'), False, test_transform , 120) elif dataset == 'ImageNet1k': train_dataset = H5Dataset(os.path.join(datadir, 'imagenet-train-256.h5'), transform=train_transform) test_dataset = H5Dataset(os.path.join(datadir, 'imagenet-val-256.h5'), transform=test_transform) else: raise ValueError('There are no more cifars or imagenets.') train_loader = DataLoader( train_dataset, train_batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) test_loader = DataLoader( test_dataset, test_batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) return train_loader, test_loader def get_mnist_dataloaders(train_batch_size, val_batch_size, num_workers): data_transform = Compose([transforms.ToTensor()]) # Normalise? transforms.Normalize((0.1307,), (0.3081,)) train_dataset = MNIST("_dataset", True, data_transform, download=True) test_dataset = MNIST("_dataset", False, data_transform, download=True) train_loader = DataLoader( train_dataset, train_batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) test_loader = DataLoader( test_dataset, val_batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) return train_loader, test_loader if __name__ == '__main__': tr, te = get_cifar_dataloaders(64, 64, 'random', 2, resize=None, datadir='_dataset') for x, y in tr: print(x.size(), y.size()) break