From 13198c355d11c3a0c45f09d1f15ead4b81a5043f Mon Sep 17 00:00:00 2001 From: Zach Teed Date: Fri, 28 Aug 2020 17:18:41 -0600 Subject: [PATCH] fixed bug with alternate_corr flag --- core/corr.py | 30 +++++------------------------- core/raft.py | 11 +++++------ 2 files changed, 10 insertions(+), 31 deletions(-) diff --git a/core/corr.py b/core/corr.py index 632d7f7..645ca55 100644 --- a/core/corr.py +++ b/core/corr.py @@ -60,26 +60,6 @@ class CorrBlock: return corr / torch.sqrt(torch.tensor(dim).float()) -class CorrLayer(torch.autograd.Function): - @staticmethod - def forward(ctx, fmap1, fmap2, coords, r): - fmap1 = fmap1.contiguous() - fmap2 = fmap2.contiguous() - coords = coords.contiguous() - ctx.save_for_backward(fmap1, fmap2, coords) - ctx.r = r - corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r) - return corr - - @staticmethod - def backward(ctx, grad_corr): - fmap1, fmap2, coords = ctx.saved_tensors - grad_corr = grad_corr.contiguous() - fmap1_grad, fmap2_grad, coords_grad = \ - correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r) - return fmap1_grad, fmap2_grad, coords_grad, None - - class AlternateCorrBlock: def __init__(self, fmap1, fmap2, num_levels=4, radius=4): self.num_levels = num_levels @@ -92,20 +72,20 @@ class AlternateCorrBlock: self.pyramid.append((fmap1, fmap2)) def __call__(self, coords): - coords = coords.permute(0, 2, 3, 1) B, H, W, _ = coords.shape + dim = self.pyramid[0][0].shape[1] corr_list = [] for i in range(self.num_levels): r = self.radius - fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1) - fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1) + fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() + fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() - corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r) + corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) corr_list.append(corr.squeeze(1)) corr = torch.stack(corr_list, dim=1) corr = corr.reshape(B, -1, H, W) - return corr / 16.0 + return corr / torch.sqrt(torch.tensor(dim).float()) diff --git a/core/raft.py b/core/raft.py index ce5297e..e0519ed 100644 --- a/core/raft.py +++ b/core/raft.py @@ -38,11 +38,11 @@ class RAFT(nn.Module): args.corr_levels = 4 args.corr_radius = 4 - if 'dropout' not in args._get_kwargs(): - args.dropout = 0 + if 'dropout' not in self.args: + self.args.dropout = 0 - if 'alternate_corr' not in args._get_kwargs(): - args.alternate_corr = False + if 'alternate_corr' not in self.args: + self.args.alternate_corr = False # feature network, context network, and update block if args.small: @@ -55,7 +55,6 @@ class RAFT(nn.Module): self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) - def freeze_bn(self): for m in self.modules(): if isinstance(m, nn.BatchNorm2d): @@ -103,7 +102,7 @@ class RAFT(nn.Module): fmap1 = fmap1.float() fmap2 = fmap2.float() if self.args.alternate_corr: - corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius) + corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) else: corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)