import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from .operations import OPS, FactorizedReduce, ReLUConvBN, Identity


def random_select(length, ratio):
  clist = []
  index = random.randint(0, length-1)
  for i in range(length):
    if i == index or random.random() < ratio:
      clist.append( 1 )
    else:
      clist.append( 0 )
  return clist


def all_select(length):
  return [1 for i in range(length)]


def drop_path(x, drop_prob):
  if drop_prob > 0.:
    keep_prob = 1. - drop_prob
    mask = x.new_zeros(x.size(0), 1, 1, 1)
    mask = mask.bernoulli_(keep_prob)
    x.div_(keep_prob)
    x.mul_(mask)
  return x


def return_alphas_str(basemodel):
  string = 'normal : {:}'.format( F.softmax(basemodel.alphas_normal, dim=-1) )
  if hasattr(basemodel, 'alphas_reduce'):
    string = string + '\nreduce : {:}'.format( F.softmax(basemodel.alphas_reduce, dim=-1) )
  return string


class Cell(nn.Module):

  def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
    super(Cell, self).__init__()
    print(C_prev_prev, C_prev, C)

    if reduction_prev:
      self.preprocess0 = FactorizedReduce(C_prev_prev, C)
    else:
      self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
    self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
    
    if reduction:
      op_names, indices, values = zip(*genotype.reduce)
      concat = genotype.reduce_concat
    else:
      op_names, indices, values = zip(*genotype.normal)
      concat = genotype.normal_concat
    self._compile(C, op_names, indices, values, concat, reduction)

  def _compile(self, C, op_names, indices, values, concat, reduction):
    assert len(op_names) == len(indices)
    self._steps = len(op_names) // 2
    self._concat = concat
    self.multiplier = len(concat)

    self._ops = nn.ModuleList()
    for name, index in zip(op_names, indices):
      stride = 2 if reduction and index < 2 else 1
      op = OPS[name](C, stride, True)
      self._ops.append( op )
    self._indices = indices
    self._values  = values

  def forward(self, s0, s1, drop_prob):
    s0 = self.preprocess0(s0)
    s1 = self.preprocess1(s1)

    states = [s0, s1]
    for i in range(self._steps):
      h1 = states[self._indices[2*i]]
      h2 = states[self._indices[2*i+1]]
      op1 = self._ops[2*i]
      op2 = self._ops[2*i+1]
      h1 = op1(h1)
      h2 = op2(h2)
      if self.training and drop_prob > 0.:
        if not isinstance(op1, Identity):
          h1 = drop_path(h1, drop_prob)
        if not isinstance(op2, Identity):
          h2 = drop_path(h2, drop_prob)

      s = h1 + h2

      states += [s]
    return torch.cat([states[i] for i in self._concat], dim=1)



class Transition(nn.Module):

  def __init__(self, C_prev_prev, C_prev, C, reduction_prev, multiplier=4):
    super(Transition, self).__init__()
    if reduction_prev:
      self.preprocess0 = FactorizedReduce(C_prev_prev, C)
    else:
      self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
    self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
    self.multiplier  = multiplier

    self.reduction = True
    self.ops1 = nn.ModuleList(
                  [nn.Sequential(
                      nn.ReLU(inplace=False),
                      nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
                      nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
                      nn.BatchNorm2d(C, affine=True),
                      nn.ReLU(inplace=False),
                      nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
                      nn.BatchNorm2d(C, affine=True)),
                   nn.Sequential(
                      nn.ReLU(inplace=False),
                      nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
                      nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
                      nn.BatchNorm2d(C, affine=True),
                      nn.ReLU(inplace=False),
                      nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
                      nn.BatchNorm2d(C, affine=True))])

    self.ops2 = nn.ModuleList(
                  [nn.Sequential(
                      nn.MaxPool2d(3, stride=2, padding=1),
                      nn.BatchNorm2d(C, affine=True)),
                   nn.Sequential(
                      nn.MaxPool2d(3, stride=2, padding=1),
                      nn.BatchNorm2d(C, affine=True))])


  def forward(self, s0, s1, drop_prob = -1):
    s0 = self.preprocess0(s0)
    s1 = self.preprocess1(s1)

    X0 = self.ops1[0] (s0)
    X1 = self.ops1[1] (s1)
    if self.training and drop_prob > 0.:
      X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob)

    #X2 = self.ops2[0] (X0+X1)
    X2 = self.ops2[0] (s0)
    X3 = self.ops2[1] (s1)
    if self.training and drop_prob > 0.:
      X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob)
    return torch.cat([X0, X1, X2, X3], dim=1)