105 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			105 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import copy, torch
 | |
| import torch.nn as nn
 | |
| import torch.nn.functional as F
 | |
| from collections import namedtuple
 | |
| from .genotypes import PRIMITIVES, STEPS, CONCAT, Genotype
 | |
| from .basemodel import DARTSCell, RNNModel
 | |
| 
 | |
| 
 | |
| class DARTSCellSearch(DARTSCell):
 | |
| 
 | |
|   def __init__(self, ninp, nhid, dropouth, dropoutx):
 | |
|     super(DARTSCellSearch, self).__init__(ninp, nhid, dropouth, dropoutx, genotype=None)
 | |
|     self.bn = nn.BatchNorm1d(nhid, affine=False)
 | |
|     self.check_zero = False
 | |
| 
 | |
|   def set_check(self, check_zero):
 | |
|     self.check_zero = check_zero
 | |
| 
 | |
|   def cell(self, x, h_prev, x_mask, h_mask, arch_probs):
 | |
|     s0 = self._compute_init_state(x, h_prev, x_mask, h_mask)
 | |
|     s0 = self.bn(s0)
 | |
|     if self.check_zero:
 | |
|       arch_probs_cpu = arch_probs.cpu().tolist()
 | |
|     #arch_probs = F.softmax(self.weights, dim=-1)
 | |
| 
 | |
|     offset = 0
 | |
|     states = s0.unsqueeze(0)
 | |
|     for i in range(STEPS):
 | |
|       if self.training:
 | |
|         masked_states = states * h_mask.unsqueeze(0)
 | |
|       else:
 | |
|         masked_states = states
 | |
|       ch = masked_states.view(-1, self.nhid).mm(self._Ws[i]).view(i+1, -1, 2*self.nhid)
 | |
|       c, h = torch.split(ch, self.nhid, dim=-1)
 | |
|       c = c.sigmoid()
 | |
| 
 | |
|       s = torch.zeros_like(s0)
 | |
|       for k, name in enumerate(PRIMITIVES):
 | |
|         if name == 'none':
 | |
|           continue
 | |
|         fn = self._get_activation(name)
 | |
|         unweighted = states + c * (fn(h) - states)
 | |
|         if self.check_zero:
 | |
|           INDEX, INDDX = [], []
 | |
|           for jj in range(offset, offset+i+1):
 | |
|             if arch_probs_cpu[jj][k] > 0:
 | |
|               INDEX.append(jj)
 | |
|               INDDX.append(jj-offset)
 | |
|           if len(INDEX) == 0: continue
 | |
|           s += torch.sum(arch_probs[INDEX, k].unsqueeze(-1).unsqueeze(-1) * unweighted[INDDX, :, :], dim=0)
 | |
|         else:
 | |
|           s += torch.sum(arch_probs[offset:offset+i+1, k].unsqueeze(-1).unsqueeze(-1) * unweighted, dim=0)
 | |
|       s = self.bn(s)
 | |
|       states = torch.cat([states, s.unsqueeze(0)], 0)
 | |
|       offset += i+1
 | |
|     output = torch.mean(states[-CONCAT:], dim=0)
 | |
|     return output
 | |
| 
 | |
| 
 | |
| class RNNModelSearch(RNNModel):
 | |
| 
 | |
|   def __init__(self, *args):
 | |
|     super(RNNModelSearch, self).__init__(*args)
 | |
|     self._args = copy.deepcopy( args )
 | |
| 
 | |
|     k = sum(i for i in range(1, STEPS+1))
 | |
|     self.arch_weights = nn.Parameter(torch.Tensor(k, len(PRIMITIVES)))
 | |
|     nn.init.normal_(self.arch_weights, 0, 0.001)
 | |
| 
 | |
|   def base_parameters(self):
 | |
|     lists  = list(self.lockdrop.parameters())
 | |
|     lists += list(self.encoder.parameters())
 | |
|     lists += list(self.rnns.parameters())
 | |
|     lists += list(self.decoder.parameters())
 | |
|     return lists
 | |
| 
 | |
|   def arch_parameters(self):
 | |
|     return [self.arch_weights]
 | |
| 
 | |
|   def genotype(self):
 | |
| 
 | |
|     def _parse(probs):
 | |
|       gene = []
 | |
|       start = 0
 | |
|       for i in range(STEPS):
 | |
|         end = start + i + 1
 | |
|         W = probs[start:end].copy()
 | |
|         #j = sorted(range(i + 1), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[0]
 | |
|         j = sorted(range(i + 1), key=lambda x: -max(W[x][k] for k in range(len(W[x])) ))[0]
 | |
|         k_best = None
 | |
|         for k in range(len(W[j])):
 | |
|           #if k != PRIMITIVES.index('none'):
 | |
|           #  if k_best is None or W[j][k] > W[j][k_best]:
 | |
|           #    k_best = k
 | |
|           if k_best is None or W[j][k] > W[j][k_best]:
 | |
|             k_best = k
 | |
|         gene.append((PRIMITIVES[k_best], j))
 | |
|         start = end
 | |
|       return gene
 | |
| 
 | |
|     with torch.no_grad():
 | |
|       gene = _parse(F.softmax(self.arch_weights, dim=-1).cpu().numpy())
 | |
|     genotype = Genotype(recurrent=gene, concat=list(range(STEPS+1)[-CONCAT:]))
 | |
|     return genotype
 |