# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
from torch import nn
from .initialization import initialize_resnet


class ConvBNReLU(nn.Module):
  def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
    super(ConvBNReLU, self).__init__()
    padding = (kernel_size - 1) // 2
    self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False)
    self.bn   = nn.BatchNorm2d(out_planes)
    self.relu = nn.ReLU6(inplace=True)
  
  def forward(self, x):
    out = self.conv( x )
    out = self.bn  ( out )
    out = self.relu( out )
    return out


class InvertedResidual(nn.Module):
  def __init__(self, inp, oup, stride, expand_ratio):
    super(InvertedResidual, self).__init__()
    self.stride = stride
    assert stride in [1, 2]

    hidden_dim = int(round(inp * expand_ratio))
    self.use_res_connect = self.stride == 1 and inp == oup

    layers = []
    if expand_ratio != 1:
      # pw
      layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
    layers.extend([
      # dw
      ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
      # pw-linear
      nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
      nn.BatchNorm2d(oup),
    ])
    self.conv = nn.Sequential(*layers)

  def forward(self, x):
    if self.use_res_connect:
      return x + self.conv(x)
    else:
      return self.conv(x)


class MobileNetV2(nn.Module):
  def __init__(self, num_classes, width_mult, input_channel, last_channel, block_name, dropout):
    super(MobileNetV2, self).__init__()
    if block_name == 'InvertedResidual':
      block = InvertedResidual
    else:
      raise ValueError('invalid block name : {:}'.format(block_name))
    inverted_residual_setting = [
      # t, c,  n, s
      [1, 16 , 1, 1],
      [6, 24 , 2, 2],
      [6, 32 , 3, 2],
      [6, 64 , 4, 2],
      [6, 96 , 3, 1],
      [6, 160, 3, 2],
      [6, 320, 1, 1],
    ]

    # building first layer
    input_channel = int(input_channel * width_mult)
    self.last_channel = int(last_channel * max(1.0, width_mult))
    features = [ConvBNReLU(3, input_channel, stride=2)]
    # building inverted residual blocks
    for t, c, n, s in inverted_residual_setting:
      output_channel = int(c * width_mult)
      for i in range(n):
        stride = s if i == 0 else 1
        features.append(block(input_channel, output_channel, stride, expand_ratio=t))
        input_channel = output_channel
    # building last several layers
    features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
    # make it nn.Sequential
    self.features = nn.Sequential(*features)

    # building classifier
    self.classifier = nn.Sequential(
      nn.Dropout(dropout),
      nn.Linear(self.last_channel, num_classes),
    )
    self.message = 'MobileNetV2 : width_mult={:}, in-C={:}, last-C={:}, block={:}, dropout={:}'.format(width_mult, input_channel, last_channel, block_name, dropout)

    # weight initialization
    self.apply( initialize_resnet )

  def get_message(self):
    return self.message

  def forward(self, inputs):
    features = self.features(inputs)
    vectors  = features.mean([2, 3])
    predicts = self.classifier(vectors)
    return features, predicts