add aircraft in utilities
This commit is contained in:
parent
4df5615380
commit
24f15ad0fe
@ -13,7 +13,8 @@ Dataset2Class = {'cifar10': 10,
|
||||
'ImageNet16' : 1000,
|
||||
'ImageNet16-120': 120,
|
||||
'ImageNet16-150': 150,
|
||||
'ImageNet16-200': 200}
|
||||
'ImageNet16-200': 200,
|
||||
'aircraft': 100}
|
||||
|
||||
class RandChannel(object):
|
||||
# randomly pick channels from input
|
||||
@ -46,6 +47,10 @@ def get_datasets(name, root, input_size, cutout=-1):
|
||||
elif name.startswith('ImageNet16'):
|
||||
mean = [0.481098, 0.45749, 0.407882]
|
||||
std = [0.247922, 0.240235, 0.255255]
|
||||
elif name == 'aircraft':
|
||||
mean = [0.4785, 0.5100, 0.5338]
|
||||
std = [0.1845, 0.1830, 0.2060]
|
||||
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
||||
@ -55,6 +60,12 @@ def get_datasets(name, root, input_size, cutout=-1):
|
||||
if cutout > 0 : lists += [CUTOUT(cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
elif name == 'aircraft':
|
||||
lists = [transforms.RandomCrop(input_size[1], padding=0), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||
if cutout > 0 : lists += [CUTOUT(cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
|
||||
elif name.startswith('ImageNet16'):
|
||||
lists = [transforms.RandomCrop(input_size[1], padding=0), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])]
|
||||
if cutout > 0 : lists += [CUTOUT(cutout)]
|
||||
@ -86,9 +97,12 @@ def get_datasets(name, root, input_size, cutout=-1):
|
||||
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
|
||||
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||
elif name == 'aircraft':
|
||||
train_data = dset.ImageFolder(osp.join(root, 'train_sorted_images'), train_transform)
|
||||
test_data = dset.ImageFolder(osp.join(root, 'test_sorted_images'), test_transform)
|
||||
elif name.startswith('imagenet-1k'):
|
||||
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
|
||||
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
|
||||
test_data = dset.ImageFolder(osp.join(root, 'test'), test_transform)
|
||||
elif name == 'ImageNet16':
|
||||
root = osp.join(root, 'ImageNet16')
|
||||
train_data = ImageNet16(root, True , train_transform)
|
||||
|
Loading…
Reference in New Issue
Block a user