92 lines
3.2 KiB
Python
92 lines
3.2 KiB
Python
|
#####################################################
|
||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 #
|
||
|
#####################################################
|
||
|
# python exps/prepare.py --name cifar10 --root $TORCH_HOME/cifar.python --save ./data/cifar10.split.pth
|
||
|
# python exps/prepare.py --name cifar100 --root $TORCH_HOME/cifar.python --save ./data/cifar100.split.pth
|
||
|
# python exps/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012 --save ./data/imagenet-1k.split.pth
|
||
|
#####################################################
|
||
|
import sys, time, torch, random, argparse
|
||
|
from collections import defaultdict
|
||
|
import os.path as osp
|
||
|
from PIL import ImageFile
|
||
|
|
||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||
|
from copy import deepcopy
|
||
|
from pathlib import Path
|
||
|
import torchvision
|
||
|
import torchvision.datasets as dset
|
||
|
|
||
|
parser = argparse.ArgumentParser(
|
||
|
description="Prepare splits for searching",
|
||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||
|
)
|
||
|
parser.add_argument("--name", type=str, help="The dataset name.")
|
||
|
parser.add_argument("--root", type=str, help="The directory to the dataset.")
|
||
|
parser.add_argument("--save", type=str, help="The save path.")
|
||
|
parser.add_argument("--ratio", type=float, help="The save path.")
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
|
||
|
def main():
|
||
|
save_path = Path(args.save)
|
||
|
save_dir = save_path.parent
|
||
|
name = args.name
|
||
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||
|
assert not save_path.exists(), "{:} already exists".format(save_path)
|
||
|
print("torchvision version : {:}".format(torchvision.__version__))
|
||
|
|
||
|
if name == "cifar10":
|
||
|
dataset = dset.CIFAR10(args.root, train=True, download=True)
|
||
|
elif name == "cifar100":
|
||
|
dataset = dset.CIFAR100(args.root, train=True, download=True)
|
||
|
elif name == "imagenet-1k":
|
||
|
dataset = dset.ImageFolder(osp.join(args.root, "train"))
|
||
|
else:
|
||
|
raise TypeError("Unknow dataset : {:}".format(name))
|
||
|
|
||
|
if hasattr(dataset, "targets"):
|
||
|
targets = dataset.targets
|
||
|
elif hasattr(dataset, "train_labels"):
|
||
|
targets = dataset.train_labels
|
||
|
elif hasattr(dataset, "imgs"):
|
||
|
targets = [x[1] for x in dataset.imgs]
|
||
|
else:
|
||
|
raise ValueError("invalid pattern")
|
||
|
print("There are {:} samples in this dataset.".format(len(targets)))
|
||
|
|
||
|
class2index = defaultdict(list)
|
||
|
train, valid = [], []
|
||
|
random.seed(111)
|
||
|
for index, cls in enumerate(targets):
|
||
|
class2index[cls].append(index)
|
||
|
classes = sorted(list(class2index.keys()))
|
||
|
for cls in classes:
|
||
|
xlist = class2index[cls]
|
||
|
xtrain = random.sample(xlist, int(len(xlist) * args.ratio))
|
||
|
xvalid = list(set(xlist) - set(xtrain))
|
||
|
train += xtrain
|
||
|
valid += xvalid
|
||
|
train.sort()
|
||
|
valid.sort()
|
||
|
## for statistics
|
||
|
class2numT, class2numV = defaultdict(int), defaultdict(int)
|
||
|
for index in train:
|
||
|
class2numT[targets[index]] += 1
|
||
|
for index in valid:
|
||
|
class2numV[targets[index]] += 1
|
||
|
class2numT, class2numV = dict(class2numT), dict(class2numV)
|
||
|
torch.save(
|
||
|
{
|
||
|
"train": train,
|
||
|
"valid": valid,
|
||
|
"class2numTrain": class2numT,
|
||
|
"class2numValid": class2numV,
|
||
|
},
|
||
|
save_path,
|
||
|
)
|
||
|
print("-" * 80)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|