33 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			33 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| ##################################################
 | |
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | |
| ##################################################
 | |
| import torch
 | |
| from bisect import bisect_right
 | |
| 
 | |
| 
 | |
| class MultiStepLR(torch.optim.lr_scheduler._LRScheduler):
 | |
| 
 | |
|   def __init__(self, optimizer, milestones, gammas, last_epoch=-1):
 | |
|     if not list(milestones) == sorted(milestones):
 | |
|       raise ValueError('Milestones should be a list of'
 | |
|                        ' increasing integers. Got {:}', milestones)
 | |
|     assert len(milestones) == len(gammas), '{:} vs {:}'.format(milestones, gammas)
 | |
|     self.milestones = milestones
 | |
|     self.gammas = gammas
 | |
|     super(MultiStepLR, self).__init__(optimizer, last_epoch)
 | |
| 
 | |
|   def get_lr(self):
 | |
|     LR = 1
 | |
|     for x in self.gammas[:bisect_right(self.milestones, self.last_epoch)]: LR = LR * x
 | |
|     return [base_lr * LR for base_lr in self.base_lrs]
 | |
| 
 | |
| 
 | |
| def obtain_scheduler(config, optimizer):
 | |
|   if config.type == 'multistep':
 | |
|     scheduler = MultiStepLR(optimizer, milestones=config.milestones, gammas=config.gammas)
 | |
|   elif config.type == 'cosine':
 | |
|     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.epochs)
 | |
|   else:
 | |
|     raise ValueError('Unknown learning rate scheduler type : {:}'.format(config.type))
 | |
|   return scheduler
 |