fixed problems with variational dropout
This commit is contained in:
@@ -133,8 +133,20 @@ class SmallUpdateBlock(nn.Module):
|
||||
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
|
||||
|
||||
self.drop_inp = VariationalHidDropout(dropout=args.dropout)
|
||||
self.drop_net = VariationalHidDropout(dropout=args.dropout)
|
||||
|
||||
def reset_mask(self, net, inp):
|
||||
self.drop_inp.reset_mask(inp)
|
||||
self.drop_net.reset_mask(net)
|
||||
|
||||
def forward(self, net, inp, corr, flow):
|
||||
motion_features = self.encoder(flow, corr)
|
||||
|
||||
if self.training:
|
||||
net = self.drop_net(net)
|
||||
inp = self.drop_inp(inp)
|
||||
|
||||
inp = torch.cat([inp, motion_features], dim=1)
|
||||
net = self.gru(net, inp)
|
||||
delta_flow = self.flow_head(net)
|
||||
@@ -157,12 +169,12 @@ class BasicUpdateBlock(nn.Module):
|
||||
|
||||
def forward(self, net, inp, corr, flow):
|
||||
motion_features = self.encoder(flow, corr)
|
||||
inp = torch.cat([inp, motion_features], dim=1)
|
||||
|
||||
if self.training:
|
||||
net = self.drop_net(net)
|
||||
inp = self.drop_inp(inp)
|
||||
|
||||
|
||||
inp = torch.cat([inp, motion_features], dim=1)
|
||||
net = self.gru(net, inp)
|
||||
delta_flow = self.flow_head(net)
|
||||
|
||||
|
@@ -26,7 +26,7 @@ class RAFT(nn.Module):
|
||||
args.corr_levels = 4
|
||||
args.corr_radius = 4
|
||||
|
||||
if 'dropout' not in args._get_kwargs():
|
||||
if not hasattr(args, 'dropout'):
|
||||
args.dropout = 0
|
||||
|
||||
# feature network, context network, and update block
|
||||
|
Reference in New Issue
Block a user