2020-02-23 00:30:37 +01:00
|
|
|
#####################################################
|
|
|
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
|
|
|
#####################################################
|
2019-09-28 10:24:47 +02:00
|
|
|
import math, torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from bisect import bisect_right
|
|
|
|
from torch.optim import Optimizer
|
|
|
|
|
|
|
|
|
|
|
|
class _LRScheduler(object):
|
2021-03-07 04:09:47 +01:00
|
|
|
def __init__(self, optimizer, warmup_epochs, epochs):
|
|
|
|
if not isinstance(optimizer, Optimizer):
|
|
|
|
raise TypeError("{:} is not an Optimizer".format(type(optimizer).__name__))
|
|
|
|
self.optimizer = optimizer
|
|
|
|
for group in optimizer.param_groups:
|
|
|
|
group.setdefault("initial_lr", group["lr"])
|
2021-03-19 16:57:23 +01:00
|
|
|
self.base_lrs = list(
|
|
|
|
map(lambda group: group["initial_lr"], optimizer.param_groups)
|
|
|
|
)
|
2021-03-07 04:09:47 +01:00
|
|
|
self.max_epochs = epochs
|
|
|
|
self.warmup_epochs = warmup_epochs
|
|
|
|
self.current_epoch = 0
|
|
|
|
self.current_iter = 0
|
|
|
|
|
|
|
|
def extra_repr(self):
|
|
|
|
return ""
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return "{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}".format(
|
|
|
|
name=self.__class__.__name__, **self.__dict__
|
|
|
|
) + ", {:})".format(
|
|
|
|
self.extra_repr()
|
|
|
|
)
|
|
|
|
|
|
|
|
def state_dict(self):
|
2021-03-19 16:57:23 +01:00
|
|
|
return {
|
|
|
|
key: value for key, value in self.__dict__.items() if key != "optimizer"
|
|
|
|
}
|
2021-03-07 04:09:47 +01:00
|
|
|
|
|
|
|
def load_state_dict(self, state_dict):
|
|
|
|
self.__dict__.update(state_dict)
|
|
|
|
|
|
|
|
def get_lr(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def get_min_info(self):
|
|
|
|
lrs = self.get_lr()
|
|
|
|
return "#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#".format(
|
|
|
|
min(lrs), max(lrs), self.current_epoch, self.current_iter
|
|
|
|
)
|
|
|
|
|
|
|
|
def get_min_lr(self):
|
|
|
|
return min(self.get_lr())
|
|
|
|
|
|
|
|
def update(self, cur_epoch, cur_iter):
|
|
|
|
if cur_epoch is not None:
|
2021-03-19 16:57:23 +01:00
|
|
|
assert (
|
|
|
|
isinstance(cur_epoch, int) and cur_epoch >= 0
|
|
|
|
), "invalid cur-epoch : {:}".format(cur_epoch)
|
2021-03-07 04:09:47 +01:00
|
|
|
self.current_epoch = cur_epoch
|
|
|
|
if cur_iter is not None:
|
2021-03-19 16:57:23 +01:00
|
|
|
assert (
|
|
|
|
isinstance(cur_iter, float) and cur_iter >= 0
|
|
|
|
), "invalid cur-iter : {:}".format(cur_iter)
|
2021-03-07 04:09:47 +01:00
|
|
|
self.current_iter = cur_iter
|
|
|
|
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
|
|
|
param_group["lr"] = lr
|
2019-09-28 10:24:47 +02:00
|
|
|
|
|
|
|
|
|
|
|
class CosineAnnealingLR(_LRScheduler):
|
2021-03-07 04:09:47 +01:00
|
|
|
def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min):
|
|
|
|
self.T_max = T_max
|
|
|
|
self.eta_min = eta_min
|
|
|
|
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
|
|
|
|
|
|
|
|
def extra_repr(self):
|
2021-03-19 16:57:23 +01:00
|
|
|
return "type={:}, T-max={:}, eta-min={:}".format(
|
|
|
|
"cosine", self.T_max, self.eta_min
|
|
|
|
)
|
2021-03-07 04:09:47 +01:00
|
|
|
|
|
|
|
def get_lr(self):
|
|
|
|
lrs = []
|
|
|
|
for base_lr in self.base_lrs:
|
2021-03-19 16:57:23 +01:00
|
|
|
if (
|
|
|
|
self.current_epoch >= self.warmup_epochs
|
|
|
|
and self.current_epoch < self.max_epochs
|
|
|
|
):
|
2021-03-07 04:09:47 +01:00
|
|
|
last_epoch = self.current_epoch - self.warmup_epochs
|
|
|
|
# if last_epoch < self.T_max:
|
|
|
|
# if last_epoch < self.max_epochs:
|
2021-03-19 16:57:23 +01:00
|
|
|
lr = (
|
|
|
|
self.eta_min
|
|
|
|
+ (base_lr - self.eta_min)
|
|
|
|
* (1 + math.cos(math.pi * last_epoch / self.T_max))
|
|
|
|
/ 2
|
|
|
|
)
|
2021-03-07 04:09:47 +01:00
|
|
|
# else:
|
|
|
|
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
|
|
|
|
elif self.current_epoch >= self.max_epochs:
|
|
|
|
lr = self.eta_min
|
|
|
|
else:
|
2021-03-19 16:57:23 +01:00
|
|
|
lr = (
|
|
|
|
self.current_epoch / self.warmup_epochs
|
|
|
|
+ self.current_iter / self.warmup_epochs
|
|
|
|
) * base_lr
|
2021-03-07 04:09:47 +01:00
|
|
|
lrs.append(lr)
|
|
|
|
return lrs
|
2019-09-28 10:24:47 +02:00
|
|
|
|
|
|
|
|
|
|
|
class MultiStepLR(_LRScheduler):
|
2021-03-07 04:09:47 +01:00
|
|
|
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
|
2021-03-19 16:57:23 +01:00
|
|
|
assert len(milestones) == len(gammas), "invalid {:} vs {:}".format(
|
|
|
|
len(milestones), len(gammas)
|
|
|
|
)
|
2021-03-07 04:09:47 +01:00
|
|
|
self.milestones = milestones
|
|
|
|
self.gammas = gammas
|
|
|
|
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
|
|
|
|
|
|
|
|
def extra_repr(self):
|
|
|
|
return "type={:}, milestones={:}, gammas={:}, base-lrs={:}".format(
|
|
|
|
"multistep", self.milestones, self.gammas, self.base_lrs
|
|
|
|
)
|
|
|
|
|
|
|
|
def get_lr(self):
|
|
|
|
lrs = []
|
|
|
|
for base_lr in self.base_lrs:
|
|
|
|
if self.current_epoch >= self.warmup_epochs:
|
|
|
|
last_epoch = self.current_epoch - self.warmup_epochs
|
|
|
|
idx = bisect_right(self.milestones, last_epoch)
|
|
|
|
lr = base_lr
|
|
|
|
for x in self.gammas[:idx]:
|
|
|
|
lr *= x
|
|
|
|
else:
|
2021-03-19 16:57:23 +01:00
|
|
|
lr = (
|
|
|
|
self.current_epoch / self.warmup_epochs
|
|
|
|
+ self.current_iter / self.warmup_epochs
|
|
|
|
) * base_lr
|
2021-03-07 04:09:47 +01:00
|
|
|
lrs.append(lr)
|
|
|
|
return lrs
|
2019-09-28 10:24:47 +02:00
|
|
|
|
|
|
|
|
|
|
|
class ExponentialLR(_LRScheduler):
|
2021-03-07 04:09:47 +01:00
|
|
|
def __init__(self, optimizer, warmup_epochs, epochs, gamma):
|
|
|
|
self.gamma = gamma
|
|
|
|
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
|
|
|
|
|
|
|
|
def extra_repr(self):
|
2021-03-19 16:57:23 +01:00
|
|
|
return "type={:}, gamma={:}, base-lrs={:}".format(
|
|
|
|
"exponential", self.gamma, self.base_lrs
|
|
|
|
)
|
2021-03-07 04:09:47 +01:00
|
|
|
|
|
|
|
def get_lr(self):
|
|
|
|
lrs = []
|
|
|
|
for base_lr in self.base_lrs:
|
|
|
|
if self.current_epoch >= self.warmup_epochs:
|
|
|
|
last_epoch = self.current_epoch - self.warmup_epochs
|
|
|
|
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
|
2022-03-21 07:18:23 +01:00
|
|
|
lr = base_lr * (self.gamma**last_epoch)
|
2021-03-07 04:09:47 +01:00
|
|
|
else:
|
2021-03-19 16:57:23 +01:00
|
|
|
lr = (
|
|
|
|
self.current_epoch / self.warmup_epochs
|
|
|
|
+ self.current_iter / self.warmup_epochs
|
|
|
|
) * base_lr
|
2021-03-07 04:09:47 +01:00
|
|
|
lrs.append(lr)
|
|
|
|
return lrs
|
2019-09-28 10:24:47 +02:00
|
|
|
|
|
|
|
|
|
|
|
class LinearLR(_LRScheduler):
|
2021-03-07 04:09:47 +01:00
|
|
|
def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR):
|
|
|
|
self.max_LR = max_LR
|
|
|
|
self.min_LR = min_LR
|
|
|
|
super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs)
|
|
|
|
|
|
|
|
def extra_repr(self):
|
|
|
|
return "type={:}, max_LR={:}, min_LR={:}, base-lrs={:}".format(
|
|
|
|
"LinearLR", self.max_LR, self.min_LR, self.base_lrs
|
|
|
|
)
|
|
|
|
|
|
|
|
def get_lr(self):
|
|
|
|
lrs = []
|
|
|
|
for base_lr in self.base_lrs:
|
|
|
|
if self.current_epoch >= self.warmup_epochs:
|
|
|
|
last_epoch = self.current_epoch - self.warmup_epochs
|
|
|
|
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
|
2021-03-19 16:57:23 +01:00
|
|
|
ratio = (
|
|
|
|
(self.max_LR - self.min_LR)
|
|
|
|
* last_epoch
|
|
|
|
/ self.max_epochs
|
|
|
|
/ self.max_LR
|
|
|
|
)
|
2021-03-07 04:09:47 +01:00
|
|
|
lr = base_lr * (1 - ratio)
|
|
|
|
else:
|
2021-03-19 16:57:23 +01:00
|
|
|
lr = (
|
|
|
|
self.current_epoch / self.warmup_epochs
|
|
|
|
+ self.current_iter / self.warmup_epochs
|
|
|
|
) * base_lr
|
2021-03-07 04:09:47 +01:00
|
|
|
lrs.append(lr)
|
|
|
|
return lrs
|
2019-09-28 10:24:47 +02:00
|
|
|
|
|
|
|
|
|
|
|
class CrossEntropyLabelSmooth(nn.Module):
|
2021-03-07 04:09:47 +01:00
|
|
|
def __init__(self, num_classes, epsilon):
|
|
|
|
super(CrossEntropyLabelSmooth, self).__init__()
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.epsilon = epsilon
|
|
|
|
self.logsoftmax = nn.LogSoftmax(dim=1)
|
2019-09-28 10:24:47 +02:00
|
|
|
|
2021-03-07 04:09:47 +01:00
|
|
|
def forward(self, inputs, targets):
|
|
|
|
log_probs = self.logsoftmax(inputs)
|
|
|
|
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
|
|
|
|
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
|
|
|
loss = (-targets * log_probs).mean(0).sum()
|
|
|
|
return loss
|
2019-09-28 10:24:47 +02:00
|
|
|
|
|
|
|
|
|
|
|
def get_optim_scheduler(parameters, config):
|
2021-03-07 04:09:47 +01:00
|
|
|
assert (
|
2021-03-19 16:57:23 +01:00
|
|
|
hasattr(config, "optim")
|
|
|
|
and hasattr(config, "scheduler")
|
|
|
|
and hasattr(config, "criterion")
|
|
|
|
), "config must have optim / scheduler / criterion keys instead of {:}".format(
|
|
|
|
config
|
|
|
|
)
|
2021-03-07 04:09:47 +01:00
|
|
|
if config.optim == "SGD":
|
|
|
|
optim = torch.optim.SGD(
|
2021-03-19 16:57:23 +01:00
|
|
|
parameters,
|
|
|
|
config.LR,
|
|
|
|
momentum=config.momentum,
|
|
|
|
weight_decay=config.decay,
|
|
|
|
nesterov=config.nesterov,
|
2021-03-07 04:09:47 +01:00
|
|
|
)
|
|
|
|
elif config.optim == "RMSprop":
|
2021-03-19 16:57:23 +01:00
|
|
|
optim = torch.optim.RMSprop(
|
|
|
|
parameters, config.LR, momentum=config.momentum, weight_decay=config.decay
|
|
|
|
)
|
2021-03-07 04:09:47 +01:00
|
|
|
else:
|
|
|
|
raise ValueError("invalid optim : {:}".format(config.optim))
|
|
|
|
|
|
|
|
if config.scheduler == "cos":
|
|
|
|
T_max = getattr(config, "T_max", config.epochs)
|
2021-03-19 16:57:23 +01:00
|
|
|
scheduler = CosineAnnealingLR(
|
|
|
|
optim, config.warmup, config.epochs, T_max, config.eta_min
|
|
|
|
)
|
2021-03-07 04:09:47 +01:00
|
|
|
elif config.scheduler == "multistep":
|
2021-03-19 16:57:23 +01:00
|
|
|
scheduler = MultiStepLR(
|
|
|
|
optim, config.warmup, config.epochs, config.milestones, config.gammas
|
|
|
|
)
|
2021-03-07 04:09:47 +01:00
|
|
|
elif config.scheduler == "exponential":
|
|
|
|
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
|
|
|
|
elif config.scheduler == "linear":
|
2021-03-19 16:57:23 +01:00
|
|
|
scheduler = LinearLR(
|
|
|
|
optim, config.warmup, config.epochs, config.LR, config.LR_min
|
|
|
|
)
|
2021-03-07 04:09:47 +01:00
|
|
|
else:
|
|
|
|
raise ValueError("invalid scheduler : {:}".format(config.scheduler))
|
|
|
|
|
|
|
|
if config.criterion == "Softmax":
|
|
|
|
criterion = torch.nn.CrossEntropyLoss()
|
|
|
|
elif config.criterion == "SmoothSoftmax":
|
|
|
|
criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth)
|
|
|
|
else:
|
|
|
|
raise ValueError("invalid criterion : {:}".format(config.criterion))
|
|
|
|
return optim, scheduler, criterion
|