fixed bug with alternate_corr flag
This commit is contained in:
parent
01ad964d94
commit
13198c355d
30
core/corr.py
30
core/corr.py
@ -60,26 +60,6 @@ class CorrBlock:
|
||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||
|
||||
|
||||
class CorrLayer(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, fmap1, fmap2, coords, r):
|
||||
fmap1 = fmap1.contiguous()
|
||||
fmap2 = fmap2.contiguous()
|
||||
coords = coords.contiguous()
|
||||
ctx.save_for_backward(fmap1, fmap2, coords)
|
||||
ctx.r = r
|
||||
corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
|
||||
return corr
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_corr):
|
||||
fmap1, fmap2, coords = ctx.saved_tensors
|
||||
grad_corr = grad_corr.contiguous()
|
||||
fmap1_grad, fmap2_grad, coords_grad = \
|
||||
correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
|
||||
return fmap1_grad, fmap2_grad, coords_grad, None
|
||||
|
||||
|
||||
class AlternateCorrBlock:
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
self.num_levels = num_levels
|
||||
@ -92,20 +72,20 @@ class AlternateCorrBlock:
|
||||
self.pyramid.append((fmap1, fmap2))
|
||||
|
||||
def __call__(self, coords):
|
||||
|
||||
coords = coords.permute(0, 2, 3, 1)
|
||||
B, H, W, _ = coords.shape
|
||||
dim = self.pyramid[0][0].shape[1]
|
||||
|
||||
corr_list = []
|
||||
for i in range(self.num_levels):
|
||||
r = self.radius
|
||||
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
|
||||
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
|
||||
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
|
||||
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
|
||||
corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
|
||||
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
|
||||
corr_list.append(corr.squeeze(1))
|
||||
|
||||
corr = torch.stack(corr_list, dim=1)
|
||||
corr = corr.reshape(B, -1, H, W)
|
||||
return corr / 16.0
|
||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||
|
11
core/raft.py
11
core/raft.py
@ -38,11 +38,11 @@ class RAFT(nn.Module):
|
||||
args.corr_levels = 4
|
||||
args.corr_radius = 4
|
||||
|
||||
if 'dropout' not in args._get_kwargs():
|
||||
args.dropout = 0
|
||||
if 'dropout' not in self.args:
|
||||
self.args.dropout = 0
|
||||
|
||||
if 'alternate_corr' not in args._get_kwargs():
|
||||
args.alternate_corr = False
|
||||
if 'alternate_corr' not in self.args:
|
||||
self.args.alternate_corr = False
|
||||
|
||||
# feature network, context network, and update block
|
||||
if args.small:
|
||||
@ -55,7 +55,6 @@ class RAFT(nn.Module):
|
||||
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
|
||||
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
||||
|
||||
|
||||
def freeze_bn(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
@ -103,7 +102,7 @@ class RAFT(nn.Module):
|
||||
fmap1 = fmap1.float()
|
||||
fmap2 = fmap2.float()
|
||||
if self.args.alternate_corr:
|
||||
corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
else:
|
||||
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user