add aircraft root

This commit is contained in:
Hanzhang Ma 2024-11-27 16:58:12 +01:00
parent cd80aa277c
commit 4df61fcbb3

View File

@ -25,6 +25,32 @@ import torch
from .imagenet16 import *
class CUTOUT(object):
def __init__(self, length):
self.length = length
def __repr__(self):
return "{name}(length={length})".format(
name=self.__class__.__name__, **self.__dict__
)
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1:y2, x1:x2] = 0.0
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_workers, resize=None, datadir='_dataset'):
# print(dataset)
if 'ImageNet16' in dataset:
@ -74,7 +100,8 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker
transforms.ToTensor(),
transforms.Normalize(mean,std),
])
root = '/nfs/data3/hanzhang/MeCo/data'
root = '/home/iicd/MeCo/data'
aircraft_dataset_root = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data'
if dataset == 'cifar10':
train_dataset = CIFAR10(datadir, True, train_transform, download=True)
@ -84,18 +111,18 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker
test_dataset = CIFAR100(datadir, False, test_transform, download=True)
elif dataset == 'aircraft':
lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)]
# if resize != None :
# print(resize)
# lists += [CUTOUT(resize)]
if resize != None :
print(resize)
lists += [CUTOUT(resize)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
train_data = dset.ImageFolder(os.path.join(root, 'train_sorted_images'), train_transform)
test_data = dset.ImageFolder(os.path.join(root, 'test_sorted_images'), test_transform)
train_data = dset.ImageFolder(os.path.join(aircraft_dataset_root, 'train_sorted_images'), train_transform)
test_data = dset.ImageFolder(os.path.join(aircraft_dataset_root, 'test_sorted_images'), test_transform)
elif dataset == 'oxford':
lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)]
# if resize != None :
# print(resize)
# lists += [CUTOUT(resize)]
if resize != None :
print(resize)
lists += [CUTOUT(resize)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])