fixed bug with alternate_corr flag
This commit is contained in:
11
core/raft.py
11
core/raft.py
@@ -38,11 +38,11 @@ class RAFT(nn.Module):
|
||||
args.corr_levels = 4
|
||||
args.corr_radius = 4
|
||||
|
||||
if 'dropout' not in args._get_kwargs():
|
||||
args.dropout = 0
|
||||
if 'dropout' not in self.args:
|
||||
self.args.dropout = 0
|
||||
|
||||
if 'alternate_corr' not in args._get_kwargs():
|
||||
args.alternate_corr = False
|
||||
if 'alternate_corr' not in self.args:
|
||||
self.args.alternate_corr = False
|
||||
|
||||
# feature network, context network, and update block
|
||||
if args.small:
|
||||
@@ -55,7 +55,6 @@ class RAFT(nn.Module):
|
||||
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
|
||||
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
||||
|
||||
|
||||
def freeze_bn(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
@@ -103,7 +102,7 @@ class RAFT(nn.Module):
|
||||
fmap1 = fmap1.float()
|
||||
fmap2 = fmap2.float()
|
||||
if self.args.alternate_corr:
|
||||
corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
else:
|
||||
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user