Fix small bugs
This commit is contained in:
		| @@ -60,14 +60,14 @@ def get_datasets(name, root, cutout): | ||||
|   else: raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|  | ||||
|   if name == 'cifar10': | ||||
|     train_data = dset.CIFAR10(root, train=True , transform=train_transform, download=True) | ||||
|     test_data  = dset.CIFAR10(root, train=False, transform=test_transform , download=True) | ||||
|     train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True) | ||||
|     test_data  = dset.CIFAR10 (root, train=False, transform=test_transform , download=True) | ||||
|   elif name == 'cifar100': | ||||
|     train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True) | ||||
|     test_data  = dset.CIFAR100(root, train=False, transform=test_transform , download=True) | ||||
|   elif name == 'imagenet-1k' or name == 'imagenet-100': | ||||
|     train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform) | ||||
|     test_data  = dset.ImageFolder(osp.join(root, 'val'), train_transform) | ||||
|     test_data  = dset.ImageFolder(osp.join(root, 'val'),   test_transform) | ||||
|   else: raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|    | ||||
|   class_num = Dataset2Class[name] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user