Refine lib -> xautodl
This commit is contained in:
		
							
								
								
									
										91
									
								
								exps/TAS/prepare.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								exps/TAS/prepare.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,91 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 # | ||||
| ##################################################### | ||||
| # python exps/prepare.py --name cifar10     --root $TORCH_HOME/cifar.python --save ./data/cifar10.split.pth | ||||
| # python exps/prepare.py --name cifar100    --root $TORCH_HOME/cifar.python --save ./data/cifar100.split.pth | ||||
| # python exps/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012   --save ./data/imagenet-1k.split.pth | ||||
| ##################################################### | ||||
| import sys, time, torch, random, argparse | ||||
| from collections import defaultdict | ||||
| import os.path as osp | ||||
| from PIL import ImageFile | ||||
|  | ||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
| import torchvision | ||||
| import torchvision.datasets as dset | ||||
|  | ||||
| parser = argparse.ArgumentParser( | ||||
|     description="Prepare splits for searching", | ||||
|     formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
| ) | ||||
| parser.add_argument("--name", type=str, help="The dataset name.") | ||||
| parser.add_argument("--root", type=str, help="The directory to the dataset.") | ||||
| parser.add_argument("--save", type=str, help="The save path.") | ||||
| parser.add_argument("--ratio", type=float, help="The save path.") | ||||
| args = parser.parse_args() | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     save_path = Path(args.save) | ||||
|     save_dir = save_path.parent | ||||
|     name = args.name | ||||
|     save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     assert not save_path.exists(), "{:} already exists".format(save_path) | ||||
|     print("torchvision version : {:}".format(torchvision.__version__)) | ||||
|  | ||||
|     if name == "cifar10": | ||||
|         dataset = dset.CIFAR10(args.root, train=True, download=True) | ||||
|     elif name == "cifar100": | ||||
|         dataset = dset.CIFAR100(args.root, train=True, download=True) | ||||
|     elif name == "imagenet-1k": | ||||
|         dataset = dset.ImageFolder(osp.join(args.root, "train")) | ||||
|     else: | ||||
|         raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|  | ||||
|     if hasattr(dataset, "targets"): | ||||
|         targets = dataset.targets | ||||
|     elif hasattr(dataset, "train_labels"): | ||||
|         targets = dataset.train_labels | ||||
|     elif hasattr(dataset, "imgs"): | ||||
|         targets = [x[1] for x in dataset.imgs] | ||||
|     else: | ||||
|         raise ValueError("invalid pattern") | ||||
|     print("There are {:} samples in this dataset.".format(len(targets))) | ||||
|  | ||||
|     class2index = defaultdict(list) | ||||
|     train, valid = [], [] | ||||
|     random.seed(111) | ||||
|     for index, cls in enumerate(targets): | ||||
|         class2index[cls].append(index) | ||||
|     classes = sorted(list(class2index.keys())) | ||||
|     for cls in classes: | ||||
|         xlist = class2index[cls] | ||||
|         xtrain = random.sample(xlist, int(len(xlist) * args.ratio)) | ||||
|         xvalid = list(set(xlist) - set(xtrain)) | ||||
|         train += xtrain | ||||
|         valid += xvalid | ||||
|     train.sort() | ||||
|     valid.sort() | ||||
|     ## for statistics | ||||
|     class2numT, class2numV = defaultdict(int), defaultdict(int) | ||||
|     for index in train: | ||||
|         class2numT[targets[index]] += 1 | ||||
|     for index in valid: | ||||
|         class2numV[targets[index]] += 1 | ||||
|     class2numT, class2numV = dict(class2numT), dict(class2numV) | ||||
|     torch.save( | ||||
|         { | ||||
|             "train": train, | ||||
|             "valid": valid, | ||||
|             "class2numTrain": class2numT, | ||||
|             "class2numValid": class2numV, | ||||
|         }, | ||||
|         save_path, | ||||
|     ) | ||||
|     print("-" * 80) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
		Reference in New Issue
	
	Block a user