update to oxford and aircraft

This commit is contained in:
mhz 2024-11-26 11:02:56 +01:00
parent 0d830dd2f6
commit a6e411a94b
2 changed files with 53 additions and 4 deletions

View File

@ -71,7 +71,7 @@ def parse_arguments():
parser.add_argument('--write_freq', type=int, default=1, help='frequency of write to file') parser.add_argument('--write_freq', type=int, default=1, help='frequency of write to file')
parser.add_argument('--start', type=int, default=0, help='start index') parser.add_argument('--start', type=int, default=0, help='start index')
parser.add_argument('--end', type=int, default=0, help='end index') parser.add_argument('--end', type=int, default=0, help='end index')
parser.add_argument('--noacc', default=False, action='store_true', parser.add_argument('--noacc', default=True, action='store_true',
help='avoid loading NASBench2 api an instead load a pickle file with tuple (index, arch_str)') help='avoid loading NASBench2 api an instead load a pickle file with tuple (index, arch_str)')
args = parser.parse_args() args = parser.parse_args()
args.device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu") args.device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
@ -94,7 +94,14 @@ if __name__ == '__main__':
x, y = next(iter(train_loader)) x, y = next(iter(train_loader))
cached_res = [] cached_res = []
pre = 'cf' if 'cifar' in args.dataset else 'im' if 'cifar' in args.dataset :
pre = 'cf'
elif 'Image' in args.dataset:
pre = 'im'
elif 'oxford' in args.dataset:
pre = 'ox'
elif 'air' in args.dataset:
pre = 'ai'
pfn = f'nb2_{args.search_space}_{pre}{get_num_classes(args)}_seed{args.seed}_dl{args.dataload}_dlinfo{args.dataload_info}_initw{args.init_w_type}_initb{args.init_b_type}_{args.batch_size}.p' pfn = f'nb2_{args.search_space}_{pre}{get_num_classes(args)}_seed{args.seed}_dl{args.dataload}_dlinfo{args.dataload_info}_initw{args.init_w_type}_initb{args.init_b_type}_{args.batch_size}.p'
op = os.path.join(args.outdir, pfn) op = os.path.join(args.outdir, pfn)

View File

@ -18,6 +18,7 @@
from torchvision.datasets import MNIST, CIFAR10, CIFAR100, SVHN from torchvision.datasets import MNIST, CIFAR10, CIFAR100, SVHN
from torchvision.transforms import Compose, ToTensor, Normalize from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision import transforms from torchvision import transforms
import torchvision.datasets as dset
from torch.utils.data import TensorDataset, DataLoader from torch.utils.data import TensorDataset, DataLoader
import torch import torch
@ -44,6 +45,14 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker
mean = (0.485, 0.456, 0.406) mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225) std = (0.229, 0.224, 0.225)
#resize = 256 #resize = 256
elif dataset == 'aircraft':
mean = (0.4785, 0.5100, 0.5338)
std = (0.1845, 0.1830, 0.2060)
size, pad = 224, 2
elif dataset == 'oxford':
mean = (0.4811, 0.4492, 0.3957)
std = (0.2260, 0.2231, 0.2249)
size, pad = 32, 0
elif 'random' in dataset: elif 'random' in dataset:
mean = (0.5, 0.5, 0.5) mean = (0.5, 0.5, 0.5)
std = (1, 1, 1) std = (1, 1, 1)
@ -65,6 +74,7 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(mean,std), transforms.Normalize(mean,std),
]) ])
root = '/nfs/data3/hanzhang/MeCo/data'
if dataset == 'cifar10': if dataset == 'cifar10':
train_dataset = CIFAR10(datadir, True, train_transform, download=True) train_dataset = CIFAR10(datadir, True, train_transform, download=True)
@ -72,6 +82,40 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker
elif dataset == 'cifar100': elif dataset == 'cifar100':
train_dataset = CIFAR100(datadir, True, train_transform, download=True) train_dataset = CIFAR100(datadir, True, train_transform, download=True)
test_dataset = CIFAR100(datadir, False, test_transform, download=True) test_dataset = CIFAR100(datadir, False, test_transform, download=True)
elif dataset == 'aircraft':
lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)]
# if resize != None :
# print(resize)
# lists += [CUTOUT(resize)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
train_data = dset.ImageFolder(os.path.join(root, 'train_sorted_images'), train_transform)
test_data = dset.ImageFolder(os.path.join(root, 'test_sorted_images'), test_transform)
elif dataset == 'oxford':
lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)]
# if resize != None :
# print(resize)
# lists += [CUTOUT(resize)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
train_data = torch.load(os.path.join(root, 'train85.pth'))
test_data = torch.load(os.path.join(root, 'test15.pth'))
train_tensor_data = [(image, label) for image, label in train_data]
test_tensor_data = [(image, label) for image, label in test_data]
sum_data = train_tensor_data + test_tensor_data
train_images = [image for image, label in train_tensor_data]
train_labels = torch.tensor([label for image, label in train_tensor_data])
test_images = [image for image, label in test_tensor_data]
test_labels = torch.tensor([label for image, label in test_tensor_data])
train_tensors = torch.stack([train_transform(image) for image in train_images])
test_tensors = torch.stack([test_transform(image) for image in test_images])
train_dataset = TensorDataset(train_tensors, train_labels)
test_dataset = TensorDataset(test_tensors, test_labels)
elif dataset == 'svhn': elif dataset == 'svhn':
train_dataset = SVHN(datadir, split='train', transform=train_transform, download=True) train_dataset = SVHN(datadir, split='train', transform=train_transform, download=True)
test_dataset = SVHN(datadir, split='test', transform=test_transform, download=True) test_dataset = SVHN(datadir, split='test', transform=test_transform, download=True)
@ -97,8 +141,6 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker
shuffle=False, shuffle=False,
num_workers=num_workers, num_workers=num_workers,
pin_memory=True) pin_memory=True)
return train_loader, test_loader return train_loader, test_loader