added cuda extension for efficent implementation

This commit is contained in:
Zach Teed
2020-08-22 18:49:24 -06:00
parent 5b1f510d6b
commit c86b3dc8f3
13 changed files with 519 additions and 191 deletions

View File

@@ -5,7 +5,7 @@ import torch.nn.functional as F
from update import BasicUpdateBlock, SmallUpdateBlock
from extractor import BasicEncoder, SmallEncoder
from corr import CorrBlock
from corr import CorrBlock, AlternateCorrBlock
from utils.utils import bilinear_sampler, coords_grid, upflow8
try:
@@ -41,6 +41,9 @@ class RAFT(nn.Module):
if 'dropout' not in args._get_kwargs():
args.dropout = 0
if 'alternate_corr' not in args._get_kwargs():
args.alternate_corr = False
# feature network, context network, and update block
if args.small:
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
@@ -99,7 +102,10 @@ class RAFT(nn.Module):
fmap1 = fmap1.float()
fmap2 = fmap2.float()
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
if self.args.alternate_corr:
corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius)
else:
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
# run the context network
with autocast(enabled=self.args.mixed_precision):