Refine lib -> xautodl
This commit is contained in:
		| @@ -7,8 +7,8 @@ In this paper, we proposed a differentiable searching strategy for transformable | |||||||
| You could see the highlight of our Transformable Architecture Search (TAS) at our [project page](https://xuanyidong.com/assets/projects/NeurIPS-2019-TAS.html). | You could see the highlight of our Transformable Architecture Search (TAS) at our [project page](https://xuanyidong.com/assets/projects/NeurIPS-2019-TAS.html). | ||||||
|  |  | ||||||
| <p float="left"> | <p float="left"> | ||||||
| <img src="https://d-x-y.github.com/resources/paper-icon/NIPS-2019-TAS.png" width="680px"/> | <img src="http://xuanyidong.com/resources/paper-icon/NIPS-2019-TAS.png" width="680px"/> | ||||||
| <img src="https://d-x-y.github.com/resources/videos/NeurIPS-2019-TAS/TAS-arch.gif?raw=true" width="180px"/> | <img src="http://xuanyidong.com/resources/videos/NeurIPS-2019-TAS/TAS-arch.gif?raw=true" width="180px"/> | ||||||
| </p> | </p> | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -24,7 +24,7 @@ We provide some logs at [Google Drive](https://drive.google.com/open?id=1_qUY4DT | |||||||
|  |  | ||||||
| ## Usage | ## Usage | ||||||
|  |  | ||||||
| Use `bash ./scripts/prepare.sh` to prepare data splits for `CIFAR-10`, `CIFARR-100`, and `ILSVRC2012`. | Use `bash ./scripts/TAS/prepare.sh` to prepare data splits for `CIFAR-10`, `CIFARR-100`, and `ILSVRC2012`. | ||||||
| If you do not have `ILSVRC2012` data, please comment L12 in `./scripts/prepare.sh`. | If you do not have `ILSVRC2012` data, please comment L12 in `./scripts/prepare.sh`. | ||||||
|  |  | ||||||
| args: `cifar10` indicates the dataset name, `ResNet56` indicates the basemodel name, `CIFARX` indicates the searching hyper-parameters, `0.47/0.57` indicates the expected FLOP ratio, `-1` indicates the random seed. | args: `cifar10` indicates the dataset name, `ResNet56` indicates the basemodel name, `CIFARX` indicates the searching hyper-parameters, `0.47/0.57` indicates the expected FLOP ratio, `-1` indicates the random seed. | ||||||
|   | |||||||
| @@ -27,8 +27,8 @@ from xautodl.datasets.synthetic_core import get_synthetic_env, EnvSampler | |||||||
| from xautodl.models.xcore import get_model | from xautodl.models.xcore import get_model | ||||||
| from xautodl.xlayers import super_core, trunc_normal_ | from xautodl.xlayers import super_core, trunc_normal_ | ||||||
|  |  | ||||||
| from xautodl.lfna_utils import lfna_setup, train_model, TimeData | from lfna_utils import lfna_setup, train_model, TimeData | ||||||
| from xautodl.lfna_meta_model import LFNA_Meta | from lfna_meta_model import LFNA_Meta | ||||||
|  |  | ||||||
|  |  | ||||||
| def epoch_train(loader, meta_model, base_model, optimizer, criterion, device, logger): | def epoch_train(loader, meta_model, base_model, optimizer, criterion, device, logger): | ||||||
|   | |||||||
| @@ -4,8 +4,8 @@ | |||||||
| import copy | import copy | ||||||
| import torch | import torch | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| from procedures import prepare_seed, prepare_logger | from xautodl.procedures import prepare_seed, prepare_logger | ||||||
| from datasets.synthetic_core import get_synthetic_env | from xautodl.datasets.synthetic_core import get_synthetic_env | ||||||
|  |  | ||||||
|  |  | ||||||
| def lfna_setup(args): | def lfna_setup(args): | ||||||
|   | |||||||
| @@ -665,7 +665,7 @@ if __name__ == "__main__": | |||||||
|                     len(args.datasets), len(args.xpaths), len(args.splits) |                     len(args.datasets), len(args.xpaths), len(args.splits) | ||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|         if args.workers <= 0: |         if args.workers < 0: | ||||||
|             raise ValueError("invalid number of workers : {:}".format(args.workers)) |             raise ValueError("invalid number of workers : {:}".format(args.workers)) | ||||||
|  |  | ||||||
|         target_indexes = filter_indexes( |         target_indexes = filter_indexes( | ||||||
| @@ -675,7 +675,7 @@ if __name__ == "__main__": | |||||||
|         assert torch.cuda.is_available(), "CUDA is not available." |         assert torch.cuda.is_available(), "CUDA is not available." | ||||||
|         torch.backends.cudnn.enabled = True |         torch.backends.cudnn.enabled = True | ||||||
|         torch.backends.cudnn.deterministic = True |         torch.backends.cudnn.deterministic = True | ||||||
|         torch.set_num_threads(args.workers) |         torch.set_num_threads(args.workers if args.workers > 0 else 1) | ||||||
|  |  | ||||||
|         main( |         main( | ||||||
|             save_dir, |             save_dir, | ||||||
|   | |||||||
| @@ -1,6 +1,10 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 # | ||||||
|  | ##################################################### | ||||||
| # python exps/prepare.py --name cifar10     --root $TORCH_HOME/cifar.python --save ./data/cifar10.split.pth | # python exps/prepare.py --name cifar10     --root $TORCH_HOME/cifar.python --save ./data/cifar10.split.pth | ||||||
| # python exps/prepare.py --name cifar100    --root $TORCH_HOME/cifar.python --save ./data/cifar100.split.pth | # python exps/prepare.py --name cifar100    --root $TORCH_HOME/cifar.python --save ./data/cifar100.split.pth | ||||||
| # python exps/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012   --save ./data/imagenet-1k.split.pth | # python exps/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012   --save ./data/imagenet-1k.split.pth | ||||||
|  | ##################################################### | ||||||
| import sys, time, torch, random, argparse | import sys, time, torch, random, argparse | ||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
| import os.path as osp | import os.path as osp | ||||||
| @@ -12,9 +16,6 @@ from pathlib import Path | |||||||
| import torchvision | import torchvision | ||||||
| import torchvision.datasets as dset | import torchvision.datasets as dset | ||||||
| 
 | 
 | ||||||
| lib_dir = (Path(__file__).parent / ".." / "lib").resolve() |  | ||||||
| if str(lib_dir) not in sys.path: |  | ||||||
|     sys.path.insert(0, str(lib_dir)) |  | ||||||
| parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||||
|     description="Prepare splits for searching", |     description="Prepare splits for searching", | ||||||
|     formatter_class=argparse.ArgumentDefaultsHelpFormatter, |     formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||||
| @@ -35,9 +36,9 @@ def main(): | |||||||
|     print("torchvision version : {:}".format(torchvision.__version__)) |     print("torchvision version : {:}".format(torchvision.__version__)) | ||||||
| 
 | 
 | ||||||
|     if name == "cifar10": |     if name == "cifar10": | ||||||
|         dataset = dset.CIFAR10(args.root, train=True) |         dataset = dset.CIFAR10(args.root, train=True, download=True) | ||||||
|     elif name == "cifar100": |     elif name == "cifar100": | ||||||
|         dataset = dset.CIFAR100(args.root, train=True) |         dataset = dset.CIFAR100(args.root, train=True, download=True) | ||||||
|     elif name == "imagenet-1k": |     elif name == "imagenet-1k": | ||||||
|         dataset = dset.ImageFolder(osp.join(args.root, "train")) |         dataset = dset.ImageFolder(osp.join(args.root, "train")) | ||||||
|     else: |     else: | ||||||
							
								
								
									
										13
									
								
								scripts/TAS/prepare.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								scripts/TAS/prepare.sh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | |||||||
|  | #!/bin/bash | ||||||
|  | # bash ./scripts/TAS/prepare.sh | ||||||
|  | #datasets='cifar10 cifar100 imagenet-1k' | ||||||
|  | #ratios='0.5 0.8 0.9' | ||||||
|  | ratios='0.5' | ||||||
|  | save_dir=./.latent-data/splits | ||||||
|  |  | ||||||
|  | for ratio in ${ratios} | ||||||
|  | do | ||||||
|  |   python ./exps/TAS/prepare.py --name cifar10  --root $TORCH_HOME/cifar.python  --save ${save_dir}/cifar10-${ratio}.pth --ratio ${ratio} | ||||||
|  |   python ./exps/TAS/prepare.py --name cifar100 --root $TORCH_HOME/cifar.python  --save ${save_dir}/cifar100-${ratio}.pth --ratio ${ratio} | ||||||
|  |   python ./exps/TAS/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012 --save ${save_dir}/imagenet-1k-${ratio}.pth --ratio ${ratio} | ||||||
|  | done | ||||||
| @@ -1,13 +0,0 @@ | |||||||
| #!/bin/bash |  | ||||||
| # bash ./scripts/prepare.sh |  | ||||||
| #datasets='cifar10 cifar100 imagenet-1k' |  | ||||||
| #ratios='0.5 0.8 0.9' |  | ||||||
| ratios='0.5' |  | ||||||
| save_dir=./.latent-data/splits |  | ||||||
|  |  | ||||||
| for ratio in ${ratios} |  | ||||||
| do |  | ||||||
|   python ./exps/prepare.py --name cifar10  --root $TORCH_HOME/cifar.python  --save ${save_dir}/cifar10-${ratio}.pth --ratio ${ratio} |  | ||||||
|   python ./exps/prepare.py --name cifar100 --root $TORCH_HOME/cifar.python  --save ${save_dir}/cifar100-${ratio}.pth --ratio ${ratio} |  | ||||||
|   python ./exps/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012 --save ${save_dir}/imagenet-1k-${ratio}.pth --ratio ${ratio} |  | ||||||
| done |  | ||||||
| @@ -6,7 +6,7 @@ import torch | |||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
|  |  | ||||||
| from models.cell_operations import OPS | from xautodl.models.cell_operations import OPS | ||||||
|  |  | ||||||
|  |  | ||||||
| # Cell for NAS-Bench-201 | # Cell for NAS-Bench-201 | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ | |||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
|  |  | ||||||
| from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR | from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -9,11 +9,11 @@ import torch | |||||||
| __all__ = ["get_model"] | __all__ = ["get_model"] | ||||||
|  |  | ||||||
|  |  | ||||||
| from xlayers.super_core import SuperSequential | from xautodl.xlayers.super_core import SuperSequential | ||||||
| from xlayers.super_core import SuperLinear | from xautodl.xlayers.super_core import SuperLinear | ||||||
| from xlayers.super_core import SuperDropout | from xautodl.xlayers.super_core import SuperDropout | ||||||
| from xlayers.super_core import super_name2norm | from xautodl.xlayers.super_core import super_name2norm | ||||||
| from xlayers.super_core import super_name2activation | from xautodl.xlayers.super_core import super_name2activation | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_model(config: Dict[Text, Any], **kwargs): | def get_model(config: Dict[Text, Any], **kwargs): | ||||||
|   | |||||||
| @@ -7,8 +7,7 @@ import os, sys, time, torch | |||||||
| from typing import Optional, Text, Callable | from typing import Optional, Text, Callable | ||||||
|  |  | ||||||
| # modules in AutoDL | # modules in AutoDL | ||||||
| from log_utils import AverageMeter | from xautodl.log_utils import AverageMeter, time_string | ||||||
| from log_utils import time_string |  | ||||||
| from .eval_funcs import obtain_accuracy | from .eval_funcs import obtain_accuracy | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -4,8 +4,7 @@ | |||||||
| import os, sys, time, torch | import os, sys, time, torch | ||||||
|  |  | ||||||
| # modules in AutoDL | # modules in AutoDL | ||||||
| from log_utils import AverageMeter | from xautodl.log_utils import AverageMeter, time_string | ||||||
| from log_utils import time_string |  | ||||||
| from .eval_funcs import obtain_accuracy | from .eval_funcs import obtain_accuracy | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -15,6 +15,6 @@ def obtain_accuracy(output, target, topk=(1,)): | |||||||
|  |  | ||||||
|     res = [] |     res = [] | ||||||
|     for k in topk: |     for k in topk: | ||||||
|         correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) |         correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) | ||||||
|         res.append(correct_k.mul_(100.0 / batch_size)) |         res.append(correct_k.mul_(100.0 / batch_size)) | ||||||
|     return res |     return res | ||||||
|   | |||||||
| @@ -4,7 +4,7 @@ | |||||||
| import os, time, copy, torch, pathlib | import os, time, copy, torch, pathlib | ||||||
|  |  | ||||||
| # modules in AutoDL | # modules in AutoDL | ||||||
| import xautodl.datasets | from xautodl import datasets | ||||||
| from xautodl.config_utils import load_config | from xautodl.config_utils import load_config | ||||||
| from xautodl.procedures import prepare_seed, get_optim_scheduler | from xautodl.procedures import prepare_seed, get_optim_scheduler | ||||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||||
|   | |||||||
| @@ -8,7 +8,6 @@ import pprint | |||||||
| import logging | import logging | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
|  |  | ||||||
| from log_utils import pickle_load |  | ||||||
| import qlib | import qlib | ||||||
| from qlib.utils import init_instance_by_config | from qlib.utils import init_instance_by_config | ||||||
| from qlib.workflow import R | from qlib.workflow import R | ||||||
|   | |||||||
| @@ -2,8 +2,9 @@ | |||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||||
| ################################################## | ################################################## | ||||||
| import os, sys, time, torch | import os, sys, time, torch | ||||||
| from log_utils import AverageMeter, time_string |  | ||||||
| from models import change_key | from xautodl.log_utils import AverageMeter, time_string | ||||||
|  | from xautodl.models import change_key | ||||||
|  |  | ||||||
| from .eval_funcs import obtain_accuracy | from .eval_funcs import obtain_accuracy | ||||||
|  |  | ||||||
|   | |||||||
| @@ -4,8 +4,8 @@ | |||||||
| import os, sys, time, torch | import os, sys, time, torch | ||||||
|  |  | ||||||
| # modules in AutoDL | # modules in AutoDL | ||||||
| from log_utils import AverageMeter, time_string | from xautodl.log_utils import AverageMeter, time_string | ||||||
| from models import change_key | from xautodl.models import change_key | ||||||
| from .eval_funcs import obtain_accuracy | from .eval_funcs import obtain_accuracy | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -5,7 +5,7 @@ import os, sys, time, torch | |||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
|  |  | ||||||
| # modules in AutoDL | # modules in AutoDL | ||||||
| from log_utils import AverageMeter, time_string | from xautodl.log_utils import AverageMeter, time_string | ||||||
| from .eval_funcs import obtain_accuracy | from .eval_funcs import obtain_accuracy | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -16,7 +16,7 @@ def prepare_seed(rand_seed): | |||||||
|  |  | ||||||
| def prepare_logger(xargs): | def prepare_logger(xargs): | ||||||
|     args = copy.deepcopy(xargs) |     args = copy.deepcopy(xargs) | ||||||
|     from log_utils import Logger |     from xautodl.log_utils import Logger | ||||||
|  |  | ||||||
|     logger = Logger(args.save_dir, args.rand_seed) |     logger = Logger(args.save_dir, args.rand_seed) | ||||||
|     logger.log("Main Function with logger : {:}".format(logger)) |     logger.log("Main Function with logger : {:}".format(logger)) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user