182 lines
5.4 KiB
Python
182 lines
5.4 KiB
Python
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from .genotypes import STEPS
|
|
from .utils import mask2d, LockedDropout, embedded_dropout
|
|
|
|
|
|
INITRANGE = 0.04
|
|
|
|
def none_func(x):
|
|
return x * 0
|
|
|
|
|
|
class DARTSCell(nn.Module):
|
|
|
|
def __init__(self, ninp, nhid, dropouth, dropoutx, genotype):
|
|
super(DARTSCell, self).__init__()
|
|
self.nhid = nhid
|
|
self.dropouth = dropouth
|
|
self.dropoutx = dropoutx
|
|
self.genotype = genotype
|
|
|
|
# genotype is None when doing arch search
|
|
steps = len(self.genotype.recurrent) if self.genotype is not None else STEPS
|
|
self._W0 = nn.Parameter(torch.Tensor(ninp+nhid, 2*nhid).uniform_(-INITRANGE, INITRANGE))
|
|
self._Ws = nn.ParameterList([
|
|
nn.Parameter(torch.Tensor(nhid, 2*nhid).uniform_(-INITRANGE, INITRANGE)) for i in range(steps)
|
|
])
|
|
|
|
def forward(self, inputs, hidden, arch_probs):
|
|
T, B = inputs.size(0), inputs.size(1)
|
|
|
|
if self.training:
|
|
x_mask = mask2d(B, inputs.size(2), keep_prob=1.-self.dropoutx)
|
|
h_mask = mask2d(B, hidden.size(2), keep_prob=1.-self.dropouth)
|
|
else:
|
|
x_mask = h_mask = None
|
|
|
|
hidden = hidden[0]
|
|
hiddens = []
|
|
for t in range(T):
|
|
hidden = self.cell(inputs[t], hidden, x_mask, h_mask, arch_probs)
|
|
hiddens.append(hidden)
|
|
hiddens = torch.stack(hiddens)
|
|
return hiddens, hiddens[-1].unsqueeze(0)
|
|
|
|
def _compute_init_state(self, x, h_prev, x_mask, h_mask):
|
|
if self.training:
|
|
xh_prev = torch.cat([x * x_mask, h_prev * h_mask], dim=-1)
|
|
else:
|
|
xh_prev = torch.cat([x, h_prev], dim=-1)
|
|
c0, h0 = torch.split(xh_prev.mm(self._W0), self.nhid, dim=-1)
|
|
c0 = c0.sigmoid()
|
|
h0 = h0.tanh()
|
|
s0 = h_prev + c0 * (h0-h_prev)
|
|
return s0
|
|
|
|
def _get_activation(self, name):
|
|
if name == 'tanh':
|
|
f = torch.tanh
|
|
elif name == 'relu':
|
|
f = torch.relu
|
|
elif name == 'sigmoid':
|
|
f = torch.sigmoid
|
|
elif name == 'identity':
|
|
f = lambda x: x
|
|
elif name == 'none':
|
|
f = none_func
|
|
else:
|
|
raise NotImplementedError
|
|
return f
|
|
|
|
def cell(self, x, h_prev, x_mask, h_mask, _):
|
|
s0 = self._compute_init_state(x, h_prev, x_mask, h_mask)
|
|
|
|
states = [s0]
|
|
for i, (name, pred) in enumerate(self.genotype.recurrent):
|
|
s_prev = states[pred]
|
|
if self.training:
|
|
ch = (s_prev * h_mask).mm(self._Ws[i])
|
|
else:
|
|
ch = s_prev.mm(self._Ws[i])
|
|
c, h = torch.split(ch, self.nhid, dim=-1)
|
|
c = c.sigmoid()
|
|
fn = self._get_activation(name)
|
|
h = fn(h)
|
|
s = s_prev + c * (h-s_prev)
|
|
states += [s]
|
|
output = torch.mean(torch.stack([states[i] for i in self.genotype.concat], -1), -1)
|
|
return output
|
|
|
|
|
|
class RNNModel(nn.Module):
|
|
"""Container module with an encoder, a recurrent module, and a decoder."""
|
|
def __init__(self, ntoken, ninp, nhid, nhidlast,
|
|
dropout=0.5, dropouth=0.5, dropoutx=0.5, dropouti=0.5, dropoute=0.1,
|
|
cell_cls=None, genotype=None):
|
|
super(RNNModel, self).__init__()
|
|
self.lockdrop = LockedDropout()
|
|
self.encoder = nn.Embedding(ntoken, ninp)
|
|
|
|
assert ninp == nhid == nhidlast
|
|
if cell_cls == DARTSCell:
|
|
assert genotype is not None
|
|
rnns = [cell_cls(ninp, nhid, dropouth, dropoutx, genotype)]
|
|
else:
|
|
assert genotype is None
|
|
rnns = [cell_cls(ninp, nhid, dropouth, dropoutx)]
|
|
|
|
self.rnns = torch.nn.ModuleList(rnns)
|
|
self.decoder = nn.Linear(ninp, ntoken)
|
|
self.decoder.weight = self.encoder.weight
|
|
self.init_weights()
|
|
self.arch_weights = None
|
|
|
|
self.ninp = ninp
|
|
self.nhid = nhid
|
|
self.nhidlast = nhidlast
|
|
self.dropout = dropout
|
|
self.dropouti = dropouti
|
|
self.dropoute = dropoute
|
|
self.ntoken = ntoken
|
|
self.cell_cls = cell_cls
|
|
# acceleration
|
|
self.tau = None
|
|
self.use_gumbel = False
|
|
|
|
def set_gumbel(self, use_gumbel, set_check):
|
|
self.use_gumbel = use_gumbel
|
|
for i, rnn in enumerate(self.rnns):
|
|
rnn.set_check(set_check)
|
|
|
|
def set_tau(self, tau):
|
|
self.tau = tau
|
|
|
|
def get_tau(self):
|
|
return self.tau
|
|
|
|
def init_weights(self):
|
|
self.encoder.weight.data.uniform_(-INITRANGE, INITRANGE)
|
|
self.decoder.bias.data.fill_(0)
|
|
self.decoder.weight.data.uniform_(-INITRANGE, INITRANGE)
|
|
|
|
def forward(self, input, hidden, return_h=False):
|
|
batch_size = input.size(1)
|
|
|
|
emb = embedded_dropout(self.encoder, input, dropout=self.dropoute if self.training else 0)
|
|
emb = self.lockdrop(emb, self.dropouti)
|
|
|
|
raw_output = emb
|
|
new_hidden = []
|
|
raw_outputs = []
|
|
outputs = []
|
|
if self.arch_weights is None:
|
|
arch_probs = None
|
|
else:
|
|
if self.use_gumbel: arch_probs = F.gumbel_softmax(self.arch_weights, self.tau, False)
|
|
else : arch_probs = F.softmax(self.arch_weights, dim=-1)
|
|
|
|
for l, rnn in enumerate(self.rnns):
|
|
current_input = raw_output
|
|
raw_output, new_h = rnn(raw_output, hidden[l], arch_probs)
|
|
new_hidden.append(new_h)
|
|
raw_outputs.append(raw_output)
|
|
hidden = new_hidden
|
|
|
|
output = self.lockdrop(raw_output, self.dropout)
|
|
outputs.append(output)
|
|
|
|
logit = self.decoder(output.view(-1, self.ninp))
|
|
log_prob = nn.functional.log_softmax(logit, dim=-1)
|
|
model_output = log_prob
|
|
model_output = model_output.view(-1, batch_size, self.ntoken)
|
|
|
|
if return_h: return model_output, hidden, raw_outputs, outputs
|
|
else : return model_output, hidden
|
|
|
|
def init_hidden(self, bsz):
|
|
weight = next(self.parameters()).clone()
|
|
return [weight.new(1, bsz, self.nhid).zero_()]
|