2020-02-23 00:30:37 +01:00
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
2019-09-28 10:24:47 +02:00
import math , torch
import torch . nn as nn
from bisect import bisect_right
from torch . optim import Optimizer
class _LRScheduler ( object ) :
def __init__ ( self , optimizer , warmup_epochs , epochs ) :
if not isinstance ( optimizer , Optimizer ) :
raise TypeError ( ' {:} is not an Optimizer ' . format ( type ( optimizer ) . __name__ ) )
self . optimizer = optimizer
for group in optimizer . param_groups :
group . setdefault ( ' initial_lr ' , group [ ' lr ' ] )
self . base_lrs = list ( map ( lambda group : group [ ' initial_lr ' ] , optimizer . param_groups ) )
self . max_epochs = epochs
self . warmup_epochs = warmup_epochs
self . current_epoch = 0
self . current_iter = 0
def extra_repr ( self ) :
return ' '
def __repr__ ( self ) :
return ( ' {name} (warmup= {warmup_epochs} , max-epoch= {max_epochs} , current::epoch= {current_epoch} , iter= {current_iter:.2f} ' . format ( name = self . __class__ . __name__ , * * self . __dict__ )
+ ' , {:} ) ' . format ( self . extra_repr ( ) ) )
def state_dict ( self ) :
return { key : value for key , value in self . __dict__ . items ( ) if key != ' optimizer ' }
def load_state_dict ( self , state_dict ) :
self . __dict__ . update ( state_dict )
def get_lr ( self ) :
raise NotImplementedError
def get_min_info ( self ) :
lrs = self . get_lr ( )
return ' #LR=[ {:.6f} ~ {:.6f} ] epoch= {:03d} , iter= {:4.2f} # ' . format ( min ( lrs ) , max ( lrs ) , self . current_epoch , self . current_iter )
def get_min_lr ( self ) :
return min ( self . get_lr ( ) )
def update ( self , cur_epoch , cur_iter ) :
if cur_epoch is not None :
assert isinstance ( cur_epoch , int ) and cur_epoch > = 0 , ' invalid cur-epoch : {:} ' . format ( cur_epoch )
self . current_epoch = cur_epoch
if cur_iter is not None :
assert isinstance ( cur_iter , float ) and cur_iter > = 0 , ' invalid cur-iter : {:} ' . format ( cur_iter )
self . current_iter = cur_iter
for param_group , lr in zip ( self . optimizer . param_groups , self . get_lr ( ) ) :
param_group [ ' lr ' ] = lr
class CosineAnnealingLR ( _LRScheduler ) :
def __init__ ( self , optimizer , warmup_epochs , epochs , T_max , eta_min ) :
self . T_max = T_max
self . eta_min = eta_min
super ( CosineAnnealingLR , self ) . __init__ ( optimizer , warmup_epochs , epochs )
def extra_repr ( self ) :
return ' type= {:} , T-max= {:} , eta-min= {:} ' . format ( ' cosine ' , self . T_max , self . eta_min )
def get_lr ( self ) :
lrs = [ ]
for base_lr in self . base_lrs :
2019-11-09 06:50:13 +01:00
if self . current_epoch > = self . warmup_epochs and self . current_epoch < self . max_epochs :
2019-09-28 10:24:47 +02:00
last_epoch = self . current_epoch - self . warmup_epochs
2019-11-09 06:50:13 +01:00
#if last_epoch < self.T_max:
#if last_epoch < self.max_epochs:
lr = self . eta_min + ( base_lr - self . eta_min ) * ( 1 + math . cos ( math . pi * last_epoch / self . T_max ) ) / 2
#else:
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
elif self . current_epoch > = self . max_epochs :
lr = self . eta_min
2019-09-28 10:24:47 +02:00
else :
lr = ( self . current_epoch / self . warmup_epochs + self . current_iter / self . warmup_epochs ) * base_lr
lrs . append ( lr )
return lrs
class MultiStepLR ( _LRScheduler ) :
def __init__ ( self , optimizer , warmup_epochs , epochs , milestones , gammas ) :
assert len ( milestones ) == len ( gammas ) , ' invalid {:} vs {:} ' . format ( len ( milestones ) , len ( gammas ) )
self . milestones = milestones
self . gammas = gammas
super ( MultiStepLR , self ) . __init__ ( optimizer , warmup_epochs , epochs )
def extra_repr ( self ) :
return ' type= {:} , milestones= {:} , gammas= {:} , base-lrs= {:} ' . format ( ' multistep ' , self . milestones , self . gammas , self . base_lrs )
def get_lr ( self ) :
lrs = [ ]
for base_lr in self . base_lrs :
if self . current_epoch > = self . warmup_epochs :
last_epoch = self . current_epoch - self . warmup_epochs
idx = bisect_right ( self . milestones , last_epoch )
lr = base_lr
for x in self . gammas [ : idx ] : lr * = x
else :
lr = ( self . current_epoch / self . warmup_epochs + self . current_iter / self . warmup_epochs ) * base_lr
lrs . append ( lr )
return lrs
class ExponentialLR ( _LRScheduler ) :
def __init__ ( self , optimizer , warmup_epochs , epochs , gamma ) :
self . gamma = gamma
super ( ExponentialLR , self ) . __init__ ( optimizer , warmup_epochs , epochs )
def extra_repr ( self ) :
return ' type= {:} , gamma= {:} , base-lrs= {:} ' . format ( ' exponential ' , self . gamma , self . base_lrs )
def get_lr ( self ) :
lrs = [ ]
for base_lr in self . base_lrs :
if self . current_epoch > = self . warmup_epochs :
last_epoch = self . current_epoch - self . warmup_epochs
assert last_epoch > = 0 , ' invalid last_epoch : {:} ' . format ( last_epoch )
lr = base_lr * ( self . gamma * * last_epoch )
else :
lr = ( self . current_epoch / self . warmup_epochs + self . current_iter / self . warmup_epochs ) * base_lr
lrs . append ( lr )
return lrs
class LinearLR ( _LRScheduler ) :
def __init__ ( self , optimizer , warmup_epochs , epochs , max_LR , min_LR ) :
self . max_LR = max_LR
self . min_LR = min_LR
super ( LinearLR , self ) . __init__ ( optimizer , warmup_epochs , epochs )
def extra_repr ( self ) :
return ' type= {:} , max_LR= {:} , min_LR= {:} , base-lrs= {:} ' . format ( ' LinearLR ' , self . max_LR , self . min_LR , self . base_lrs )
def get_lr ( self ) :
lrs = [ ]
for base_lr in self . base_lrs :
if self . current_epoch > = self . warmup_epochs :
last_epoch = self . current_epoch - self . warmup_epochs
assert last_epoch > = 0 , ' invalid last_epoch : {:} ' . format ( last_epoch )
ratio = ( self . max_LR - self . min_LR ) * last_epoch / self . max_epochs / self . max_LR
lr = base_lr * ( 1 - ratio )
else :
lr = ( self . current_epoch / self . warmup_epochs + self . current_iter / self . warmup_epochs ) * base_lr
lrs . append ( lr )
return lrs
class CrossEntropyLabelSmooth ( nn . Module ) :
def __init__ ( self , num_classes , epsilon ) :
super ( CrossEntropyLabelSmooth , self ) . __init__ ( )
self . num_classes = num_classes
self . epsilon = epsilon
self . logsoftmax = nn . LogSoftmax ( dim = 1 )
def forward ( self , inputs , targets ) :
log_probs = self . logsoftmax ( inputs )
targets = torch . zeros_like ( log_probs ) . scatter_ ( 1 , targets . unsqueeze ( 1 ) , 1 )
targets = ( 1 - self . epsilon ) * targets + self . epsilon / self . num_classes
loss = ( - targets * log_probs ) . mean ( 0 ) . sum ( )
return loss
def get_optim_scheduler ( parameters , config ) :
assert hasattr ( config , ' optim ' ) and hasattr ( config , ' scheduler ' ) and hasattr ( config , ' criterion ' ) , ' config must have optim / scheduler / criterion keys instead of {:} ' . format ( config )
if config . optim == ' SGD ' :
optim = torch . optim . SGD ( parameters , config . LR , momentum = config . momentum , weight_decay = config . decay , nesterov = config . nesterov )
elif config . optim == ' RMSprop ' :
optim = torch . optim . RMSprop ( parameters , config . LR , momentum = config . momentum , weight_decay = config . decay )
else :
raise ValueError ( ' invalid optim : {:} ' . format ( config . optim ) )
if config . scheduler == ' cos ' :
T_max = getattr ( config , ' T_max ' , config . epochs )
scheduler = CosineAnnealingLR ( optim , config . warmup , config . epochs , T_max , config . eta_min )
elif config . scheduler == ' multistep ' :
scheduler = MultiStepLR ( optim , config . warmup , config . epochs , config . milestones , config . gammas )
elif config . scheduler == ' exponential ' :
scheduler = ExponentialLR ( optim , config . warmup , config . epochs , config . gamma )
elif config . scheduler == ' linear ' :
scheduler = LinearLR ( optim , config . warmup , config . epochs , config . LR , config . LR_min )
else :
raise ValueError ( ' invalid scheduler : {:} ' . format ( config . scheduler ) )
if config . criterion == ' Softmax ' :
criterion = torch . nn . CrossEntropyLoss ( )
elif config . criterion == ' SmoothSoftmax ' :
criterion = CrossEntropyLabelSmooth ( config . class_num , config . label_smooth )
else :
raise ValueError ( ' invalid criterion : {:} ' . format ( config . criterion ) )
return optim , scheduler , criterion