Update xmisc.scheduler/sampler
This commit is contained in:
		
							
								
								
									
										41
									
								
								.github/workflows/test-misc.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								.github/workflows/test-misc.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,41 @@ | |||||||
|  | name: Test Xmisc | ||||||
|  | on: | ||||||
|  |   push: | ||||||
|  |     branches: | ||||||
|  |       - main | ||||||
|  |   pull_request: | ||||||
|  |     branches: | ||||||
|  |       - main | ||||||
|  |  | ||||||
|  |  | ||||||
|  | jobs: | ||||||
|  |   build: | ||||||
|  |     strategy: | ||||||
|  |       matrix: | ||||||
|  |         os: [ubuntu-16.04, ubuntu-18.04, ubuntu-20.04, macos-latest] | ||||||
|  |         python-version: [3.6, 3.7, 3.8, 3.9] | ||||||
|  |  | ||||||
|  |     runs-on: ${{ matrix.os }} | ||||||
|  |     steps: | ||||||
|  |       - uses: actions/checkout@v2 | ||||||
|  |  | ||||||
|  |       - name: Set up Python ${{ matrix.python-version }} | ||||||
|  |         uses: actions/setup-python@v2 | ||||||
|  |         with: | ||||||
|  |           python-version: ${{ matrix.python-version }} | ||||||
|  |  | ||||||
|  |       - name: Install XAutoDL from source | ||||||
|  |         run: | | ||||||
|  |           python setup.py install | ||||||
|  |  | ||||||
|  |       - name: Test Xmisc | ||||||
|  |         run: | | ||||||
|  |           python -m pip install pytest numpy | ||||||
|  |           python -m pip install torch torchvision | ||||||
|  |           python -m pip install parameterized | ||||||
|  |           echo $PWD | ||||||
|  |           echo "Show what we have here:" | ||||||
|  |           ls | ||||||
|  |           python --version | ||||||
|  |           python -m pytest ./tests/test_misc* -s | ||||||
|  |         shell: bash | ||||||
| @@ -46,8 +46,7 @@ def main(args): | |||||||
|  |  | ||||||
|     train_loader = torch.utils.data.DataLoader( |     train_loader = torch.utils.data.DataLoader( | ||||||
|         train_data, |         train_data, | ||||||
|         batch_size=args.batch_size, |         batch_sampler=xmisc.BatchSampler(train_data, args.batch_size, args.steps), | ||||||
|         shuffle=True, |  | ||||||
|         num_workers=args.workers, |         num_workers=args.workers, | ||||||
|         pin_memory=True, |         pin_memory=True, | ||||||
|     ) |     ) | ||||||
| @@ -57,6 +56,7 @@ def main(args): | |||||||
|         shuffle=False, |         shuffle=False, | ||||||
|         num_workers=args.workers, |         num_workers=args.workers, | ||||||
|         pin_memory=True, |         pin_memory=True, | ||||||
|  |         drop_last=False, | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     logger.log("The training loader: {:}".format(train_loader)) |     logger.log("The training loader: {:}".format(train_loader)) | ||||||
| @@ -73,6 +73,9 @@ def main(args): | |||||||
|     logger.log("The loss is {:}".format(loss)) |     logger.log("The loss is {:}".format(loss)) | ||||||
|  |  | ||||||
|     model, loss = torch.nn.DataParallel(model).cuda(), loss.cuda() |     model, loss = torch.nn.DataParallel(model).cuda(), loss.cuda() | ||||||
|  |     scheduler = xmisc.LRMultiplier( | ||||||
|  |         optimizer, xmisc.get_scheduler(args.scheduler, args.lr), args.steps | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     import pdb |     import pdb | ||||||
|  |  | ||||||
| @@ -241,10 +244,11 @@ if __name__ == "__main__": | |||||||
|         "--valid_data_config", type=str, help="The validation dataset config path." |         "--valid_data_config", type=str, help="The validation dataset config path." | ||||||
|     ) |     ) | ||||||
|     parser.add_argument("--data_path", type=str, help="The path to the dataset.") |     parser.add_argument("--data_path", type=str, help="The path to the dataset.") | ||||||
|     parser.add_argument("--algorithm", type=str, help="The algorithm.") |  | ||||||
|     # Optimization options |     # Optimization options | ||||||
|     parser.add_argument("--lr", type=float, help="The learning rate") |     parser.add_argument("--lr", type=float, help="The learning rate") | ||||||
|     parser.add_argument("--weight_decay", type=float, help="The weight decay") |     parser.add_argument("--weight_decay", type=float, help="The weight decay") | ||||||
|  |     parser.add_argument("--scheduler", type=str, help="The scheduler indicator.") | ||||||
|  |     parser.add_argument("--steps", type=int, help="The total number of steps.") | ||||||
|     parser.add_argument("--batch_size", type=int, default=2, help="The batch size.") |     parser.add_argument("--batch_size", type=int, default=2, help="The batch size.") | ||||||
|     parser.add_argument("--workers", type=int, default=4, help="The number of workers") |     parser.add_argument("--workers", type=int, default=4, help="The number of workers") | ||||||
|     # Random Seed |     # Random Seed | ||||||
|   | |||||||
							
								
								
									
										119
									
								
								notebooks/spaces-xmisc/scheduler.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								notebooks/spaces-xmisc/scheduler.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| @@ -1,76 +0,0 @@ | |||||||
| import os |  | ||||||
| import sys |  | ||||||
| import qlib |  | ||||||
| import pprint |  | ||||||
| import numpy as np |  | ||||||
| import pandas as pd |  | ||||||
|  |  | ||||||
| from pathlib import Path |  | ||||||
| import torch |  | ||||||
|  |  | ||||||
| __file__ = os.path.dirname(os.path.realpath("__file__")) |  | ||||||
|  |  | ||||||
| lib_dir = (Path(__file__).parent / ".." / "lib").resolve() |  | ||||||
| print("library path: {:}".format(lib_dir)) |  | ||||||
| assert lib_dir.exists(), "{:} does not exist".format(lib_dir) |  | ||||||
| if str(lib_dir) not in sys.path: |  | ||||||
|     sys.path.insert(0, str(lib_dir)) |  | ||||||
|  |  | ||||||
| from trade_models import get_transformer |  | ||||||
|  |  | ||||||
| from qlib import config as qconfig |  | ||||||
| from qlib.utils import init_instance_by_config |  | ||||||
| from qlib.model.base import Model |  | ||||||
| from qlib.data.dataset import DatasetH |  | ||||||
| from qlib.data.dataset.handler import DataHandlerLP |  | ||||||
|  |  | ||||||
| qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN) |  | ||||||
|  |  | ||||||
| dataset_config = { |  | ||||||
|     "class": "DatasetH", |  | ||||||
|     "module_path": "qlib.data.dataset", |  | ||||||
|     "kwargs": { |  | ||||||
|         "handler": { |  | ||||||
|             "class": "Alpha360", |  | ||||||
|             "module_path": "qlib.contrib.data.handler", |  | ||||||
|             "kwargs": { |  | ||||||
|                 "start_time": "2008-01-01", |  | ||||||
|                 "end_time": "2020-08-01", |  | ||||||
|                 "fit_start_time": "2008-01-01", |  | ||||||
|                 "fit_end_time": "2014-12-31", |  | ||||||
|                 "instruments": "csi100", |  | ||||||
|             }, |  | ||||||
|         }, |  | ||||||
|         "segments": { |  | ||||||
|             "train": ("2008-01-01", "2014-12-31"), |  | ||||||
|             "valid": ("2015-01-01", "2016-12-31"), |  | ||||||
|             "test": ("2017-01-01", "2020-08-01"), |  | ||||||
|         }, |  | ||||||
|     }, |  | ||||||
| } |  | ||||||
| pprint.pprint(dataset_config) |  | ||||||
| dataset = init_instance_by_config(dataset_config) |  | ||||||
|  |  | ||||||
| df_train, df_valid, df_test = dataset.prepare( |  | ||||||
|     ["train", "valid", "test"], |  | ||||||
|     col_set=["feature", "label"], |  | ||||||
|     data_key=DataHandlerLP.DK_L, |  | ||||||
| ) |  | ||||||
| model = get_transformer(None) |  | ||||||
| print(model) |  | ||||||
|  |  | ||||||
| features = torch.from_numpy(df_train["feature"].values).float() |  | ||||||
| labels = torch.from_numpy(df_train["label"].values).squeeze().float() |  | ||||||
|  |  | ||||||
| batch = list(range(2000)) |  | ||||||
| predicts = model(features[batch]) |  | ||||||
| mask = ~torch.isnan(labels[batch]) |  | ||||||
|  |  | ||||||
| pred = predicts[mask] |  | ||||||
| label = labels[batch][mask] |  | ||||||
|  |  | ||||||
| loss = torch.nn.functional.mse_loss(pred, label) |  | ||||||
|  |  | ||||||
| from sklearn.metrics import mean_squared_error |  | ||||||
|  |  | ||||||
| mse_loss = mean_squared_error(pred.numpy(), label.numpy()) |  | ||||||
| @@ -28,4 +28,4 @@ python ./exps/basic/xmain.py --save_dir ${save_dir} --rand_seed ${rseed} \ | |||||||
| 	--model_config ./configs/yaml.model/vit-cifar10.s0 \ | 	--model_config ./configs/yaml.model/vit-cifar10.s0 \ | ||||||
| 	--optim_config ./configs/yaml.opt/vit.cifar \ | 	--optim_config ./configs/yaml.opt/vit.cifar \ | ||||||
| 	--loss_config ./configs/yaml.loss/cross-entropy \ | 	--loss_config ./configs/yaml.loss/cross-entropy \ | ||||||
| 	--lr 0.003 --weight_decay 0.3  | 	--lr 0.003 --weight_decay 0.3 --scheduler warm-cos --steps 10000 | ||||||
|   | |||||||
							
								
								
									
										73
									
								
								tests/test_misc_scheduler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								tests/test_misc_scheduler.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,73 @@ | |||||||
|  | #################################################### | ||||||
|  | # Copyright (c) Facebook, Inc. and its affiliates. # | ||||||
|  | #################################################### | ||||||
|  | # Inspired from https://github.com/facebookresearch/detectron2/blob/master/tests/test_scheduler.py | ||||||
|  | #################################################### | ||||||
|  | import math | ||||||
|  | import numpy as np | ||||||
|  | from unittest import TestCase | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  | from xautodl.xmisc.scheduler_utils import CosineParamScheduler, MultiStepParamScheduler | ||||||
|  | from xautodl.xmisc.scheduler_utils import LRMultiplier, WarmupParamScheduler | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestScheduler(TestCase): | ||||||
|  |     """Test the scheduler.""" | ||||||
|  |  | ||||||
|  |     def test_warmup_multistep(self): | ||||||
|  |         p = torch.nn.Parameter(torch.zeros(0)) | ||||||
|  |         opt = torch.optim.SGD([p], lr=5) | ||||||
|  |  | ||||||
|  |         multiplier = WarmupParamScheduler( | ||||||
|  |             MultiStepParamScheduler( | ||||||
|  |                 [1, 0.1, 0.01, 0.001], | ||||||
|  |                 milestones=[10, 15, 20], | ||||||
|  |                 num_updates=30, | ||||||
|  |             ), | ||||||
|  |             0.001, | ||||||
|  |             5 / 30, | ||||||
|  |         ) | ||||||
|  |         sched = LRMultiplier(opt, multiplier, 30) | ||||||
|  |         # This is an equivalent of: | ||||||
|  |         # sched = WarmupMultiStepLR( | ||||||
|  |         # opt, milestones=[10, 15, 20], gamma=0.1, warmup_factor=0.001, warmup_iters=5) | ||||||
|  |  | ||||||
|  |         p.sum().backward() | ||||||
|  |         opt.step() | ||||||
|  |  | ||||||
|  |         lrs = [0.005] | ||||||
|  |         for _ in range(30): | ||||||
|  |             sched.step() | ||||||
|  |             lrs.append(opt.param_groups[0]["lr"]) | ||||||
|  |         self.assertTrue(np.allclose(lrs[:5], [0.005, 1.004, 2.003, 3.002, 4.001])) | ||||||
|  |         self.assertTrue(np.allclose(lrs[5:10], 5.0)) | ||||||
|  |         self.assertTrue(np.allclose(lrs[10:15], 0.5)) | ||||||
|  |         self.assertTrue(np.allclose(lrs[15:20], 0.05)) | ||||||
|  |         self.assertTrue(np.allclose(lrs[20:], 0.005)) | ||||||
|  |  | ||||||
|  |     def test_warmup_cosine(self): | ||||||
|  |         p = torch.nn.Parameter(torch.zeros(0)) | ||||||
|  |         opt = torch.optim.SGD([p], lr=5) | ||||||
|  |         multiplier = WarmupParamScheduler( | ||||||
|  |             CosineParamScheduler(1, 0), | ||||||
|  |             0.001, | ||||||
|  |             5 / 30, | ||||||
|  |         ) | ||||||
|  |         sched = LRMultiplier(opt, multiplier, 30) | ||||||
|  |  | ||||||
|  |         p.sum().backward() | ||||||
|  |         opt.step() | ||||||
|  |         self.assertEqual(opt.param_groups[0]["lr"], 0.005) | ||||||
|  |         lrs = [0.005] | ||||||
|  |  | ||||||
|  |         for _ in range(30): | ||||||
|  |             sched.step() | ||||||
|  |             lrs.append(opt.param_groups[0]["lr"]) | ||||||
|  |         for idx, lr in enumerate(lrs): | ||||||
|  |             expected_cosine = 2.5 * (1.0 + math.cos(math.pi * idx / 30)) | ||||||
|  |             if idx >= 5: | ||||||
|  |                 self.assertAlmostEqual(lr, expected_cosine) | ||||||
|  |             else: | ||||||
|  |                 self.assertNotAlmostEqual(lr, expected_cosine) | ||||||
| @@ -10,3 +10,23 @@ from .yaml_utils import load_yaml | |||||||
| from .torch_utils import count_parameters | from .torch_utils import count_parameters | ||||||
|  |  | ||||||
| from .logger_utils import Logger | from .logger_utils import Logger | ||||||
|  |  | ||||||
|  | # sampler | ||||||
|  | from .sampler_utils import BatchSampler | ||||||
|  |  | ||||||
|  | # scheduler related | ||||||
|  | from .scheduler_utils import CosineParamScheduler, WarmupParamScheduler, LRMultiplier | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_scheduler(indicator, lr): | ||||||
|  |     if indicator == "warm-cos": | ||||||
|  |         multiplier = WarmupParamScheduler( | ||||||
|  |             CosineParamScheduler(lr, lr * 1e-3), | ||||||
|  |             warmup_factor=0.001, | ||||||
|  |             warmup_length=0.05, | ||||||
|  |             warmup_method="linear", | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     else: | ||||||
|  |         raise ValueError("Unknown indicator: {:}".format(indicator)) | ||||||
|  |     return multiplier | ||||||
|   | |||||||
							
								
								
									
										32
									
								
								xautodl/xmisc/sampler_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								xautodl/xmisc/sampler_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||||
|  | ##################################################### | ||||||
|  | import random | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class BatchSampler: | ||||||
|  |     """A batch sampler used for single machine training.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, dataset, batch, steps): | ||||||
|  |         self._num_per_epoch = len(dataset) | ||||||
|  |         self._iter_per_epoch = self._num_per_epoch // batch | ||||||
|  |         self._steps = steps | ||||||
|  |         self._batch = batch | ||||||
|  |         if self._num_per_epoch < self._batch: | ||||||
|  |             raise ValueError( | ||||||
|  |                 "The dataset size must be larger than batch={:}".format(batch) | ||||||
|  |             ) | ||||||
|  |         self._indexes = list(range(self._num_per_epoch)) | ||||||
|  |  | ||||||
|  |     def __iter__(self): | ||||||
|  |         """ | ||||||
|  |         yield a batch of indexes using random sampling | ||||||
|  |         """ | ||||||
|  |         for i in range(self._steps): | ||||||
|  |             if i % self._iter_per_epoch == 0: | ||||||
|  |                 random.shuffle(self._indexes) | ||||||
|  |             j = i % self._iter_per_epoch | ||||||
|  |             yield self._indexes[j * self._batch : (j + 1) * self._batch] | ||||||
|  |  | ||||||
|  |     def __len__(self): | ||||||
|  |         return self._steps | ||||||
| @@ -1,136 +1,532 @@ | |||||||
| ##################################################### | #################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | # Copyright (c) Facebook, Inc. and its affiliates. # | ||||||
| ##################################################### | #################################################### | ||||||
| from torch.optim.lr_scheduler import _LRScheduler | # Borrowed from https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/param_scheduler.py | ||||||
|  | #           and https://github.com/facebookresearch/detectron2/blob/master/detectron2/solver/lr_scheduler.py | ||||||
|  | #################################################### | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  | import bisect | ||||||
|  | import math | ||||||
|  | from typing import List, Optional, Sequence, Union | ||||||
|  |  | ||||||
|  | __all__ = [ | ||||||
|  |     "ParamScheduler", | ||||||
|  |     "ConstantParamScheduler", | ||||||
|  |     "CosineParamScheduler", | ||||||
|  |     "ExponentialParamScheduler", | ||||||
|  |     "LinearParamScheduler", | ||||||
|  |     "CompositeParamScheduler", | ||||||
|  |     "MultiStepParamScheduler", | ||||||
|  |     "StepParamScheduler", | ||||||
|  |     "StepWithFixedGammaParamScheduler", | ||||||
|  |     "PolynomialDecayParamScheduler", | ||||||
|  |     "WarmupParamScheduler", | ||||||
|  |     "LRMultiplier", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  |  | ||||||
| class CosineDecayWithWarmup(_LRScheduler): | class ParamScheduler: | ||||||
|     r"""Set the learning rate of each parameter group using a cosine annealing |     """ | ||||||
|     schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` |     Base class for parameter schedulers. | ||||||
|     is the number of epochs since the last restart and :math:`T_{i}` is the number |     A parameter scheduler defines a mapping from a progress value in [0, 1) to | ||||||
|     of epochs between two warm restarts in SGDR: |     a number (e.g. learning rate). | ||||||
|     .. math:: |     """ | ||||||
|         \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + |  | ||||||
|         \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right) |     # To be used for comparisons with where | ||||||
|     When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. |     WHERE_EPSILON = 1e-6 | ||||||
|     When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. |  | ||||||
|     It has been proposed in |     def __call__(self, where: float) -> float: | ||||||
|     `SGDR: Stochastic Gradient Descent with Warm Restarts`_. |         """ | ||||||
|  |         Get the value of the param for a given point at training. | ||||||
|  |  | ||||||
|  |         We update params (such as learning rate) based on the percent progress | ||||||
|  |         of training completed. This allows a scheduler to be agnostic to the | ||||||
|  |         exact length of a particular run (e.g. 120 epochs vs 90 epochs), as | ||||||
|  |         long as the relative progress where params should be updated is the same. | ||||||
|  |         However, it assumes that the total length of training is known. | ||||||
|  |  | ||||||
|         Args: |         Args: | ||||||
|         optimizer (Optimizer): Wrapped optimizer. |             where: A float in [0,1) that represents how far training has progressed | ||||||
|         T_0 (int): Number of iterations for the first restart. |  | ||||||
|         T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1. |         """ | ||||||
|         eta_min (float, optional): Minimum learning rate. Default: 0. |         raise NotImplementedError("Param schedulers must override __call__") | ||||||
|         last_epoch (int, optional): The index of last epoch. Default: -1. |  | ||||||
|         verbose (bool): If ``True``, prints a message to stdout for |  | ||||||
|             each update. Default: ``False``. | class ConstantParamScheduler(ParamScheduler): | ||||||
|     .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: |     """ | ||||||
|         https://arxiv.org/abs/1608.03983 |     Returns a constant value for a param. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, value: float) -> None: | ||||||
|  |         self._value = value | ||||||
|  |  | ||||||
|  |     def __call__(self, where: float) -> float: | ||||||
|  |         if where >= 1.0: | ||||||
|  |             raise RuntimeError( | ||||||
|  |                 f"where in ParamScheduler must be in [0, 1]: got {where}" | ||||||
|  |             ) | ||||||
|  |         return self._value | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CosineParamScheduler(ParamScheduler): | ||||||
|  |     """ | ||||||
|  |     Cosine decay or cosine warmup schedules based on start and end values. | ||||||
|  |     The schedule is updated based on the fraction of training progress. | ||||||
|  |     The schedule was proposed in 'SGDR: Stochastic Gradient Descent with | ||||||
|  |     Warm Restarts' (https://arxiv.org/abs/1608.03983). Note that this class | ||||||
|  |     only implements the cosine annealing part of SGDR, and not the restarts. | ||||||
|  |  | ||||||
|  |     Example: | ||||||
|  |  | ||||||
|  |         .. code-block:: python | ||||||
|  |  | ||||||
|  |           CosineParamScheduler(start_value=0.1, end_value=0.0001) | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False |         self, | ||||||
|     ): |         start_value: float, | ||||||
|         if T_0 <= 0 or not isinstance(T_0, int): |         end_value: float, | ||||||
|             raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) |     ) -> None: | ||||||
|         if T_mult < 1 or not isinstance(T_mult, int): |         self._start_value = start_value | ||||||
|             raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) |         self._end_value = end_value | ||||||
|         self.T_0 = T_0 |  | ||||||
|         self.T_i = T_0 |  | ||||||
|         self.T_mult = T_mult |  | ||||||
|         self.eta_min = eta_min |  | ||||||
|  |  | ||||||
|         super(CosineDecayWithWarmup, self).__init__(optimizer, last_epoch, verbose) |     def __call__(self, where: float) -> float: | ||||||
|  |         return self._end_value + 0.5 * (self._start_value - self._end_value) * ( | ||||||
|         self.T_cur = self.last_epoch |             1 + math.cos(math.pi * where) | ||||||
|  |  | ||||||
|     def get_lr(self): |  | ||||||
|         if not self._get_lr_called_within_step: |  | ||||||
|             warnings.warn( |  | ||||||
|                 "To get the last learning rate computed by the scheduler, " |  | ||||||
|                 "please use `get_last_lr()`.", |  | ||||||
|                 UserWarning, |  | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         return [ |  | ||||||
|             self.eta_min |  | ||||||
|             + (base_lr - self.eta_min) |  | ||||||
|             * (1 + math.cos(math.pi * self.T_cur / self.T_i)) |  | ||||||
|             / 2 |  | ||||||
|             for base_lr in self.base_lrs |  | ||||||
|         ] |  | ||||||
|  |  | ||||||
|     def step(self, epoch=None): | class ExponentialParamScheduler(ParamScheduler): | ||||||
|         """Step could be called after every batch update |     """ | ||||||
|  |     Exponetial schedule parameterized by a start value and decay. | ||||||
|  |     The schedule is updated based on the fraction of training | ||||||
|  |     progress, `where`, with the formula | ||||||
|  |     `param_t = start_value * (decay ** where)`. | ||||||
|  |  | ||||||
|     Example: |     Example: | ||||||
|             >>> scheduler = CosineDecayWithWarmup(optimizer, T_0, T_mult) |  | ||||||
|             >>> iters = len(dataloader) |         .. code-block:: python | ||||||
|             >>> for epoch in range(20): |             ExponentialParamScheduler(start_value=2.0, decay=0.02) | ||||||
|             >>>     for i, sample in enumerate(dataloader): |  | ||||||
|             >>>         inputs, labels = sample['inputs'], sample['labels'] |     Corresponds to a decreasing schedule with values in [2.0, 0.04). | ||||||
|             >>>         optimizer.zero_grad() |  | ||||||
|             >>>         outputs = net(inputs) |  | ||||||
|             >>>         loss = criterion(outputs, labels) |  | ||||||
|             >>>         loss.backward() |  | ||||||
|             >>>         optimizer.step() |  | ||||||
|             >>>         scheduler.step(epoch + i / iters) |  | ||||||
|         This function can be called in an interleaved way. |  | ||||||
|         Example: |  | ||||||
|             >>> scheduler = CosineDecayWithWarmup(optimizer, T_0, T_mult) |  | ||||||
|             >>> for epoch in range(20): |  | ||||||
|             >>>     scheduler.step() |  | ||||||
|             >>> scheduler.step(26) |  | ||||||
|             >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) |  | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|         if epoch is None and self.last_epoch < 0: |     def __init__( | ||||||
|             epoch = 0 |         self, | ||||||
|  |         start_value: float, | ||||||
|  |         decay: float, | ||||||
|  |     ) -> None: | ||||||
|  |         self._start_value = start_value | ||||||
|  |         self._decay = decay | ||||||
|  |  | ||||||
|         if epoch is None: |     def __call__(self, where: float) -> float: | ||||||
|             epoch = self.last_epoch + 1 |         return self._start_value * (self._decay ** where) | ||||||
|             self.T_cur = self.T_cur + 1 |  | ||||||
|             if self.T_cur >= self.T_i: |  | ||||||
|                 self.T_cur = self.T_cur - self.T_i | class LinearParamScheduler(ParamScheduler): | ||||||
|                 self.T_i = self.T_i * self.T_mult |     """ | ||||||
|  |     Linearly interpolates parameter between ``start_value`` and ``end_value``. | ||||||
|  |     Can be used for either warmup or decay based on start and end values. | ||||||
|  |     The schedule is updated after every train step by default. | ||||||
|  |  | ||||||
|  |     Example: | ||||||
|  |  | ||||||
|  |         .. code-block:: python | ||||||
|  |  | ||||||
|  |             LinearParamScheduler(start_value=0.0001, end_value=0.01) | ||||||
|  |  | ||||||
|  |     Corresponds to a linear increasing schedule with values in [0.0001, 0.01) | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         start_value: float, | ||||||
|  |         end_value: float, | ||||||
|  |     ) -> None: | ||||||
|  |         self._start_value = start_value | ||||||
|  |         self._end_value = end_value | ||||||
|  |  | ||||||
|  |     def __call__(self, where: float) -> float: | ||||||
|  |         # interpolate between start and end values | ||||||
|  |         return self._end_value * where + self._start_value * (1 - where) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class MultiStepParamScheduler(ParamScheduler): | ||||||
|  |     """ | ||||||
|  |     Takes a predefined schedule for a param value, and a list of epochs or steps | ||||||
|  |     which stand for the upper boundary (excluded) of each range. | ||||||
|  |  | ||||||
|  |     Example: | ||||||
|  |  | ||||||
|  |         .. code-block:: python | ||||||
|  |  | ||||||
|  |           MultiStepParamScheduler( | ||||||
|  |             values=[0.1, 0.01, 0.001, 0.0001], | ||||||
|  |             milestones=[30, 60, 80, 120] | ||||||
|  |           ) | ||||||
|  |  | ||||||
|  |     Then the param value will be 0.1 for epochs 0-29, 0.01 for | ||||||
|  |     epochs 30-59, 0.001 for epochs 60-79, 0.0001 for epochs 80-120. | ||||||
|  |     Note that the length of values must be equal to the length of milestones | ||||||
|  |     plus one. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         values: List[float], | ||||||
|  |         num_updates: Optional[int] = None, | ||||||
|  |         milestones: Optional[List[int]] = None, | ||||||
|  |     ) -> None: | ||||||
|  |         """ | ||||||
|  |         Args: | ||||||
|  |             values: param value in each range | ||||||
|  |             num_updates: the end of the last range. If None, will use ``milestones[-1]`` | ||||||
|  |             milestones: the boundary of each range. If None, will evenly split ``num_updates`` | ||||||
|  |  | ||||||
|  |         For example, all the following combinations define the same scheduler: | ||||||
|  |  | ||||||
|  |         * num_updates=90, milestones=[30, 60], values=[1, 0.1, 0.01] | ||||||
|  |         * num_updates=90, values=[1, 0.1, 0.01] | ||||||
|  |         * milestones=[30, 60, 90], values=[1, 0.1, 0.01] | ||||||
|  |         * milestones=[3, 6, 9], values=[1, 0.1, 0.01]  (ParamScheduler is scale-invariant) | ||||||
|  |         """ | ||||||
|  |         if num_updates is None and milestones is None: | ||||||
|  |             raise ValueError("num_updates and milestones cannot both be None") | ||||||
|  |         if milestones is None: | ||||||
|  |             # Default equispaced drop_epochs behavior | ||||||
|  |             milestones = [] | ||||||
|  |             step_width = math.ceil(num_updates / float(len(values))) | ||||||
|  |             for idx in range(len(values) - 1): | ||||||
|  |                 milestones.append(step_width * (idx + 1)) | ||||||
|         else: |         else: | ||||||
|             if epoch < 0: |             if not ( | ||||||
|  |                 isinstance(milestones, Sequence) | ||||||
|  |                 and len(milestones) == len(values) - int(num_updates is not None) | ||||||
|  |             ): | ||||||
|                 raise ValueError( |                 raise ValueError( | ||||||
|                     "Expected non-negative epoch, but got {}".format(epoch) |                     "MultiStep scheduler requires a list of %d miletones" | ||||||
|  |                     % (len(values) - int(num_updates is not None)) | ||||||
|                 ) |                 ) | ||||||
|             if epoch >= self.T_0: |  | ||||||
|                 if self.T_mult == 1: |         if num_updates is None: | ||||||
|                     self.T_cur = epoch % self.T_0 |             num_updates, milestones = milestones[-1], milestones[:-1] | ||||||
|  |         if num_updates < len(values): | ||||||
|  |             raise ValueError( | ||||||
|  |                 "Total num_updates must be greater than length of param schedule" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         self._param_schedule = values | ||||||
|  |         self._num_updates = num_updates | ||||||
|  |         self._milestones: List[int] = milestones | ||||||
|  |  | ||||||
|  |         start_epoch = 0 | ||||||
|  |         for milestone in self._milestones: | ||||||
|  |             # Do not exceed the total number of epochs | ||||||
|  |             if milestone >= self._num_updates: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     "Milestone must be smaller than total number of updates: " | ||||||
|  |                     "num_updates=%d, milestone=%d" % (self._num_updates, milestone) | ||||||
|  |                 ) | ||||||
|  |             # Must be in ascending order | ||||||
|  |             if start_epoch >= milestone: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     "Milestone must be smaller than start epoch: start_epoch=%d, milestone=%d" | ||||||
|  |                     % (start_epoch, milestone) | ||||||
|  |                 ) | ||||||
|  |             start_epoch = milestone | ||||||
|  |  | ||||||
|  |     def __call__(self, where: float) -> float: | ||||||
|  |         if where > 1.0: | ||||||
|  |             raise RuntimeError( | ||||||
|  |                 f"where in ParamScheduler must be in [0, 1]: got {where}" | ||||||
|  |             ) | ||||||
|  |         epoch_num = int((where + self.WHERE_EPSILON) * self._num_updates) | ||||||
|  |         return self._param_schedule[bisect.bisect_right(self._milestones, epoch_num)] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class PolynomialDecayParamScheduler(ParamScheduler): | ||||||
|  |     """ | ||||||
|  |     Decays the param value after every epoch according to a | ||||||
|  |     polynomial function with a fixed power. | ||||||
|  |     The schedule is updated after every train step by default. | ||||||
|  |  | ||||||
|  |     Example: | ||||||
|  |  | ||||||
|  |         .. code-block:: python | ||||||
|  |  | ||||||
|  |           PolynomialDecayParamScheduler(base_value=0.1, power=0.9) | ||||||
|  |  | ||||||
|  |     Then the param value will be 0.1 for epoch 0, 0.099 for epoch 1, and | ||||||
|  |     so on. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         base_value: float, | ||||||
|  |         power: float, | ||||||
|  |     ) -> None: | ||||||
|  |         self._base_value = base_value | ||||||
|  |         self._power = power | ||||||
|  |  | ||||||
|  |     def __call__(self, where: float) -> float: | ||||||
|  |         return self._base_value * (1 - where) ** self._power | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class StepParamScheduler(ParamScheduler): | ||||||
|  |     """ | ||||||
|  |     Takes a fixed schedule for a param value.  If the length of the | ||||||
|  |     fixed schedule is less than the number of epochs, then the epochs | ||||||
|  |     are divided evenly among the param schedule. | ||||||
|  |     The schedule is updated after every train epoch by default. | ||||||
|  |  | ||||||
|  |     Example: | ||||||
|  |  | ||||||
|  |         .. code-block:: python | ||||||
|  |  | ||||||
|  |           StepParamScheduler(values=[0.1, 0.01, 0.001, 0.0001], num_updates=120) | ||||||
|  |  | ||||||
|  |     Then the param value will be 0.1 for epochs 0-29, 0.01 for | ||||||
|  |     epochs 30-59, 0.001 for epoch 60-89, 0.0001 for epochs 90-119. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         num_updates: Union[int, float], | ||||||
|  |         values: List[float], | ||||||
|  |     ) -> None: | ||||||
|  |         if num_updates <= 0: | ||||||
|  |             raise ValueError("Number of updates must be larger than 0") | ||||||
|  |         if not (isinstance(values, Sequence) and len(values) > 0): | ||||||
|  |             raise ValueError( | ||||||
|  |                 "Step scheduler requires a list of at least one param value" | ||||||
|  |             ) | ||||||
|  |         self._param_schedule = values | ||||||
|  |  | ||||||
|  |     def __call__(self, where: float) -> float: | ||||||
|  |         ind = int((where + self.WHERE_EPSILON) * len(self._param_schedule)) | ||||||
|  |         return self._param_schedule[ind] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class StepWithFixedGammaParamScheduler(ParamScheduler): | ||||||
|  |     """ | ||||||
|  |     Decays the param value by gamma at equal number of steps so as to have the | ||||||
|  |     specified total number of decays. | ||||||
|  |  | ||||||
|  |     Example: | ||||||
|  |  | ||||||
|  |         .. code-block:: python | ||||||
|  |  | ||||||
|  |           StepWithFixedGammaParamScheduler( | ||||||
|  |             base_value=0.1, gamma=0.1, num_decays=3, num_updates=120) | ||||||
|  |  | ||||||
|  |     Then the param value will be 0.1 for epochs 0-29, 0.01 for | ||||||
|  |     epochs 30-59, 0.001 for epoch 60-89, 0.0001 for epochs 90-119. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         base_value: float, | ||||||
|  |         num_decays: int, | ||||||
|  |         gamma: float, | ||||||
|  |         num_updates: int, | ||||||
|  |     ) -> None: | ||||||
|  |         for k in [base_value, gamma]: | ||||||
|  |             if not (isinstance(k, (int, float)) and k > 0): | ||||||
|  |                 raise ValueError("base_value and gamma must be positive numbers") | ||||||
|  |         for k in [num_decays, num_updates]: | ||||||
|  |             if not (isinstance(k, int) and k > 0): | ||||||
|  |                 raise ValueError("num_decays and num_updates must be positive integers") | ||||||
|  |  | ||||||
|  |         self.base_value = base_value | ||||||
|  |         self.num_decays = num_decays | ||||||
|  |         self.gamma = gamma | ||||||
|  |         self.num_updates = num_updates | ||||||
|  |         values = [base_value] | ||||||
|  |         for _ in range(num_decays): | ||||||
|  |             values.append(values[-1] * gamma) | ||||||
|  |  | ||||||
|  |         self._step_param_scheduler = StepParamScheduler( | ||||||
|  |             num_updates=num_updates, values=values | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def __call__(self, where: float) -> float: | ||||||
|  |         return self._step_param_scheduler(where) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CompositeParamScheduler(ParamScheduler): | ||||||
|  |     """ | ||||||
|  |     Composite parameter scheduler composed of intermediate schedulers. | ||||||
|  |     Takes a list of schedulers and a list of lengths corresponding to | ||||||
|  |     percentage of training each scheduler should run for. Schedulers | ||||||
|  |     are run in order. All values in lengths should sum to 1.0. | ||||||
|  |  | ||||||
|  |     Each scheduler also has a corresponding interval scale. If interval | ||||||
|  |     scale is 'fixed', the intermediate scheduler will be run without any rescaling | ||||||
|  |     of the time. If interval scale is 'rescaled', intermediate scheduler is | ||||||
|  |     run such that each scheduler will start and end at the same values as it | ||||||
|  |     would if it were the only scheduler. Default is 'rescaled' for all schedulers. | ||||||
|  |  | ||||||
|  |     Example: | ||||||
|  |  | ||||||
|  |         .. code-block:: python | ||||||
|  |  | ||||||
|  |               schedulers = [ | ||||||
|  |                 ConstantParamScheduler(value=0.42), | ||||||
|  |                 CosineParamScheduler(start_value=0.42, end_value=1e-4) | ||||||
|  |               ] | ||||||
|  |               CompositeParamScheduler( | ||||||
|  |                 schedulers=schedulers, | ||||||
|  |                 interval_scaling=['rescaled', 'rescaled'], | ||||||
|  |                 lengths=[0.3, 0.7]) | ||||||
|  |  | ||||||
|  |     The parameter value will be 0.42 for the first [0%, 30%) of steps, | ||||||
|  |     and then will cosine decay from 0.42 to 0.0001 for [30%, 100%) of | ||||||
|  |     training. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         schedulers: Sequence[ParamScheduler], | ||||||
|  |         lengths: List[float], | ||||||
|  |         interval_scaling: Sequence[str], | ||||||
|  |     ) -> None: | ||||||
|  |         if len(schedulers) != len(lengths): | ||||||
|  |             raise ValueError("Schedulers and lengths must be same length") | ||||||
|  |         if len(schedulers) == 0: | ||||||
|  |             raise ValueError( | ||||||
|  |                 "There must be at least one scheduler in the composite scheduler" | ||||||
|  |             ) | ||||||
|  |         if abs(sum(lengths) - 1.0) >= 1e-3: | ||||||
|  |             raise ValueError("The sum of all values in lengths must be 1") | ||||||
|  |         if sum(lengths) != 1.0: | ||||||
|  |             lengths[-1] = 1.0 - sum(lengths[:-1]) | ||||||
|  |         for s in interval_scaling: | ||||||
|  |             if s not in ["rescaled", "fixed"]: | ||||||
|  |                 raise ValueError(f"Unsupported interval_scaling: {s}") | ||||||
|  |  | ||||||
|  |         self._lengths = lengths | ||||||
|  |         self._schedulers = schedulers | ||||||
|  |         self._interval_scaling = interval_scaling | ||||||
|  |  | ||||||
|  |     def __call__(self, where: float) -> float: | ||||||
|  |         # Find scheduler corresponding to where | ||||||
|  |         i = 0 | ||||||
|  |         running_total = self._lengths[i] | ||||||
|  |         while (where + self.WHERE_EPSILON) > running_total and i < len( | ||||||
|  |             self._schedulers | ||||||
|  |         ) - 1: | ||||||
|  |             i += 1 | ||||||
|  |             running_total += self._lengths[i] | ||||||
|  |         scheduler = self._schedulers[i] | ||||||
|  |         scheduler_where = where | ||||||
|  |         interval_scale = self._interval_scaling[i] | ||||||
|  |         if interval_scale == "rescaled": | ||||||
|  |             # Calculate corresponding where % for scheduler | ||||||
|  |             scheduler_start = running_total - self._lengths[i] | ||||||
|  |             scheduler_where = (where - scheduler_start) / self._lengths[i] | ||||||
|  |         return scheduler(scheduler_where) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class WarmupParamScheduler(CompositeParamScheduler): | ||||||
|  |     """ | ||||||
|  |     Add an initial warmup stage to another scheduler. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         scheduler: ParamScheduler, | ||||||
|  |         warmup_factor: float, | ||||||
|  |         warmup_length: float, | ||||||
|  |         warmup_method: str = "linear", | ||||||
|  |     ): | ||||||
|  |         """ | ||||||
|  |         Args: | ||||||
|  |             scheduler: warmup will be added at the beginning of this scheduler | ||||||
|  |             warmup_factor: the factor w.r.t the initial value of ``scheduler``, e.g. 0.001 | ||||||
|  |             warmup_length: the relative length (in [0, 1]) of warmup steps w.r.t the entire | ||||||
|  |                 training, e.g. 0.01 | ||||||
|  |             warmup_method: one of "linear" or "constant" | ||||||
|  |         """ | ||||||
|  |         end_value = scheduler(warmup_length)  # the value to reach when warmup ends | ||||||
|  |         start_value = warmup_factor * scheduler(0.0) | ||||||
|  |         if warmup_method == "constant": | ||||||
|  |             warmup = ConstantParamScheduler(start_value) | ||||||
|  |         elif warmup_method == "linear": | ||||||
|  |             warmup = LinearParamScheduler(start_value, end_value) | ||||||
|         else: |         else: | ||||||
|                     n = int( |             raise ValueError("Unknown warmup method: {}".format(warmup_method)) | ||||||
|                         math.log( |         super().__init__( | ||||||
|                             (epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult |             [warmup, scheduler], | ||||||
|  |             interval_scaling=["rescaled", "fixed"], | ||||||
|  |             lengths=[warmup_length, 1 - warmup_length], | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ##### LR Scheduler | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class LRMultiplier(torch.optim.lr_scheduler._LRScheduler): | ||||||
|  |     """ | ||||||
|  |     A LRScheduler which uses fvcore :class:`ParamScheduler` to multiply the | ||||||
|  |     learning rate of each param in the optimizer. | ||||||
|  |     Every step, the learning rate of each parameter becomes its initial value | ||||||
|  |     multiplied by the output of the given :class:`ParamScheduler`. | ||||||
|  |     The absolute learning rate value of each parameter can be different. | ||||||
|  |     This scheduler can be used as long as the relative scale among them do | ||||||
|  |     not change during training. | ||||||
|  |     Examples: | ||||||
|  |     :: | ||||||
|  |         LRMultiplier( | ||||||
|  |             opt, | ||||||
|  |             WarmupParamScheduler( | ||||||
|  |                 MultiStepParamScheduler( | ||||||
|  |                     [1, 0.1, 0.01], | ||||||
|  |                     milestones=[60000, 80000], | ||||||
|  |                     num_updates=90000, | ||||||
|  |                 ), 0.001, 100 / 90000 | ||||||
|  |             ), | ||||||
|  |             max_iter=90000 | ||||||
|         ) |         ) | ||||||
|                     self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / ( |     """ | ||||||
|                         self.T_mult - 1 |  | ||||||
|  |     # NOTES: in the most general case, every LR can use its own scheduler. | ||||||
|  |     # Supporting this requires interaction with the optimizer when its parameter | ||||||
|  |     # group is initialized. For example, classyvision implements its own optimizer | ||||||
|  |     # that allows different schedulers for every parameter group. | ||||||
|  |     # To avoid this complexity, we use this class to support the most common cases | ||||||
|  |     # where the relative scale among all LRs stay unchanged during training.  In this | ||||||
|  |     # case we only need a total of one scheduler that defines the relative LR multiplier. | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         optimizer: torch.optim.Optimizer, | ||||||
|  |         multiplier: ParamScheduler, | ||||||
|  |         max_iter: int, | ||||||
|  |         last_iter: int = -1, | ||||||
|  |     ): | ||||||
|  |         """ | ||||||
|  |         Args: | ||||||
|  |             optimizer, last_iter: See ``torch.optim.lr_scheduler._LRScheduler``. | ||||||
|  |                 ``last_iter`` is the same as ``last_epoch``. | ||||||
|  |             multiplier: a fvcore ParamScheduler that defines the multiplier on | ||||||
|  |                 every LR of the optimizer | ||||||
|  |             max_iter: the total number of training iterations | ||||||
|  |         """ | ||||||
|  |         if not isinstance(multiplier, ParamScheduler): | ||||||
|  |             raise ValueError( | ||||||
|  |                 "_LRMultiplier(multiplier=) must be an instance of fvcore " | ||||||
|  |                 f"ParamScheduler. Got {multiplier} instead." | ||||||
|             ) |             ) | ||||||
|                     self.T_i = self.T_0 * self.T_mult ** (n) |         self._multiplier = multiplier | ||||||
|             else: |         self._max_iter = max_iter | ||||||
|                 self.T_i = self.T_0 |         super().__init__(optimizer, last_epoch=last_iter) | ||||||
|                 self.T_cur = epoch |  | ||||||
|         self.last_epoch = math.floor(epoch) |  | ||||||
|  |  | ||||||
|         class _enable_get_lr_call: |     def state_dict(self): | ||||||
|             def __init__(self, o): |         # fvcore schedulers are stateless. Only keep pytorch scheduler states | ||||||
|                 self.o = o |         return {"base_lrs": self.base_lrs, "last_epoch": self.last_epoch} | ||||||
|  |  | ||||||
|             def __enter__(self): |     def get_lr(self) -> List[float]: | ||||||
|                 self.o._get_lr_called_within_step = True |         multiplier = self._multiplier(self.last_epoch / self._max_iter) | ||||||
|                 return self |         return [base_lr * multiplier for base_lr in self.base_lrs] | ||||||
|  |  | ||||||
|             def __exit__(self, type, value, traceback): |  | ||||||
|                 self.o._get_lr_called_within_step = False |  | ||||||
|                 return self |  | ||||||
|  |  | ||||||
|         with _enable_get_lr_call(self): |  | ||||||
|             for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())): |  | ||||||
|                 param_group, lr = data |  | ||||||
|                 param_group["lr"] = lr |  | ||||||
|                 self.print_lr(self.verbose, i, lr, epoch) |  | ||||||
|  |  | ||||||
|         self._last_lr = [group["lr"] for group in self.optimizer.param_groups] |  | ||||||
|   | |||||||
| @@ -5,6 +5,3 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
|  |  | ||||||
| from .transformers import get_transformer | from .transformers import get_transformer | ||||||
|  |  | ||||||
| def obtain_model(config): |  | ||||||
|   raise NotImplementedError |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user