add aircraft root
This commit is contained in:
parent
cd80aa277c
commit
4df61fcbb3
@ -25,6 +25,32 @@ import torch
|
||||
from .imagenet16 import *
|
||||
|
||||
|
||||
class CUTOUT(object):
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(length={length})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def __call__(self, img):
|
||||
h, w = img.size(1), img.size(2)
|
||||
mask = np.ones((h, w), np.float32)
|
||||
y = np.random.randint(h)
|
||||
x = np.random.randint(w)
|
||||
|
||||
y1 = np.clip(y - self.length // 2, 0, h)
|
||||
y2 = np.clip(y + self.length // 2, 0, h)
|
||||
x1 = np.clip(x - self.length // 2, 0, w)
|
||||
x2 = np.clip(x + self.length // 2, 0, w)
|
||||
|
||||
mask[y1:y2, x1:x2] = 0.0
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = mask.expand_as(img)
|
||||
img *= mask
|
||||
return img
|
||||
|
||||
def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_workers, resize=None, datadir='_dataset'):
|
||||
# print(dataset)
|
||||
if 'ImageNet16' in dataset:
|
||||
@ -74,7 +100,8 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean,std),
|
||||
])
|
||||
root = '/nfs/data3/hanzhang/MeCo/data'
|
||||
root = '/home/iicd/MeCo/data'
|
||||
aircraft_dataset_root = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data'
|
||||
|
||||
if dataset == 'cifar10':
|
||||
train_dataset = CIFAR10(datadir, True, train_transform, download=True)
|
||||
@ -84,18 +111,18 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker
|
||||
test_dataset = CIFAR100(datadir, False, test_transform, download=True)
|
||||
elif dataset == 'aircraft':
|
||||
lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||
# if resize != None :
|
||||
# print(resize)
|
||||
# lists += [CUTOUT(resize)]
|
||||
if resize != None :
|
||||
print(resize)
|
||||
lists += [CUTOUT(resize)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
train_data = dset.ImageFolder(os.path.join(root, 'train_sorted_images'), train_transform)
|
||||
test_data = dset.ImageFolder(os.path.join(root, 'test_sorted_images'), test_transform)
|
||||
train_data = dset.ImageFolder(os.path.join(aircraft_dataset_root, 'train_sorted_images'), train_transform)
|
||||
test_data = dset.ImageFolder(os.path.join(aircraft_dataset_root, 'test_sorted_images'), test_transform)
|
||||
elif dataset == 'oxford':
|
||||
lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||
# if resize != None :
|
||||
# print(resize)
|
||||
# lists += [CUTOUT(resize)]
|
||||
if resize != None :
|
||||
print(resize)
|
||||
lists += [CUTOUT(resize)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
|
||||
@ -172,4 +199,4 @@ if __name__ == '__main__':
|
||||
tr, te = get_cifar_dataloaders(64, 64, 'random', 2, resize=None, datadir='_dataset')
|
||||
for x, y in tr:
|
||||
print(x.size(), y.size())
|
||||
break
|
||||
break
|
||||
|
Loading…
Reference in New Issue
Block a user