Compare commits

..

No commits in common. "3fa0bb0a9c633ea0a9bb8a79c576b6785d4e6a02" and "c86b3dc8f360345827d4c25ff919228f26b6be61" have entirely different histories.

9 changed files with 88 additions and 49 deletions

View File

@ -12,7 +12,11 @@ The code has been tested with PyTorch 1.6 and Cuda 10.1.
```Shell
conda create --name raft
conda activate raft
conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch
conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 -c pytorch
conda install matplotlib
conda install tensorboard
conda install scipy
conda install opencv
```
## Demos
@ -20,13 +24,21 @@ Pretrained models can be downloaded by running
```Shell
./download_models.sh
```
or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing)
or downloaded from [google drive](https://drive.google.com/file/d/10-BYgHqRNPGvmNUWr8razjb1xHu55pyA/view?usp=sharing)
You can demo a trained model on a sequence of frames
```Shell
python demo.py --model=models/raft-things.pth --path=demo-frames
```
## (Optional) Efficent Implementation
You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
```Shell
cd alt_cuda_corr && python setup.py install && cd ..
```
and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag.Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.
## Required Data
To evaluate/train RAFT, you will need to download the required datasets.
* [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
@ -71,10 +83,3 @@ If you have a RTX GPU, training can be accelerated using mixed precision. You ca
```Shell
./train_mixed.sh
```
## (Optional) Efficent Implementation
You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
```Shell
cd alt_cuda_corr && python setup.py install && cd ..
```
and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.

View File

@ -34,9 +34,9 @@ class CorrBlock:
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
dx = torch.linspace(-r, r, 2*r+1)
dy = torch.linspace(-r, r, 2*r+1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
@ -60,6 +60,26 @@ class CorrBlock:
return corr / torch.sqrt(torch.tensor(dim).float())
class CorrLayer(torch.autograd.Function):
@staticmethod
def forward(ctx, fmap1, fmap2, coords, r):
fmap1 = fmap1.contiguous()
fmap2 = fmap2.contiguous()
coords = coords.contiguous()
ctx.save_for_backward(fmap1, fmap2, coords)
ctx.r = r
corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
return corr
@staticmethod
def backward(ctx, grad_corr):
fmap1, fmap2, coords = ctx.saved_tensors
grad_corr = grad_corr.contiguous()
fmap1_grad, fmap2_grad, coords_grad = \
correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
return fmap1_grad, fmap2_grad, coords_grad, None
class AlternateCorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
@ -72,20 +92,20 @@ class AlternateCorrBlock:
self.pyramid.append((fmap1, fmap2))
def __call__(self, coords):
coords = coords.permute(0, 2, 3, 1)
B, H, W, _ = coords.shape
dim = self.pyramid[0][0].shape[1]
corr_list = []
for i in range(self.num_levels):
r = self.radius
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
corr_list.append(corr.squeeze(1))
corr = torch.stack(corr_list, dim=1)
corr = corr.reshape(B, -1, H, W)
return corr / torch.sqrt(torch.tensor(dim).float())
return corr / 16.0

View File

@ -38,11 +38,11 @@ class RAFT(nn.Module):
args.corr_levels = 4
args.corr_radius = 4
if 'dropout' not in self.args:
self.args.dropout = 0
if 'dropout' not in args._get_kwargs():
args.dropout = 0
if 'alternate_corr' not in self.args:
self.args.alternate_corr = False
if 'alternate_corr' not in args._get_kwargs():
args.alternate_corr = False
# feature network, context network, and update block
if args.small:
@ -55,6 +55,7 @@ class RAFT(nn.Module):
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
@ -63,8 +64,8 @@ class RAFT(nn.Module):
def initialize_flow(self, img):
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
N, C, H, W = img.shape
coords0 = coords_grid(N, H//8, W//8, device=img.device)
coords1 = coords_grid(N, H//8, W//8, device=img.device)
coords0 = coords_grid(N, H//8, W//8).to(img.device)
coords1 = coords_grid(N, H//8, W//8).to(img.device)
# optical flow computed as difference: flow = coords1 - coords0
return coords0, coords1
@ -102,7 +103,7 @@ class RAFT(nn.Module):
fmap1 = fmap1.float()
fmap2 = fmap2.float()
if self.args.alternate_corr:
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius)
else:
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)

View File

@ -71,8 +71,8 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False):
return img
def coords_grid(batch, ht, wd, device):
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
def coords_grid(batch, ht, wd):
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)

29
demo.py
View File

@ -20,9 +20,21 @@ DEVICE = 'cuda'
def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)
img = torch.from_numpy(img).permute(2, 0, 1).float()
return img[None].to(DEVICE)
return img
def load_image_list(image_files):
images = []
for imfile in sorted(image_files):
images.append(load_image(imfile))
images = torch.stack(images, dim=0)
images = images.to(DEVICE)
padder = InputPadder(images.shape)
return padder.pad(images)[0]
def viz(img, flo):
img = img[0].permute(1,2,0).cpu().numpy()
flo = flo[0].permute(1,2,0).cpu().numpy()
@ -31,10 +43,6 @@ def viz(img, flo):
flo = flow_viz.flow_to_image(flo)
img_flo = np.concatenate([img, flo], axis=0)
# import matplotlib.pyplot as plt
# plt.imshow(img_flo / 255.0)
# plt.show()
cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
cv2.waitKey()
@ -50,14 +58,11 @@ def demo(args):
with torch.no_grad():
images = glob.glob(os.path.join(args.path, '*.png')) + \
glob.glob(os.path.join(args.path, '*.jpg'))
images = sorted(images)
for imfile1, imfile2 in zip(images[:-1], images[1:]):
image1 = load_image(imfile1)
image2 = load_image(imfile2)
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
images = load_image_list(images)
for i in range(images.shape[0]-1):
image1 = images[i,None]
image2 = images[i+1,None]
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
viz(image1, flow_up)

View File

@ -1,3 +1,3 @@
#!/bin/bash
wget https://dl.dropboxusercontent.com/s/4j4z58wuv8o0mfz/models.zip
wget https://www.dropbox.com/s/npt24nvhoojdr0n/models.zip
unzip models.zip

View File

@ -44,7 +44,7 @@ SUM_FREQ = 100
VAL_FREQ = 5000
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
""" Loss function defined over sequence of flow predictions """
n_predictions = len(flow_preds)
@ -55,7 +55,7 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
valid = (valid >= 0.5) & (mag < max_flow)
for i in range(n_predictions):
i_weight = gamma**(n_predictions - i - 1)
i_weight = 0.8**(n_predictions - i - 1)
i_loss = (flow_preds[i] - flow_gt).abs()
flow_loss += i_weight * (valid[:, None] * i_loss).mean()
@ -71,11 +71,16 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
return flow_loss, metrics
def show_image(img):
img = img.permute(1,2,0).cpu().numpy()
plt.imshow(img/255.0)
plt.show()
# cv2.imshow('image', img/255.0)
# cv2.waitKey()
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def fetch_optimizer(args, model):
""" Create the optimizer and learning rate scheduler """
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
@ -164,6 +169,9 @@ def train(args):
optimizer.zero_grad()
image1, image2, flow, valid = [x.cuda() for x in data_blob]
# show_image(image1[0])
# show_image(image2[0])
if args.add_noise:
stdv = np.random.uniform(0.0, 5.0)
image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
@ -171,7 +179,7 @@ def train(args):
flow_predictions = model(image1, image2, iters=args.iters)
loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma)
loss, metrics = sequence_loss(flow_predictions, flow, valid)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
@ -180,6 +188,7 @@ def train(args):
scheduler.step()
scaler.update()
logger.push(metrics)
if total_steps % VAL_FREQ == VAL_FREQ - 1:
@ -234,7 +243,6 @@ if __name__ == '__main__':
parser.add_argument('--epsilon', type=float, default=1e-8)
parser.add_argument('--clip', type=float, default=1.0)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
parser.add_argument('--add_noise', action='store_true')
args = parser.parse_args()

View File

@ -2,5 +2,5 @@
mkdir -p checkpoints
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision
python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --mixed_precision
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --mixed_precision

View File

@ -1,6 +1,6 @@
#!/bin/bash
mkdir -p checkpoints
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 10 --lr 0.0004 --image_size 368 496 --wdecay 0.0001
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001
python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma=0.85
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001