Update yaml configs

This commit is contained in:
D-X-Y 2021-06-10 21:53:22 +08:00
parent 1a7440d2af
commit 9bf0fa5f04
21 changed files with 410 additions and 178 deletions

View File

@ -1,7 +0,0 @@
class_or_func: CIFAR10
module_path: torchvision.datasets
args: []
kwargs:
train: False
download: True
transform: null

View File

@ -1,7 +0,0 @@
class_or_func: CIFAR10
module_path: torchvision.datasets
args: []
kwargs:
train: True
download: True
transform: null

View File

@ -0,0 +1,22 @@
class_or_func: CIFAR10
module_path: torchvision.datasets
args: []
kwargs:
train: False
download: True
transform:
class_or_func: Compose
module_path: torchvision.transforms
args:
-
- class_or_func: ToTensor
module_path: torchvision.transforms
args: []
kwargs: {}
- class_or_func: Normalize
module_path: torchvision.transforms
args: []
kwargs:
mean: (0.491, 0.482, 0.447)
std: (0.247, 0.244, 0.262)
kwargs: {}

View File

@ -0,0 +1,30 @@
class_or_func: CIFAR10
module_path: torchvision.datasets
args: []
kwargs:
train: True
download: True
transform:
class_or_func: Compose
module_path: torchvision.transforms
args:
-
- class_or_func: RandomHorizontalFlip
module_path: torchvision.transforms
args: []
kwargs: {}
- class_or_func: RandomCrop
module_path: torchvision.transforms
args: [32]
kwargs: {padding: 4}
- class_or_func: ToTensor
module_path: torchvision.transforms
args: []
kwargs: {}
- class_or_func: Normalize
module_path: torchvision.transforms
args: []
kwargs:
mean: (0.491, 0.482, 0.447)
std: (0.247, 0.244, 0.262)
kwargs: {}

View File

@ -0,0 +1,4 @@
class_or_func: CrossEntropyLoss
module_path: torch.nn
args: []
kwargs: {}

View File

@ -0,0 +1,4 @@
class_or_func: get_transformer
module_path: xautodl.xmodels.transformers
args: [vit-cifar10-p4-d4-h4-c32]
kwargs: {}

View File

@ -0,0 +1,7 @@
class_or_func: Adam
module_path: torch.optim
args: []
kwargs:
betas: [0.9, 0.999]
weight_decay: 0.1
amsgrad: False

View File

@ -3,7 +3,7 @@
##################################################### #####################################################
# python exps/basic/xmain.py --save_dir outputs/x # # python exps/basic/xmain.py --save_dir outputs/x #
##################################################### #####################################################
import sys, time, torch, random, argparse import os, sys, time, torch, random, argparse
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
@ -12,24 +12,38 @@ print("LIB-DIR: {:}".format(lib_dir))
if str(lib_dir) not in sys.path: if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir)) sys.path.insert(0, str(lib_dir))
from xautodl.xmisc import nested_call_by_yaml from xautodl import xmisc
def main(args): def main(args):
train_data = nested_call_by_yaml(args.train_data_config, args.data_path) train_data = xmisc.nested_call_by_yaml(args.train_data_config, args.data_path)
valid_data = nested_call_by_yaml(args.valid_data_config, args.data_path) valid_data = xmisc.nested_call_by_yaml(args.valid_data_config, args.data_path)
logger = xmisc.Logger(args.save_dir, prefix="seed-{:}-".format(args.rand_seed))
import pdb logger.log("Create the logger: {:}".format(logger))
logger.log("Arguments : -------------------------------")
pdb.set_trace() for name, value in args._get_kwargs():
logger.log("{:16} : {:}".format(name, value))
prepare_seed(args.rand_seed) logger.log("Python Version : {:}".format(sys.version.replace("\n", " ")))
logger = prepare_logger(args) logger.log("PyTorch Version : {:}".format(torch.__version__))
logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version()))
train_data, valid_data, xshape, class_num = get_datasets( logger.log("CUDA available : {:}".format(torch.cuda.is_available()))
args.dataset, args.data_path, args.cutout_length logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
logger.log(
"CUDA_VISIBLE_DEVICES : {:}".format(
os.environ["CUDA_VISIBLE_DEVICES"]
if "CUDA_VISIBLE_DEVICES" in os.environ
else "None"
)
) )
logger.log("The training data is:\n{:}".format(train_data))
logger.log("The validation data is:\n{:}".format(valid_data))
model = xmisc.nested_call_by_yaml(args.model_config)
logger.log("The model is:\n{:}".format(model))
logger.log("The model size is {:.4f} M".format(xmisc.count_parameters(model)))
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_data, train_data,
batch_size=args.batch_size, batch_size=args.batch_size,
@ -44,100 +58,25 @@ def main(args):
num_workers=args.workers, num_workers=args.workers,
pin_memory=True, pin_memory=True,
) )
# get configures
model_config = load_config(args.model_config, {"class_num": class_num}, logger)
optim_config = load_config(args.optim_config, {"class_num": class_num}, logger)
if args.model_source == "normal": logger.log("The training loader: {:}".format(train_loader))
base_model = obtain_model(model_config) logger.log("The validation loader: {:}".format(valid_loader))
elif args.model_source == "nas": optimizer = xmisc.nested_call_by_yaml(
base_model = obtain_nas_infer_model(model_config, args.extra_model_path) args.optim_config,
elif args.model_source == "autodl-searched": model.parameters(),
base_model = obtain_model(model_config, args.extra_model_path) lr=args.lr,
elif args.model_source in ("x", "xmodel"): weight_decay=args.weight_decay,
base_model = obtain_xmodel(model_config)
else:
raise ValueError("invalid model-source : {:}".format(args.model_source))
flop, param = get_model_infos(base_model, xshape)
logger.log("model ====>>>>:\n{:}".format(base_model))
logger.log("model information : {:}".format(base_model.get_message()))
logger.log("-" * 50)
logger.log(
"Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G".format(
param, flop, flop / 1e3
)
) )
logger.log("-" * 50) loss = xmisc.nested_call_by_yaml(args.loss_config)
logger.log("train_data : {:}".format(train_data))
logger.log("valid_data : {:}".format(valid_data))
optimizer, scheduler, criterion = get_optim_scheduler(
base_model.parameters(), optim_config
)
logger.log("optimizer : {:}".format(optimizer))
logger.log("scheduler : {:}".format(scheduler))
logger.log("criterion : {:}".format(criterion))
last_info, model_base_path, model_best_path = ( logger.log("The optimizer is:\n{:}".format(optimizer))
logger.path("info"), logger.log("The loss is {:}".format(loss))
logger.path("model"),
logger.path("best"),
)
network, criterion = torch.nn.DataParallel(base_model).cuda(), criterion.cuda()
if last_info.exists(): # automatically resume from previous checkpoint model, loss = torch.nn.DataParallel(model).cuda(), loss.cuda()
logger.log(
"=> loading checkpoint of the last-info '{:}' start".format(last_info) import pdb
)
last_infox = torch.load(last_info) pdb.set_trace()
start_epoch = last_infox["epoch"] + 1
last_checkpoint_path = last_infox["last_checkpoint"]
if not last_checkpoint_path.exists():
logger.log(
"Does not find {:}, try another path".format(last_checkpoint_path)
)
last_checkpoint_path = (
last_info.parent
/ last_checkpoint_path.parent.name
/ last_checkpoint_path.name
)
checkpoint = torch.load(last_checkpoint_path)
base_model.load_state_dict(checkpoint["base-model"])
scheduler.load_state_dict(checkpoint["scheduler"])
optimizer.load_state_dict(checkpoint["optimizer"])
valid_accuracies = checkpoint["valid_accuracies"]
max_bytes = checkpoint["max_bytes"]
logger.log(
"=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(
last_info, start_epoch
)
)
elif args.resume is not None:
assert Path(args.resume).exists(), "Can not find the resume file : {:}".format(
args.resume
)
checkpoint = torch.load(args.resume)
start_epoch = checkpoint["epoch"] + 1
base_model.load_state_dict(checkpoint["base-model"])
scheduler.load_state_dict(checkpoint["scheduler"])
optimizer.load_state_dict(checkpoint["optimizer"])
valid_accuracies = checkpoint["valid_accuracies"]
max_bytes = checkpoint["max_bytes"]
logger.log(
"=> loading checkpoint from '{:}' start with {:}-th epoch.".format(
args.resume, start_epoch
)
)
elif args.init_model is not None:
assert Path(
args.init_model
).exists(), "Can not find the initialization file : {:}".format(args.init_model)
checkpoint = torch.load(args.init_model)
base_model.load_state_dict(checkpoint["base-model"])
start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {}
logger.log("=> initialize the model from {:}".format(args.init_model))
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {}
train_func, valid_func = get_procedures(args.procedure) train_func, valid_func = get_procedures(args.procedure)
@ -284,7 +223,7 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Train a model with a loss function.", description="Train a classification model with a loss function.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
parser.add_argument( parser.add_argument(
@ -293,27 +232,21 @@ if __name__ == "__main__":
parser.add_argument("--resume", type=str, help="Resume path.") parser.add_argument("--resume", type=str, help="Resume path.")
parser.add_argument("--init_model", type=str, help="The initialization model path.") parser.add_argument("--init_model", type=str, help="The initialization model path.")
parser.add_argument("--model_config", type=str, help="The path to the model config") parser.add_argument("--model_config", type=str, help="The path to the model config")
parser.add_argument("--optim_config", type=str, help="The optimizer config file.")
parser.add_argument("--loss_config", type=str, help="The loss config file.")
parser.add_argument( parser.add_argument(
"--optim_config", type=str, help="The path to the optimizer config" "--train_data_config", type=str, help="The training dataset config path."
) )
parser.add_argument( parser.add_argument(
"--train_data_config", type=str, help="The dataset config path." "--valid_data_config", type=str, help="The validation dataset config path."
)
parser.add_argument(
"--valid_data_config", type=str, help="The dataset config path."
)
parser.add_argument(
"--data_path", type=str, help="The path to the dataset."
) )
parser.add_argument("--data_path", type=str, help="The path to the dataset.")
parser.add_argument("--algorithm", type=str, help="The algorithm.") parser.add_argument("--algorithm", type=str, help="The algorithm.")
# Optimization options # Optimization options
parser.add_argument("--lr", type=float, help="The learning rate")
parser.add_argument("--weight_decay", type=float, help="The weight decay")
parser.add_argument("--batch_size", type=int, default=2, help="The batch size.") parser.add_argument("--batch_size", type=int, default=2, help="The batch size.")
parser.add_argument( parser.add_argument("--workers", type=int, default=4, help="The number of workers")
"--workers",
type=int,
default=8,
help="number of data loading workers (default: 8)",
)
# Random Seed # Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")

View File

@ -22,6 +22,10 @@ save_dir=./outputs/${dataset}/vit-experimental
python --version python --version
python ./exps/basic/xmain.py --save_dir ${save_dir} --rand_seed ${rseed} \ python ./exps/basic/xmain.py --save_dir ${save_dir} --rand_seed ${rseed} \
--train_data_config ./configs/data.yaml/${dataset}.train \ --train_data_config ./configs/yaml.data/${dataset}.train \
--valid_data_config ./configs/data.yaml/${dataset}.test \ --valid_data_config ./configs/yaml.data/${dataset}.test \
--data_path $TORCH_HOME/cifar.python --data_path $TORCH_HOME/cifar.python \
--model_config ./configs/yaml.model/vit-cifar10.s0 \
--optim_config ./configs/yaml.opt/vit.cifar \
--loss_config ./configs/yaml.loss/cross-entropy \
--lr 0.003 --weight_decay 0.3

View File

@ -3,10 +3,8 @@
##################################################### #####################################################
# pytest tests/test_basic_space.py -s # # pytest tests/test_basic_space.py -s #
##################################################### #####################################################
import sys, random import random
import unittest import unittest
import pytest
from pathlib import Path
from xautodl.spaces import Categorical from xautodl.spaces import Categorical
from xautodl.spaces import Continuous from xautodl.spaces import Continuous

View File

@ -3,12 +3,6 @@
##################################################### #####################################################
# pytest ./tests/test_import.py # # pytest ./tests/test_import.py #
##################################################### #####################################################
import os, sys, time, torch
import pickle
import tempfile
from pathlib import Path
def test_import(): def test_import():
from xautodl import config_utils from xautodl import config_utils
from xautodl import datasets from xautodl import datasets
@ -19,6 +13,9 @@ def test_import():
from xautodl import spaces from xautodl import spaces
from xautodl import trade_models from xautodl import trade_models
from xautodl import utils from xautodl import utils
from xautodl import xlayers from xautodl import xlayers
from xautodl import xmisc
from xautodl import xmmodels
print("Check all imports done") print("Check all imports done")

View File

@ -3,13 +3,11 @@
##################################################### #####################################################
# pytest ./tests/test_super_att.py -s # # pytest ./tests/test_super_att.py -s #
##################################################### #####################################################
import sys, random import random
import unittest import unittest
from parameterized import parameterized from parameterized import parameterized
from pathlib import Path
import torch import torch
from xautodl import spaces from xautodl import spaces
from xautodl.xlayers import super_core from xautodl.xlayers import super_core

View File

@ -3,10 +3,9 @@
##################################################### #####################################################
# pytest ./tests/test_super_container.py -s # # pytest ./tests/test_super_container.py -s #
##################################################### #####################################################
import sys, random import random
import unittest import unittest
import pytest import pytest
from pathlib import Path
import torch import torch
from xautodl import spaces from xautodl import spaces

View File

@ -3,7 +3,6 @@
##################################################### #####################################################
# pytest ./tests/test_super_rearrange.py -s # # pytest ./tests/test_super_rearrange.py -s #
##################################################### #####################################################
import sys
import unittest import unittest
import torch import torch

View File

@ -3,8 +3,8 @@
##################################################### #####################################################
# pytest ./tests/test_super_vit.py -s # # pytest ./tests/test_super_vit.py -s #
##################################################### #####################################################
import sys
import unittest import unittest
from parameterized import parameterized
import torch import torch
from xautodl.xmodels import transformers from xautodl.xmodels import transformers
@ -16,25 +16,28 @@ class TestSuperViT(unittest.TestCase):
def test_super_vit(self): def test_super_vit(self):
model = transformers.get_transformer("vit-base-16") model = transformers.get_transformer("vit-base-16")
tensor = torch.rand((16, 3, 224, 224)) tensor = torch.rand((2, 3, 224, 224))
print("The tensor shape: {:}".format(tensor.shape)) print("The tensor shape: {:}".format(tensor.shape))
# print(model) # print(model)
outs = model(tensor) outs = model(tensor)
print("The output tensor shape: {:}".format(outs.shape)) print("The output tensor shape: {:}".format(outs.shape))
def test_imagenet(self): @parameterized.expand(
name2config = transformers.name2config [
print("There are {:} models in total.".format(len(name2config))) ["vit-cifar10-p4-d4-h4-c32", 32],
for name, config in name2config.items(): ["vit-base-16", 224],
if "cifar" in name: ["vit-large-16", 224],
tensor = torch.rand((16, 3, 32, 32)) ["vit-huge-14", 224],
else: ]
tensor = torch.rand((16, 3, 224, 224)) )
model = transformers.get_transformer(config) def test_imagenet(self, name, resolution):
outs = model(tensor) tensor = torch.rand((2, 3, resolution, resolution))
size = count_parameters(model, "mb", True) config = transformers.name2config[name]
print( model = transformers.get_transformer(config)
"{:10s} : size={:.2f}MB, out-shape: {:}".format( outs = model(tensor)
name, size, tuple(outs.shape) size = count_parameters(model, "mb", True)
) print(
"{:10s} : size={:.2f}MB, out-shape: {:}".format(
name, size, tuple(outs.shape)
) )
)

View File

@ -6,3 +6,7 @@ from .module_utils import call_by_yaml
from .module_utils import nested_call_by_dict from .module_utils import nested_call_by_dict
from .module_utils import nested_call_by_yaml from .module_utils import nested_call_by_yaml
from .yaml_utils import load_yaml from .yaml_utils import load_yaml
from .torch_utils import count_parameters
from .logger_utils import Logger

View File

@ -0,0 +1,49 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
#####################################################
import sys
from pathlib import Path
from .time_utils import time_for_file, time_string
class Logger:
"""A logger used in xautodl."""
def __init__(self, root_dir, prefix="", log_time=True):
"""Create a summary writer logging to log_dir."""
self.root_dir = Path(root_dir)
self.log_dir = self.root_dir / "logs"
self.log_dir.mkdir(parents=True, exist_ok=True)
self._prefix = prefix
self._log_time = log_time
self.logger_path = self.log_dir / "{:}{:}.log".format(
self._prefix, time_for_file()
)
self._logger_file = open(self.logger_path, "w")
@property
def logger(self):
return self._logger_file
def log(self, string, save=True, stdout=False):
string = "{:} {:}".format(time_string(), string) if self._log_time else string
if stdout:
sys.stdout.write(string)
sys.stdout.flush()
else:
print(string)
if save:
self._logger_file.write("{:}\n".format(string))
self._logger_file.flush()
def close(self):
self._logger_file.close()
if self.writer is not None:
self.writer.close()
def __repr__(self):
return "{name}(dir={log_dir}, prefix={_prefix}, log_time={_log_time})".format(
name=self.__class__.__name__, **self.__dict__
)

View File

@ -62,18 +62,25 @@ def call_by_yaml(path, *args, **kwargs) -> object:
def nested_call_by_dict(config: Union[Dict[Text, Any], Any], *args, **kwargs) -> object: def nested_call_by_dict(config: Union[Dict[Text, Any], Any], *args, **kwargs) -> object:
"""Similar to `call_by_dict`, but differently, the args may contain another dict needs to be called.""" """Similar to `call_by_dict`, but differently, the args may contain another dict needs to be called."""
if not has_key_words(config): if isinstance(config, list):
return [nested_call_by_dict(x) for x in config]
elif isinstance(config, tuple):
return (nested_call_by_dict(x) for x in config)
elif not isinstance(config, dict):
return config return config
module = get_module_by_module_path(config["module_path"]) elif not has_key_words(config):
cls_or_func = getattr(module, config[CLS_FUNC_KEY]) return {key: nested_call_by_dict(x) for x, key in config.items()}
args = tuple(list(config["args"]) + list(args)) else:
kwargs = {**config["kwargs"], **kwargs} module = get_module_by_module_path(config["module_path"])
# check whether there are nested special dict cls_or_func = getattr(module, config[CLS_FUNC_KEY])
new_args = [nested_call_by_dict(x) for x in args] args = tuple(list(config["args"]) + list(args))
new_kwargs = {} kwargs = {**config["kwargs"], **kwargs}
for key, x in kwargs.items(): # check whether there are nested special dict
new_kwargs[key] = nested_call_by_dict(x) new_args = [nested_call_by_dict(x) for x in args]
return cls_or_func(*new_args, **new_kwargs) new_kwargs = {}
for key, x in kwargs.items():
new_kwargs[key] = nested_call_by_dict(x)
return cls_or_func(*new_args, **new_kwargs)
def nested_call_by_yaml(path, *args, **kwargs) -> object: def nested_call_by_yaml(path, *args, **kwargs) -> object:

View File

@ -0,0 +1,136 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
#####################################################
from torch.optim.lr_scheduler import _LRScheduler
class CosineDecayWithWarmup(_LRScheduler):
r"""Set the learning rate of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
is the number of epochs since the last restart and :math:`T_{i}` is the number
of epochs between two warm restarts in SGDR:
.. math::
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
\cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
It has been proposed in
`SGDR: Stochastic Gradient Descent with Warm Restarts`_.
Args:
optimizer (Optimizer): Wrapped optimizer.
T_0 (int): Number of iterations for the first restart.
T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
eta_min (float, optional): Minimum learning rate. Default: 0.
last_epoch (int, optional): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
"""
def __init__(
self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False
):
if T_0 <= 0 or not isinstance(T_0, int):
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
if T_mult < 1 or not isinstance(T_mult, int):
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
self.T_0 = T_0
self.T_i = T_0
self.T_mult = T_mult
self.eta_min = eta_min
super(CosineDecayWithWarmup, self).__init__(optimizer, last_epoch, verbose)
self.T_cur = self.last_epoch
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.",
UserWarning,
)
return [
self.eta_min
+ (base_lr - self.eta_min)
* (1 + math.cos(math.pi * self.T_cur / self.T_i))
/ 2
for base_lr in self.base_lrs
]
def step(self, epoch=None):
"""Step could be called after every batch update
Example:
>>> scheduler = CosineDecayWithWarmup(optimizer, T_0, T_mult)
>>> iters = len(dataloader)
>>> for epoch in range(20):
>>> for i, sample in enumerate(dataloader):
>>> inputs, labels = sample['inputs'], sample['labels']
>>> optimizer.zero_grad()
>>> outputs = net(inputs)
>>> loss = criterion(outputs, labels)
>>> loss.backward()
>>> optimizer.step()
>>> scheduler.step(epoch + i / iters)
This function can be called in an interleaved way.
Example:
>>> scheduler = CosineDecayWithWarmup(optimizer, T_0, T_mult)
>>> for epoch in range(20):
>>> scheduler.step()
>>> scheduler.step(26)
>>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
"""
if epoch is None and self.last_epoch < 0:
epoch = 0
if epoch is None:
epoch = self.last_epoch + 1
self.T_cur = self.T_cur + 1
if self.T_cur >= self.T_i:
self.T_cur = self.T_cur - self.T_i
self.T_i = self.T_i * self.T_mult
else:
if epoch < 0:
raise ValueError(
"Expected non-negative epoch, but got {}".format(epoch)
)
if epoch >= self.T_0:
if self.T_mult == 1:
self.T_cur = epoch % self.T_0
else:
n = int(
math.log(
(epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult
)
)
self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (
self.T_mult - 1
)
self.T_i = self.T_0 * self.T_mult ** (n)
else:
self.T_i = self.T_0
self.T_cur = epoch
self.last_epoch = math.floor(epoch)
class _enable_get_lr_call:
def __init__(self, o):
self.o = o
def __enter__(self):
self.o._get_lr_called_within_step = True
return self
def __exit__(self, type, value, traceback):
self.o._get_lr_called_within_step = False
return self
with _enable_get_lr_call(self):
for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
param_group, lr = data
param_group["lr"] = lr
self.print_lr(self.verbose, i, lr, epoch)
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]

View File

@ -0,0 +1,26 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
#####################################################
import time
def time_for_file():
ISOTIMEFORMAT = "%d-%h-at-%H-%M-%S"
return "{:}".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time())))
def time_string():
ISOTIMEFORMAT = "%Y-%m-%d %X"
string = "[{:}]".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time())))
return string
def convert_secs2time(epoch_time, return_str=False):
need_hour = int(epoch_time / 3600)
need_mins = int((epoch_time - 3600 * need_hour) / 60)
need_secs = int(epoch_time - 3600 * need_hour - 60 * need_mins)
if return_str:
str = "[{:02d}:{:02d}:{:02d}]".format(need_hour, need_mins, need_secs)
return str
else:
return need_hour, need_mins, need_secs

View File

@ -0,0 +1,26 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
#####################################################
import torch
import torch.nn as nn
import numpy as np
def count_parameters(model_or_parameters, unit="mb"):
if isinstance(model_or_parameters, nn.Module):
counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters())
elif isinstance(model_or_parameters, nn.Parameter):
counts = models_or_parameters.numel()
elif isinstance(model_or_parameters, (list, tuple)):
counts = sum(count_parameters(x, None) for x in models_or_parameters)
else:
counts = sum(np.prod(v.size()) for v in model_or_parameters)
if unit.lower() == "kb" or unit.lower() == "k":
counts /= 1e3
elif unit.lower() == "mb" or unit.lower() == "m":
counts /= 1e6
elif unit.lower() == "gb" or unit.lower() == "g":
counts /= 1e9
elif unit is not None:
raise ValueError("Unknow unit: {:}".format(unit))
return counts