add aircraft in utilities

This commit is contained in:
Mhrooz 2024-08-31 15:52:00 +02:00
parent 4df5615380
commit 24f15ad0fe

View File

@ -13,7 +13,8 @@ Dataset2Class = {'cifar10': 10,
'ImageNet16' : 1000,
'ImageNet16-120': 120,
'ImageNet16-150': 150,
'ImageNet16-200': 200}
'ImageNet16-200': 200,
'aircraft': 100}
class RandChannel(object):
# randomly pick channels from input
@ -46,6 +47,10 @@ def get_datasets(name, root, input_size, cutout=-1):
elif name.startswith('ImageNet16'):
mean = [0.481098, 0.45749, 0.407882]
std = [0.247922, 0.240235, 0.255255]
elif name == 'aircraft':
mean = [0.4785, 0.5100, 0.5338]
std = [0.1845, 0.1830, 0.2060]
else:
raise TypeError("Unknow dataset : {:}".format(name))
@ -55,6 +60,12 @@ def get_datasets(name, root, input_size, cutout=-1):
if cutout > 0 : lists += [CUTOUT(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
elif name == 'aircraft':
lists = [transforms.RandomCrop(input_size[1], padding=0), transforms.ToTensor(), transforms.Normalize(mean, std)]
if cutout > 0 : lists += [CUTOUT(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
elif name.startswith('ImageNet16'):
lists = [transforms.RandomCrop(input_size[1], padding=0), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])]
if cutout > 0 : lists += [CUTOUT(cutout)]
@ -86,9 +97,12 @@ def get_datasets(name, root, input_size, cutout=-1):
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
assert len(train_data) == 50000 and len(test_data) == 10000
elif name == 'aircraft':
train_data = dset.ImageFolder(osp.join(root, 'train_sorted_images'), train_transform)
test_data = dset.ImageFolder(osp.join(root, 'test_sorted_images'), test_transform)
elif name.startswith('imagenet-1k'):
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
test_data = dset.ImageFolder(osp.join(root, 'test'), test_transform)
elif name == 'ImageNet16':
root = osp.join(root, 'ImageNet16')
train_data = ImageNet16(root, True , train_transform)