# gumbel softmax
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from .operations import OPS, FactorizedReduce, ReLUConvBN
from .genotypes import PRIMITIVES, Genotype


class MixedOp(nn.Module):

  def __init__(self, C, stride):
    super(MixedOp, self).__init__()
    self._ops = nn.ModuleList()
    for primitive in PRIMITIVES:
      op = OPS[primitive](C, stride, False)
      self._ops.append(op)

  def forward(self, x, weights, cpu_weights):
    use_sum = sum([abs(_) > 1e-10 for _ in cpu_weights])
    if use_sum > 3:
      return sum(w * op(x) for w, op in zip(weights, self._ops))
    else:
      clist = []
      for j, cpu_weight in enumerate(cpu_weights):
        if abs(cpu_weight) > 1e-10:
          clist.append( weights[j] * self._ops[j](x) )
      assert len(clist) > 0, 'invalid length : {:}'.format(cpu_weights)
      return sum(clist)


class Cell(nn.Module):

  def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
    super(Cell, self).__init__()
    self.reduction = reduction

    if reduction_prev:
      self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
    else:
      self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
    self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
    self._steps = steps
    self._multiplier = multiplier

    self._ops = nn.ModuleList()
    for i in range(self._steps):
      for j in range(2+i):
        stride = 2 if reduction and j < 2 else 1
        op = MixedOp(C, stride)
        self._ops.append(op)

  def forward(self, s0, s1, weights):
    s0 = self.preprocess0(s0)
    s1 = self.preprocess1(s1)

    cpu_weights = weights.tolist()
    states = [s0, s1]
    offset = 0
    for i in range(self._steps):
      clist = []
      for j, h in enumerate(states):
        x = self._ops[offset+j](h, weights[offset+j], cpu_weights[offset+j])
        clist.append( x )
      s = sum(clist)
      offset += len(states)
      states.append(s)

    return torch.cat(states[-self._multiplier:], dim=1)


class NetworkACC2(nn.Module):

  def __init__(self, C, num_classes, layers, steps=4, multiplier=4, stem_multiplier=3):
    super(NetworkACC2, self).__init__()
    self._C = C
    self._num_classes = num_classes
    self._layers = layers
    self._steps  = steps
    self._multiplier = multiplier

    C_curr = stem_multiplier*C
    self.stem = nn.Sequential(
      nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
      nn.BatchNorm2d(C_curr)
    )
 
    C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
    reduction_prev, cells = False, []
    for i in range(layers):
      if i in [layers//3, 2*layers//3]:
        C_curr *= 2
        reduction = True
      else:
        reduction = False
      cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
      reduction_prev = reduction
      cells.append( cell )
      C_prev_prev, C_prev = C_prev, multiplier*C_curr
    self.cells = nn.ModuleList(cells)

    self.global_pooling = nn.AdaptiveAvgPool2d(1)
    self.classifier = nn.Linear(C_prev, num_classes)
    self.tau        = 5
    self.use_gumbel = True

    # initialize architecture parameters
    k = sum(1 for i in range(self._steps) for n in range(2+i))
    num_ops = len(PRIMITIVES)

    self.alphas_normal = Parameter(torch.Tensor(k, num_ops))
    self.alphas_reduce = Parameter(torch.Tensor(k, num_ops))
    nn.init.normal_(self.alphas_normal, 0, 0.001)
    nn.init.normal_(self.alphas_reduce, 0, 0.001)

  def set_gumbel(self, use_gumbel):
    self.use_gumbel = use_gumbel

  def set_tau(self, tau):
    self.tau = tau

  def get_tau(self):
    return self.tau

  def arch_parameters(self):
    return [self.alphas_normal, self.alphas_reduce]

  def base_parameters(self):
    lists = list(self.stem.parameters()) + list(self.cells.parameters())
    lists += list(self.global_pooling.parameters())
    lists += list(self.classifier.parameters())
    return lists

  def forward(self, inputs):
    batch, C, H, W = inputs.size()
    s0 = s1 = self.stem(inputs)
    for i, cell in enumerate(self.cells):
      if cell.reduction:
        if self.use_gumbel : weights = F.gumbel_softmax(self.alphas_reduce, self.tau, True)
        else               : weights = F.softmax(self.alphas_reduce, dim=-1)
      else:
        if self.use_gumbel : weights = F.gumbel_softmax(self.alphas_normal, self.tau, True)
        else               : weights = F.softmax(self.alphas_normal, dim=-1)

      s0, s1 = s1, cell(s0, s1, weights)
    out = self.global_pooling(s1)
    out = out.view(batch, -1)
    logits = self.classifier(out)
    return logits

  def genotype(self):

    def _parse(weights):
      gene, n, start = [], 2, 0
      for i in range(self._steps):
        end = start + n
        W = weights[start:end].copy()
        edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
        for j in edges:
          k_best = None
          for k in range(len(W[j])):
            if k != PRIMITIVES.index('none'):
              if k_best is None or W[j][k] > W[j][k_best]:
                k_best = k
          gene.append((PRIMITIVES[k_best], j, float(W[j][k_best])))
        start = end
        n += 1
      return gene

    with torch.no_grad():
      gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy())
      gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).cpu().numpy())

      concat = range(2+self._steps-self._multiplier, self._steps+2)
      genotype = Genotype(
        normal=gene_normal, normal_concat=concat,
        reduce=gene_reduce, reduce_concat=concat
      )
    return genotype