Update xmisc.scheduler/sampler
This commit is contained in:
		
							
								
								
									
										41
									
								
								.github/workflows/test-misc.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								.github/workflows/test-misc.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,41 @@ | ||||
| name: Test Xmisc | ||||
| on: | ||||
|   push: | ||||
|     branches: | ||||
|       - main | ||||
|   pull_request: | ||||
|     branches: | ||||
|       - main | ||||
|  | ||||
|  | ||||
| jobs: | ||||
|   build: | ||||
|     strategy: | ||||
|       matrix: | ||||
|         os: [ubuntu-16.04, ubuntu-18.04, ubuntu-20.04, macos-latest] | ||||
|         python-version: [3.6, 3.7, 3.8, 3.9] | ||||
|  | ||||
|     runs-on: ${{ matrix.os }} | ||||
|     steps: | ||||
|       - uses: actions/checkout@v2 | ||||
|  | ||||
|       - name: Set up Python ${{ matrix.python-version }} | ||||
|         uses: actions/setup-python@v2 | ||||
|         with: | ||||
|           python-version: ${{ matrix.python-version }} | ||||
|  | ||||
|       - name: Install XAutoDL from source | ||||
|         run: | | ||||
|           python setup.py install | ||||
|  | ||||
|       - name: Test Xmisc | ||||
|         run: | | ||||
|           python -m pip install pytest numpy | ||||
|           python -m pip install torch torchvision | ||||
|           python -m pip install parameterized | ||||
|           echo $PWD | ||||
|           echo "Show what we have here:" | ||||
|           ls | ||||
|           python --version | ||||
|           python -m pytest ./tests/test_misc* -s | ||||
|         shell: bash | ||||
| @@ -46,8 +46,7 @@ def main(args): | ||||
|  | ||||
|     train_loader = torch.utils.data.DataLoader( | ||||
|         train_data, | ||||
|         batch_size=args.batch_size, | ||||
|         shuffle=True, | ||||
|         batch_sampler=xmisc.BatchSampler(train_data, args.batch_size, args.steps), | ||||
|         num_workers=args.workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
| @@ -57,6 +56,7 @@ def main(args): | ||||
|         shuffle=False, | ||||
|         num_workers=args.workers, | ||||
|         pin_memory=True, | ||||
|         drop_last=False, | ||||
|     ) | ||||
|  | ||||
|     logger.log("The training loader: {:}".format(train_loader)) | ||||
| @@ -73,6 +73,9 @@ def main(args): | ||||
|     logger.log("The loss is {:}".format(loss)) | ||||
|  | ||||
|     model, loss = torch.nn.DataParallel(model).cuda(), loss.cuda() | ||||
|     scheduler = xmisc.LRMultiplier( | ||||
|         optimizer, xmisc.get_scheduler(args.scheduler, args.lr), args.steps | ||||
|     ) | ||||
|  | ||||
|     import pdb | ||||
|  | ||||
| @@ -241,10 +244,11 @@ if __name__ == "__main__": | ||||
|         "--valid_data_config", type=str, help="The validation dataset config path." | ||||
|     ) | ||||
|     parser.add_argument("--data_path", type=str, help="The path to the dataset.") | ||||
|     parser.add_argument("--algorithm", type=str, help="The algorithm.") | ||||
|     # Optimization options | ||||
|     parser.add_argument("--lr", type=float, help="The learning rate") | ||||
|     parser.add_argument("--weight_decay", type=float, help="The weight decay") | ||||
|     parser.add_argument("--scheduler", type=str, help="The scheduler indicator.") | ||||
|     parser.add_argument("--steps", type=int, help="The total number of steps.") | ||||
|     parser.add_argument("--batch_size", type=int, default=2, help="The batch size.") | ||||
|     parser.add_argument("--workers", type=int, default=4, help="The number of workers") | ||||
|     # Random Seed | ||||
|   | ||||
							
								
								
									
										119
									
								
								notebooks/spaces-xmisc/scheduler.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								notebooks/spaces-xmisc/scheduler.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| @@ -1,76 +0,0 @@ | ||||
| import os | ||||
| import sys | ||||
| import qlib | ||||
| import pprint | ||||
| import numpy as np | ||||
| import pandas as pd | ||||
|  | ||||
| from pathlib import Path | ||||
| import torch | ||||
|  | ||||
| __file__ = os.path.dirname(os.path.realpath("__file__")) | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / "lib").resolve() | ||||
| print("library path: {:}".format(lib_dir)) | ||||
| assert lib_dir.exists(), "{:} does not exist".format(lib_dir) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from trade_models import get_transformer | ||||
|  | ||||
| from qlib import config as qconfig | ||||
| from qlib.utils import init_instance_by_config | ||||
| from qlib.model.base import Model | ||||
| from qlib.data.dataset import DatasetH | ||||
| from qlib.data.dataset.handler import DataHandlerLP | ||||
|  | ||||
| qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN) | ||||
|  | ||||
| dataset_config = { | ||||
|     "class": "DatasetH", | ||||
|     "module_path": "qlib.data.dataset", | ||||
|     "kwargs": { | ||||
|         "handler": { | ||||
|             "class": "Alpha360", | ||||
|             "module_path": "qlib.contrib.data.handler", | ||||
|             "kwargs": { | ||||
|                 "start_time": "2008-01-01", | ||||
|                 "end_time": "2020-08-01", | ||||
|                 "fit_start_time": "2008-01-01", | ||||
|                 "fit_end_time": "2014-12-31", | ||||
|                 "instruments": "csi100", | ||||
|             }, | ||||
|         }, | ||||
|         "segments": { | ||||
|             "train": ("2008-01-01", "2014-12-31"), | ||||
|             "valid": ("2015-01-01", "2016-12-31"), | ||||
|             "test": ("2017-01-01", "2020-08-01"), | ||||
|         }, | ||||
|     }, | ||||
| } | ||||
| pprint.pprint(dataset_config) | ||||
| dataset = init_instance_by_config(dataset_config) | ||||
|  | ||||
| df_train, df_valid, df_test = dataset.prepare( | ||||
|     ["train", "valid", "test"], | ||||
|     col_set=["feature", "label"], | ||||
|     data_key=DataHandlerLP.DK_L, | ||||
| ) | ||||
| model = get_transformer(None) | ||||
| print(model) | ||||
|  | ||||
| features = torch.from_numpy(df_train["feature"].values).float() | ||||
| labels = torch.from_numpy(df_train["label"].values).squeeze().float() | ||||
|  | ||||
| batch = list(range(2000)) | ||||
| predicts = model(features[batch]) | ||||
| mask = ~torch.isnan(labels[batch]) | ||||
|  | ||||
| pred = predicts[mask] | ||||
| label = labels[batch][mask] | ||||
|  | ||||
| loss = torch.nn.functional.mse_loss(pred, label) | ||||
|  | ||||
| from sklearn.metrics import mean_squared_error | ||||
|  | ||||
| mse_loss = mean_squared_error(pred.numpy(), label.numpy()) | ||||
| @@ -28,4 +28,4 @@ python ./exps/basic/xmain.py --save_dir ${save_dir} --rand_seed ${rseed} \ | ||||
| 	--model_config ./configs/yaml.model/vit-cifar10.s0 \ | ||||
| 	--optim_config ./configs/yaml.opt/vit.cifar \ | ||||
| 	--loss_config ./configs/yaml.loss/cross-entropy \ | ||||
| 	--lr 0.003 --weight_decay 0.3  | ||||
| 	--lr 0.003 --weight_decay 0.3 --scheduler warm-cos --steps 10000 | ||||
|   | ||||
							
								
								
									
										73
									
								
								tests/test_misc_scheduler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								tests/test_misc_scheduler.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,73 @@ | ||||
| #################################################### | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. # | ||||
| #################################################### | ||||
| # Inspired from https://github.com/facebookresearch/detectron2/blob/master/tests/test_scheduler.py | ||||
| #################################################### | ||||
| import math | ||||
| import numpy as np | ||||
| from unittest import TestCase | ||||
|  | ||||
| import torch | ||||
|  | ||||
| from xautodl.xmisc.scheduler_utils import CosineParamScheduler, MultiStepParamScheduler | ||||
| from xautodl.xmisc.scheduler_utils import LRMultiplier, WarmupParamScheduler | ||||
|  | ||||
|  | ||||
| class TestScheduler(TestCase): | ||||
|     """Test the scheduler.""" | ||||
|  | ||||
|     def test_warmup_multistep(self): | ||||
|         p = torch.nn.Parameter(torch.zeros(0)) | ||||
|         opt = torch.optim.SGD([p], lr=5) | ||||
|  | ||||
|         multiplier = WarmupParamScheduler( | ||||
|             MultiStepParamScheduler( | ||||
|                 [1, 0.1, 0.01, 0.001], | ||||
|                 milestones=[10, 15, 20], | ||||
|                 num_updates=30, | ||||
|             ), | ||||
|             0.001, | ||||
|             5 / 30, | ||||
|         ) | ||||
|         sched = LRMultiplier(opt, multiplier, 30) | ||||
|         # This is an equivalent of: | ||||
|         # sched = WarmupMultiStepLR( | ||||
|         # opt, milestones=[10, 15, 20], gamma=0.1, warmup_factor=0.001, warmup_iters=5) | ||||
|  | ||||
|         p.sum().backward() | ||||
|         opt.step() | ||||
|  | ||||
|         lrs = [0.005] | ||||
|         for _ in range(30): | ||||
|             sched.step() | ||||
|             lrs.append(opt.param_groups[0]["lr"]) | ||||
|         self.assertTrue(np.allclose(lrs[:5], [0.005, 1.004, 2.003, 3.002, 4.001])) | ||||
|         self.assertTrue(np.allclose(lrs[5:10], 5.0)) | ||||
|         self.assertTrue(np.allclose(lrs[10:15], 0.5)) | ||||
|         self.assertTrue(np.allclose(lrs[15:20], 0.05)) | ||||
|         self.assertTrue(np.allclose(lrs[20:], 0.005)) | ||||
|  | ||||
|     def test_warmup_cosine(self): | ||||
|         p = torch.nn.Parameter(torch.zeros(0)) | ||||
|         opt = torch.optim.SGD([p], lr=5) | ||||
|         multiplier = WarmupParamScheduler( | ||||
|             CosineParamScheduler(1, 0), | ||||
|             0.001, | ||||
|             5 / 30, | ||||
|         ) | ||||
|         sched = LRMultiplier(opt, multiplier, 30) | ||||
|  | ||||
|         p.sum().backward() | ||||
|         opt.step() | ||||
|         self.assertEqual(opt.param_groups[0]["lr"], 0.005) | ||||
|         lrs = [0.005] | ||||
|  | ||||
|         for _ in range(30): | ||||
|             sched.step() | ||||
|             lrs.append(opt.param_groups[0]["lr"]) | ||||
|         for idx, lr in enumerate(lrs): | ||||
|             expected_cosine = 2.5 * (1.0 + math.cos(math.pi * idx / 30)) | ||||
|             if idx >= 5: | ||||
|                 self.assertAlmostEqual(lr, expected_cosine) | ||||
|             else: | ||||
|                 self.assertNotAlmostEqual(lr, expected_cosine) | ||||
| @@ -10,3 +10,23 @@ from .yaml_utils import load_yaml | ||||
| from .torch_utils import count_parameters | ||||
|  | ||||
| from .logger_utils import Logger | ||||
|  | ||||
| # sampler | ||||
| from .sampler_utils import BatchSampler | ||||
|  | ||||
| # scheduler related | ||||
| from .scheduler_utils import CosineParamScheduler, WarmupParamScheduler, LRMultiplier | ||||
|  | ||||
|  | ||||
| def get_scheduler(indicator, lr): | ||||
|     if indicator == "warm-cos": | ||||
|         multiplier = WarmupParamScheduler( | ||||
|             CosineParamScheduler(lr, lr * 1e-3), | ||||
|             warmup_factor=0.001, | ||||
|             warmup_length=0.05, | ||||
|             warmup_method="linear", | ||||
|         ) | ||||
|  | ||||
|     else: | ||||
|         raise ValueError("Unknown indicator: {:}".format(indicator)) | ||||
|     return multiplier | ||||
|   | ||||
							
								
								
									
										32
									
								
								xautodl/xmisc/sampler_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								xautodl/xmisc/sampler_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||
| ##################################################### | ||||
| import random | ||||
|  | ||||
|  | ||||
| class BatchSampler: | ||||
|     """A batch sampler used for single machine training.""" | ||||
|  | ||||
|     def __init__(self, dataset, batch, steps): | ||||
|         self._num_per_epoch = len(dataset) | ||||
|         self._iter_per_epoch = self._num_per_epoch // batch | ||||
|         self._steps = steps | ||||
|         self._batch = batch | ||||
|         if self._num_per_epoch < self._batch: | ||||
|             raise ValueError( | ||||
|                 "The dataset size must be larger than batch={:}".format(batch) | ||||
|             ) | ||||
|         self._indexes = list(range(self._num_per_epoch)) | ||||
|  | ||||
|     def __iter__(self): | ||||
|         """ | ||||
|         yield a batch of indexes using random sampling | ||||
|         """ | ||||
|         for i in range(self._steps): | ||||
|             if i % self._iter_per_epoch == 0: | ||||
|                 random.shuffle(self._indexes) | ||||
|             j = i % self._iter_per_epoch | ||||
|             yield self._indexes[j * self._batch : (j + 1) * self._batch] | ||||
|  | ||||
|     def __len__(self): | ||||
|         return self._steps | ||||
| @@ -1,136 +1,532 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||
| ##################################################### | ||||
| from torch.optim.lr_scheduler import _LRScheduler | ||||
| #################################################### | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. # | ||||
| #################################################### | ||||
| # Borrowed from https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/param_scheduler.py | ||||
| #           and https://github.com/facebookresearch/detectron2/blob/master/detectron2/solver/lr_scheduler.py | ||||
| #################################################### | ||||
| import torch | ||||
|  | ||||
| import bisect | ||||
| import math | ||||
| from typing import List, Optional, Sequence, Union | ||||
|  | ||||
| __all__ = [ | ||||
|     "ParamScheduler", | ||||
|     "ConstantParamScheduler", | ||||
|     "CosineParamScheduler", | ||||
|     "ExponentialParamScheduler", | ||||
|     "LinearParamScheduler", | ||||
|     "CompositeParamScheduler", | ||||
|     "MultiStepParamScheduler", | ||||
|     "StepParamScheduler", | ||||
|     "StepWithFixedGammaParamScheduler", | ||||
|     "PolynomialDecayParamScheduler", | ||||
|     "WarmupParamScheduler", | ||||
|     "LRMultiplier", | ||||
| ] | ||||
|  | ||||
|  | ||||
| class CosineDecayWithWarmup(_LRScheduler): | ||||
|     r"""Set the learning rate of each parameter group using a cosine annealing | ||||
|     schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` | ||||
|     is the number of epochs since the last restart and :math:`T_{i}` is the number | ||||
|     of epochs between two warm restarts in SGDR: | ||||
|     .. math:: | ||||
|         \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + | ||||
|         \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right) | ||||
|     When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. | ||||
|     When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. | ||||
|     It has been proposed in | ||||
|     `SGDR: Stochastic Gradient Descent with Warm Restarts`_. | ||||
|     Args: | ||||
|         optimizer (Optimizer): Wrapped optimizer. | ||||
|         T_0 (int): Number of iterations for the first restart. | ||||
|         T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1. | ||||
|         eta_min (float, optional): Minimum learning rate. Default: 0. | ||||
|         last_epoch (int, optional): The index of last epoch. Default: -1. | ||||
|         verbose (bool): If ``True``, prints a message to stdout for | ||||
|             each update. Default: ``False``. | ||||
|     .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: | ||||
|         https://arxiv.org/abs/1608.03983 | ||||
| class ParamScheduler: | ||||
|     """ | ||||
|     Base class for parameter schedulers. | ||||
|     A parameter scheduler defines a mapping from a progress value in [0, 1) to | ||||
|     a number (e.g. learning rate). | ||||
|     """ | ||||
|  | ||||
|     # To be used for comparisons with where | ||||
|     WHERE_EPSILON = 1e-6 | ||||
|  | ||||
|     def __call__(self, where: float) -> float: | ||||
|         """ | ||||
|         Get the value of the param for a given point at training. | ||||
|  | ||||
|         We update params (such as learning rate) based on the percent progress | ||||
|         of training completed. This allows a scheduler to be agnostic to the | ||||
|         exact length of a particular run (e.g. 120 epochs vs 90 epochs), as | ||||
|         long as the relative progress where params should be updated is the same. | ||||
|         However, it assumes that the total length of training is known. | ||||
|  | ||||
|         Args: | ||||
|             where: A float in [0,1) that represents how far training has progressed | ||||
|  | ||||
|         """ | ||||
|         raise NotImplementedError("Param schedulers must override __call__") | ||||
|  | ||||
|  | ||||
| class ConstantParamScheduler(ParamScheduler): | ||||
|     """ | ||||
|     Returns a constant value for a param. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, value: float) -> None: | ||||
|         self._value = value | ||||
|  | ||||
|     def __call__(self, where: float) -> float: | ||||
|         if where >= 1.0: | ||||
|             raise RuntimeError( | ||||
|                 f"where in ParamScheduler must be in [0, 1]: got {where}" | ||||
|             ) | ||||
|         return self._value | ||||
|  | ||||
|  | ||||
| class CosineParamScheduler(ParamScheduler): | ||||
|     """ | ||||
|     Cosine decay or cosine warmup schedules based on start and end values. | ||||
|     The schedule is updated based on the fraction of training progress. | ||||
|     The schedule was proposed in 'SGDR: Stochastic Gradient Descent with | ||||
|     Warm Restarts' (https://arxiv.org/abs/1608.03983). Note that this class | ||||
|     only implements the cosine annealing part of SGDR, and not the restarts. | ||||
|  | ||||
|     Example: | ||||
|  | ||||
|         .. code-block:: python | ||||
|  | ||||
|           CosineParamScheduler(start_value=0.1, end_value=0.0001) | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False | ||||
|     ): | ||||
|         if T_0 <= 0 or not isinstance(T_0, int): | ||||
|             raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) | ||||
|         if T_mult < 1 or not isinstance(T_mult, int): | ||||
|             raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) | ||||
|         self.T_0 = T_0 | ||||
|         self.T_i = T_0 | ||||
|         self.T_mult = T_mult | ||||
|         self.eta_min = eta_min | ||||
|         self, | ||||
|         start_value: float, | ||||
|         end_value: float, | ||||
|     ) -> None: | ||||
|         self._start_value = start_value | ||||
|         self._end_value = end_value | ||||
|  | ||||
|         super(CosineDecayWithWarmup, self).__init__(optimizer, last_epoch, verbose) | ||||
|     def __call__(self, where: float) -> float: | ||||
|         return self._end_value + 0.5 * (self._start_value - self._end_value) * ( | ||||
|             1 + math.cos(math.pi * where) | ||||
|         ) | ||||
|  | ||||
|         self.T_cur = self.last_epoch | ||||
|  | ||||
|     def get_lr(self): | ||||
|         if not self._get_lr_called_within_step: | ||||
|             warnings.warn( | ||||
|                 "To get the last learning rate computed by the scheduler, " | ||||
|                 "please use `get_last_lr()`.", | ||||
|                 UserWarning, | ||||
| class ExponentialParamScheduler(ParamScheduler): | ||||
|     """ | ||||
|     Exponetial schedule parameterized by a start value and decay. | ||||
|     The schedule is updated based on the fraction of training | ||||
|     progress, `where`, with the formula | ||||
|     `param_t = start_value * (decay ** where)`. | ||||
|  | ||||
|     Example: | ||||
|  | ||||
|         .. code-block:: python | ||||
|             ExponentialParamScheduler(start_value=2.0, decay=0.02) | ||||
|  | ||||
|     Corresponds to a decreasing schedule with values in [2.0, 0.04). | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         start_value: float, | ||||
|         decay: float, | ||||
|     ) -> None: | ||||
|         self._start_value = start_value | ||||
|         self._decay = decay | ||||
|  | ||||
|     def __call__(self, where: float) -> float: | ||||
|         return self._start_value * (self._decay ** where) | ||||
|  | ||||
|  | ||||
| class LinearParamScheduler(ParamScheduler): | ||||
|     """ | ||||
|     Linearly interpolates parameter between ``start_value`` and ``end_value``. | ||||
|     Can be used for either warmup or decay based on start and end values. | ||||
|     The schedule is updated after every train step by default. | ||||
|  | ||||
|     Example: | ||||
|  | ||||
|         .. code-block:: python | ||||
|  | ||||
|             LinearParamScheduler(start_value=0.0001, end_value=0.01) | ||||
|  | ||||
|     Corresponds to a linear increasing schedule with values in [0.0001, 0.01) | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         start_value: float, | ||||
|         end_value: float, | ||||
|     ) -> None: | ||||
|         self._start_value = start_value | ||||
|         self._end_value = end_value | ||||
|  | ||||
|     def __call__(self, where: float) -> float: | ||||
|         # interpolate between start and end values | ||||
|         return self._end_value * where + self._start_value * (1 - where) | ||||
|  | ||||
|  | ||||
| class MultiStepParamScheduler(ParamScheduler): | ||||
|     """ | ||||
|     Takes a predefined schedule for a param value, and a list of epochs or steps | ||||
|     which stand for the upper boundary (excluded) of each range. | ||||
|  | ||||
|     Example: | ||||
|  | ||||
|         .. code-block:: python | ||||
|  | ||||
|           MultiStepParamScheduler( | ||||
|             values=[0.1, 0.01, 0.001, 0.0001], | ||||
|             milestones=[30, 60, 80, 120] | ||||
|           ) | ||||
|  | ||||
|     Then the param value will be 0.1 for epochs 0-29, 0.01 for | ||||
|     epochs 30-59, 0.001 for epochs 60-79, 0.0001 for epochs 80-120. | ||||
|     Note that the length of values must be equal to the length of milestones | ||||
|     plus one. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         values: List[float], | ||||
|         num_updates: Optional[int] = None, | ||||
|         milestones: Optional[List[int]] = None, | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Args: | ||||
|             values: param value in each range | ||||
|             num_updates: the end of the last range. If None, will use ``milestones[-1]`` | ||||
|             milestones: the boundary of each range. If None, will evenly split ``num_updates`` | ||||
|  | ||||
|         For example, all the following combinations define the same scheduler: | ||||
|  | ||||
|         * num_updates=90, milestones=[30, 60], values=[1, 0.1, 0.01] | ||||
|         * num_updates=90, values=[1, 0.1, 0.01] | ||||
|         * milestones=[30, 60, 90], values=[1, 0.1, 0.01] | ||||
|         * milestones=[3, 6, 9], values=[1, 0.1, 0.01]  (ParamScheduler is scale-invariant) | ||||
|         """ | ||||
|         if num_updates is None and milestones is None: | ||||
|             raise ValueError("num_updates and milestones cannot both be None") | ||||
|         if milestones is None: | ||||
|             # Default equispaced drop_epochs behavior | ||||
|             milestones = [] | ||||
|             step_width = math.ceil(num_updates / float(len(values))) | ||||
|             for idx in range(len(values) - 1): | ||||
|                 milestones.append(step_width * (idx + 1)) | ||||
|         else: | ||||
|             if not ( | ||||
|                 isinstance(milestones, Sequence) | ||||
|                 and len(milestones) == len(values) - int(num_updates is not None) | ||||
|             ): | ||||
|                 raise ValueError( | ||||
|                     "MultiStep scheduler requires a list of %d miletones" | ||||
|                     % (len(values) - int(num_updates is not None)) | ||||
|                 ) | ||||
|  | ||||
|         if num_updates is None: | ||||
|             num_updates, milestones = milestones[-1], milestones[:-1] | ||||
|         if num_updates < len(values): | ||||
|             raise ValueError( | ||||
|                 "Total num_updates must be greater than length of param schedule" | ||||
|             ) | ||||
|  | ||||
|         return [ | ||||
|             self.eta_min | ||||
|             + (base_lr - self.eta_min) | ||||
|             * (1 + math.cos(math.pi * self.T_cur / self.T_i)) | ||||
|             / 2 | ||||
|             for base_lr in self.base_lrs | ||||
|         ] | ||||
|         self._param_schedule = values | ||||
|         self._num_updates = num_updates | ||||
|         self._milestones: List[int] = milestones | ||||
|  | ||||
|     def step(self, epoch=None): | ||||
|         """Step could be called after every batch update | ||||
|         Example: | ||||
|             >>> scheduler = CosineDecayWithWarmup(optimizer, T_0, T_mult) | ||||
|             >>> iters = len(dataloader) | ||||
|             >>> for epoch in range(20): | ||||
|             >>>     for i, sample in enumerate(dataloader): | ||||
|             >>>         inputs, labels = sample['inputs'], sample['labels'] | ||||
|             >>>         optimizer.zero_grad() | ||||
|             >>>         outputs = net(inputs) | ||||
|             >>>         loss = criterion(outputs, labels) | ||||
|             >>>         loss.backward() | ||||
|             >>>         optimizer.step() | ||||
|             >>>         scheduler.step(epoch + i / iters) | ||||
|         This function can be called in an interleaved way. | ||||
|         Example: | ||||
|             >>> scheduler = CosineDecayWithWarmup(optimizer, T_0, T_mult) | ||||
|             >>> for epoch in range(20): | ||||
|             >>>     scheduler.step() | ||||
|             >>> scheduler.step(26) | ||||
|             >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) | ||||
|         """ | ||||
|  | ||||
|         if epoch is None and self.last_epoch < 0: | ||||
|             epoch = 0 | ||||
|  | ||||
|         if epoch is None: | ||||
|             epoch = self.last_epoch + 1 | ||||
|             self.T_cur = self.T_cur + 1 | ||||
|             if self.T_cur >= self.T_i: | ||||
|                 self.T_cur = self.T_cur - self.T_i | ||||
|                 self.T_i = self.T_i * self.T_mult | ||||
|         else: | ||||
|             if epoch < 0: | ||||
|         start_epoch = 0 | ||||
|         for milestone in self._milestones: | ||||
|             # Do not exceed the total number of epochs | ||||
|             if milestone >= self._num_updates: | ||||
|                 raise ValueError( | ||||
|                     "Expected non-negative epoch, but got {}".format(epoch) | ||||
|                     "Milestone must be smaller than total number of updates: " | ||||
|                     "num_updates=%d, milestone=%d" % (self._num_updates, milestone) | ||||
|                 ) | ||||
|             if epoch >= self.T_0: | ||||
|                 if self.T_mult == 1: | ||||
|                     self.T_cur = epoch % self.T_0 | ||||
|                 else: | ||||
|                     n = int( | ||||
|                         math.log( | ||||
|                             (epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult | ||||
|                         ) | ||||
|                     ) | ||||
|                     self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / ( | ||||
|                         self.T_mult - 1 | ||||
|                     ) | ||||
|                     self.T_i = self.T_0 * self.T_mult ** (n) | ||||
|             else: | ||||
|                 self.T_i = self.T_0 | ||||
|                 self.T_cur = epoch | ||||
|         self.last_epoch = math.floor(epoch) | ||||
|             # Must be in ascending order | ||||
|             if start_epoch >= milestone: | ||||
|                 raise ValueError( | ||||
|                     "Milestone must be smaller than start epoch: start_epoch=%d, milestone=%d" | ||||
|                     % (start_epoch, milestone) | ||||
|                 ) | ||||
|             start_epoch = milestone | ||||
|  | ||||
|         class _enable_get_lr_call: | ||||
|             def __init__(self, o): | ||||
|                 self.o = o | ||||
|     def __call__(self, where: float) -> float: | ||||
|         if where > 1.0: | ||||
|             raise RuntimeError( | ||||
|                 f"where in ParamScheduler must be in [0, 1]: got {where}" | ||||
|             ) | ||||
|         epoch_num = int((where + self.WHERE_EPSILON) * self._num_updates) | ||||
|         return self._param_schedule[bisect.bisect_right(self._milestones, epoch_num)] | ||||
|  | ||||
|             def __enter__(self): | ||||
|                 self.o._get_lr_called_within_step = True | ||||
|                 return self | ||||
|  | ||||
|             def __exit__(self, type, value, traceback): | ||||
|                 self.o._get_lr_called_within_step = False | ||||
|                 return self | ||||
| class PolynomialDecayParamScheduler(ParamScheduler): | ||||
|     """ | ||||
|     Decays the param value after every epoch according to a | ||||
|     polynomial function with a fixed power. | ||||
|     The schedule is updated after every train step by default. | ||||
|  | ||||
|         with _enable_get_lr_call(self): | ||||
|             for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())): | ||||
|                 param_group, lr = data | ||||
|                 param_group["lr"] = lr | ||||
|                 self.print_lr(self.verbose, i, lr, epoch) | ||||
|     Example: | ||||
|  | ||||
|         self._last_lr = [group["lr"] for group in self.optimizer.param_groups] | ||||
|         .. code-block:: python | ||||
|  | ||||
|           PolynomialDecayParamScheduler(base_value=0.1, power=0.9) | ||||
|  | ||||
|     Then the param value will be 0.1 for epoch 0, 0.099 for epoch 1, and | ||||
|     so on. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         base_value: float, | ||||
|         power: float, | ||||
|     ) -> None: | ||||
|         self._base_value = base_value | ||||
|         self._power = power | ||||
|  | ||||
|     def __call__(self, where: float) -> float: | ||||
|         return self._base_value * (1 - where) ** self._power | ||||
|  | ||||
|  | ||||
| class StepParamScheduler(ParamScheduler): | ||||
|     """ | ||||
|     Takes a fixed schedule for a param value.  If the length of the | ||||
|     fixed schedule is less than the number of epochs, then the epochs | ||||
|     are divided evenly among the param schedule. | ||||
|     The schedule is updated after every train epoch by default. | ||||
|  | ||||
|     Example: | ||||
|  | ||||
|         .. code-block:: python | ||||
|  | ||||
|           StepParamScheduler(values=[0.1, 0.01, 0.001, 0.0001], num_updates=120) | ||||
|  | ||||
|     Then the param value will be 0.1 for epochs 0-29, 0.01 for | ||||
|     epochs 30-59, 0.001 for epoch 60-89, 0.0001 for epochs 90-119. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         num_updates: Union[int, float], | ||||
|         values: List[float], | ||||
|     ) -> None: | ||||
|         if num_updates <= 0: | ||||
|             raise ValueError("Number of updates must be larger than 0") | ||||
|         if not (isinstance(values, Sequence) and len(values) > 0): | ||||
|             raise ValueError( | ||||
|                 "Step scheduler requires a list of at least one param value" | ||||
|             ) | ||||
|         self._param_schedule = values | ||||
|  | ||||
|     def __call__(self, where: float) -> float: | ||||
|         ind = int((where + self.WHERE_EPSILON) * len(self._param_schedule)) | ||||
|         return self._param_schedule[ind] | ||||
|  | ||||
|  | ||||
| class StepWithFixedGammaParamScheduler(ParamScheduler): | ||||
|     """ | ||||
|     Decays the param value by gamma at equal number of steps so as to have the | ||||
|     specified total number of decays. | ||||
|  | ||||
|     Example: | ||||
|  | ||||
|         .. code-block:: python | ||||
|  | ||||
|           StepWithFixedGammaParamScheduler( | ||||
|             base_value=0.1, gamma=0.1, num_decays=3, num_updates=120) | ||||
|  | ||||
|     Then the param value will be 0.1 for epochs 0-29, 0.01 for | ||||
|     epochs 30-59, 0.001 for epoch 60-89, 0.0001 for epochs 90-119. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         base_value: float, | ||||
|         num_decays: int, | ||||
|         gamma: float, | ||||
|         num_updates: int, | ||||
|     ) -> None: | ||||
|         for k in [base_value, gamma]: | ||||
|             if not (isinstance(k, (int, float)) and k > 0): | ||||
|                 raise ValueError("base_value and gamma must be positive numbers") | ||||
|         for k in [num_decays, num_updates]: | ||||
|             if not (isinstance(k, int) and k > 0): | ||||
|                 raise ValueError("num_decays and num_updates must be positive integers") | ||||
|  | ||||
|         self.base_value = base_value | ||||
|         self.num_decays = num_decays | ||||
|         self.gamma = gamma | ||||
|         self.num_updates = num_updates | ||||
|         values = [base_value] | ||||
|         for _ in range(num_decays): | ||||
|             values.append(values[-1] * gamma) | ||||
|  | ||||
|         self._step_param_scheduler = StepParamScheduler( | ||||
|             num_updates=num_updates, values=values | ||||
|         ) | ||||
|  | ||||
|     def __call__(self, where: float) -> float: | ||||
|         return self._step_param_scheduler(where) | ||||
|  | ||||
|  | ||||
| class CompositeParamScheduler(ParamScheduler): | ||||
|     """ | ||||
|     Composite parameter scheduler composed of intermediate schedulers. | ||||
|     Takes a list of schedulers and a list of lengths corresponding to | ||||
|     percentage of training each scheduler should run for. Schedulers | ||||
|     are run in order. All values in lengths should sum to 1.0. | ||||
|  | ||||
|     Each scheduler also has a corresponding interval scale. If interval | ||||
|     scale is 'fixed', the intermediate scheduler will be run without any rescaling | ||||
|     of the time. If interval scale is 'rescaled', intermediate scheduler is | ||||
|     run such that each scheduler will start and end at the same values as it | ||||
|     would if it were the only scheduler. Default is 'rescaled' for all schedulers. | ||||
|  | ||||
|     Example: | ||||
|  | ||||
|         .. code-block:: python | ||||
|  | ||||
|               schedulers = [ | ||||
|                 ConstantParamScheduler(value=0.42), | ||||
|                 CosineParamScheduler(start_value=0.42, end_value=1e-4) | ||||
|               ] | ||||
|               CompositeParamScheduler( | ||||
|                 schedulers=schedulers, | ||||
|                 interval_scaling=['rescaled', 'rescaled'], | ||||
|                 lengths=[0.3, 0.7]) | ||||
|  | ||||
|     The parameter value will be 0.42 for the first [0%, 30%) of steps, | ||||
|     and then will cosine decay from 0.42 to 0.0001 for [30%, 100%) of | ||||
|     training. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         schedulers: Sequence[ParamScheduler], | ||||
|         lengths: List[float], | ||||
|         interval_scaling: Sequence[str], | ||||
|     ) -> None: | ||||
|         if len(schedulers) != len(lengths): | ||||
|             raise ValueError("Schedulers and lengths must be same length") | ||||
|         if len(schedulers) == 0: | ||||
|             raise ValueError( | ||||
|                 "There must be at least one scheduler in the composite scheduler" | ||||
|             ) | ||||
|         if abs(sum(lengths) - 1.0) >= 1e-3: | ||||
|             raise ValueError("The sum of all values in lengths must be 1") | ||||
|         if sum(lengths) != 1.0: | ||||
|             lengths[-1] = 1.0 - sum(lengths[:-1]) | ||||
|         for s in interval_scaling: | ||||
|             if s not in ["rescaled", "fixed"]: | ||||
|                 raise ValueError(f"Unsupported interval_scaling: {s}") | ||||
|  | ||||
|         self._lengths = lengths | ||||
|         self._schedulers = schedulers | ||||
|         self._interval_scaling = interval_scaling | ||||
|  | ||||
|     def __call__(self, where: float) -> float: | ||||
|         # Find scheduler corresponding to where | ||||
|         i = 0 | ||||
|         running_total = self._lengths[i] | ||||
|         while (where + self.WHERE_EPSILON) > running_total and i < len( | ||||
|             self._schedulers | ||||
|         ) - 1: | ||||
|             i += 1 | ||||
|             running_total += self._lengths[i] | ||||
|         scheduler = self._schedulers[i] | ||||
|         scheduler_where = where | ||||
|         interval_scale = self._interval_scaling[i] | ||||
|         if interval_scale == "rescaled": | ||||
|             # Calculate corresponding where % for scheduler | ||||
|             scheduler_start = running_total - self._lengths[i] | ||||
|             scheduler_where = (where - scheduler_start) / self._lengths[i] | ||||
|         return scheduler(scheduler_where) | ||||
|  | ||||
|  | ||||
| class WarmupParamScheduler(CompositeParamScheduler): | ||||
|     """ | ||||
|     Add an initial warmup stage to another scheduler. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         scheduler: ParamScheduler, | ||||
|         warmup_factor: float, | ||||
|         warmup_length: float, | ||||
|         warmup_method: str = "linear", | ||||
|     ): | ||||
|         """ | ||||
|         Args: | ||||
|             scheduler: warmup will be added at the beginning of this scheduler | ||||
|             warmup_factor: the factor w.r.t the initial value of ``scheduler``, e.g. 0.001 | ||||
|             warmup_length: the relative length (in [0, 1]) of warmup steps w.r.t the entire | ||||
|                 training, e.g. 0.01 | ||||
|             warmup_method: one of "linear" or "constant" | ||||
|         """ | ||||
|         end_value = scheduler(warmup_length)  # the value to reach when warmup ends | ||||
|         start_value = warmup_factor * scheduler(0.0) | ||||
|         if warmup_method == "constant": | ||||
|             warmup = ConstantParamScheduler(start_value) | ||||
|         elif warmup_method == "linear": | ||||
|             warmup = LinearParamScheduler(start_value, end_value) | ||||
|         else: | ||||
|             raise ValueError("Unknown warmup method: {}".format(warmup_method)) | ||||
|         super().__init__( | ||||
|             [warmup, scheduler], | ||||
|             interval_scaling=["rescaled", "fixed"], | ||||
|             lengths=[warmup_length, 1 - warmup_length], | ||||
|         ) | ||||
|  | ||||
|  | ||||
| ##### LR Scheduler | ||||
|  | ||||
|  | ||||
| class LRMultiplier(torch.optim.lr_scheduler._LRScheduler): | ||||
|     """ | ||||
|     A LRScheduler which uses fvcore :class:`ParamScheduler` to multiply the | ||||
|     learning rate of each param in the optimizer. | ||||
|     Every step, the learning rate of each parameter becomes its initial value | ||||
|     multiplied by the output of the given :class:`ParamScheduler`. | ||||
|     The absolute learning rate value of each parameter can be different. | ||||
|     This scheduler can be used as long as the relative scale among them do | ||||
|     not change during training. | ||||
|     Examples: | ||||
|     :: | ||||
|         LRMultiplier( | ||||
|             opt, | ||||
|             WarmupParamScheduler( | ||||
|                 MultiStepParamScheduler( | ||||
|                     [1, 0.1, 0.01], | ||||
|                     milestones=[60000, 80000], | ||||
|                     num_updates=90000, | ||||
|                 ), 0.001, 100 / 90000 | ||||
|             ), | ||||
|             max_iter=90000 | ||||
|         ) | ||||
|     """ | ||||
|  | ||||
|     # NOTES: in the most general case, every LR can use its own scheduler. | ||||
|     # Supporting this requires interaction with the optimizer when its parameter | ||||
|     # group is initialized. For example, classyvision implements its own optimizer | ||||
|     # that allows different schedulers for every parameter group. | ||||
|     # To avoid this complexity, we use this class to support the most common cases | ||||
|     # where the relative scale among all LRs stay unchanged during training.  In this | ||||
|     # case we only need a total of one scheduler that defines the relative LR multiplier. | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         optimizer: torch.optim.Optimizer, | ||||
|         multiplier: ParamScheduler, | ||||
|         max_iter: int, | ||||
|         last_iter: int = -1, | ||||
|     ): | ||||
|         """ | ||||
|         Args: | ||||
|             optimizer, last_iter: See ``torch.optim.lr_scheduler._LRScheduler``. | ||||
|                 ``last_iter`` is the same as ``last_epoch``. | ||||
|             multiplier: a fvcore ParamScheduler that defines the multiplier on | ||||
|                 every LR of the optimizer | ||||
|             max_iter: the total number of training iterations | ||||
|         """ | ||||
|         if not isinstance(multiplier, ParamScheduler): | ||||
|             raise ValueError( | ||||
|                 "_LRMultiplier(multiplier=) must be an instance of fvcore " | ||||
|                 f"ParamScheduler. Got {multiplier} instead." | ||||
|             ) | ||||
|         self._multiplier = multiplier | ||||
|         self._max_iter = max_iter | ||||
|         super().__init__(optimizer, last_epoch=last_iter) | ||||
|  | ||||
|     def state_dict(self): | ||||
|         # fvcore schedulers are stateless. Only keep pytorch scheduler states | ||||
|         return {"base_lrs": self.base_lrs, "last_epoch": self.last_epoch} | ||||
|  | ||||
|     def get_lr(self) -> List[float]: | ||||
|         multiplier = self._multiplier(self.last_epoch / self._max_iter) | ||||
|         return [base_lr * multiplier for base_lr in self.base_lrs] | ||||
|   | ||||
| @@ -5,6 +5,3 @@ | ||||
| ##################################################### | ||||
|  | ||||
| from .transformers import get_transformer | ||||
|  | ||||
| def obtain_model(config): | ||||
|   raise NotImplementedError | ||||
|   | ||||
		Reference in New Issue
	
	Block a user