153 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			153 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import random
 | |
| import torch
 | |
| import torch.nn as nn
 | |
| import torch.nn.functional as F
 | |
| from .operations import OPS, FactorizedReduce, ReLUConvBN, Identity
 | |
| 
 | |
| 
 | |
| def random_select(length, ratio):
 | |
|   clist = []
 | |
|   index = random.randint(0, length-1)
 | |
|   for i in range(length):
 | |
|     if i == index or random.random() < ratio:
 | |
|       clist.append( 1 )
 | |
|     else:
 | |
|       clist.append( 0 )
 | |
|   return clist
 | |
| 
 | |
| 
 | |
| def all_select(length):
 | |
|   return [1 for i in range(length)]
 | |
| 
 | |
| 
 | |
| def drop_path(x, drop_prob):
 | |
|   if drop_prob > 0.:
 | |
|     keep_prob = 1. - drop_prob
 | |
|     mask = x.new_zeros(x.size(0), 1, 1, 1)
 | |
|     mask = mask.bernoulli_(keep_prob)
 | |
|     x.div_(keep_prob)
 | |
|     x.mul_(mask)
 | |
|   return x
 | |
| 
 | |
| 
 | |
| def return_alphas_str(basemodel):
 | |
|   string = 'normal : {:}'.format( F.softmax(basemodel.alphas_normal, dim=-1) )
 | |
|   if hasattr(basemodel, 'alphas_reduce'):
 | |
|     string = string + '\nreduce : {:}'.format( F.softmax(basemodel.alphas_reduce, dim=-1) )
 | |
|   return string
 | |
| 
 | |
| 
 | |
| class Cell(nn.Module):
 | |
| 
 | |
|   def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
 | |
|     super(Cell, self).__init__()
 | |
|     print(C_prev_prev, C_prev, C)
 | |
| 
 | |
|     if reduction_prev:
 | |
|       self.preprocess0 = FactorizedReduce(C_prev_prev, C)
 | |
|     else:
 | |
|       self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
 | |
|     self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
 | |
|     
 | |
|     if reduction:
 | |
|       op_names, indices, values = zip(*genotype.reduce)
 | |
|       concat = genotype.reduce_concat
 | |
|     else:
 | |
|       op_names, indices, values = zip(*genotype.normal)
 | |
|       concat = genotype.normal_concat
 | |
|     self._compile(C, op_names, indices, values, concat, reduction)
 | |
| 
 | |
|   def _compile(self, C, op_names, indices, values, concat, reduction):
 | |
|     assert len(op_names) == len(indices)
 | |
|     self._steps = len(op_names) // 2
 | |
|     self._concat = concat
 | |
|     self.multiplier = len(concat)
 | |
| 
 | |
|     self._ops = nn.ModuleList()
 | |
|     for name, index in zip(op_names, indices):
 | |
|       stride = 2 if reduction and index < 2 else 1
 | |
|       op = OPS[name](C, stride, True)
 | |
|       self._ops.append( op )
 | |
|     self._indices = indices
 | |
|     self._values  = values
 | |
| 
 | |
|   def forward(self, s0, s1, drop_prob):
 | |
|     s0 = self.preprocess0(s0)
 | |
|     s1 = self.preprocess1(s1)
 | |
| 
 | |
|     states = [s0, s1]
 | |
|     for i in range(self._steps):
 | |
|       h1 = states[self._indices[2*i]]
 | |
|       h2 = states[self._indices[2*i+1]]
 | |
|       op1 = self._ops[2*i]
 | |
|       op2 = self._ops[2*i+1]
 | |
|       h1 = op1(h1)
 | |
|       h2 = op2(h2)
 | |
|       if self.training and drop_prob > 0.:
 | |
|         if not isinstance(op1, Identity):
 | |
|           h1 = drop_path(h1, drop_prob)
 | |
|         if not isinstance(op2, Identity):
 | |
|           h2 = drop_path(h2, drop_prob)
 | |
| 
 | |
|       s = h1 + h2
 | |
| 
 | |
|       states += [s]
 | |
|     return torch.cat([states[i] for i in self._concat], dim=1)
 | |
| 
 | |
| 
 | |
| 
 | |
| class Transition(nn.Module):
 | |
| 
 | |
|   def __init__(self, C_prev_prev, C_prev, C, reduction_prev, multiplier=4):
 | |
|     super(Transition, self).__init__()
 | |
|     if reduction_prev:
 | |
|       self.preprocess0 = FactorizedReduce(C_prev_prev, C)
 | |
|     else:
 | |
|       self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
 | |
|     self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
 | |
|     self.multiplier  = multiplier
 | |
| 
 | |
|     self.reduction = True
 | |
|     self.ops1 = nn.ModuleList(
 | |
|                   [nn.Sequential(
 | |
|                       nn.ReLU(inplace=False),
 | |
|                       nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
 | |
|                       nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
 | |
|                       nn.BatchNorm2d(C, affine=True),
 | |
|                       nn.ReLU(inplace=False),
 | |
|                       nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
 | |
|                       nn.BatchNorm2d(C, affine=True)),
 | |
|                    nn.Sequential(
 | |
|                       nn.ReLU(inplace=False),
 | |
|                       nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
 | |
|                       nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
 | |
|                       nn.BatchNorm2d(C, affine=True),
 | |
|                       nn.ReLU(inplace=False),
 | |
|                       nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
 | |
|                       nn.BatchNorm2d(C, affine=True))])
 | |
| 
 | |
|     self.ops2 = nn.ModuleList(
 | |
|                   [nn.Sequential(
 | |
|                       nn.MaxPool2d(3, stride=2, padding=1),
 | |
|                       nn.BatchNorm2d(C, affine=True)),
 | |
|                    nn.Sequential(
 | |
|                       nn.MaxPool2d(3, stride=2, padding=1),
 | |
|                       nn.BatchNorm2d(C, affine=True))])
 | |
| 
 | |
| 
 | |
|   def forward(self, s0, s1, drop_prob = -1):
 | |
|     s0 = self.preprocess0(s0)
 | |
|     s1 = self.preprocess1(s1)
 | |
| 
 | |
|     X0 = self.ops1[0] (s0)
 | |
|     X1 = self.ops1[1] (s1)
 | |
|     if self.training and drop_prob > 0.:
 | |
|       X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob)
 | |
| 
 | |
|     #X2 = self.ops2[0] (X0+X1)
 | |
|     X2 = self.ops2[0] (s0)
 | |
|     X3 = self.ops2[1] (s1)
 | |
|     if self.training and drop_prob > 0.:
 | |
|       X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob)
 | |
|     return torch.cat([X0, X1, X2, X3], dim=1)
 |