Update yaml configs
This commit is contained in:
		| @@ -6,3 +6,7 @@ from .module_utils import call_by_yaml | ||||
| from .module_utils import nested_call_by_dict | ||||
| from .module_utils import nested_call_by_yaml | ||||
| from .yaml_utils import load_yaml | ||||
|  | ||||
| from .torch_utils import count_parameters | ||||
|  | ||||
| from .logger_utils import Logger | ||||
|   | ||||
							
								
								
									
										49
									
								
								xautodl/xmisc/logger_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								xautodl/xmisc/logger_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,49 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||
| ##################################################### | ||||
| import sys | ||||
| from pathlib import Path | ||||
|  | ||||
| from .time_utils import time_for_file, time_string | ||||
|  | ||||
|  | ||||
| class Logger: | ||||
|     """A logger used in xautodl.""" | ||||
|  | ||||
|     def __init__(self, root_dir, prefix="", log_time=True): | ||||
|         """Create a summary writer logging to log_dir.""" | ||||
|         self.root_dir = Path(root_dir) | ||||
|         self.log_dir = self.root_dir / "logs" | ||||
|         self.log_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|         self._prefix = prefix | ||||
|         self._log_time = log_time | ||||
|         self.logger_path = self.log_dir / "{:}{:}.log".format( | ||||
|             self._prefix, time_for_file() | ||||
|         ) | ||||
|         self._logger_file = open(self.logger_path, "w") | ||||
|  | ||||
|     @property | ||||
|     def logger(self): | ||||
|         return self._logger_file | ||||
|  | ||||
|     def log(self, string, save=True, stdout=False): | ||||
|         string = "{:} {:}".format(time_string(), string) if self._log_time else string | ||||
|         if stdout: | ||||
|             sys.stdout.write(string) | ||||
|             sys.stdout.flush() | ||||
|         else: | ||||
|             print(string) | ||||
|         if save: | ||||
|             self._logger_file.write("{:}\n".format(string)) | ||||
|             self._logger_file.flush() | ||||
|  | ||||
|     def close(self): | ||||
|         self._logger_file.close() | ||||
|         if self.writer is not None: | ||||
|             self.writer.close() | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(dir={log_dir}, prefix={_prefix}, log_time={_log_time})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
| @@ -62,18 +62,25 @@ def call_by_yaml(path, *args, **kwargs) -> object: | ||||
|  | ||||
| def nested_call_by_dict(config: Union[Dict[Text, Any], Any], *args, **kwargs) -> object: | ||||
|     """Similar to `call_by_dict`, but differently, the args may contain another dict needs to be called.""" | ||||
|     if not has_key_words(config): | ||||
|     if isinstance(config, list): | ||||
|         return [nested_call_by_dict(x) for x in config] | ||||
|     elif isinstance(config, tuple): | ||||
|         return (nested_call_by_dict(x) for x in config) | ||||
|     elif not isinstance(config, dict): | ||||
|         return config | ||||
|     module = get_module_by_module_path(config["module_path"]) | ||||
|     cls_or_func = getattr(module, config[CLS_FUNC_KEY]) | ||||
|     args = tuple(list(config["args"]) + list(args)) | ||||
|     kwargs = {**config["kwargs"], **kwargs} | ||||
|     # check whether there are nested special dict | ||||
|     new_args = [nested_call_by_dict(x) for x in args] | ||||
|     new_kwargs = {} | ||||
|     for key, x in kwargs.items(): | ||||
|         new_kwargs[key] = nested_call_by_dict(x) | ||||
|     return cls_or_func(*new_args, **new_kwargs) | ||||
|     elif not has_key_words(config): | ||||
|         return {key: nested_call_by_dict(x) for x, key in config.items()} | ||||
|     else: | ||||
|         module = get_module_by_module_path(config["module_path"]) | ||||
|         cls_or_func = getattr(module, config[CLS_FUNC_KEY]) | ||||
|         args = tuple(list(config["args"]) + list(args)) | ||||
|         kwargs = {**config["kwargs"], **kwargs} | ||||
|         # check whether there are nested special dict | ||||
|         new_args = [nested_call_by_dict(x) for x in args] | ||||
|         new_kwargs = {} | ||||
|         for key, x in kwargs.items(): | ||||
|             new_kwargs[key] = nested_call_by_dict(x) | ||||
|         return cls_or_func(*new_args, **new_kwargs) | ||||
|  | ||||
|  | ||||
| def nested_call_by_yaml(path, *args, **kwargs) -> object: | ||||
|   | ||||
							
								
								
									
										136
									
								
								xautodl/xmisc/scheduler_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								xautodl/xmisc/scheduler_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,136 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||
| ##################################################### | ||||
| from torch.optim.lr_scheduler import _LRScheduler | ||||
|  | ||||
|  | ||||
| class CosineDecayWithWarmup(_LRScheduler): | ||||
|     r"""Set the learning rate of each parameter group using a cosine annealing | ||||
|     schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` | ||||
|     is the number of epochs since the last restart and :math:`T_{i}` is the number | ||||
|     of epochs between two warm restarts in SGDR: | ||||
|     .. math:: | ||||
|         \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + | ||||
|         \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right) | ||||
|     When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. | ||||
|     When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. | ||||
|     It has been proposed in | ||||
|     `SGDR: Stochastic Gradient Descent with Warm Restarts`_. | ||||
|     Args: | ||||
|         optimizer (Optimizer): Wrapped optimizer. | ||||
|         T_0 (int): Number of iterations for the first restart. | ||||
|         T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1. | ||||
|         eta_min (float, optional): Minimum learning rate. Default: 0. | ||||
|         last_epoch (int, optional): The index of last epoch. Default: -1. | ||||
|         verbose (bool): If ``True``, prints a message to stdout for | ||||
|             each update. Default: ``False``. | ||||
|     .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: | ||||
|         https://arxiv.org/abs/1608.03983 | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False | ||||
|     ): | ||||
|         if T_0 <= 0 or not isinstance(T_0, int): | ||||
|             raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) | ||||
|         if T_mult < 1 or not isinstance(T_mult, int): | ||||
|             raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) | ||||
|         self.T_0 = T_0 | ||||
|         self.T_i = T_0 | ||||
|         self.T_mult = T_mult | ||||
|         self.eta_min = eta_min | ||||
|  | ||||
|         super(CosineDecayWithWarmup, self).__init__(optimizer, last_epoch, verbose) | ||||
|  | ||||
|         self.T_cur = self.last_epoch | ||||
|  | ||||
|     def get_lr(self): | ||||
|         if not self._get_lr_called_within_step: | ||||
|             warnings.warn( | ||||
|                 "To get the last learning rate computed by the scheduler, " | ||||
|                 "please use `get_last_lr()`.", | ||||
|                 UserWarning, | ||||
|             ) | ||||
|  | ||||
|         return [ | ||||
|             self.eta_min | ||||
|             + (base_lr - self.eta_min) | ||||
|             * (1 + math.cos(math.pi * self.T_cur / self.T_i)) | ||||
|             / 2 | ||||
|             for base_lr in self.base_lrs | ||||
|         ] | ||||
|  | ||||
|     def step(self, epoch=None): | ||||
|         """Step could be called after every batch update | ||||
|         Example: | ||||
|             >>> scheduler = CosineDecayWithWarmup(optimizer, T_0, T_mult) | ||||
|             >>> iters = len(dataloader) | ||||
|             >>> for epoch in range(20): | ||||
|             >>>     for i, sample in enumerate(dataloader): | ||||
|             >>>         inputs, labels = sample['inputs'], sample['labels'] | ||||
|             >>>         optimizer.zero_grad() | ||||
|             >>>         outputs = net(inputs) | ||||
|             >>>         loss = criterion(outputs, labels) | ||||
|             >>>         loss.backward() | ||||
|             >>>         optimizer.step() | ||||
|             >>>         scheduler.step(epoch + i / iters) | ||||
|         This function can be called in an interleaved way. | ||||
|         Example: | ||||
|             >>> scheduler = CosineDecayWithWarmup(optimizer, T_0, T_mult) | ||||
|             >>> for epoch in range(20): | ||||
|             >>>     scheduler.step() | ||||
|             >>> scheduler.step(26) | ||||
|             >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) | ||||
|         """ | ||||
|  | ||||
|         if epoch is None and self.last_epoch < 0: | ||||
|             epoch = 0 | ||||
|  | ||||
|         if epoch is None: | ||||
|             epoch = self.last_epoch + 1 | ||||
|             self.T_cur = self.T_cur + 1 | ||||
|             if self.T_cur >= self.T_i: | ||||
|                 self.T_cur = self.T_cur - self.T_i | ||||
|                 self.T_i = self.T_i * self.T_mult | ||||
|         else: | ||||
|             if epoch < 0: | ||||
|                 raise ValueError( | ||||
|                     "Expected non-negative epoch, but got {}".format(epoch) | ||||
|                 ) | ||||
|             if epoch >= self.T_0: | ||||
|                 if self.T_mult == 1: | ||||
|                     self.T_cur = epoch % self.T_0 | ||||
|                 else: | ||||
|                     n = int( | ||||
|                         math.log( | ||||
|                             (epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult | ||||
|                         ) | ||||
|                     ) | ||||
|                     self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / ( | ||||
|                         self.T_mult - 1 | ||||
|                     ) | ||||
|                     self.T_i = self.T_0 * self.T_mult ** (n) | ||||
|             else: | ||||
|                 self.T_i = self.T_0 | ||||
|                 self.T_cur = epoch | ||||
|         self.last_epoch = math.floor(epoch) | ||||
|  | ||||
|         class _enable_get_lr_call: | ||||
|             def __init__(self, o): | ||||
|                 self.o = o | ||||
|  | ||||
|             def __enter__(self): | ||||
|                 self.o._get_lr_called_within_step = True | ||||
|                 return self | ||||
|  | ||||
|             def __exit__(self, type, value, traceback): | ||||
|                 self.o._get_lr_called_within_step = False | ||||
|                 return self | ||||
|  | ||||
|         with _enable_get_lr_call(self): | ||||
|             for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())): | ||||
|                 param_group, lr = data | ||||
|                 param_group["lr"] = lr | ||||
|                 self.print_lr(self.verbose, i, lr, epoch) | ||||
|  | ||||
|         self._last_lr = [group["lr"] for group in self.optimizer.param_groups] | ||||
							
								
								
									
										26
									
								
								xautodl/xmisc/time_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								xautodl/xmisc/time_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||
| ##################################################### | ||||
| import time | ||||
|  | ||||
|  | ||||
| def time_for_file(): | ||||
|     ISOTIMEFORMAT = "%d-%h-at-%H-%M-%S" | ||||
|     return "{:}".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time()))) | ||||
|  | ||||
|  | ||||
| def time_string(): | ||||
|     ISOTIMEFORMAT = "%Y-%m-%d %X" | ||||
|     string = "[{:}]".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time()))) | ||||
|     return string | ||||
|  | ||||
|  | ||||
| def convert_secs2time(epoch_time, return_str=False): | ||||
|     need_hour = int(epoch_time / 3600) | ||||
|     need_mins = int((epoch_time - 3600 * need_hour) / 60) | ||||
|     need_secs = int(epoch_time - 3600 * need_hour - 60 * need_mins) | ||||
|     if return_str: | ||||
|         str = "[{:02d}:{:02d}:{:02d}]".format(need_hour, need_mins, need_secs) | ||||
|         return str | ||||
|     else: | ||||
|         return need_hour, need_mins, need_secs | ||||
							
								
								
									
										26
									
								
								xautodl/xmisc/torch_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								xautodl/xmisc/torch_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||
| ##################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| def count_parameters(model_or_parameters, unit="mb"): | ||||
|     if isinstance(model_or_parameters, nn.Module): | ||||
|         counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters()) | ||||
|     elif isinstance(model_or_parameters, nn.Parameter): | ||||
|         counts = models_or_parameters.numel() | ||||
|     elif isinstance(model_or_parameters, (list, tuple)): | ||||
|         counts = sum(count_parameters(x, None) for x in models_or_parameters) | ||||
|     else: | ||||
|         counts = sum(np.prod(v.size()) for v in model_or_parameters) | ||||
|     if unit.lower() == "kb" or unit.lower() == "k": | ||||
|         counts /= 1e3 | ||||
|     elif unit.lower() == "mb" or unit.lower() == "m": | ||||
|         counts /= 1e6 | ||||
|     elif unit.lower() == "gb" or unit.lower() == "g": | ||||
|         counts /= 1e9 | ||||
|     elif unit is not None: | ||||
|         raise ValueError("Unknow unit: {:}".format(unit)) | ||||
|     return counts | ||||
		Reference in New Issue
	
	Block a user