fixed problems with variational dropout

This commit is contained in:
Zach Teed
2020-05-25 14:30:45 -04:00
parent dd91321527
commit 3fac6470f4
5 changed files with 22 additions and 8 deletions

View File

@@ -133,8 +133,20 @@ class SmallUpdateBlock(nn.Module):
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
self.drop_inp = VariationalHidDropout(dropout=args.dropout)
self.drop_net = VariationalHidDropout(dropout=args.dropout)
def reset_mask(self, net, inp):
self.drop_inp.reset_mask(inp)
self.drop_net.reset_mask(net)
def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)
if self.training:
net = self.drop_net(net)
inp = self.drop_inp(inp)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
@@ -157,12 +169,12 @@ class BasicUpdateBlock(nn.Module):
def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
if self.training:
net = self.drop_net(net)
inp = self.drop_inp(inp)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)