fixed bug with alternate_corr flag

This commit is contained in:
Zach Teed
2020-08-28 17:18:41 -06:00
parent 01ad964d94
commit 13198c355d
2 changed files with 10 additions and 31 deletions

View File

@@ -38,11 +38,11 @@ class RAFT(nn.Module):
args.corr_levels = 4
args.corr_radius = 4
if 'dropout' not in args._get_kwargs():
args.dropout = 0
if 'dropout' not in self.args:
self.args.dropout = 0
if 'alternate_corr' not in args._get_kwargs():
args.alternate_corr = False
if 'alternate_corr' not in self.args:
self.args.alternate_corr = False
# feature network, context network, and update block
if args.small:
@@ -55,7 +55,6 @@ class RAFT(nn.Module):
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
@@ -103,7 +102,7 @@ class RAFT(nn.Module):
fmap1 = fmap1.float()
fmap2 = fmap2.float()
if self.args.alternate_corr:
corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius)
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
else:
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)