fixed bug with alternate_corr flag
This commit is contained in:
parent
01ad964d94
commit
13198c355d
30
core/corr.py
30
core/corr.py
@ -60,26 +60,6 @@ class CorrBlock:
|
|||||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||||
|
|
||||||
|
|
||||||
class CorrLayer(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, fmap1, fmap2, coords, r):
|
|
||||||
fmap1 = fmap1.contiguous()
|
|
||||||
fmap2 = fmap2.contiguous()
|
|
||||||
coords = coords.contiguous()
|
|
||||||
ctx.save_for_backward(fmap1, fmap2, coords)
|
|
||||||
ctx.r = r
|
|
||||||
corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
|
|
||||||
return corr
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_corr):
|
|
||||||
fmap1, fmap2, coords = ctx.saved_tensors
|
|
||||||
grad_corr = grad_corr.contiguous()
|
|
||||||
fmap1_grad, fmap2_grad, coords_grad = \
|
|
||||||
correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
|
|
||||||
return fmap1_grad, fmap2_grad, coords_grad, None
|
|
||||||
|
|
||||||
|
|
||||||
class AlternateCorrBlock:
|
class AlternateCorrBlock:
|
||||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||||
self.num_levels = num_levels
|
self.num_levels = num_levels
|
||||||
@ -92,20 +72,20 @@ class AlternateCorrBlock:
|
|||||||
self.pyramid.append((fmap1, fmap2))
|
self.pyramid.append((fmap1, fmap2))
|
||||||
|
|
||||||
def __call__(self, coords):
|
def __call__(self, coords):
|
||||||
|
|
||||||
coords = coords.permute(0, 2, 3, 1)
|
coords = coords.permute(0, 2, 3, 1)
|
||||||
B, H, W, _ = coords.shape
|
B, H, W, _ = coords.shape
|
||||||
|
dim = self.pyramid[0][0].shape[1]
|
||||||
|
|
||||||
corr_list = []
|
corr_list = []
|
||||||
for i in range(self.num_levels):
|
for i in range(self.num_levels):
|
||||||
r = self.radius
|
r = self.radius
|
||||||
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
|
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
|
||||||
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
|
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
|
||||||
|
|
||||||
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
|
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
|
||||||
corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
|
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
|
||||||
corr_list.append(corr.squeeze(1))
|
corr_list.append(corr.squeeze(1))
|
||||||
|
|
||||||
corr = torch.stack(corr_list, dim=1)
|
corr = torch.stack(corr_list, dim=1)
|
||||||
corr = corr.reshape(B, -1, H, W)
|
corr = corr.reshape(B, -1, H, W)
|
||||||
return corr / 16.0
|
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||||
|
11
core/raft.py
11
core/raft.py
@ -38,11 +38,11 @@ class RAFT(nn.Module):
|
|||||||
args.corr_levels = 4
|
args.corr_levels = 4
|
||||||
args.corr_radius = 4
|
args.corr_radius = 4
|
||||||
|
|
||||||
if 'dropout' not in args._get_kwargs():
|
if 'dropout' not in self.args:
|
||||||
args.dropout = 0
|
self.args.dropout = 0
|
||||||
|
|
||||||
if 'alternate_corr' not in args._get_kwargs():
|
if 'alternate_corr' not in self.args:
|
||||||
args.alternate_corr = False
|
self.args.alternate_corr = False
|
||||||
|
|
||||||
# feature network, context network, and update block
|
# feature network, context network, and update block
|
||||||
if args.small:
|
if args.small:
|
||||||
@ -55,7 +55,6 @@ class RAFT(nn.Module):
|
|||||||
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
|
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
|
||||||
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
||||||
|
|
||||||
|
|
||||||
def freeze_bn(self):
|
def freeze_bn(self):
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.BatchNorm2d):
|
if isinstance(m, nn.BatchNorm2d):
|
||||||
@ -103,7 +102,7 @@ class RAFT(nn.Module):
|
|||||||
fmap1 = fmap1.float()
|
fmap1 = fmap1.float()
|
||||||
fmap2 = fmap2.float()
|
fmap2 = fmap2.float()
|
||||||
if self.args.alternate_corr:
|
if self.args.alternate_corr:
|
||||||
corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius)
|
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||||
else:
|
else:
|
||||||
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user