import math
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from .construct_utils import drop_path
from ..operations import OPS, Identity, FactorizedReduce, ReLUConvBN


class MixedOp(nn.Module):

  def __init__(self, C, stride, PRIMITIVES):
    super(MixedOp, self).__init__()
    self._ops = nn.ModuleList()
    self.name2idx = {}
    for idx, primitive in enumerate(PRIMITIVES):
      op = OPS[primitive](C, C, stride, False)
      self._ops.append(op)
      assert primitive not in self.name2idx, '{:} has already in'.format(primitive)
      self.name2idx[primitive] = idx

  def forward(self, x, weights, op_name):
    if op_name is None:
      if weights is None:
        return [op(x) for op in self._ops]
      else:
        return sum(w * op(x) for w, op in zip(weights, self._ops))
    else:
      op_index = self.name2idx[op_name]
      return self._ops[op_index](x)



class SearchCell(nn.Module):

  def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, PRIMITIVES, use_residual):
    super(SearchCell, self).__init__()
    self.reduction  = reduction
    self.PRIMITIVES = deepcopy(PRIMITIVES)
  
    if reduction_prev:
      self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine=False)
    else:
      self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
    self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
    self._steps        = steps
    self._multiplier   = multiplier
    self._use_residual = use_residual

    self._ops = nn.ModuleList()
    for i in range(self._steps):
      for j in range(2+i):
        stride = 2 if reduction and j < 2 else 1
        op = MixedOp(C, stride, self.PRIMITIVES)
        self._ops.append(op)

  def extra_repr(self):
    return ('{name}(residual={_use_residual}, steps={_steps}, multiplier={_multiplier})'.format(name=self.__class__.__name__, **self.__dict__))

  def forward(self, S0, S1, weights, connect, adjacency, drop_prob, modes):
    if modes[0] is None:
      if modes[1] == 'normal':
        output = self.__forwardBoth(S0, S1, weights, connect, adjacency, drop_prob)
      elif modes[1] == 'only_W':
        output = self.__forwardOnlyW(S0, S1, drop_prob)
    else:
      test_genotype = modes[0]
      if self.reduction: operations, concats = test_genotype.reduce, test_genotype.reduce_concat
      else             : operations, concats = test_genotype.normal, test_genotype.normal_concat
      s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
      states, offset = [s0, s1], 0
      assert self._steps == len(operations), '{:} vs. {:}'.format(self._steps, len(operations))
      for i, (opA, opB) in enumerate(operations):
        A = self._ops[offset + opA[1]](states[opA[1]], None, opA[0])
        B = self._ops[offset + opB[1]](states[opB[1]], None, opB[0])
        state = A + B
        offset += len(states)
        states.append(state)
      output = torch.cat([states[i] for i in concats], dim=1)
    if self._use_residual and S1.size() == output.size():
      return S1 + output
    else: return output
  
  def __forwardBoth(self, S0, S1, weights, connect, adjacency, drop_prob):
    s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
    states, offset = [s0, s1], 0
    for i in range(self._steps):
      clist = []
      for j, h in enumerate(states):
        x = self._ops[offset+j](h, weights[offset+j], None)
        if self.training and drop_prob > 0.:
          x = drop_path(x, math.pow(drop_prob, 1./len(states)))
        clist.append( x )
      connection = torch.mm(connect['{:}'.format(i)], adjacency[i]).squeeze(0)
      state = sum(w * node for w, node in zip(connection, clist))
      offset += len(states)
      states.append(state)
    return torch.cat(states[-self._multiplier:], dim=1)

  def __forwardOnlyW(self, S0, S1, drop_prob):
    s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
    states, offset = [s0, s1], 0
    for i in range(self._steps):
      clist = []
      for j, h in enumerate(states):
        xs = self._ops[offset+j](h, None, None)
        clist += xs
      if self.training and drop_prob > 0.:
        xlist = [drop_path(x, math.pow(drop_prob, 1./len(states))) for x in clist]
      else: xlist = clist
      state = sum(xlist) * 2 / len(xlist)
      offset += len(states)
      states.append(state)
    return torch.cat(states[-self._multiplier:], dim=1)



class InferCell(nn.Module):

  def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
    super(InferCell, self).__init__()
    print(C_prev_prev, C_prev, C)

    if reduction_prev is None:
      self.preprocess0 = Identity()
    elif reduction_prev:
      self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2)
    else:
      self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
    self.preprocess1   = ReLUConvBN(C_prev, C, 1, 1, 0)
    
    if reduction: step_ops, concat = genotype.reduce, genotype.reduce_concat
    else        : step_ops, concat = genotype.normal, genotype.normal_concat
    self._steps        = len(step_ops)
    self._concat       = concat
    self._multiplier   = len(concat)
    self._ops          = nn.ModuleList()
    self._indices      = []
    for operations in step_ops:
      for name, index in operations:
        stride = 2 if reduction and index < 2 else 1
        if reduction_prev is None and index == 0:
          op = OPS[name](C_prev_prev, C, stride, True)
        else:
          op = OPS[name](C          , C, stride, True)
        self._ops.append( op )
        self._indices.append( index )

  def extra_repr(self):
    return ('{name}(steps={_steps}, concat={_concat})'.format(name=self.__class__.__name__, **self.__dict__))

  def forward(self, S0, S1, drop_prob):
    s0 = self.preprocess0(S0)
    s1 = self.preprocess1(S1)

    states = [s0, s1]
    for i in range(self._steps):
      h1 = states[self._indices[2*i]]
      h2 = states[self._indices[2*i+1]]
      op1 = self._ops[2*i]
      op2 = self._ops[2*i+1]
      h1 = op1(h1)
      h2 = op2(h2)
      if self.training and drop_prob > 0.:
        if not isinstance(op1, Identity):
          h1 = drop_path(h1, drop_prob)
        if not isinstance(op2, Identity):
          h2 = drop_path(h2, drop_prob)

      state = h1 + h2
      states += [state]
    output = torch.cat([states[i] for i in self._concat], dim=1)
    return output