# python exps/prepare.py --name cifar10     --root $TORCH_HOME/cifar.python --save ./data/cifar10.split.pth
# python exps/prepare.py --name cifar100    --root $TORCH_HOME/cifar.python --save ./data/cifar100.split.pth
# python exps/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012   --save ./data/imagenet-1k.split.pth
import sys, time, torch, random, argparse
from collections import defaultdict
import os.path as osp
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True
from copy import deepcopy
from pathlib import Path
import torchvision
import torchvision.datasets as dset

lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
    sys.path.insert(0, str(lib_dir))
parser = argparse.ArgumentParser(
    description="Prepare splits for searching",
    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--name", type=str, help="The dataset name.")
parser.add_argument("--root", type=str, help="The directory to the dataset.")
parser.add_argument("--save", type=str, help="The save path.")
parser.add_argument("--ratio", type=float, help="The save path.")
args = parser.parse_args()


def main():
    save_path = Path(args.save)
    save_dir = save_path.parent
    name = args.name
    save_dir.mkdir(parents=True, exist_ok=True)
    assert not save_path.exists(), "{:} already exists".format(save_path)
    print("torchvision version : {:}".format(torchvision.__version__))

    if name == "cifar10":
        dataset = dset.CIFAR10(args.root, train=True)
    elif name == "cifar100":
        dataset = dset.CIFAR100(args.root, train=True)
    elif name == "imagenet-1k":
        dataset = dset.ImageFolder(osp.join(args.root, "train"))
    else:
        raise TypeError("Unknow dataset : {:}".format(name))

    if hasattr(dataset, "targets"):
        targets = dataset.targets
    elif hasattr(dataset, "train_labels"):
        targets = dataset.train_labels
    elif hasattr(dataset, "imgs"):
        targets = [x[1] for x in dataset.imgs]
    else:
        raise ValueError("invalid pattern")
    print("There are {:} samples in this dataset.".format(len(targets)))

    class2index = defaultdict(list)
    train, valid = [], []
    random.seed(111)
    for index, cls in enumerate(targets):
        class2index[cls].append(index)
    classes = sorted(list(class2index.keys()))
    for cls in classes:
        xlist = class2index[cls]
        xtrain = random.sample(xlist, int(len(xlist) * args.ratio))
        xvalid = list(set(xlist) - set(xtrain))
        train += xtrain
        valid += xvalid
    train.sort()
    valid.sort()
    ## for statistics
    class2numT, class2numV = defaultdict(int), defaultdict(int)
    for index in train:
        class2numT[targets[index]] += 1
    for index in valid:
        class2numV[targets[index]] += 1
    class2numT, class2numV = dict(class2numT), dict(class2numV)
    torch.save(
        {
            "train": train,
            "valid": valid,
            "class2numTrain": class2numT,
            "class2numValid": class2numV,
        },
        save_path,
    )
    print("-" * 80)


if __name__ == "__main__":
    main()