diff --git a/correlation/foresight/dataset.py b/correlation/foresight/dataset.py index 25de156..005cc5a 100644 --- a/correlation/foresight/dataset.py +++ b/correlation/foresight/dataset.py @@ -25,6 +25,32 @@ import torch from .imagenet16 import * +class CUTOUT(object): + def __init__(self, length): + self.length = length + + def __repr__(self): + return "{name}(length={length})".format( + name=self.__class__.__name__, **self.__dict__ + ) + + def __call__(self, img): + h, w = img.size(1), img.size(2) + mask = np.ones((h, w), np.float32) + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + + mask[y1:y2, x1:x2] = 0.0 + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + img *= mask + return img + def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_workers, resize=None, datadir='_dataset'): # print(dataset) if 'ImageNet16' in dataset: @@ -74,7 +100,8 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker transforms.ToTensor(), transforms.Normalize(mean,std), ]) - root = '/nfs/data3/hanzhang/MeCo/data' + root = '/home/iicd/MeCo/data' + aircraft_dataset_root = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data' if dataset == 'cifar10': train_dataset = CIFAR10(datadir, True, train_transform, download=True) @@ -84,18 +111,18 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker test_dataset = CIFAR100(datadir, False, test_transform, download=True) elif dataset == 'aircraft': lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)] - # if resize != None : - # print(resize) - # lists += [CUTOUT(resize)] + if resize != None : + print(resize) + lists += [CUTOUT(resize)] train_transform = transforms.Compose(lists) test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)]) - train_data = dset.ImageFolder(os.path.join(root, 'train_sorted_images'), train_transform) - test_data = dset.ImageFolder(os.path.join(root, 'test_sorted_images'), test_transform) + train_data = dset.ImageFolder(os.path.join(aircraft_dataset_root, 'train_sorted_images'), train_transform) + test_data = dset.ImageFolder(os.path.join(aircraft_dataset_root, 'test_sorted_images'), test_transform) elif dataset == 'oxford': lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)] - # if resize != None : - # print(resize) - # lists += [CUTOUT(resize)] + if resize != None : + print(resize) + lists += [CUTOUT(resize)] train_transform = transforms.Compose(lists) test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)]) @@ -172,4 +199,4 @@ if __name__ == '__main__': tr, te = get_cifar_dataloaders(64, 64, 'random', 2, resize=None, datadir='_dataset') for x, y in tr: print(x.size(), y.size()) - break \ No newline at end of file + break