update to oxford and aircraft
This commit is contained in:
parent
0d830dd2f6
commit
a6e411a94b
@ -71,7 +71,7 @@ def parse_arguments():
|
|||||||
parser.add_argument('--write_freq', type=int, default=1, help='frequency of write to file')
|
parser.add_argument('--write_freq', type=int, default=1, help='frequency of write to file')
|
||||||
parser.add_argument('--start', type=int, default=0, help='start index')
|
parser.add_argument('--start', type=int, default=0, help='start index')
|
||||||
parser.add_argument('--end', type=int, default=0, help='end index')
|
parser.add_argument('--end', type=int, default=0, help='end index')
|
||||||
parser.add_argument('--noacc', default=False, action='store_true',
|
parser.add_argument('--noacc', default=True, action='store_true',
|
||||||
help='avoid loading NASBench2 api an instead load a pickle file with tuple (index, arch_str)')
|
help='avoid loading NASBench2 api an instead load a pickle file with tuple (index, arch_str)')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
|
args.device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
|
||||||
@ -94,7 +94,14 @@ if __name__ == '__main__':
|
|||||||
x, y = next(iter(train_loader))
|
x, y = next(iter(train_loader))
|
||||||
|
|
||||||
cached_res = []
|
cached_res = []
|
||||||
pre = 'cf' if 'cifar' in args.dataset else 'im'
|
if 'cifar' in args.dataset :
|
||||||
|
pre = 'cf'
|
||||||
|
elif 'Image' in args.dataset:
|
||||||
|
pre = 'im'
|
||||||
|
elif 'oxford' in args.dataset:
|
||||||
|
pre = 'ox'
|
||||||
|
elif 'air' in args.dataset:
|
||||||
|
pre = 'ai'
|
||||||
pfn = f'nb2_{args.search_space}_{pre}{get_num_classes(args)}_seed{args.seed}_dl{args.dataload}_dlinfo{args.dataload_info}_initw{args.init_w_type}_initb{args.init_b_type}_{args.batch_size}.p'
|
pfn = f'nb2_{args.search_space}_{pre}{get_num_classes(args)}_seed{args.seed}_dl{args.dataload}_dlinfo{args.dataload_info}_initw{args.init_w_type}_initb{args.init_b_type}_{args.batch_size}.p'
|
||||||
op = os.path.join(args.outdir, pfn)
|
op = os.path.join(args.outdir, pfn)
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@
|
|||||||
from torchvision.datasets import MNIST, CIFAR10, CIFAR100, SVHN
|
from torchvision.datasets import MNIST, CIFAR10, CIFAR100, SVHN
|
||||||
from torchvision.transforms import Compose, ToTensor, Normalize
|
from torchvision.transforms import Compose, ToTensor, Normalize
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
import torchvision.datasets as dset
|
||||||
from torch.utils.data import TensorDataset, DataLoader
|
from torch.utils.data import TensorDataset, DataLoader
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -44,6 +45,14 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker
|
|||||||
mean = (0.485, 0.456, 0.406)
|
mean = (0.485, 0.456, 0.406)
|
||||||
std = (0.229, 0.224, 0.225)
|
std = (0.229, 0.224, 0.225)
|
||||||
#resize = 256
|
#resize = 256
|
||||||
|
elif dataset == 'aircraft':
|
||||||
|
mean = (0.4785, 0.5100, 0.5338)
|
||||||
|
std = (0.1845, 0.1830, 0.2060)
|
||||||
|
size, pad = 224, 2
|
||||||
|
elif dataset == 'oxford':
|
||||||
|
mean = (0.4811, 0.4492, 0.3957)
|
||||||
|
std = (0.2260, 0.2231, 0.2249)
|
||||||
|
size, pad = 32, 0
|
||||||
elif 'random' in dataset:
|
elif 'random' in dataset:
|
||||||
mean = (0.5, 0.5, 0.5)
|
mean = (0.5, 0.5, 0.5)
|
||||||
std = (1, 1, 1)
|
std = (1, 1, 1)
|
||||||
@ -65,6 +74,7 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker
|
|||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean,std),
|
transforms.Normalize(mean,std),
|
||||||
])
|
])
|
||||||
|
root = '/nfs/data3/hanzhang/MeCo/data'
|
||||||
|
|
||||||
if dataset == 'cifar10':
|
if dataset == 'cifar10':
|
||||||
train_dataset = CIFAR10(datadir, True, train_transform, download=True)
|
train_dataset = CIFAR10(datadir, True, train_transform, download=True)
|
||||||
@ -72,6 +82,40 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker
|
|||||||
elif dataset == 'cifar100':
|
elif dataset == 'cifar100':
|
||||||
train_dataset = CIFAR100(datadir, True, train_transform, download=True)
|
train_dataset = CIFAR100(datadir, True, train_transform, download=True)
|
||||||
test_dataset = CIFAR100(datadir, False, test_transform, download=True)
|
test_dataset = CIFAR100(datadir, False, test_transform, download=True)
|
||||||
|
elif dataset == 'aircraft':
|
||||||
|
lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||||
|
# if resize != None :
|
||||||
|
# print(resize)
|
||||||
|
# lists += [CUTOUT(resize)]
|
||||||
|
train_transform = transforms.Compose(lists)
|
||||||
|
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||||
|
train_data = dset.ImageFolder(os.path.join(root, 'train_sorted_images'), train_transform)
|
||||||
|
test_data = dset.ImageFolder(os.path.join(root, 'test_sorted_images'), test_transform)
|
||||||
|
elif dataset == 'oxford':
|
||||||
|
lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||||
|
# if resize != None :
|
||||||
|
# print(resize)
|
||||||
|
# lists += [CUTOUT(resize)]
|
||||||
|
train_transform = transforms.Compose(lists)
|
||||||
|
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||||
|
|
||||||
|
train_data = torch.load(os.path.join(root, 'train85.pth'))
|
||||||
|
test_data = torch.load(os.path.join(root, 'test15.pth'))
|
||||||
|
|
||||||
|
train_tensor_data = [(image, label) for image, label in train_data]
|
||||||
|
test_tensor_data = [(image, label) for image, label in test_data]
|
||||||
|
sum_data = train_tensor_data + test_tensor_data
|
||||||
|
|
||||||
|
train_images = [image for image, label in train_tensor_data]
|
||||||
|
train_labels = torch.tensor([label for image, label in train_tensor_data])
|
||||||
|
test_images = [image for image, label in test_tensor_data]
|
||||||
|
test_labels = torch.tensor([label for image, label in test_tensor_data])
|
||||||
|
|
||||||
|
train_tensors = torch.stack([train_transform(image) for image in train_images])
|
||||||
|
test_tensors = torch.stack([test_transform(image) for image in test_images])
|
||||||
|
|
||||||
|
train_dataset = TensorDataset(train_tensors, train_labels)
|
||||||
|
test_dataset = TensorDataset(test_tensors, test_labels)
|
||||||
elif dataset == 'svhn':
|
elif dataset == 'svhn':
|
||||||
train_dataset = SVHN(datadir, split='train', transform=train_transform, download=True)
|
train_dataset = SVHN(datadir, split='train', transform=train_transform, download=True)
|
||||||
test_dataset = SVHN(datadir, split='test', transform=test_transform, download=True)
|
test_dataset = SVHN(datadir, split='test', transform=test_transform, download=True)
|
||||||
@ -97,8 +141,6 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker
|
|||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
|
|
||||||
|
|
||||||
return train_loader, test_loader
|
return train_loader, test_loader
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user