add aircraft root
This commit is contained in:
		| @@ -25,6 +25,32 @@ import torch | |||||||
| from .imagenet16 import * | from .imagenet16 import * | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CUTOUT(object): | ||||||
|  |     def __init__(self, length): | ||||||
|  |         self.length = length | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}(length={length})".format( | ||||||
|  |             name=self.__class__.__name__, **self.__dict__ | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def __call__(self, img): | ||||||
|  |         h, w = img.size(1), img.size(2) | ||||||
|  |         mask = np.ones((h, w), np.float32) | ||||||
|  |         y = np.random.randint(h) | ||||||
|  |         x = np.random.randint(w) | ||||||
|  |  | ||||||
|  |         y1 = np.clip(y - self.length // 2, 0, h) | ||||||
|  |         y2 = np.clip(y + self.length // 2, 0, h) | ||||||
|  |         x1 = np.clip(x - self.length // 2, 0, w) | ||||||
|  |         x2 = np.clip(x + self.length // 2, 0, w) | ||||||
|  |  | ||||||
|  |         mask[y1:y2, x1:x2] = 0.0 | ||||||
|  |         mask = torch.from_numpy(mask) | ||||||
|  |         mask = mask.expand_as(img) | ||||||
|  |         img *= mask | ||||||
|  |         return img | ||||||
|  |      | ||||||
| def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_workers, resize=None, datadir='_dataset'): | def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_workers, resize=None, datadir='_dataset'): | ||||||
|     # print(dataset) |     # print(dataset) | ||||||
|     if 'ImageNet16' in dataset: |     if 'ImageNet16' in dataset: | ||||||
| @@ -74,7 +100,8 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker | |||||||
|         transforms.ToTensor(), |         transforms.ToTensor(), | ||||||
|         transforms.Normalize(mean,std), |         transforms.Normalize(mean,std), | ||||||
|     ]) |     ]) | ||||||
|     root = '/nfs/data3/hanzhang/MeCo/data' |     root = '/home/iicd/MeCo/data' | ||||||
|  |     aircraft_dataset_root = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data' | ||||||
|  |  | ||||||
|     if dataset == 'cifar10': |     if dataset == 'cifar10': | ||||||
|         train_dataset = CIFAR10(datadir, True, train_transform, download=True) |         train_dataset = CIFAR10(datadir, True, train_transform, download=True) | ||||||
| @@ -84,18 +111,18 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker | |||||||
|         test_dataset = CIFAR100(datadir, False, test_transform, download=True) |         test_dataset = CIFAR100(datadir, False, test_transform, download=True) | ||||||
|     elif dataset == 'aircraft':  |     elif dataset == 'aircraft':  | ||||||
|         lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)] |         lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||||
|         # if resize != None :  |         if resize != None :  | ||||||
|         #     print(resize) |             print(resize) | ||||||
|         #     lists += [CUTOUT(resize)] |             lists += [CUTOUT(resize)] | ||||||
|         train_transform = transforms.Compose(lists) |         train_transform = transforms.Compose(lists) | ||||||
|         test_transform  = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)]) |         test_transform  = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||||
|         train_data = dset.ImageFolder(os.path.join(root, 'train_sorted_images'), train_transform) |         train_data = dset.ImageFolder(os.path.join(aircraft_dataset_root, 'train_sorted_images'), train_transform) | ||||||
|         test_data  = dset.ImageFolder(os.path.join(root, 'test_sorted_images'),  test_transform) |         test_data  = dset.ImageFolder(os.path.join(aircraft_dataset_root, 'test_sorted_images'),  test_transform) | ||||||
|     elif dataset == 'oxford': |     elif dataset == 'oxford': | ||||||
|         lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)] |         lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||||
|         # if resize != None :  |         if resize != None :  | ||||||
|         #     print(resize) |             print(resize) | ||||||
|         #     lists += [CUTOUT(resize)] |             lists += [CUTOUT(resize)] | ||||||
|         train_transform = transforms.Compose(lists) |         train_transform = transforms.Compose(lists) | ||||||
|         test_transform  = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)]) |         test_transform  = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user