From 9bf0fa5f04c9ce7e81f624a1b3ccad19292d3cb5 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 10 Jun 2021 21:53:22 +0800 Subject: [PATCH] Update yaml configs --- configs/data.yaml/cifar10.test | 7 -- configs/data.yaml/cifar10.train | 7 -- configs/yaml.data/cifar10.test | 22 ++++ configs/yaml.data/cifar10.train | 30 ++++++ configs/yaml.loss/cross-entropy | 4 + configs/yaml.model/vit-cifar10.s0 | 4 + configs/yaml.opt/vit.cifar | 7 ++ exps/basic/xmain.py | 169 +++++++++--------------------- scripts/experimental/train-vit.sh | 10 +- tests/test_basic_space.py | 4 +- tests/test_import.py | 9 +- tests/test_super_att.py | 4 +- tests/test_super_container.py | 3 +- tests/test_super_rearrange.py | 1 - tests/test_super_vit.py | 37 ++++--- xautodl/xmisc/__init__.py | 4 + xautodl/xmisc/logger_utils.py | 49 +++++++++ xautodl/xmisc/module_utils.py | 29 +++-- xautodl/xmisc/scheduler_utils.py | 136 ++++++++++++++++++++++++ xautodl/xmisc/time_utils.py | 26 +++++ xautodl/xmisc/torch_utils.py | 26 +++++ 21 files changed, 410 insertions(+), 178 deletions(-) delete mode 100644 configs/data.yaml/cifar10.test delete mode 100644 configs/data.yaml/cifar10.train create mode 100644 configs/yaml.data/cifar10.test create mode 100644 configs/yaml.data/cifar10.train create mode 100644 configs/yaml.loss/cross-entropy create mode 100644 configs/yaml.model/vit-cifar10.s0 create mode 100644 configs/yaml.opt/vit.cifar create mode 100644 xautodl/xmisc/logger_utils.py create mode 100644 xautodl/xmisc/scheduler_utils.py create mode 100644 xautodl/xmisc/time_utils.py create mode 100644 xautodl/xmisc/torch_utils.py diff --git a/configs/data.yaml/cifar10.test b/configs/data.yaml/cifar10.test deleted file mode 100644 index 284a6ea..0000000 --- a/configs/data.yaml/cifar10.test +++ /dev/null @@ -1,7 +0,0 @@ -class_or_func: CIFAR10 -module_path: torchvision.datasets -args: [] -kwargs: - train: False - download: True - transform: null diff --git a/configs/data.yaml/cifar10.train b/configs/data.yaml/cifar10.train deleted file mode 100644 index 26d57f3..0000000 --- a/configs/data.yaml/cifar10.train +++ /dev/null @@ -1,7 +0,0 @@ -class_or_func: CIFAR10 -module_path: torchvision.datasets -args: [] -kwargs: - train: True - download: True - transform: null diff --git a/configs/yaml.data/cifar10.test b/configs/yaml.data/cifar10.test new file mode 100644 index 0000000..3b8e8d8 --- /dev/null +++ b/configs/yaml.data/cifar10.test @@ -0,0 +1,22 @@ +class_or_func: CIFAR10 +module_path: torchvision.datasets +args: [] +kwargs: + train: False + download: True + transform: + class_or_func: Compose + module_path: torchvision.transforms + args: + - + - class_or_func: ToTensor + module_path: torchvision.transforms + args: [] + kwargs: {} + - class_or_func: Normalize + module_path: torchvision.transforms + args: [] + kwargs: + mean: (0.491, 0.482, 0.447) + std: (0.247, 0.244, 0.262) + kwargs: {} diff --git a/configs/yaml.data/cifar10.train b/configs/yaml.data/cifar10.train new file mode 100644 index 0000000..f787228 --- /dev/null +++ b/configs/yaml.data/cifar10.train @@ -0,0 +1,30 @@ +class_or_func: CIFAR10 +module_path: torchvision.datasets +args: [] +kwargs: + train: True + download: True + transform: + class_or_func: Compose + module_path: torchvision.transforms + args: + - + - class_or_func: RandomHorizontalFlip + module_path: torchvision.transforms + args: [] + kwargs: {} + - class_or_func: RandomCrop + module_path: torchvision.transforms + args: [32] + kwargs: {padding: 4} + - class_or_func: ToTensor + module_path: torchvision.transforms + args: [] + kwargs: {} + - class_or_func: Normalize + module_path: torchvision.transforms + args: [] + kwargs: + mean: (0.491, 0.482, 0.447) + std: (0.247, 0.244, 0.262) + kwargs: {} diff --git a/configs/yaml.loss/cross-entropy b/configs/yaml.loss/cross-entropy new file mode 100644 index 0000000..7c86921 --- /dev/null +++ b/configs/yaml.loss/cross-entropy @@ -0,0 +1,4 @@ +class_or_func: CrossEntropyLoss +module_path: torch.nn +args: [] +kwargs: {} diff --git a/configs/yaml.model/vit-cifar10.s0 b/configs/yaml.model/vit-cifar10.s0 new file mode 100644 index 0000000..b88fafa --- /dev/null +++ b/configs/yaml.model/vit-cifar10.s0 @@ -0,0 +1,4 @@ +class_or_func: get_transformer +module_path: xautodl.xmodels.transformers +args: [vit-cifar10-p4-d4-h4-c32] +kwargs: {} diff --git a/configs/yaml.opt/vit.cifar b/configs/yaml.opt/vit.cifar new file mode 100644 index 0000000..d01e9ed --- /dev/null +++ b/configs/yaml.opt/vit.cifar @@ -0,0 +1,7 @@ +class_or_func: Adam +module_path: torch.optim +args: [] +kwargs: + betas: [0.9, 0.999] + weight_decay: 0.1 + amsgrad: False diff --git a/exps/basic/xmain.py b/exps/basic/xmain.py index dfb430c..9ee0a16 100644 --- a/exps/basic/xmain.py +++ b/exps/basic/xmain.py @@ -3,7 +3,7 @@ ##################################################### # python exps/basic/xmain.py --save_dir outputs/x # ##################################################### -import sys, time, torch, random, argparse +import os, sys, time, torch, random, argparse from copy import deepcopy from pathlib import Path @@ -12,24 +12,38 @@ print("LIB-DIR: {:}".format(lib_dir)) if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from xautodl.xmisc import nested_call_by_yaml +from xautodl import xmisc def main(args): - train_data = nested_call_by_yaml(args.train_data_config, args.data_path) - valid_data = nested_call_by_yaml(args.valid_data_config, args.data_path) + train_data = xmisc.nested_call_by_yaml(args.train_data_config, args.data_path) + valid_data = xmisc.nested_call_by_yaml(args.valid_data_config, args.data_path) + logger = xmisc.Logger(args.save_dir, prefix="seed-{:}-".format(args.rand_seed)) - import pdb - - pdb.set_trace() - - prepare_seed(args.rand_seed) - logger = prepare_logger(args) - - train_data, valid_data, xshape, class_num = get_datasets( - args.dataset, args.data_path, args.cutout_length + logger.log("Create the logger: {:}".format(logger)) + logger.log("Arguments : -------------------------------") + for name, value in args._get_kwargs(): + logger.log("{:16} : {:}".format(name, value)) + logger.log("Python Version : {:}".format(sys.version.replace("\n", " "))) + logger.log("PyTorch Version : {:}".format(torch.__version__)) + logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version())) + logger.log("CUDA available : {:}".format(torch.cuda.is_available())) + logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count())) + logger.log( + "CUDA_VISIBLE_DEVICES : {:}".format( + os.environ["CUDA_VISIBLE_DEVICES"] + if "CUDA_VISIBLE_DEVICES" in os.environ + else "None" + ) ) + logger.log("The training data is:\n{:}".format(train_data)) + logger.log("The validation data is:\n{:}".format(valid_data)) + + model = xmisc.nested_call_by_yaml(args.model_config) + logger.log("The model is:\n{:}".format(model)) + logger.log("The model size is {:.4f} M".format(xmisc.count_parameters(model))) + train_loader = torch.utils.data.DataLoader( train_data, batch_size=args.batch_size, @@ -44,100 +58,25 @@ def main(args): num_workers=args.workers, pin_memory=True, ) - # get configures - model_config = load_config(args.model_config, {"class_num": class_num}, logger) - optim_config = load_config(args.optim_config, {"class_num": class_num}, logger) - if args.model_source == "normal": - base_model = obtain_model(model_config) - elif args.model_source == "nas": - base_model = obtain_nas_infer_model(model_config, args.extra_model_path) - elif args.model_source == "autodl-searched": - base_model = obtain_model(model_config, args.extra_model_path) - elif args.model_source in ("x", "xmodel"): - base_model = obtain_xmodel(model_config) - else: - raise ValueError("invalid model-source : {:}".format(args.model_source)) - flop, param = get_model_infos(base_model, xshape) - logger.log("model ====>>>>:\n{:}".format(base_model)) - logger.log("model information : {:}".format(base_model.get_message())) - logger.log("-" * 50) - logger.log( - "Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G".format( - param, flop, flop / 1e3 - ) + logger.log("The training loader: {:}".format(train_loader)) + logger.log("The validation loader: {:}".format(valid_loader)) + optimizer = xmisc.nested_call_by_yaml( + args.optim_config, + model.parameters(), + lr=args.lr, + weight_decay=args.weight_decay, ) - logger.log("-" * 50) - logger.log("train_data : {:}".format(train_data)) - logger.log("valid_data : {:}".format(valid_data)) - optimizer, scheduler, criterion = get_optim_scheduler( - base_model.parameters(), optim_config - ) - logger.log("optimizer : {:}".format(optimizer)) - logger.log("scheduler : {:}".format(scheduler)) - logger.log("criterion : {:}".format(criterion)) + loss = xmisc.nested_call_by_yaml(args.loss_config) - last_info, model_base_path, model_best_path = ( - logger.path("info"), - logger.path("model"), - logger.path("best"), - ) - network, criterion = torch.nn.DataParallel(base_model).cuda(), criterion.cuda() + logger.log("The optimizer is:\n{:}".format(optimizer)) + logger.log("The loss is {:}".format(loss)) - if last_info.exists(): # automatically resume from previous checkpoint - logger.log( - "=> loading checkpoint of the last-info '{:}' start".format(last_info) - ) - last_infox = torch.load(last_info) - start_epoch = last_infox["epoch"] + 1 - last_checkpoint_path = last_infox["last_checkpoint"] - if not last_checkpoint_path.exists(): - logger.log( - "Does not find {:}, try another path".format(last_checkpoint_path) - ) - last_checkpoint_path = ( - last_info.parent - / last_checkpoint_path.parent.name - / last_checkpoint_path.name - ) - checkpoint = torch.load(last_checkpoint_path) - base_model.load_state_dict(checkpoint["base-model"]) - scheduler.load_state_dict(checkpoint["scheduler"]) - optimizer.load_state_dict(checkpoint["optimizer"]) - valid_accuracies = checkpoint["valid_accuracies"] - max_bytes = checkpoint["max_bytes"] - logger.log( - "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( - last_info, start_epoch - ) - ) - elif args.resume is not None: - assert Path(args.resume).exists(), "Can not find the resume file : {:}".format( - args.resume - ) - checkpoint = torch.load(args.resume) - start_epoch = checkpoint["epoch"] + 1 - base_model.load_state_dict(checkpoint["base-model"]) - scheduler.load_state_dict(checkpoint["scheduler"]) - optimizer.load_state_dict(checkpoint["optimizer"]) - valid_accuracies = checkpoint["valid_accuracies"] - max_bytes = checkpoint["max_bytes"] - logger.log( - "=> loading checkpoint from '{:}' start with {:}-th epoch.".format( - args.resume, start_epoch - ) - ) - elif args.init_model is not None: - assert Path( - args.init_model - ).exists(), "Can not find the initialization file : {:}".format(args.init_model) - checkpoint = torch.load(args.init_model) - base_model.load_state_dict(checkpoint["base-model"]) - start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {} - logger.log("=> initialize the model from {:}".format(args.init_model)) - else: - logger.log("=> do not find the last-info file : {:}".format(last_info)) - start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {} + model, loss = torch.nn.DataParallel(model).cuda(), loss.cuda() + + import pdb + + pdb.set_trace() train_func, valid_func = get_procedures(args.procedure) @@ -284,7 +223,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Train a model with a loss function.", + description="Train a classification model with a loss function.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( @@ -293,27 +232,21 @@ if __name__ == "__main__": parser.add_argument("--resume", type=str, help="Resume path.") parser.add_argument("--init_model", type=str, help="The initialization model path.") parser.add_argument("--model_config", type=str, help="The path to the model config") + parser.add_argument("--optim_config", type=str, help="The optimizer config file.") + parser.add_argument("--loss_config", type=str, help="The loss config file.") parser.add_argument( - "--optim_config", type=str, help="The path to the optimizer config" + "--train_data_config", type=str, help="The training dataset config path." ) parser.add_argument( - "--train_data_config", type=str, help="The dataset config path." - ) - parser.add_argument( - "--valid_data_config", type=str, help="The dataset config path." - ) - parser.add_argument( - "--data_path", type=str, help="The path to the dataset." + "--valid_data_config", type=str, help="The validation dataset config path." ) + parser.add_argument("--data_path", type=str, help="The path to the dataset.") parser.add_argument("--algorithm", type=str, help="The algorithm.") # Optimization options + parser.add_argument("--lr", type=float, help="The learning rate") + parser.add_argument("--weight_decay", type=float, help="The weight decay") parser.add_argument("--batch_size", type=int, default=2, help="The batch size.") - parser.add_argument( - "--workers", - type=int, - default=8, - help="number of data loading workers (default: 8)", - ) + parser.add_argument("--workers", type=int, default=4, help="The number of workers") # Random Seed parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") diff --git a/scripts/experimental/train-vit.sh b/scripts/experimental/train-vit.sh index 43831cc..4dc1772 100644 --- a/scripts/experimental/train-vit.sh +++ b/scripts/experimental/train-vit.sh @@ -22,6 +22,10 @@ save_dir=./outputs/${dataset}/vit-experimental python --version python ./exps/basic/xmain.py --save_dir ${save_dir} --rand_seed ${rseed} \ - --train_data_config ./configs/data.yaml/${dataset}.train \ - --valid_data_config ./configs/data.yaml/${dataset}.test \ - --data_path $TORCH_HOME/cifar.python + --train_data_config ./configs/yaml.data/${dataset}.train \ + --valid_data_config ./configs/yaml.data/${dataset}.test \ + --data_path $TORCH_HOME/cifar.python \ + --model_config ./configs/yaml.model/vit-cifar10.s0 \ + --optim_config ./configs/yaml.opt/vit.cifar \ + --loss_config ./configs/yaml.loss/cross-entropy \ + --lr 0.003 --weight_decay 0.3 diff --git a/tests/test_basic_space.py b/tests/test_basic_space.py index ce4b3e3..f0a7fab 100644 --- a/tests/test_basic_space.py +++ b/tests/test_basic_space.py @@ -3,10 +3,8 @@ ##################################################### # pytest tests/test_basic_space.py -s # ##################################################### -import sys, random +import random import unittest -import pytest -from pathlib import Path from xautodl.spaces import Categorical from xautodl.spaces import Continuous diff --git a/tests/test_import.py b/tests/test_import.py index 53f47c7..8b4e442 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -3,12 +3,6 @@ ##################################################### # pytest ./tests/test_import.py # ##################################################### -import os, sys, time, torch -import pickle -import tempfile -from pathlib import Path - - def test_import(): from xautodl import config_utils from xautodl import datasets @@ -19,6 +13,9 @@ def test_import(): from xautodl import spaces from xautodl import trade_models from xautodl import utils + from xautodl import xlayers + from xautodl import xmisc + from xautodl import xmmodels print("Check all imports done") diff --git a/tests/test_super_att.py b/tests/test_super_att.py index 6df7e33..8fbdb35 100644 --- a/tests/test_super_att.py +++ b/tests/test_super_att.py @@ -3,13 +3,11 @@ ##################################################### # pytest ./tests/test_super_att.py -s # ##################################################### -import sys, random +import random import unittest from parameterized import parameterized -from pathlib import Path import torch - from xautodl import spaces from xautodl.xlayers import super_core diff --git a/tests/test_super_container.py b/tests/test_super_container.py index 37e4523..a14f539 100644 --- a/tests/test_super_container.py +++ b/tests/test_super_container.py @@ -3,10 +3,9 @@ ##################################################### # pytest ./tests/test_super_container.py -s # ##################################################### -import sys, random +import random import unittest import pytest -from pathlib import Path import torch from xautodl import spaces diff --git a/tests/test_super_rearrange.py b/tests/test_super_rearrange.py index 3b86d37..df1862b 100644 --- a/tests/test_super_rearrange.py +++ b/tests/test_super_rearrange.py @@ -3,7 +3,6 @@ ##################################################### # pytest ./tests/test_super_rearrange.py -s # ##################################################### -import sys import unittest import torch diff --git a/tests/test_super_vit.py b/tests/test_super_vit.py index 1b5d390..05b13b1 100644 --- a/tests/test_super_vit.py +++ b/tests/test_super_vit.py @@ -3,8 +3,8 @@ ##################################################### # pytest ./tests/test_super_vit.py -s # ##################################################### -import sys import unittest +from parameterized import parameterized import torch from xautodl.xmodels import transformers @@ -16,25 +16,28 @@ class TestSuperViT(unittest.TestCase): def test_super_vit(self): model = transformers.get_transformer("vit-base-16") - tensor = torch.rand((16, 3, 224, 224)) + tensor = torch.rand((2, 3, 224, 224)) print("The tensor shape: {:}".format(tensor.shape)) # print(model) outs = model(tensor) print("The output tensor shape: {:}".format(outs.shape)) - def test_imagenet(self): - name2config = transformers.name2config - print("There are {:} models in total.".format(len(name2config))) - for name, config in name2config.items(): - if "cifar" in name: - tensor = torch.rand((16, 3, 32, 32)) - else: - tensor = torch.rand((16, 3, 224, 224)) - model = transformers.get_transformer(config) - outs = model(tensor) - size = count_parameters(model, "mb", True) - print( - "{:10s} : size={:.2f}MB, out-shape: {:}".format( - name, size, tuple(outs.shape) - ) + @parameterized.expand( + [ + ["vit-cifar10-p4-d4-h4-c32", 32], + ["vit-base-16", 224], + ["vit-large-16", 224], + ["vit-huge-14", 224], + ] + ) + def test_imagenet(self, name, resolution): + tensor = torch.rand((2, 3, resolution, resolution)) + config = transformers.name2config[name] + model = transformers.get_transformer(config) + outs = model(tensor) + size = count_parameters(model, "mb", True) + print( + "{:10s} : size={:.2f}MB, out-shape: {:}".format( + name, size, tuple(outs.shape) ) + ) diff --git a/xautodl/xmisc/__init__.py b/xautodl/xmisc/__init__.py index 76c963c..dad05e2 100644 --- a/xautodl/xmisc/__init__.py +++ b/xautodl/xmisc/__init__.py @@ -6,3 +6,7 @@ from .module_utils import call_by_yaml from .module_utils import nested_call_by_dict from .module_utils import nested_call_by_yaml from .yaml_utils import load_yaml + +from .torch_utils import count_parameters + +from .logger_utils import Logger diff --git a/xautodl/xmisc/logger_utils.py b/xautodl/xmisc/logger_utils.py new file mode 100644 index 0000000..2a8bbd5 --- /dev/null +++ b/xautodl/xmisc/logger_utils.py @@ -0,0 +1,49 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # +##################################################### +import sys +from pathlib import Path + +from .time_utils import time_for_file, time_string + + +class Logger: + """A logger used in xautodl.""" + + def __init__(self, root_dir, prefix="", log_time=True): + """Create a summary writer logging to log_dir.""" + self.root_dir = Path(root_dir) + self.log_dir = self.root_dir / "logs" + self.log_dir.mkdir(parents=True, exist_ok=True) + + self._prefix = prefix + self._log_time = log_time + self.logger_path = self.log_dir / "{:}{:}.log".format( + self._prefix, time_for_file() + ) + self._logger_file = open(self.logger_path, "w") + + @property + def logger(self): + return self._logger_file + + def log(self, string, save=True, stdout=False): + string = "{:} {:}".format(time_string(), string) if self._log_time else string + if stdout: + sys.stdout.write(string) + sys.stdout.flush() + else: + print(string) + if save: + self._logger_file.write("{:}\n".format(string)) + self._logger_file.flush() + + def close(self): + self._logger_file.close() + if self.writer is not None: + self.writer.close() + + def __repr__(self): + return "{name}(dir={log_dir}, prefix={_prefix}, log_time={_log_time})".format( + name=self.__class__.__name__, **self.__dict__ + ) diff --git a/xautodl/xmisc/module_utils.py b/xautodl/xmisc/module_utils.py index 93deb18..2357f37 100644 --- a/xautodl/xmisc/module_utils.py +++ b/xautodl/xmisc/module_utils.py @@ -62,18 +62,25 @@ def call_by_yaml(path, *args, **kwargs) -> object: def nested_call_by_dict(config: Union[Dict[Text, Any], Any], *args, **kwargs) -> object: """Similar to `call_by_dict`, but differently, the args may contain another dict needs to be called.""" - if not has_key_words(config): + if isinstance(config, list): + return [nested_call_by_dict(x) for x in config] + elif isinstance(config, tuple): + return (nested_call_by_dict(x) for x in config) + elif not isinstance(config, dict): return config - module = get_module_by_module_path(config["module_path"]) - cls_or_func = getattr(module, config[CLS_FUNC_KEY]) - args = tuple(list(config["args"]) + list(args)) - kwargs = {**config["kwargs"], **kwargs} - # check whether there are nested special dict - new_args = [nested_call_by_dict(x) for x in args] - new_kwargs = {} - for key, x in kwargs.items(): - new_kwargs[key] = nested_call_by_dict(x) - return cls_or_func(*new_args, **new_kwargs) + elif not has_key_words(config): + return {key: nested_call_by_dict(x) for x, key in config.items()} + else: + module = get_module_by_module_path(config["module_path"]) + cls_or_func = getattr(module, config[CLS_FUNC_KEY]) + args = tuple(list(config["args"]) + list(args)) + kwargs = {**config["kwargs"], **kwargs} + # check whether there are nested special dict + new_args = [nested_call_by_dict(x) for x in args] + new_kwargs = {} + for key, x in kwargs.items(): + new_kwargs[key] = nested_call_by_dict(x) + return cls_or_func(*new_args, **new_kwargs) def nested_call_by_yaml(path, *args, **kwargs) -> object: diff --git a/xautodl/xmisc/scheduler_utils.py b/xautodl/xmisc/scheduler_utils.py new file mode 100644 index 0000000..4dc5ade --- /dev/null +++ b/xautodl/xmisc/scheduler_utils.py @@ -0,0 +1,136 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # +##################################################### +from torch.optim.lr_scheduler import _LRScheduler + + +class CosineDecayWithWarmup(_LRScheduler): + r"""Set the learning rate of each parameter group using a cosine annealing + schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` + is the number of epochs since the last restart and :math:`T_{i}` is the number + of epochs between two warm restarts in SGDR: + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right) + When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. + When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. + Args: + optimizer (Optimizer): Wrapped optimizer. + T_0 (int): Number of iterations for the first restart. + T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1. + eta_min (float, optional): Minimum learning rate. Default: 0. + last_epoch (int, optional): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + """ + + def __init__( + self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False + ): + if T_0 <= 0 or not isinstance(T_0, int): + raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) + if T_mult < 1 or not isinstance(T_mult, int): + raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) + self.T_0 = T_0 + self.T_i = T_0 + self.T_mult = T_mult + self.eta_min = eta_min + + super(CosineDecayWithWarmup, self).__init__(optimizer, last_epoch, verbose) + + self.T_cur = self.last_epoch + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + UserWarning, + ) + + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos(math.pi * self.T_cur / self.T_i)) + / 2 + for base_lr in self.base_lrs + ] + + def step(self, epoch=None): + """Step could be called after every batch update + Example: + >>> scheduler = CosineDecayWithWarmup(optimizer, T_0, T_mult) + >>> iters = len(dataloader) + >>> for epoch in range(20): + >>> for i, sample in enumerate(dataloader): + >>> inputs, labels = sample['inputs'], sample['labels'] + >>> optimizer.zero_grad() + >>> outputs = net(inputs) + >>> loss = criterion(outputs, labels) + >>> loss.backward() + >>> optimizer.step() + >>> scheduler.step(epoch + i / iters) + This function can be called in an interleaved way. + Example: + >>> scheduler = CosineDecayWithWarmup(optimizer, T_0, T_mult) + >>> for epoch in range(20): + >>> scheduler.step() + >>> scheduler.step(26) + >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) + """ + + if epoch is None and self.last_epoch < 0: + epoch = 0 + + if epoch is None: + epoch = self.last_epoch + 1 + self.T_cur = self.T_cur + 1 + if self.T_cur >= self.T_i: + self.T_cur = self.T_cur - self.T_i + self.T_i = self.T_i * self.T_mult + else: + if epoch < 0: + raise ValueError( + "Expected non-negative epoch, but got {}".format(epoch) + ) + if epoch >= self.T_0: + if self.T_mult == 1: + self.T_cur = epoch % self.T_0 + else: + n = int( + math.log( + (epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult + ) + ) + self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / ( + self.T_mult - 1 + ) + self.T_i = self.T_0 * self.T_mult ** (n) + else: + self.T_i = self.T_0 + self.T_cur = epoch + self.last_epoch = math.floor(epoch) + + class _enable_get_lr_call: + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + return self + + with _enable_get_lr_call(self): + for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())): + param_group, lr = data + param_group["lr"] = lr + self.print_lr(self.verbose, i, lr, epoch) + + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] diff --git a/xautodl/xmisc/time_utils.py b/xautodl/xmisc/time_utils.py new file mode 100644 index 0000000..a55d9eb --- /dev/null +++ b/xautodl/xmisc/time_utils.py @@ -0,0 +1,26 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # +##################################################### +import time + + +def time_for_file(): + ISOTIMEFORMAT = "%d-%h-at-%H-%M-%S" + return "{:}".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time()))) + + +def time_string(): + ISOTIMEFORMAT = "%Y-%m-%d %X" + string = "[{:}]".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time()))) + return string + + +def convert_secs2time(epoch_time, return_str=False): + need_hour = int(epoch_time / 3600) + need_mins = int((epoch_time - 3600 * need_hour) / 60) + need_secs = int(epoch_time - 3600 * need_hour - 60 * need_mins) + if return_str: + str = "[{:02d}:{:02d}:{:02d}]".format(need_hour, need_mins, need_secs) + return str + else: + return need_hour, need_mins, need_secs diff --git a/xautodl/xmisc/torch_utils.py b/xautodl/xmisc/torch_utils.py new file mode 100644 index 0000000..f873a6e --- /dev/null +++ b/xautodl/xmisc/torch_utils.py @@ -0,0 +1,26 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # +##################################################### +import torch +import torch.nn as nn +import numpy as np + + +def count_parameters(model_or_parameters, unit="mb"): + if isinstance(model_or_parameters, nn.Module): + counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters()) + elif isinstance(model_or_parameters, nn.Parameter): + counts = models_or_parameters.numel() + elif isinstance(model_or_parameters, (list, tuple)): + counts = sum(count_parameters(x, None) for x in models_or_parameters) + else: + counts = sum(np.prod(v.size()) for v in model_or_parameters) + if unit.lower() == "kb" or unit.lower() == "k": + counts /= 1e3 + elif unit.lower() == "mb" or unit.lower() == "m": + counts /= 1e6 + elif unit.lower() == "gb" or unit.lower() == "g": + counts /= 1e9 + elif unit is not None: + raise ValueError("Unknow unit: {:}".format(unit)) + return counts