added cuda extension for efficent implementation
This commit is contained in:
10
core/raft.py
10
core/raft.py
@@ -5,7 +5,7 @@ import torch.nn.functional as F
|
||||
|
||||
from update import BasicUpdateBlock, SmallUpdateBlock
|
||||
from extractor import BasicEncoder, SmallEncoder
|
||||
from corr import CorrBlock
|
||||
from corr import CorrBlock, AlternateCorrBlock
|
||||
from utils.utils import bilinear_sampler, coords_grid, upflow8
|
||||
|
||||
try:
|
||||
@@ -41,6 +41,9 @@ class RAFT(nn.Module):
|
||||
if 'dropout' not in args._get_kwargs():
|
||||
args.dropout = 0
|
||||
|
||||
if 'alternate_corr' not in args._get_kwargs():
|
||||
args.alternate_corr = False
|
||||
|
||||
# feature network, context network, and update block
|
||||
if args.small:
|
||||
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
|
||||
@@ -99,7 +102,10 @@ class RAFT(nn.Module):
|
||||
|
||||
fmap1 = fmap1.float()
|
||||
fmap2 = fmap2.float()
|
||||
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
if self.args.alternate_corr:
|
||||
corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
else:
|
||||
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
|
||||
# run the context network
|
||||
with autocast(enabled=self.args.mixed_precision):
|
||||
|
||||
Reference in New Issue
Block a user