Update xmisc with yaml
This commit is contained in:
		
							
								
								
									
										2
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							| @@ -32,7 +32,7 @@ jobs: | |||||||
|           echo $PWD ; ls |           echo $PWD ; ls | ||||||
|           python -m black ./exps -l 88 --check --diff --verbose |           python -m black ./exps -l 88 --check --diff --verbose | ||||||
|           python -m black ./tests -l 88 --check --diff --verbose |           python -m black ./tests -l 88 --check --diff --verbose | ||||||
|           python -m black ./xautodl/xlayers -l 88 --check --diff --verbose |           python -m black ./xautodl/x* -l 88 --check --diff --verbose | ||||||
|           python -m black ./xautodl/spaces -l 88 --check --diff --verbose |           python -m black ./xautodl/spaces -l 88 --check --diff --verbose | ||||||
|           python -m black ./xautodl/trade_models -l 88 --check --diff --verbose |           python -m black ./xautodl/trade_models -l 88 --check --diff --verbose | ||||||
|           python -m black ./xautodl/procedures -l 88 --check --diff --verbose |           python -m black ./xautodl/procedures -l 88 --check --diff --verbose | ||||||
|   | |||||||
							
								
								
									
										7
									
								
								configs/data.yaml/cifar10.test
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								configs/data.yaml/cifar10.test
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | |||||||
|  | class_or_func: CIFAR10 | ||||||
|  | module_path: torchvision.datasets | ||||||
|  | args: [] | ||||||
|  | kwargs: | ||||||
|  |   train: False | ||||||
|  |   download: True | ||||||
|  |   transform: null | ||||||
							
								
								
									
										7
									
								
								configs/data.yaml/cifar10.train
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								configs/data.yaml/cifar10.train
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | |||||||
|  | class_or_func: CIFAR10 | ||||||
|  | module_path: torchvision.datasets | ||||||
|  | args: [] | ||||||
|  | kwargs: | ||||||
|  |   train: True | ||||||
|  |   download: True | ||||||
|  |   transform: null | ||||||
| @@ -1,35 +1,28 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||||
|  | ##################################################### | ||||||
|  | # python exps/basic/xmain.py --save_dir outputs/x   # | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, time, torch, random, argparse | import sys, time, torch, random, argparse | ||||||
| from PIL import ImageFile |  | ||||||
|  |  | ||||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True |  | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  |  | ||||||
| from xautodl.datasets import get_datasets | lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||||
| from xautodl.config_utils import load_config, obtain_basic_args as obtain_args | print("LIB-DIR: {:}".format(lib_dir)) | ||||||
| from xautodl.procedures import ( | if str(lib_dir) not in sys.path: | ||||||
|     prepare_seed, |     sys.path.insert(0, str(lib_dir)) | ||||||
|     prepare_logger, |  | ||||||
|     save_checkpoint, | from xautodl.xmisc import nested_call_by_yaml | ||||||
|     copy_checkpoint, |  | ||||||
| ) |  | ||||||
| from xautodl.procedures import get_optim_scheduler, get_procedures |  | ||||||
| from xautodl.models import obtain_model |  | ||||||
| from xautodl.xmodels import obtain_model as obtain_xmodel |  | ||||||
| from xautodl.nas_infer_model import obtain_nas_infer_model |  | ||||||
| from xautodl.utils import get_model_infos |  | ||||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(args): | def main(args): | ||||||
|     assert torch.cuda.is_available(), "CUDA is not available." |  | ||||||
|     torch.backends.cudnn.enabled = True |     train_data = nested_call_by_yaml(args.train_data_config, args.data_path) | ||||||
|     torch.backends.cudnn.benchmark = True |     valid_data = nested_call_by_yaml(args.valid_data_config, args.data_path) | ||||||
|     # torch.backends.cudnn.deterministic = True |  | ||||||
|     # torch.set_num_threads(args.workers) |     import pdb | ||||||
|  |  | ||||||
|  |     pdb.set_trace() | ||||||
|  |  | ||||||
|     prepare_seed(args.rand_seed) |     prepare_seed(args.rand_seed) | ||||||
|     logger = prepare_logger(args) |     logger = prepare_logger(args) | ||||||
| @@ -290,5 +283,44 @@ def main(args): | |||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     args = obtain_args() |     parser = argparse.ArgumentParser( | ||||||
|  |         description="Train a model with a loss function.", | ||||||
|  |         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--save_dir", type=str, help="Folder to save checkpoints and log." | ||||||
|  |     ) | ||||||
|  |     parser.add_argument("--resume", type=str, help="Resume path.") | ||||||
|  |     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||||
|  |     parser.add_argument("--model_config", type=str, help="The path to the model config") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--optim_config", type=str, help="The path to the optimizer config" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--train_data_config", type=str, help="The dataset config path." | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--valid_data_config", type=str, help="The dataset config path." | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--data_path", type=str, help="The path to the dataset." | ||||||
|  |     ) | ||||||
|  |     parser.add_argument("--algorithm", type=str, help="The algorithm.") | ||||||
|  |     # Optimization options | ||||||
|  |     parser.add_argument("--batch_size", type=int, default=2, help="The batch size.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--workers", | ||||||
|  |         type=int, | ||||||
|  |         default=8, | ||||||
|  |         help="number of data loading workers (default: 8)", | ||||||
|  |     ) | ||||||
|  |     # Random Seed | ||||||
|  |     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||||
|  |  | ||||||
|  |     args = parser.parse_args() | ||||||
|  |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|  |         args.rand_seed = random.randint(1, 100000) | ||||||
|  |     if args.save_dir is None: | ||||||
|  |         raise ValueError("The save-path argument can not be None") | ||||||
|  |  | ||||||
|     main(args) |     main(args) | ||||||
|   | |||||||
							
								
								
									
										27
									
								
								scripts/experimental/train-vit.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								scripts/experimental/train-vit.sh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,27 @@ | |||||||
|  | #!/bin/bash | ||||||
|  | # bash ./scripts/experimental/train-vit.sh cifar10 -1 | ||||||
|  | echo script name: $0 | ||||||
|  | echo $# arguments | ||||||
|  | if [ "$#" -ne 2 ] ;then | ||||||
|  |   echo "Input illegal number of parameters " $# | ||||||
|  |   echo "Need 2 parameters for dataset and random-seed" | ||||||
|  |   exit 1 | ||||||
|  | fi | ||||||
|  | if [ "$TORCH_HOME" = "" ]; then | ||||||
|  |   echo "Must set TORCH_HOME envoriment variable for data dir saving" | ||||||
|  |   exit 1 | ||||||
|  | else | ||||||
|  |   echo "TORCH_HOME : $TORCH_HOME" | ||||||
|  | fi | ||||||
|  |  | ||||||
|  | dataset=$1 | ||||||
|  | rseed=$2 | ||||||
|  |  | ||||||
|  | save_dir=./outputs/${dataset}/vit-experimental | ||||||
|  |  | ||||||
|  | python --version | ||||||
|  |  | ||||||
|  | python ./exps/basic/xmain.py --save_dir ${save_dir} --rand_seed ${rseed} \ | ||||||
|  | 	--train_data_config ./configs/data.yaml/${dataset}.train \ | ||||||
|  | 	--valid_data_config ./configs/data.yaml/${dataset}.test \ | ||||||
|  | 	--data_path $TORCH_HOME/cifar.python | ||||||
| @@ -5,34 +5,41 @@ import torch.nn as nn | |||||||
| class ImageNetHEAD(nn.Sequential): | class ImageNetHEAD(nn.Sequential): | ||||||
|     def __init__(self, C, stride=2): |     def __init__(self, C, stride=2): | ||||||
|         super(ImageNetHEAD, self).__init__() |         super(ImageNetHEAD, self).__init__() | ||||||
|     self.add_module('conv1', nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False)) |         self.add_module( | ||||||
|     self.add_module('bn1'  , nn.BatchNorm2d(C // 2)) |             "conv1", | ||||||
|     self.add_module('relu1', nn.ReLU(inplace=True)) |             nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False), | ||||||
|     self.add_module('conv2', nn.Conv2d(C // 2, C, kernel_size=3, stride=stride, padding=1, bias=False)) |         ) | ||||||
|     self.add_module('bn2'  , nn.BatchNorm2d(C)) |         self.add_module("bn1", nn.BatchNorm2d(C // 2)) | ||||||
|  |         self.add_module("relu1", nn.ReLU(inplace=True)) | ||||||
|  |         self.add_module( | ||||||
|  |             "conv2", | ||||||
|  |             nn.Conv2d(C // 2, C, kernel_size=3, stride=stride, padding=1, bias=False), | ||||||
|  |         ) | ||||||
|  |         self.add_module("bn2", nn.BatchNorm2d(C)) | ||||||
|  |  | ||||||
|  |  | ||||||
| class CifarHEAD(nn.Sequential): | class CifarHEAD(nn.Sequential): | ||||||
|     def __init__(self, C): |     def __init__(self, C): | ||||||
|         super(CifarHEAD, self).__init__() |         super(CifarHEAD, self).__init__() | ||||||
|     self.add_module('conv', nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False)) |         self.add_module("conv", nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False)) | ||||||
|     self.add_module('bn', nn.BatchNorm2d(C)) |         self.add_module("bn", nn.BatchNorm2d(C)) | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuxiliaryHeadCIFAR(nn.Module): | class AuxiliaryHeadCIFAR(nn.Module): | ||||||
|  |  | ||||||
|     def __init__(self, C, num_classes): |     def __init__(self, C, num_classes): | ||||||
|         """assuming input size 8x8""" |         """assuming input size 8x8""" | ||||||
|         super(AuxiliaryHeadCIFAR, self).__init__() |         super(AuxiliaryHeadCIFAR, self).__init__() | ||||||
|         self.features = nn.Sequential( |         self.features = nn.Sequential( | ||||||
|             nn.ReLU(inplace=True), |             nn.ReLU(inplace=True), | ||||||
|       nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 |             nn.AvgPool2d( | ||||||
|  |                 5, stride=3, padding=0, count_include_pad=False | ||||||
|  |             ),  # image size = 2 x 2 | ||||||
|             nn.Conv2d(C, 128, 1, bias=False), |             nn.Conv2d(C, 128, 1, bias=False), | ||||||
|             nn.BatchNorm2d(128), |             nn.BatchNorm2d(128), | ||||||
|             nn.ReLU(inplace=True), |             nn.ReLU(inplace=True), | ||||||
|             nn.Conv2d(128, 768, 2, bias=False), |             nn.Conv2d(128, 768, 2, bias=False), | ||||||
|             nn.BatchNorm2d(768), |             nn.BatchNorm2d(768), | ||||||
|       nn.ReLU(inplace=True) |             nn.ReLU(inplace=True), | ||||||
|         ) |         ) | ||||||
|         self.classifier = nn.Linear(768, num_classes) |         self.classifier = nn.Linear(768, num_classes) | ||||||
|  |  | ||||||
| @@ -43,7 +50,6 @@ class AuxiliaryHeadCIFAR(nn.Module): | |||||||
|  |  | ||||||
|  |  | ||||||
| class AuxiliaryHeadImageNet(nn.Module): | class AuxiliaryHeadImageNet(nn.Module): | ||||||
|  |  | ||||||
|     def __init__(self, C, num_classes): |     def __init__(self, C, num_classes): | ||||||
|         """assuming input size 14x14""" |         """assuming input size 14x14""" | ||||||
|         super(AuxiliaryHeadImageNet, self).__init__() |         super(AuxiliaryHeadImageNet, self).__init__() | ||||||
| @@ -55,7 +61,7 @@ class AuxiliaryHeadImageNet(nn.Module): | |||||||
|             nn.ReLU(inplace=True), |             nn.ReLU(inplace=True), | ||||||
|             nn.Conv2d(128, 768, 2, bias=False), |             nn.Conv2d(128, 768, 2, bias=False), | ||||||
|             nn.BatchNorm2d(768), |             nn.BatchNorm2d(768), | ||||||
|       nn.ReLU(inplace=True) |             nn.ReLU(inplace=True), | ||||||
|         ) |         ) | ||||||
|         self.classifier = nn.Linear(768, num_classes) |         self.classifier = nn.Linear(768, num_classes) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,6 +1,8 @@ | |||||||
| ################################################## | ################################################## | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||||
| ################################################## | ###################################################################### | ||||||
|  | # This folder is deprecated, which is re-organized in "xalgorithms". # | ||||||
|  | ###################################################################### | ||||||
| from .starts import prepare_seed | from .starts import prepare_seed | ||||||
| from .starts import prepare_logger | from .starts import prepare_logger | ||||||
| from .starts import get_machine_info | from .starts import get_machine_info | ||||||
|   | |||||||
| @@ -47,7 +47,7 @@ class SuperSelfAttention(SuperModule): | |||||||
|         self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) |         self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) | ||||||
|  |  | ||||||
|         self.attn_drop = SuperDrop(attn_drop or 0.0, [-1, -1, -1, -1], recover=True) |         self.attn_drop = SuperDrop(attn_drop or 0.0, [-1, -1, -1, -1], recover=True) | ||||||
|         if proj_dim is None: |         if proj_dim is not None: | ||||||
|             self.proj = SuperLinear(input_dim, proj_dim) |             self.proj = SuperLinear(input_dim, proj_dim) | ||||||
|             self.proj_drop = SuperDropout(proj_drop or 0.0) |             self.proj_drop = SuperDropout(proj_drop or 0.0) | ||||||
|         else: |         else: | ||||||
|   | |||||||
							
								
								
									
										8
									
								
								xautodl/xmisc/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								xautodl/xmisc/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,8 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||||
|  | ##################################################### | ||||||
|  | from .module_utils import call_by_dict | ||||||
|  | from .module_utils import call_by_yaml | ||||||
|  | from .module_utils import nested_call_by_dict | ||||||
|  | from .module_utils import nested_call_by_yaml | ||||||
|  | from .yaml_utils import load_yaml | ||||||
							
								
								
									
										81
									
								
								xautodl/xmisc/module_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								xautodl/xmisc/module_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,81 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 # | ||||||
|  | ##################################################### | ||||||
|  | from typing import Union, Dict, Text, Any | ||||||
|  | import importlib | ||||||
|  |  | ||||||
|  | from .yaml_utils import load_yaml | ||||||
|  |  | ||||||
|  | CLS_FUNC_KEY = "class_or_func" | ||||||
|  | KEYS = (CLS_FUNC_KEY, "module_path", "args", "kwargs") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def has_key_words(xdict): | ||||||
|  |     if not isinstance(xdict, dict): | ||||||
|  |         return False | ||||||
|  |     key_set = set(KEYS) | ||||||
|  |     cur_set = set(xdict.keys()) | ||||||
|  |     return key_set.intersection(cur_set) == key_set | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_module_by_module_path(module_path): | ||||||
|  |     """Load the module from the path.""" | ||||||
|  |  | ||||||
|  |     if module_path.endswith(".py"): | ||||||
|  |         module_spec = importlib.util.spec_from_file_location("", module_path) | ||||||
|  |         module = importlib.util.module_from_spec(module_spec) | ||||||
|  |         module_spec.loader.exec_module(module) | ||||||
|  |     else: | ||||||
|  |         module = importlib.import_module(module_path) | ||||||
|  |  | ||||||
|  |     return module | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def call_by_dict(config: Dict[Text, Any], *args, **kwargs) -> object: | ||||||
|  |     """ | ||||||
|  |     get initialized instance with config | ||||||
|  |     Parameters | ||||||
|  |     ---------- | ||||||
|  |     config : a dictionary, such as: | ||||||
|  |             { | ||||||
|  |                 'cls_or_func': 'ClassName', | ||||||
|  |                 'args': list, | ||||||
|  |                 'kwargs': dict, | ||||||
|  |                 'model_path': a string indicating the path, | ||||||
|  |             } | ||||||
|  |     Returns | ||||||
|  |     ------- | ||||||
|  |     object: | ||||||
|  |         An initialized object based on the config info | ||||||
|  |     """ | ||||||
|  |     module = get_module_by_module_path(config["module_path"]) | ||||||
|  |     cls_or_func = getattr(module, config[CLS_FUNC_KEY]) | ||||||
|  |     args = tuple(list(config["args"]) + list(args)) | ||||||
|  |     kwargs = {**config["kwargs"], **kwargs} | ||||||
|  |     return cls_or_func(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def call_by_yaml(path, *args, **kwargs) -> object: | ||||||
|  |     config = load_yaml(path) | ||||||
|  |     return call_by_config(config, *args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def nested_call_by_dict(config: Union[Dict[Text, Any], Any], *args, **kwargs) -> object: | ||||||
|  |     """Similar to `call_by_dict`, but differently, the args may contain another dict needs to be called.""" | ||||||
|  |     if not has_key_words(config): | ||||||
|  |         return config | ||||||
|  |     module = get_module_by_module_path(config["module_path"]) | ||||||
|  |     cls_or_func = getattr(module, config[CLS_FUNC_KEY]) | ||||||
|  |     args = tuple(list(config["args"]) + list(args)) | ||||||
|  |     kwargs = {**config["kwargs"], **kwargs} | ||||||
|  |     # check whether there are nested special dict | ||||||
|  |     new_args = [nested_call_by_dict(x) for x in args] | ||||||
|  |     new_kwargs = {} | ||||||
|  |     for key, x in kwargs.items(): | ||||||
|  |         new_kwargs[key] = nested_call_by_dict(x) | ||||||
|  |     return cls_or_func(*new_args, **new_kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def nested_call_by_yaml(path, *args, **kwargs) -> object: | ||||||
|  |     config = load_yaml(path) | ||||||
|  |     return nested_call_by_dict(config, *args, **kwargs) | ||||||
							
								
								
									
										13
									
								
								xautodl/xmisc/yaml_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								xautodl/xmisc/yaml_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||||
|  | ##################################################### | ||||||
|  | import os | ||||||
|  | import yaml | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def load_yaml(path): | ||||||
|  |     if not os.path.isfile(path): | ||||||
|  |         raise ValueError("{:} is not a file.".format(path)) | ||||||
|  |     with open(path, "r") as stream: | ||||||
|  |         data = yaml.safe_load(stream) | ||||||
|  |     return data | ||||||
		Reference in New Issue
	
	Block a user