# python exps/prepare.py --name cifar10 --root $TORCH_HOME/cifar.python --save ./data/cifar10.split.pth # python exps/prepare.py --name cifar100 --root $TORCH_HOME/cifar.python --save ./data/cifar100.split.pth # python exps/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012 --save ./data/imagenet-1k.split.pth import sys, time, torch, random, argparse from collections import defaultdict import os.path as osp from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True from copy import deepcopy from pathlib import Path import torchvision import torchvision.datasets as dset lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) parser = argparse.ArgumentParser(description='Prepare splits for searching', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--name' , type=str, help='The dataset name.') parser.add_argument('--root' , type=str, help='The directory to the dataset.') parser.add_argument('--save' , type=str, help='The save path.') parser.add_argument('--ratio', type=float, help='The save path.') args = parser.parse_args() def main(): save_path = Path(args.save) save_dir = save_path.parent name = args.name save_dir.mkdir(parents=True, exist_ok=True) assert not save_path.exists(), '{:} already exists'.format(save_path) print ('torchvision version : {:}'.format(torchvision.__version__)) if name == 'cifar10': dataset = dset.CIFAR10 (args.root, train=True) elif name == 'cifar100': dataset = dset.CIFAR100(args.root, train=True) elif name == 'imagenet-1k': dataset = dset.ImageFolder(osp.join(args.root, 'train')) else: raise TypeError("Unknow dataset : {:}".format(name)) if hasattr(dataset, 'targets'): targets = dataset.targets elif hasattr(dataset, 'train_labels'): targets = dataset.train_labels elif hasattr(dataset, 'imgs'): targets = [x[1] for x in dataset.imgs] else: raise ValueError('invalid pattern') print ('There are {:} samples in this dataset.'.format( len(targets) )) class2index = defaultdict(list) train, valid = [], [] random.seed(111) for index, cls in enumerate(targets): class2index[cls].append( index ) classes = sorted( list(class2index.keys()) ) for cls in classes: xlist = class2index[cls] xtrain = random.sample(xlist, int(len(xlist)*args.ratio)) xvalid = list(set(xlist) - set(xtrain)) train += xtrain valid += xvalid train.sort() valid.sort() ## for statistics class2numT, class2numV = defaultdict(int), defaultdict(int) for index in train: class2numT[ targets[index] ] += 1 for index in valid: class2numV[ targets[index] ] += 1 class2numT, class2numV = dict(class2numT), dict(class2numV) torch.save({'train': train, 'valid': valid, 'class2numTrain': class2numT, 'class2numValid': class2numV}, save_path) print ('-'*80) if __name__ == '__main__': main()