Update yaml configs
This commit is contained in:
		| @@ -1,7 +0,0 @@ | ||||
| class_or_func: CIFAR10 | ||||
| module_path: torchvision.datasets | ||||
| args: [] | ||||
| kwargs: | ||||
|   train: False | ||||
|   download: True | ||||
|   transform: null | ||||
| @@ -1,7 +0,0 @@ | ||||
| class_or_func: CIFAR10 | ||||
| module_path: torchvision.datasets | ||||
| args: [] | ||||
| kwargs: | ||||
|   train: True | ||||
|   download: True | ||||
|   transform: null | ||||
							
								
								
									
										22
									
								
								configs/yaml.data/cifar10.test
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								configs/yaml.data/cifar10.test
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| class_or_func: CIFAR10 | ||||
| module_path: torchvision.datasets | ||||
| args: [] | ||||
| kwargs: | ||||
|   train: False | ||||
|   download: True | ||||
|   transform: | ||||
|     class_or_func: Compose | ||||
|     module_path: torchvision.transforms | ||||
|     args: | ||||
|       - | ||||
|         - class_or_func: ToTensor | ||||
|           module_path: torchvision.transforms | ||||
|           args: [] | ||||
|           kwargs: {} | ||||
|         - class_or_func: Normalize | ||||
|           module_path: torchvision.transforms | ||||
|           args: [] | ||||
|           kwargs: | ||||
|             mean: (0.491, 0.482, 0.447) | ||||
|             std: (0.247, 0.244, 0.262) | ||||
|     kwargs: {} | ||||
							
								
								
									
										30
									
								
								configs/yaml.data/cifar10.train
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								configs/yaml.data/cifar10.train
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | ||||
| class_or_func: CIFAR10 | ||||
| module_path: torchvision.datasets | ||||
| args: [] | ||||
| kwargs: | ||||
|   train: True | ||||
|   download: True | ||||
|   transform: | ||||
|     class_or_func: Compose | ||||
|     module_path: torchvision.transforms | ||||
|     args: | ||||
|       - | ||||
|         - class_or_func: RandomHorizontalFlip | ||||
|           module_path: torchvision.transforms | ||||
|           args: [] | ||||
|           kwargs: {} | ||||
|         - class_or_func: RandomCrop | ||||
|           module_path: torchvision.transforms | ||||
|           args: [32] | ||||
|           kwargs: {padding: 4} | ||||
|         - class_or_func: ToTensor | ||||
|           module_path: torchvision.transforms | ||||
|           args: [] | ||||
|           kwargs: {} | ||||
|         - class_or_func: Normalize | ||||
|           module_path: torchvision.transforms | ||||
|           args: [] | ||||
|           kwargs: | ||||
|             mean: (0.491, 0.482, 0.447) | ||||
|             std: (0.247, 0.244, 0.262) | ||||
|     kwargs: {} | ||||
							
								
								
									
										4
									
								
								configs/yaml.loss/cross-entropy
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								configs/yaml.loss/cross-entropy
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,4 @@ | ||||
| class_or_func: CrossEntropyLoss | ||||
| module_path: torch.nn | ||||
| args: [] | ||||
| kwargs: {} | ||||
							
								
								
									
										4
									
								
								configs/yaml.model/vit-cifar10.s0
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								configs/yaml.model/vit-cifar10.s0
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,4 @@ | ||||
| class_or_func: get_transformer | ||||
| module_path: xautodl.xmodels.transformers | ||||
| args: [vit-cifar10-p4-d4-h4-c32] | ||||
| kwargs: {} | ||||
							
								
								
									
										7
									
								
								configs/yaml.opt/vit.cifar
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								configs/yaml.opt/vit.cifar
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| class_or_func: Adam | ||||
| module_path: torch.optim | ||||
| args: [] | ||||
| kwargs: | ||||
|   betas: [0.9, 0.999] | ||||
|   weight_decay: 0.1 | ||||
|   amsgrad: False | ||||
| @@ -3,7 +3,7 @@ | ||||
| ##################################################### | ||||
| # python exps/basic/xmain.py --save_dir outputs/x   # | ||||
| ##################################################### | ||||
| import sys, time, torch, random, argparse | ||||
| import os, sys, time, torch, random, argparse | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| @@ -12,24 +12,38 @@ print("LIB-DIR: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from xautodl.xmisc import nested_call_by_yaml | ||||
| from xautodl import xmisc | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|  | ||||
|     train_data = nested_call_by_yaml(args.train_data_config, args.data_path) | ||||
|     valid_data = nested_call_by_yaml(args.valid_data_config, args.data_path) | ||||
|     train_data = xmisc.nested_call_by_yaml(args.train_data_config, args.data_path) | ||||
|     valid_data = xmisc.nested_call_by_yaml(args.valid_data_config, args.data_path) | ||||
|     logger = xmisc.Logger(args.save_dir, prefix="seed-{:}-".format(args.rand_seed)) | ||||
|  | ||||
|     import pdb | ||||
|  | ||||
|     pdb.set_trace() | ||||
|  | ||||
|     prepare_seed(args.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     train_data, valid_data, xshape, class_num = get_datasets( | ||||
|         args.dataset, args.data_path, args.cutout_length | ||||
|     logger.log("Create the logger: {:}".format(logger)) | ||||
|     logger.log("Arguments : -------------------------------") | ||||
|     for name, value in args._get_kwargs(): | ||||
|         logger.log("{:16} : {:}".format(name, value)) | ||||
|     logger.log("Python  Version  : {:}".format(sys.version.replace("\n", " "))) | ||||
|     logger.log("PyTorch Version  : {:}".format(torch.__version__)) | ||||
|     logger.log("cuDNN   Version  : {:}".format(torch.backends.cudnn.version())) | ||||
|     logger.log("CUDA available   : {:}".format(torch.cuda.is_available())) | ||||
|     logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count())) | ||||
|     logger.log( | ||||
|         "CUDA_VISIBLE_DEVICES : {:}".format( | ||||
|             os.environ["CUDA_VISIBLE_DEVICES"] | ||||
|             if "CUDA_VISIBLE_DEVICES" in os.environ | ||||
|             else "None" | ||||
|         ) | ||||
|     ) | ||||
|     logger.log("The training data is:\n{:}".format(train_data)) | ||||
|     logger.log("The validation data is:\n{:}".format(valid_data)) | ||||
|  | ||||
|     model = xmisc.nested_call_by_yaml(args.model_config) | ||||
|     logger.log("The model is:\n{:}".format(model)) | ||||
|     logger.log("The model size is {:.4f} M".format(xmisc.count_parameters(model))) | ||||
|  | ||||
|     train_loader = torch.utils.data.DataLoader( | ||||
|         train_data, | ||||
|         batch_size=args.batch_size, | ||||
| @@ -44,100 +58,25 @@ def main(args): | ||||
|         num_workers=args.workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     # get configures | ||||
|     model_config = load_config(args.model_config, {"class_num": class_num}, logger) | ||||
|     optim_config = load_config(args.optim_config, {"class_num": class_num}, logger) | ||||
|  | ||||
|     if args.model_source == "normal": | ||||
|         base_model = obtain_model(model_config) | ||||
|     elif args.model_source == "nas": | ||||
|         base_model = obtain_nas_infer_model(model_config, args.extra_model_path) | ||||
|     elif args.model_source == "autodl-searched": | ||||
|         base_model = obtain_model(model_config, args.extra_model_path) | ||||
|     elif args.model_source in ("x", "xmodel"): | ||||
|         base_model = obtain_xmodel(model_config) | ||||
|     else: | ||||
|         raise ValueError("invalid model-source : {:}".format(args.model_source)) | ||||
|     flop, param = get_model_infos(base_model, xshape) | ||||
|     logger.log("model ====>>>>:\n{:}".format(base_model)) | ||||
|     logger.log("model information : {:}".format(base_model.get_message())) | ||||
|     logger.log("-" * 50) | ||||
|     logger.log( | ||||
|         "Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G".format( | ||||
|             param, flop, flop / 1e3 | ||||
|     logger.log("The training loader: {:}".format(train_loader)) | ||||
|     logger.log("The validation loader: {:}".format(valid_loader)) | ||||
|     optimizer = xmisc.nested_call_by_yaml( | ||||
|         args.optim_config, | ||||
|         model.parameters(), | ||||
|         lr=args.lr, | ||||
|         weight_decay=args.weight_decay, | ||||
|     ) | ||||
|     ) | ||||
|     logger.log("-" * 50) | ||||
|     logger.log("train_data : {:}".format(train_data)) | ||||
|     logger.log("valid_data : {:}".format(valid_data)) | ||||
|     optimizer, scheduler, criterion = get_optim_scheduler( | ||||
|         base_model.parameters(), optim_config | ||||
|     ) | ||||
|     logger.log("optimizer  : {:}".format(optimizer)) | ||||
|     logger.log("scheduler  : {:}".format(scheduler)) | ||||
|     logger.log("criterion  : {:}".format(criterion)) | ||||
|     loss = xmisc.nested_call_by_yaml(args.loss_config) | ||||
|  | ||||
|     last_info, model_base_path, model_best_path = ( | ||||
|         logger.path("info"), | ||||
|         logger.path("model"), | ||||
|         logger.path("best"), | ||||
|     ) | ||||
|     network, criterion = torch.nn.DataParallel(base_model).cuda(), criterion.cuda() | ||||
|     logger.log("The optimizer is:\n{:}".format(optimizer)) | ||||
|     logger.log("The loss is {:}".format(loss)) | ||||
|  | ||||
|     if last_info.exists():  # automatically resume from previous checkpoint | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start".format(last_info) | ||||
|         ) | ||||
|         last_infox = torch.load(last_info) | ||||
|         start_epoch = last_infox["epoch"] + 1 | ||||
|         last_checkpoint_path = last_infox["last_checkpoint"] | ||||
|         if not last_checkpoint_path.exists(): | ||||
|             logger.log( | ||||
|                 "Does not find {:}, try another path".format(last_checkpoint_path) | ||||
|             ) | ||||
|             last_checkpoint_path = ( | ||||
|                 last_info.parent | ||||
|                 / last_checkpoint_path.parent.name | ||||
|                 / last_checkpoint_path.name | ||||
|             ) | ||||
|         checkpoint = torch.load(last_checkpoint_path) | ||||
|         base_model.load_state_dict(checkpoint["base-model"]) | ||||
|         scheduler.load_state_dict(checkpoint["scheduler"]) | ||||
|         optimizer.load_state_dict(checkpoint["optimizer"]) | ||||
|         valid_accuracies = checkpoint["valid_accuracies"] | ||||
|         max_bytes = checkpoint["max_bytes"] | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( | ||||
|                 last_info, start_epoch | ||||
|             ) | ||||
|         ) | ||||
|     elif args.resume is not None: | ||||
|         assert Path(args.resume).exists(), "Can not find the resume file : {:}".format( | ||||
|             args.resume | ||||
|         ) | ||||
|         checkpoint = torch.load(args.resume) | ||||
|         start_epoch = checkpoint["epoch"] + 1 | ||||
|         base_model.load_state_dict(checkpoint["base-model"]) | ||||
|         scheduler.load_state_dict(checkpoint["scheduler"]) | ||||
|         optimizer.load_state_dict(checkpoint["optimizer"]) | ||||
|         valid_accuracies = checkpoint["valid_accuracies"] | ||||
|         max_bytes = checkpoint["max_bytes"] | ||||
|         logger.log( | ||||
|             "=> loading checkpoint from '{:}' start with {:}-th epoch.".format( | ||||
|                 args.resume, start_epoch | ||||
|             ) | ||||
|         ) | ||||
|     elif args.init_model is not None: | ||||
|         assert Path( | ||||
|             args.init_model | ||||
|         ).exists(), "Can not find the initialization file : {:}".format(args.init_model) | ||||
|         checkpoint = torch.load(args.init_model) | ||||
|         base_model.load_state_dict(checkpoint["base-model"]) | ||||
|         start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {} | ||||
|         logger.log("=> initialize the model from {:}".format(args.init_model)) | ||||
|     else: | ||||
|         logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||
|         start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {} | ||||
|     model, loss = torch.nn.DataParallel(model).cuda(), loss.cuda() | ||||
|  | ||||
|     import pdb | ||||
|  | ||||
|     pdb.set_trace() | ||||
|  | ||||
|     train_func, valid_func = get_procedures(args.procedure) | ||||
|  | ||||
| @@ -284,7 +223,7 @@ def main(args): | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a model with a loss function.", | ||||
|         description="Train a classification model with a loss function.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
| @@ -293,27 +232,21 @@ if __name__ == "__main__": | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument("--model_config", type=str, help="The path to the model config") | ||||
|     parser.add_argument("--optim_config", type=str, help="The optimizer config file.") | ||||
|     parser.add_argument("--loss_config", type=str, help="The loss config file.") | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer config" | ||||
|         "--train_data_config", type=str, help="The training dataset config path." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--train_data_config", type=str, help="The dataset config path." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--valid_data_config", type=str, help="The dataset config path." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--data_path", type=str, help="The path to the dataset." | ||||
|         "--valid_data_config", type=str, help="The validation dataset config path." | ||||
|     ) | ||||
|     parser.add_argument("--data_path", type=str, help="The path to the dataset.") | ||||
|     parser.add_argument("--algorithm", type=str, help="The algorithm.") | ||||
|     # Optimization options | ||||
|     parser.add_argument("--lr", type=float, help="The learning rate") | ||||
|     parser.add_argument("--weight_decay", type=float, help="The weight decay") | ||||
|     parser.add_argument("--batch_size", type=int, default=2, help="The batch size.") | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=8, | ||||
|         help="number of data loading workers (default: 8)", | ||||
|     ) | ||||
|     parser.add_argument("--workers", type=int, default=4, help="The number of workers") | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|  | ||||
|   | ||||
| @@ -22,6 +22,10 @@ save_dir=./outputs/${dataset}/vit-experimental | ||||
| python --version | ||||
|  | ||||
| python ./exps/basic/xmain.py --save_dir ${save_dir} --rand_seed ${rseed} \ | ||||
| 	--train_data_config ./configs/data.yaml/${dataset}.train \ | ||||
| 	--valid_data_config ./configs/data.yaml/${dataset}.test \ | ||||
| 	--data_path $TORCH_HOME/cifar.python | ||||
| 	--train_data_config ./configs/yaml.data/${dataset}.train \ | ||||
| 	--valid_data_config ./configs/yaml.data/${dataset}.test \ | ||||
| 	--data_path $TORCH_HOME/cifar.python \ | ||||
| 	--model_config ./configs/yaml.model/vit-cifar10.s0 \ | ||||
| 	--optim_config ./configs/yaml.opt/vit.cifar \ | ||||
| 	--loss_config ./configs/yaml.loss/cross-entropy \ | ||||
| 	--lr 0.003 --weight_decay 0.3  | ||||
|   | ||||
| @@ -3,10 +3,8 @@ | ||||
| ##################################################### | ||||
| # pytest tests/test_basic_space.py -s               # | ||||
| ##################################################### | ||||
| import sys, random | ||||
| import random | ||||
| import unittest | ||||
| import pytest | ||||
| from pathlib import Path | ||||
|  | ||||
| from xautodl.spaces import Categorical | ||||
| from xautodl.spaces import Continuous | ||||
|   | ||||
| @@ -3,12 +3,6 @@ | ||||
| ##################################################### | ||||
| # pytest ./tests/test_import.py                     # | ||||
| ##################################################### | ||||
| import os, sys, time, torch | ||||
| import pickle | ||||
| import tempfile | ||||
| from pathlib import Path | ||||
|  | ||||
|  | ||||
| def test_import(): | ||||
|     from xautodl import config_utils | ||||
|     from xautodl import datasets | ||||
| @@ -19,6 +13,9 @@ def test_import(): | ||||
|     from xautodl import spaces | ||||
|     from xautodl import trade_models | ||||
|     from xautodl import utils | ||||
|  | ||||
|     from xautodl import xlayers | ||||
|     from xautodl import xmisc | ||||
|     from xautodl import xmmodels | ||||
|  | ||||
|     print("Check all imports done") | ||||
|   | ||||
| @@ -3,13 +3,11 @@ | ||||
| ##################################################### | ||||
| # pytest ./tests/test_super_att.py -s               # | ||||
| ##################################################### | ||||
| import sys, random | ||||
| import random | ||||
| import unittest | ||||
| from parameterized import parameterized | ||||
| from pathlib import Path | ||||
|  | ||||
| import torch | ||||
|  | ||||
| from xautodl import spaces | ||||
| from xautodl.xlayers import super_core | ||||
|  | ||||
|   | ||||
| @@ -3,10 +3,9 @@ | ||||
| ##################################################### | ||||
| # pytest ./tests/test_super_container.py -s         # | ||||
| ##################################################### | ||||
| import sys, random | ||||
| import random | ||||
| import unittest | ||||
| import pytest | ||||
| from pathlib import Path | ||||
|  | ||||
| import torch | ||||
| from xautodl import spaces | ||||
|   | ||||
| @@ -3,7 +3,6 @@ | ||||
| ##################################################### | ||||
| # pytest ./tests/test_super_rearrange.py -s         # | ||||
| ##################################################### | ||||
| import sys | ||||
| import unittest | ||||
|  | ||||
| import torch | ||||
|   | ||||
| @@ -3,8 +3,8 @@ | ||||
| ##################################################### | ||||
| # pytest ./tests/test_super_vit.py -s               # | ||||
| ##################################################### | ||||
| import sys | ||||
| import unittest | ||||
| from parameterized import parameterized | ||||
|  | ||||
| import torch | ||||
| from xautodl.xmodels import transformers | ||||
| @@ -16,20 +16,23 @@ class TestSuperViT(unittest.TestCase): | ||||
|  | ||||
|     def test_super_vit(self): | ||||
|         model = transformers.get_transformer("vit-base-16") | ||||
|         tensor = torch.rand((16, 3, 224, 224)) | ||||
|         tensor = torch.rand((2, 3, 224, 224)) | ||||
|         print("The tensor shape: {:}".format(tensor.shape)) | ||||
|         # print(model) | ||||
|         outs = model(tensor) | ||||
|         print("The output tensor shape: {:}".format(outs.shape)) | ||||
|  | ||||
|     def test_imagenet(self): | ||||
|         name2config = transformers.name2config | ||||
|         print("There are {:} models in total.".format(len(name2config))) | ||||
|         for name, config in name2config.items(): | ||||
|             if "cifar" in name: | ||||
|                 tensor = torch.rand((16, 3, 32, 32)) | ||||
|             else: | ||||
|                 tensor = torch.rand((16, 3, 224, 224)) | ||||
|     @parameterized.expand( | ||||
|         [ | ||||
|             ["vit-cifar10-p4-d4-h4-c32", 32], | ||||
|             ["vit-base-16", 224], | ||||
|             ["vit-large-16", 224], | ||||
|             ["vit-huge-14", 224], | ||||
|         ] | ||||
|     ) | ||||
|     def test_imagenet(self, name, resolution): | ||||
|         tensor = torch.rand((2, 3, resolution, resolution)) | ||||
|         config = transformers.name2config[name] | ||||
|         model = transformers.get_transformer(config) | ||||
|         outs = model(tensor) | ||||
|         size = count_parameters(model, "mb", True) | ||||
|   | ||||
| @@ -6,3 +6,7 @@ from .module_utils import call_by_yaml | ||||
| from .module_utils import nested_call_by_dict | ||||
| from .module_utils import nested_call_by_yaml | ||||
| from .yaml_utils import load_yaml | ||||
|  | ||||
| from .torch_utils import count_parameters | ||||
|  | ||||
| from .logger_utils import Logger | ||||
|   | ||||
							
								
								
									
										49
									
								
								xautodl/xmisc/logger_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								xautodl/xmisc/logger_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,49 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||
| ##################################################### | ||||
| import sys | ||||
| from pathlib import Path | ||||
|  | ||||
| from .time_utils import time_for_file, time_string | ||||
|  | ||||
|  | ||||
| class Logger: | ||||
|     """A logger used in xautodl.""" | ||||
|  | ||||
|     def __init__(self, root_dir, prefix="", log_time=True): | ||||
|         """Create a summary writer logging to log_dir.""" | ||||
|         self.root_dir = Path(root_dir) | ||||
|         self.log_dir = self.root_dir / "logs" | ||||
|         self.log_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|         self._prefix = prefix | ||||
|         self._log_time = log_time | ||||
|         self.logger_path = self.log_dir / "{:}{:}.log".format( | ||||
|             self._prefix, time_for_file() | ||||
|         ) | ||||
|         self._logger_file = open(self.logger_path, "w") | ||||
|  | ||||
|     @property | ||||
|     def logger(self): | ||||
|         return self._logger_file | ||||
|  | ||||
|     def log(self, string, save=True, stdout=False): | ||||
|         string = "{:} {:}".format(time_string(), string) if self._log_time else string | ||||
|         if stdout: | ||||
|             sys.stdout.write(string) | ||||
|             sys.stdout.flush() | ||||
|         else: | ||||
|             print(string) | ||||
|         if save: | ||||
|             self._logger_file.write("{:}\n".format(string)) | ||||
|             self._logger_file.flush() | ||||
|  | ||||
|     def close(self): | ||||
|         self._logger_file.close() | ||||
|         if self.writer is not None: | ||||
|             self.writer.close() | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(dir={log_dir}, prefix={_prefix}, log_time={_log_time})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
| @@ -62,8 +62,15 @@ def call_by_yaml(path, *args, **kwargs) -> object: | ||||
|  | ||||
| def nested_call_by_dict(config: Union[Dict[Text, Any], Any], *args, **kwargs) -> object: | ||||
|     """Similar to `call_by_dict`, but differently, the args may contain another dict needs to be called.""" | ||||
|     if not has_key_words(config): | ||||
|     if isinstance(config, list): | ||||
|         return [nested_call_by_dict(x) for x in config] | ||||
|     elif isinstance(config, tuple): | ||||
|         return (nested_call_by_dict(x) for x in config) | ||||
|     elif not isinstance(config, dict): | ||||
|         return config | ||||
|     elif not has_key_words(config): | ||||
|         return {key: nested_call_by_dict(x) for x, key in config.items()} | ||||
|     else: | ||||
|         module = get_module_by_module_path(config["module_path"]) | ||||
|         cls_or_func = getattr(module, config[CLS_FUNC_KEY]) | ||||
|         args = tuple(list(config["args"]) + list(args)) | ||||
|   | ||||
							
								
								
									
										136
									
								
								xautodl/xmisc/scheduler_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								xautodl/xmisc/scheduler_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,136 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||
| ##################################################### | ||||
| from torch.optim.lr_scheduler import _LRScheduler | ||||
|  | ||||
|  | ||||
| class CosineDecayWithWarmup(_LRScheduler): | ||||
|     r"""Set the learning rate of each parameter group using a cosine annealing | ||||
|     schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` | ||||
|     is the number of epochs since the last restart and :math:`T_{i}` is the number | ||||
|     of epochs between two warm restarts in SGDR: | ||||
|     .. math:: | ||||
|         \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + | ||||
|         \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right) | ||||
|     When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. | ||||
|     When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. | ||||
|     It has been proposed in | ||||
|     `SGDR: Stochastic Gradient Descent with Warm Restarts`_. | ||||
|     Args: | ||||
|         optimizer (Optimizer): Wrapped optimizer. | ||||
|         T_0 (int): Number of iterations for the first restart. | ||||
|         T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1. | ||||
|         eta_min (float, optional): Minimum learning rate. Default: 0. | ||||
|         last_epoch (int, optional): The index of last epoch. Default: -1. | ||||
|         verbose (bool): If ``True``, prints a message to stdout for | ||||
|             each update. Default: ``False``. | ||||
|     .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: | ||||
|         https://arxiv.org/abs/1608.03983 | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False | ||||
|     ): | ||||
|         if T_0 <= 0 or not isinstance(T_0, int): | ||||
|             raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) | ||||
|         if T_mult < 1 or not isinstance(T_mult, int): | ||||
|             raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) | ||||
|         self.T_0 = T_0 | ||||
|         self.T_i = T_0 | ||||
|         self.T_mult = T_mult | ||||
|         self.eta_min = eta_min | ||||
|  | ||||
|         super(CosineDecayWithWarmup, self).__init__(optimizer, last_epoch, verbose) | ||||
|  | ||||
|         self.T_cur = self.last_epoch | ||||
|  | ||||
|     def get_lr(self): | ||||
|         if not self._get_lr_called_within_step: | ||||
|             warnings.warn( | ||||
|                 "To get the last learning rate computed by the scheduler, " | ||||
|                 "please use `get_last_lr()`.", | ||||
|                 UserWarning, | ||||
|             ) | ||||
|  | ||||
|         return [ | ||||
|             self.eta_min | ||||
|             + (base_lr - self.eta_min) | ||||
|             * (1 + math.cos(math.pi * self.T_cur / self.T_i)) | ||||
|             / 2 | ||||
|             for base_lr in self.base_lrs | ||||
|         ] | ||||
|  | ||||
|     def step(self, epoch=None): | ||||
|         """Step could be called after every batch update | ||||
|         Example: | ||||
|             >>> scheduler = CosineDecayWithWarmup(optimizer, T_0, T_mult) | ||||
|             >>> iters = len(dataloader) | ||||
|             >>> for epoch in range(20): | ||||
|             >>>     for i, sample in enumerate(dataloader): | ||||
|             >>>         inputs, labels = sample['inputs'], sample['labels'] | ||||
|             >>>         optimizer.zero_grad() | ||||
|             >>>         outputs = net(inputs) | ||||
|             >>>         loss = criterion(outputs, labels) | ||||
|             >>>         loss.backward() | ||||
|             >>>         optimizer.step() | ||||
|             >>>         scheduler.step(epoch + i / iters) | ||||
|         This function can be called in an interleaved way. | ||||
|         Example: | ||||
|             >>> scheduler = CosineDecayWithWarmup(optimizer, T_0, T_mult) | ||||
|             >>> for epoch in range(20): | ||||
|             >>>     scheduler.step() | ||||
|             >>> scheduler.step(26) | ||||
|             >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) | ||||
|         """ | ||||
|  | ||||
|         if epoch is None and self.last_epoch < 0: | ||||
|             epoch = 0 | ||||
|  | ||||
|         if epoch is None: | ||||
|             epoch = self.last_epoch + 1 | ||||
|             self.T_cur = self.T_cur + 1 | ||||
|             if self.T_cur >= self.T_i: | ||||
|                 self.T_cur = self.T_cur - self.T_i | ||||
|                 self.T_i = self.T_i * self.T_mult | ||||
|         else: | ||||
|             if epoch < 0: | ||||
|                 raise ValueError( | ||||
|                     "Expected non-negative epoch, but got {}".format(epoch) | ||||
|                 ) | ||||
|             if epoch >= self.T_0: | ||||
|                 if self.T_mult == 1: | ||||
|                     self.T_cur = epoch % self.T_0 | ||||
|                 else: | ||||
|                     n = int( | ||||
|                         math.log( | ||||
|                             (epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult | ||||
|                         ) | ||||
|                     ) | ||||
|                     self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / ( | ||||
|                         self.T_mult - 1 | ||||
|                     ) | ||||
|                     self.T_i = self.T_0 * self.T_mult ** (n) | ||||
|             else: | ||||
|                 self.T_i = self.T_0 | ||||
|                 self.T_cur = epoch | ||||
|         self.last_epoch = math.floor(epoch) | ||||
|  | ||||
|         class _enable_get_lr_call: | ||||
|             def __init__(self, o): | ||||
|                 self.o = o | ||||
|  | ||||
|             def __enter__(self): | ||||
|                 self.o._get_lr_called_within_step = True | ||||
|                 return self | ||||
|  | ||||
|             def __exit__(self, type, value, traceback): | ||||
|                 self.o._get_lr_called_within_step = False | ||||
|                 return self | ||||
|  | ||||
|         with _enable_get_lr_call(self): | ||||
|             for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())): | ||||
|                 param_group, lr = data | ||||
|                 param_group["lr"] = lr | ||||
|                 self.print_lr(self.verbose, i, lr, epoch) | ||||
|  | ||||
|         self._last_lr = [group["lr"] for group in self.optimizer.param_groups] | ||||
							
								
								
									
										26
									
								
								xautodl/xmisc/time_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								xautodl/xmisc/time_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||
| ##################################################### | ||||
| import time | ||||
|  | ||||
|  | ||||
| def time_for_file(): | ||||
|     ISOTIMEFORMAT = "%d-%h-at-%H-%M-%S" | ||||
|     return "{:}".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time()))) | ||||
|  | ||||
|  | ||||
| def time_string(): | ||||
|     ISOTIMEFORMAT = "%Y-%m-%d %X" | ||||
|     string = "[{:}]".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time()))) | ||||
|     return string | ||||
|  | ||||
|  | ||||
| def convert_secs2time(epoch_time, return_str=False): | ||||
|     need_hour = int(epoch_time / 3600) | ||||
|     need_mins = int((epoch_time - 3600 * need_hour) / 60) | ||||
|     need_secs = int(epoch_time - 3600 * need_hour - 60 * need_mins) | ||||
|     if return_str: | ||||
|         str = "[{:02d}:{:02d}:{:02d}]".format(need_hour, need_mins, need_secs) | ||||
|         return str | ||||
|     else: | ||||
|         return need_hour, need_mins, need_secs | ||||
							
								
								
									
										26
									
								
								xautodl/xmisc/torch_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								xautodl/xmisc/torch_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||
| ##################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| def count_parameters(model_or_parameters, unit="mb"): | ||||
|     if isinstance(model_or_parameters, nn.Module): | ||||
|         counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters()) | ||||
|     elif isinstance(model_or_parameters, nn.Parameter): | ||||
|         counts = models_or_parameters.numel() | ||||
|     elif isinstance(model_or_parameters, (list, tuple)): | ||||
|         counts = sum(count_parameters(x, None) for x in models_or_parameters) | ||||
|     else: | ||||
|         counts = sum(np.prod(v.size()) for v in model_or_parameters) | ||||
|     if unit.lower() == "kb" or unit.lower() == "k": | ||||
|         counts /= 1e3 | ||||
|     elif unit.lower() == "mb" or unit.lower() == "m": | ||||
|         counts /= 1e6 | ||||
|     elif unit.lower() == "gb" or unit.lower() == "g": | ||||
|         counts /= 1e9 | ||||
|     elif unit is not None: | ||||
|         raise ValueError("Unknow unit: {:}".format(unit)) | ||||
|     return counts | ||||
		Reference in New Issue
	
	Block a user