##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # ##################################################### import math, torch import torch.nn as nn from bisect import bisect_right from torch.optim import Optimizer class _LRScheduler(object): def __init__(self, optimizer, warmup_epochs, epochs): if not isinstance(optimizer, Optimizer): raise TypeError("{:} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) self.base_lrs = list( map(lambda group: group["initial_lr"], optimizer.param_groups) ) self.max_epochs = epochs self.warmup_epochs = warmup_epochs self.current_epoch = 0 self.current_iter = 0 def extra_repr(self): return "" def __repr__(self): return "{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}".format( name=self.__class__.__name__, **self.__dict__ ) + ", {:})".format( self.extra_repr() ) def state_dict(self): return { key: value for key, value in self.__dict__.items() if key != "optimizer" } def load_state_dict(self, state_dict): self.__dict__.update(state_dict) def get_lr(self): raise NotImplementedError def get_min_info(self): lrs = self.get_lr() return "#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#".format( min(lrs), max(lrs), self.current_epoch, self.current_iter ) def get_min_lr(self): return min(self.get_lr()) def update(self, cur_epoch, cur_iter): if cur_epoch is not None: assert ( isinstance(cur_epoch, int) and cur_epoch >= 0 ), "invalid cur-epoch : {:}".format(cur_epoch) self.current_epoch = cur_epoch if cur_iter is not None: assert ( isinstance(cur_iter, float) and cur_iter >= 0 ), "invalid cur-iter : {:}".format(cur_iter) self.current_iter = cur_iter for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): param_group["lr"] = lr class CosineAnnealingLR(_LRScheduler): def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min): self.T_max = T_max self.eta_min = eta_min super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs) def extra_repr(self): return "type={:}, T-max={:}, eta-min={:}".format( "cosine", self.T_max, self.eta_min ) def get_lr(self): lrs = [] for base_lr in self.base_lrs: if ( self.current_epoch >= self.warmup_epochs and self.current_epoch < self.max_epochs ): last_epoch = self.current_epoch - self.warmup_epochs # if last_epoch < self.T_max: # if last_epoch < self.max_epochs: lr = ( self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2 ) # else: # lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2 elif self.current_epoch >= self.max_epochs: lr = self.eta_min else: lr = ( self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs ) * base_lr lrs.append(lr) return lrs class MultiStepLR(_LRScheduler): def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas): assert len(milestones) == len(gammas), "invalid {:} vs {:}".format( len(milestones), len(gammas) ) self.milestones = milestones self.gammas = gammas super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs) def extra_repr(self): return "type={:}, milestones={:}, gammas={:}, base-lrs={:}".format( "multistep", self.milestones, self.gammas, self.base_lrs ) def get_lr(self): lrs = [] for base_lr in self.base_lrs: if self.current_epoch >= self.warmup_epochs: last_epoch = self.current_epoch - self.warmup_epochs idx = bisect_right(self.milestones, last_epoch) lr = base_lr for x in self.gammas[:idx]: lr *= x else: lr = ( self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs ) * base_lr lrs.append(lr) return lrs class ExponentialLR(_LRScheduler): def __init__(self, optimizer, warmup_epochs, epochs, gamma): self.gamma = gamma super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs) def extra_repr(self): return "type={:}, gamma={:}, base-lrs={:}".format( "exponential", self.gamma, self.base_lrs ) def get_lr(self): lrs = [] for base_lr in self.base_lrs: if self.current_epoch >= self.warmup_epochs: last_epoch = self.current_epoch - self.warmup_epochs assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch) lr = base_lr * (self.gamma**last_epoch) else: lr = ( self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs ) * base_lr lrs.append(lr) return lrs class LinearLR(_LRScheduler): def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR): self.max_LR = max_LR self.min_LR = min_LR super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs) def extra_repr(self): return "type={:}, max_LR={:}, min_LR={:}, base-lrs={:}".format( "LinearLR", self.max_LR, self.min_LR, self.base_lrs ) def get_lr(self): lrs = [] for base_lr in self.base_lrs: if self.current_epoch >= self.warmup_epochs: last_epoch = self.current_epoch - self.warmup_epochs assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch) ratio = ( (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR ) lr = base_lr * (1 - ratio) else: lr = ( self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs ) * base_lr lrs.append(lr) return lrs class CrossEntropyLabelSmooth(nn.Module): def __init__(self, num_classes, epsilon): super(CrossEntropyLabelSmooth, self).__init__() self.num_classes = num_classes self.epsilon = epsilon self.logsoftmax = nn.LogSoftmax(dim=1) def forward(self, inputs, targets): log_probs = self.logsoftmax(inputs) targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes loss = (-targets * log_probs).mean(0).sum() return loss def get_optim_scheduler(parameters, config): assert ( hasattr(config, "optim") and hasattr(config, "scheduler") and hasattr(config, "criterion") ), "config must have optim / scheduler / criterion keys instead of {:}".format( config ) if config.optim == "SGD": optim = torch.optim.SGD( parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov, ) elif config.optim == "RMSprop": optim = torch.optim.RMSprop( parameters, config.LR, momentum=config.momentum, weight_decay=config.decay ) else: raise ValueError("invalid optim : {:}".format(config.optim)) if config.scheduler == "cos": T_max = getattr(config, "T_max", config.epochs) scheduler = CosineAnnealingLR( optim, config.warmup, config.epochs, T_max, config.eta_min ) elif config.scheduler == "multistep": scheduler = MultiStepLR( optim, config.warmup, config.epochs, config.milestones, config.gammas ) elif config.scheduler == "exponential": scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma) elif config.scheduler == "linear": scheduler = LinearLR( optim, config.warmup, config.epochs, config.LR, config.LR_min ) else: raise ValueError("invalid scheduler : {:}".format(config.scheduler)) if config.criterion == "Softmax": criterion = torch.nn.CrossEntropyLoss() elif config.criterion == "SmoothSoftmax": criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth) else: raise ValueError("invalid criterion : {:}".format(config.criterion)) return optim, scheduler, criterion