fixed bug with alternate_corr flag

This commit is contained in:
Zach Teed 2020-08-28 17:18:41 -06:00
parent 01ad964d94
commit 13198c355d
2 changed files with 10 additions and 31 deletions

View File

@ -60,26 +60,6 @@ class CorrBlock:
return corr / torch.sqrt(torch.tensor(dim).float())
class CorrLayer(torch.autograd.Function):
@staticmethod
def forward(ctx, fmap1, fmap2, coords, r):
fmap1 = fmap1.contiguous()
fmap2 = fmap2.contiguous()
coords = coords.contiguous()
ctx.save_for_backward(fmap1, fmap2, coords)
ctx.r = r
corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
return corr
@staticmethod
def backward(ctx, grad_corr):
fmap1, fmap2, coords = ctx.saved_tensors
grad_corr = grad_corr.contiguous()
fmap1_grad, fmap2_grad, coords_grad = \
correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
return fmap1_grad, fmap2_grad, coords_grad, None
class AlternateCorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
@ -92,20 +72,20 @@ class AlternateCorrBlock:
self.pyramid.append((fmap1, fmap2))
def __call__(self, coords):
coords = coords.permute(0, 2, 3, 1)
B, H, W, _ = coords.shape
dim = self.pyramid[0][0].shape[1]
corr_list = []
for i in range(self.num_levels):
r = self.radius
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
corr_list.append(corr.squeeze(1))
corr = torch.stack(corr_list, dim=1)
corr = corr.reshape(B, -1, H, W)
return corr / 16.0
return corr / torch.sqrt(torch.tensor(dim).float())

View File

@ -38,11 +38,11 @@ class RAFT(nn.Module):
args.corr_levels = 4
args.corr_radius = 4
if 'dropout' not in args._get_kwargs():
args.dropout = 0
if 'dropout' not in self.args:
self.args.dropout = 0
if 'alternate_corr' not in args._get_kwargs():
args.alternate_corr = False
if 'alternate_corr' not in self.args:
self.args.alternate_corr = False
# feature network, context network, and update block
if args.small:
@ -55,7 +55,6 @@ class RAFT(nn.Module):
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
@ -103,7 +102,7 @@ class RAFT(nn.Module):
fmap1 = fmap1.float()
fmap2 = fmap2.float()
if self.args.alternate_corr:
corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius)
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
else:
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)