v2
This commit is contained in:
		| @@ -3,3 +3,4 @@ | ||||
| ################################################## | ||||
| from .get_dataset_with_transform import get_datasets, get_nas_search_loaders | ||||
| from .SearchDatasetWrap import SearchDataset | ||||
| from .data import get_data | ||||
|   | ||||
							
								
								
									
										69
									
								
								datasets/data.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								datasets/data.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | ||||
| from datasets import get_datasets | ||||
| from config_utils import load_config | ||||
| import torch | ||||
| import torchvision | ||||
|  | ||||
| class AddGaussianNoise(object): | ||||
|     def __init__(self, mean=0., std=0.001): | ||||
|         self.std = std | ||||
|         self.mean = mean | ||||
|                                      | ||||
|     def __call__(self, tensor): | ||||
|         return tensor + torch.randn(tensor.size()) * self.std + self.mean | ||||
|                                                      | ||||
|     def __repr__(self): | ||||
|         return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| class RepeatSampler(torch.utils.data.sampler.Sampler): | ||||
|     def __init__(self, samp, repeat): | ||||
|         self.samp = samp | ||||
|         self.repeat = repeat | ||||
|     def __iter__(self): | ||||
|         for i in self.samp: | ||||
|             for j in range(self.repeat): | ||||
|                 yield i | ||||
|     def __len__(self): | ||||
|         return self.repeat*len(self.samp) | ||||
|  | ||||
|  | ||||
| def get_data(dataset, data_loc, trainval, batch_size, augtype, repeat, args, pin_memory=True): | ||||
|     train_data, valid_data, xshape, class_num = get_datasets(dataset, data_loc, cutout=0) | ||||
|     if augtype == 'gaussnoise': | ||||
|         train_data.transform.transforms = train_data.transform.transforms[2:] | ||||
|         train_data.transform.transforms.append(AddGaussianNoise(std=args.sigma)) | ||||
|     elif augtype == 'cutout': | ||||
|         train_data.transform.transforms = train_data.transform.transforms[2:] | ||||
|         train_data.transform.transforms.append(torchvision.transforms.RandomErasing(p=0.9, scale=(0.02, 0.04))) | ||||
|     elif augtype == 'none': | ||||
|         train_data.transform.transforms = train_data.transform.transforms[2:] | ||||
|      | ||||
|     if dataset == 'cifar10': | ||||
|         acc_type = 'ori-test' | ||||
|         val_acc_type = 'x-valid' | ||||
|      | ||||
|     else: | ||||
|         acc_type = 'x-test' | ||||
|         val_acc_type = 'x-valid' | ||||
|      | ||||
|     if trainval and 'cifar10' in dataset: | ||||
|         cifar_split = load_config('config_utils/cifar-split.txt', None, None) | ||||
|         train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|         if repeat > 0: | ||||
|             train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, | ||||
|                                                        num_workers=0, pin_memory=pin_memory, sampler= RepeatSampler(torch.utils.data.sampler.SubsetRandomSampler(train_split), repeat)) | ||||
|         else: | ||||
|             train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, | ||||
|                                                        num_workers=0, pin_memory=pin_memory, sampler= torch.utils.data.sampler.SubsetRandomSampler(train_split)) | ||||
|          | ||||
|      | ||||
|     else: | ||||
|         if repeat > 0: | ||||
|             train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, #shuffle=True, | ||||
|                                                        num_workers=0, pin_memory=pin_memory, sampler= RepeatSampler(torch.utils.data.sampler.SubsetRandomSampler(range(len(train_data))), repeat)) | ||||
|         else: | ||||
|             train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, | ||||
|                                                        num_workers=0, pin_memory=pin_memory) | ||||
|     return train_loader | ||||
| @@ -16,7 +16,9 @@ from config_utils import load_config | ||||
|  | ||||
| Dataset2Class = {'cifar10' : 10, | ||||
|                  'cifar100': 100, | ||||
|                  'fake':10, | ||||
|                  'imagenet-1k-s':1000, | ||||
|                  'imagenette2' : 10, | ||||
|                  'imagenet-1k' : 1000, | ||||
|                  'ImageNet16'  : 1000, | ||||
|                  'ImageNet16-150': 150, | ||||
| @@ -98,8 +100,13 @@ def get_datasets(name, root, cutout): | ||||
|   elif name == 'cifar100': | ||||
|     mean = [x / 255 for x in [129.3, 124.1, 112.4]] | ||||
|     std  = [x / 255 for x in [68.2, 65.4, 70.4]] | ||||
|   elif name == 'fake': | ||||
|     mean = [x / 255 for x in [129.3, 124.1, 112.4]] | ||||
|     std  = [x / 255 for x in [68.2, 65.4, 70.4]] | ||||
|   elif name.startswith('imagenet-1k'): | ||||
|     mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] | ||||
|   elif name.startswith('imagenette'): | ||||
|     mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] | ||||
|   elif name.startswith('ImageNet16'): | ||||
|     mean = [x / 255 for x in [122.68, 116.66, 104.01]] | ||||
|     std  = [x / 255 for x in [63.22,  61.26 , 65.09]] | ||||
| @@ -113,6 +120,12 @@ def get_datasets(name, root, cutout): | ||||
|     train_transform = transforms.Compose(lists) | ||||
|     test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|     xshape = (1, 3, 32, 32) | ||||
|   elif name == 'fake': | ||||
|     lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||
|     if cutout > 0 : lists += [CUTOUT(cutout)] | ||||
|     train_transform = transforms.Compose(lists) | ||||
|     test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|     xshape = (1, 3, 32, 32) | ||||
|   elif name.startswith('ImageNet16'): | ||||
|     lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||
|     if cutout > 0 : lists += [CUTOUT(cutout)] | ||||
| @@ -125,6 +138,15 @@ def get_datasets(name, root, cutout): | ||||
|     train_transform = transforms.Compose(lists) | ||||
|     test_transform  = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|     xshape = (1, 3, 32, 32) | ||||
|   elif name.startswith('imagenette'): | ||||
|     normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | ||||
|     xlists = [] | ||||
|     xlists.append( transforms.ToTensor() ) | ||||
|     xlists.append( normalize ) | ||||
|     #train_transform = transforms.Compose(xlists) | ||||
|     train_transform  = transforms.Compose([normalize, normalize, transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize]) | ||||
|     test_transform  = transforms.Compose([normalize, normalize, transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize]) | ||||
|     xshape = (1, 3, 224, 224) | ||||
|   elif name.startswith('imagenet-1k'): | ||||
|     normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | ||||
|     if name == 'imagenet-1k': | ||||
| @@ -156,6 +178,12 @@ def get_datasets(name, root, cutout): | ||||
|     train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True) | ||||
|     test_data  = dset.CIFAR100(root, train=False, transform=test_transform , download=True) | ||||
|     assert len(train_data) == 50000 and len(test_data) == 10000 | ||||
|   elif name == 'fake': | ||||
|     train_data = dset.FakeData(size=50000, image_size=(3, 32, 32), transform=train_transform) | ||||
|     test_data = dset.FakeData(size=10000, image_size=(3, 32, 32), transform=test_transform) | ||||
|   elif name.startswith('imagenette2'): | ||||
|     train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform) | ||||
|     test_data  = dset.ImageFolder(osp.join(root, 'val'),   test_transform) | ||||
|   elif name.startswith('imagenet-1k'): | ||||
|     train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform) | ||||
|     test_data  = dset.ImageFolder(osp.join(root, 'val'),   test_transform) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user