##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
from bisect import bisect_right


class MultiStepLR(torch.optim.lr_scheduler._LRScheduler):

  def __init__(self, optimizer, milestones, gammas, last_epoch=-1):
    if not list(milestones) == sorted(milestones):
      raise ValueError('Milestones should be a list of'
                       ' increasing integers. Got {:}', milestones)
    assert len(milestones) == len(gammas), '{:} vs {:}'.format(milestones, gammas)
    self.milestones = milestones
    self.gammas = gammas
    super(MultiStepLR, self).__init__(optimizer, last_epoch)

  def get_lr(self):
    LR = 1
    for x in self.gammas[:bisect_right(self.milestones, self.last_epoch)]: LR = LR * x
    return [base_lr * LR for base_lr in self.base_lrs]


def obtain_scheduler(config, optimizer):
  if config.type == 'multistep':
    scheduler = MultiStepLR(optimizer, milestones=config.milestones, gammas=config.gammas)
  elif config.type == 'cosine':
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.epochs)
  else:
    raise ValueError('Unknown learning rate scheduler type : {:}'.format(config.type))
  return scheduler