initial commit

This commit is contained in:
Zach Teed 2020-03-26 23:19:08 -04:00
commit 36d7ad338e
24 changed files with 2101 additions and 0 deletions

7
.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
*.pyc
*.egg-info
dist
datasets
pytorch_env
models
build

94
README.md Normal file
View File

@ -0,0 +1,94 @@
# RAFT
This repository contains the source code for our paper:
[RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)<br/>
Zachary Teed and Jia Deng<br/>
## Requirements
Our code was tested using PyTorch 1.3.1 and Python 3. The following additional packages need to be installed
```Shell
pip install Pillow
pip install scipy
pip install opencv-python
```
## Demos
Pretrained models can be downloaded by running
```Shell
./scripts/download_models.sh
```
You can run the demos using one of the available models.
```Shell
python demo.py --model=models/chairs+things.pth
```
or using the small (1M parameter) model
```Shell
python demo.py --model=models/small.pth --small
```
Running the demos will display the two images and a vizualization of the optical flow estimate. After the images display, press any key to continue.
## Training
To train RAFT, you will need to download the required datasets. The first stage of training requires the [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) and [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) datasets. Finetuning and evaluation require the [Sintel](http://sintel.is.tue.mpg.de/) and [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) datasets. We organize the directory structure as follows. By default `datasets.py` will search for the datasets in these locations
```Shell
├── datasets
│ ├── Sintel
| | ├── test
| | ├── training
│ ├── KITTI
| | ├── testing
| | ├── training
| | ├── devkit
│ ├── FlyingChairs_release
| | ├── data
│ ├── FlyingThings3D
| | ├── frames_cleanpass
| | ├── frames_finalpass
| | ├── optical_flow
```
We used the following training schedule in our paper (note: we use 2 GPUs for training)
```Shell
python train.py --name=chairs --image_size 368 496 --dataset=chairs --num_steps=100000 --lr=0.0002 --batch_size=6
```
Next, finetune on the FlyingThings dataset
```Shell
python train.py --name=things --image_size 368 768 --dataset=things --num_steps=60000 --lr=0.00005 --batch_size=3 --restore_ckpt=checkpoints/chairs.pth
```
You can perform dataset specific finetuning
### Sintel
```Shell
python train.py --name=sintel_ft --image_size 368 768 --dataset=sintel --num_steps=60000 --lr=0.00005 --batch_size=4 --restore_ckpt=checkpoints/things.pth
```
### KITTI
```Shell
python train.py --name=kitti_ft --image_size 288 896 --dataset=kitti --num_steps=40000 --lr=0.0001 --batch_size=4 --restore_ckpt=checkpoints/things.pth
```
## Evaluation
You can evaluate a model on Sintel and KITTI by running
```Shell
python evaluate.py --model=checkpoints/chairs+things.pth
```
or the small model by including the `small` flag
```Shell
python evaluate.py --model=checkpoints/small.pth --small
```

0
core/__init__.py Normal file
View File

312
core/datasets.py Normal file
View File

@ -0,0 +1,312 @@
# Data loading based on https://github.com/NVIDIA/flownet2-pytorch
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
import os
import cv2
import math
import random
from glob import glob
import os.path as osp
from utils import frame_utils
from utils.augmentor import FlowAugmentor, FlowAugmentorKITTI
class CombinedDataset(data.Dataset):
def __init__(self, datasets):
self.datasets = datasets
def __len__(self):
length = 0
for i in range(len(self.datasets)):
length += len(self.datsaets[i])
return length
def __getitem__(self, index):
i = 0
for j in range(len(self.datasets)):
if i + len(self.datasets[j]) >= index:
yield self.datasets[j][index-i]
break
i += len(self.datasets[j])
def __add__(self, other):
self.datasets.append(other)
return self
class FlowDataset(data.Dataset):
def __init__(self, args, image_size=None, do_augument=False):
self.image_size = image_size
self.do_augument = do_augument
if self.do_augument:
self.augumentor = FlowAugmentor(self.image_size)
self.flow_list = []
self.image_list = []
self.init_seed = False
def __getitem__(self, index):
if not self.init_seed:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
torch.manual_seed(worker_info.id)
np.random.seed(worker_info.id)
random.seed(worker_info.id)
self.init_seed = True
index = index % len(self.image_list)
flow = frame_utils.read_gen(self.flow_list[index])
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
flow = np.array(flow).astype(np.float32)
if self.do_augument:
img1, img2, flow = self.augumentor(img1, img2, flow)
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
valid = torch.ones_like(flow[0])
return img1, img2, flow, valid
def __len__(self):
return len(self.image_list)
def __add(self, other):
return CombinedDataset([self, other])
class MpiSintelTest(FlowDataset):
def __init__(self, args, root='datasets/Sintel/test', dstype='clean'):
super(MpiSintelTest, self).__init__(args, image_size=None, do_augument=False)
self.root = root
self.dstype = dstype
image_dir = osp.join(self.root, dstype)
all_sequences = os.listdir(image_dir)
self.image_list = []
for sequence in all_sequences:
frames = sorted(glob(osp.join(image_dir, sequence, '*.png')))
for i in range(len(frames)-1):
self.image_list += [[frames[i], frames[i+1], sequence, i]]
def __getitem__(self, index):
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
sequence = self.image_list[index][2]
frame = self.image_list[index][3]
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
return img1, img2, sequence, frame
class MpiSintel(FlowDataset):
def __init__(self, args, image_size=None, do_augument=True, root='datasets/Sintel/training', dstype='clean'):
super(MpiSintel, self).__init__(args, image_size, do_augument)
if do_augument:
self.augumentor.min_scale = -0.2
self.augumentor.max_scale = 0.7
self.root = root
self.dstype = dstype
flow_root = osp.join(root, 'flow')
image_root = osp.join(root, dstype)
file_list = sorted(glob(osp.join(flow_root, '*/*.flo')))
for flo in file_list:
fbase = flo[len(flow_root)+1:]
fprefix = fbase[:-8]
fnum = int(fbase[-8:-4])
img1 = osp.join(image_root, fprefix + "%04d"%(fnum+0) + '.png')
img2 = osp.join(image_root, fprefix + "%04d"%(fnum+1) + '.png')
if not osp.isfile(img1) or not osp.isfile(img2) or not osp.isfile(flo):
continue
self.image_list.append((img1, img2))
self.flow_list.append(flo)
class FlyingChairs(FlowDataset):
def __init__(self, args, image_size=None, do_augument=True, root='datasets/FlyingChairs_release/data'):
super(FlyingChairs, self).__init__(args, image_size, do_augument)
self.root = root
self.augumentor.min_scale = -0.2
self.augumentor.max_scale = 1.0
images = sorted(glob(osp.join(root, '*.ppm')))
self.flow_list = sorted(glob(osp.join(root, '*.flo')))
assert (len(images)//2 == len(self.flow_list))
self.image_list = []
for i in range(len(self.flow_list)):
im1 = images[2*i]
im2 = images[2*i + 1]
self.image_list.append([im1, im2])
class SceneFlow(FlowDataset):
def __init__(self, args, image_size, do_augument=True, root='datasets',
dstype='frames_cleanpass', use_flyingthings=True, use_monkaa=False, use_driving=False):
super(SceneFlow, self).__init__(args, image_size, do_augument)
self.root = root
self.dstype = dstype
self.augumentor.min_scale = -0.2
self.augumentor.max_scale = 0.8
if use_flyingthings:
self.add_flyingthings()
if use_monkaa:
self.add_monkaa()
if use_driving:
self.add_driving()
def add_flyingthings(self):
root = osp.join(self.root, 'FlyingThings3D')
for cam in ['left']:
for direction in ['into_future', 'into_past']:
image_dirs = sorted(glob(osp.join(root, self.dstype, 'TRAIN/*/*')))
image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
for idir, fdir in zip(image_dirs, flow_dirs):
images = sorted(glob(osp.join(idir, '*.png')) )
flows = sorted(glob(osp.join(fdir, '*.pfm')) )
for i in range(len(flows)-1):
if direction == 'into_future':
self.image_list += [[images[i], images[i+1]]]
self.flow_list += [flows[i]]
elif direction == 'into_past':
self.image_list += [[images[i+1], images[i]]]
self.flow_list += [flows[i+1]]
def add_monkaa(self):
pass # we don't use monkaa
def add_driving(self):
pass # we don't use driving
class KITTI(FlowDataset):
def __init__(self, args, image_size=None, do_augument=True, is_test=False, is_val=False, do_pad=False, split=True, root='datasets/KITTI'):
super(KITTI, self).__init__(args, image_size, do_augument)
self.root = root
self.is_test = is_test
self.is_val = is_val
self.do_pad = do_pad
if self.do_augument:
self.augumentor = FlowAugumentorKITTI(self.image_size, args.eraser_aug, min_scale=-0.2, max_scale=0.5)
if self.is_test:
images1 = sorted(glob(os.path.join(root, 'testing', 'image_2/*_10.png')))
images2 = sorted(glob(os.path.join(root, 'testing', 'image_2/*_11.png')))
for i in range(len(images1)):
self.image_list += [[images1[i], images2[i]]]
else:
flows = sorted(glob(os.path.join(root, 'training', 'flow_occ/*_10.png')))
images1 = sorted(glob(os.path.join(root, 'training', 'image_2/*_10.png')))
images2 = sorted(glob(os.path.join(root, 'training', 'image_2/*_11.png')))
for i in range(len(flows)):
self.flow_list += [flows[i]]
self.image_list += [[images1[i], images2[i]]]
def __getitem__(self, index):
if self.is_test:
frame_id = self.image_list[index][0]
frame_id = frame_id.split('/')[-1]
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
return img1, img2, frame_id
else:
if not self.init_seed:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
np.random.seed(worker_info.id)
random.seed(worker_info.id)
self.init_seed = True
index = index % len(self.image_list)
frame_id = self.image_list[index][0]
frame_id = frame_id.split('/')[-1]
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
if self.do_augument:
img1, img2, flow, valid = self.augumentor(img1, img2, flow, valid)
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
valid = torch.from_numpy(valid).float()
if self.do_pad:
ht, wd = img1.shape[1:]
pad_ht = (((ht // 8) + 1) * 8 - ht) % 8
pad_wd = (((wd // 8) + 1) * 8 - wd) % 8
pad_ht1 = [0, pad_ht]
pad_wd1 = [pad_wd//2, pad_wd - pad_wd//2]
pad = pad_wd1 + pad_ht1
img1 = img1.view(1, 3, ht, wd)
img2 = img2.view(1, 3, ht, wd)
flow = flow.view(1, 2, ht, wd)
valid = valid.view(1, 1, ht, wd)
img1 = torch.nn.functional.pad(img1, pad, mode='replicate')
img2 = torch.nn.functional.pad(img2, pad, mode='replicate')
flow = torch.nn.functional.pad(flow, pad, mode='constant', value=0)
valid = torch.nn.functional.pad(valid, pad, mode='replicate', value=0)
img1 = img1.view(3, ht+pad_ht, wd+pad_wd)
img2 = img2.view(3, ht+pad_ht, wd+pad_wd)
flow = flow.view(2, ht+pad_ht, wd+pad_wd)
valid = valid.view(ht+pad_ht, wd+pad_wd)
if self.is_test:
return img1, img2, flow, valid, frame_id
return img1, img2, flow, valid

0
core/modules/__init__.py Normal file
View File

53
core/modules/corr.py Normal file
View File

@ -0,0 +1,53 @@
import torch
import torch.nn.functional as F
from utils.utils import bilinear_sampler, coords_grid
class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
# all pairs correlation
corr = CorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.view(batch*h1*w1, dim, h2, w2)
self.corr_pyramid.append(corr)
for i in range(self.num_levels):
corr = F.avg_pool2d(corr, 2, stride=2)
self.corr_pyramid.append(corr)
def __call__(self, coords):
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
dx = torch.linspace(-r, r, 2*r+1)
dy = torch.linspace(-r, r, 2*r+1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
coords_lvl = centroid_lvl + delta_lvl
corr = bilinear_sampler(corr, coords_lvl)
corr = corr.view(batch, h1, w1, -1)
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2)
@staticmethod
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht*wd)
fmap2 = fmap2.view(batch, dim, ht*wd)
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())

269
core/modules/extractor.py Normal file
View File

@ -0,0 +1,269 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
class BottleneckBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(BottleneckBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes//4)
self.norm2 = nn.BatchNorm2d(planes//4)
self.norm3 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm4 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes//4)
self.norm2 = nn.InstanceNorm2d(planes//4)
self.norm3 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm4 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential()
if not stride == 1:
self.norm4 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
y = self.relu(self.norm3(self.conv3(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
class BasicEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=2)
# output convolution
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
else:
self.dropout = None
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x
class SmallEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(SmallEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(32)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(32)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 32
self.layer1 = self._make_layer(32, stride=1)
self.layer2 = self._make_layer(64, stride=2)
self.layer3 = self._make_layer(96, stride=2)
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
else:
self.dropout = None
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
# if self.dropout is not None:
# x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x

169
core/modules/update.py Normal file
View File

@ -0,0 +1,169 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
# VariationalHidDropout from https://github.com/locuslab/trellisnet/tree/master/TrellisNet
class VariationalHidDropout(nn.Module):
def __init__(self, dropout=0.0):
"""
Hidden-to-hidden (VD-based) dropout that applies the same mask at every time step and every layer of TrellisNet
:param dropout: The dropout rate (0 means no dropout is applied)
"""
super(VariationalHidDropout, self).__init__()
self.dropout = dropout
self.mask = None
def reset_mask(self, x):
dropout = self.dropout
# Dimension (N, C, L)
n, c, h, w = x.shape
m = x.data.new(n, c, 1, 1).bernoulli_(1 - dropout)
with torch.no_grad():
mask = m / (1 - dropout)
self.mask = mask
return mask
def forward(self, x):
if not self.training or self.dropout == 0:
return x
assert self.mask is not None, "You need to reset mask before using VariationalHidDropout"
return self.mask * x
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class ConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
class SmallMotionEncoder(nn.Module):
def __init__(self, args):
super(SmallMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
self.conv = nn.Conv2d(128, 80, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class BasicMotionEncoder(nn.Module):
def __init__(self, args):
super(BasicMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class SmallUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=96):
super(SmallUpdateBlock, self).__init__()
self.encoder = SmallMotionEncoder(args)
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
return net, delta_flow
class BasicUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__()
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.drop_inp = VariationalHidDropout(dropout=args.dropout)
self.drop_net = VariationalHidDropout(dropout=args.dropout)
def reset_mask(self, net, inp):
self.drop_inp.reset_mask(inp)
self.drop_net.reset_mask(net)
def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
if self.training:
net = self.drop_net(net)
inp = self.drop_inp(inp)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
return net, delta_flow

99
core/raft.py Normal file
View File

@ -0,0 +1,99 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules.update import BasicUpdateBlock, SmallUpdateBlock
from modules.extractor import BasicEncoder, SmallEncoder
from modules.corr import CorrBlock
from utils.utils import bilinear_sampler, coords_grid, upflow8
class RAFT(nn.Module):
def __init__(self, args):
super(RAFT, self).__init__()
self.args = args
if args.small:
self.hidden_dim = hdim = 96
self.context_dim = cdim = 64
args.corr_levels = 4
args.corr_radius = 3
else:
self.hidden_dim = hdim = 128
self.context_dim = cdim = 128
args.corr_levels = 4
args.corr_radius = 4
if 'dropout' not in args._get_kwargs():
args.dropout = 0
# feature network, context network, and update block
if args.small:
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
else:
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def initialize_flow(self, img):
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
N, C, H, W = img.shape
coords0 = coords_grid(N, H//8, W//8).to(img.device)
coords1 = coords_grid(N, H//8, W//8).to(img.device)
# optical flow computed as difference: flow = coords1 - coords0
return coords0, coords1
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True):
""" Estimate optical flow between pair of frames """
image1 = 2 * (image1 / 255.0) - 1.0
image2 = 2 * (image2 / 255.0) - 1.0
hdim = self.hidden_dim
cdim = self.context_dim
# run the feature network
fmap1, fmap2 = self.fnet([image1, image2])
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
# run the context network
cnet = self.cnet(image1)
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
net, inp = torch.tanh(net), torch.relu(inp)
# if dropout is being used reset mask
self.update_block.reset_mask(net, inp)
coords0, coords1 = self.initialize_flow(image1)
flow_predictions = []
for itr in range(iters):
coords1 = coords1.detach()
corr = corr_fn(coords1) # index correlation volume
flow = coords1 - coords0
net, delta_flow = self.update_block(net, inp, corr, flow)
# F(t+1) = F(t) + \Delta(t)
coords1 = coords1 + delta_flow
if upsample:
flow_up = upflow8(coords1 - coords0)
flow_predictions.append(flow_up)
else:
flow_predictions.append(coords1 - coords0)
return flow_predictions

0
core/utils/__init__.py Normal file
View File

233
core/utils/augmentor.py Normal file
View File

@ -0,0 +1,233 @@
import numpy as np
import random
import math
import cv2
from PIL import Image
import torch
import torchvision
import torch.nn.functional as F
class FlowAugmentor:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5):
self.crop_size = crop_size
self.augcolor = torchvision.transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.5/3.14)
self.asymmetric_color_aug_prob = 0.2
self.spatial_aug_prob = 0.8
self.eraser_aug_prob = 0.5
self.min_scale = min_scale
self.max_scale = max_scale
self.max_stretch = 0.2
self.stretch_prob = 0.8
self.margin = 20
def color_transform(self, img1, img2):
if np.random.rand() < self.asymmetric_color_aug_prob:
img1 = np.array(self.augcolor(Image.fromarray(img1)), dtype=np.uint8)
img2 = np.array(self.augcolor(Image.fromarray(img2)), dtype=np.uint8)
else:
image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(self.augcolor(Image.fromarray(image_stack)), dtype=np.uint8)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2
def eraser_transform(self, img1, img2, bounds=[50, 100]):
ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
for _ in range(np.random.randint(1, 3)):
x0 = np.random.randint(0, wd)
y0 = np.random.randint(0, ht)
dx = np.random.randint(bounds[0], bounds[1])
dy = np.random.randint(bounds[0], bounds[1])
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
return img1, img2
def spatial_transform(self, img1, img2, flow):
# randomly sample scale
ht, wd = img1.shape[:2]
min_scale = np.maximum(
(self.crop_size[0] + 1) / float(ht),
(self.crop_size[1] + 1) / float(wd))
max_scale = self.max_scale
min_scale = max(min_scale, self.min_scale)
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = scale
scale_y = scale
if np.random.rand() < self.stretch_prob:
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_x = np.clip(scale_x, min_scale, None)
scale_y = np.clip(scale_y, min_scale, None)
if np.random.rand() < self.spatial_aug_prob:
# rescale the images
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow = flow * [scale_x, scale_y]
if np.random.rand() < 0.5: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
if np.random.rand() < 0.1: # v-flip
img1 = img1[::-1, :]
img2 = img2[::-1, :]
flow = flow[::-1, :] * [1.0, -1.0]
y0 = np.random.randint(-self.margin, img1.shape[0] - self.crop_size[0] + self.margin)
x0 = np.random.randint(-self.margin, img1.shape[1] - self.crop_size[1] + self.margin)
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
return img1, img2, flow
def __call__(self, img1, img2, flow):
img1, img2 = self.color_transform(img1, img2)
img1, img2 = self.eraser_transform(img1, img2)
img1, img2, flow = self.spatial_transform(img1, img2, flow)
img1 = np.ascontiguousarray(img1)
img2 = np.ascontiguousarray(img2)
flow = np.ascontiguousarray(flow)
return img1, img2, flow
class FlowAugmentorKITTI:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5):
self.crop_size = crop_size
self.augcolor = torchvision.transforms.ColorJitter(
brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
self.max_scale = max_scale
self.min_scale = min_scale
self.spatial_aug_prob = 0.8
self.eraser_aug_prob = 0.5
def color_transform(self, img1, img2):
image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(self.augcolor(Image.fromarray(image_stack)), dtype=np.uint8)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2
def eraser_transform(self, img1, img2):
ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
for _ in range(np.random.randint(1, 3)):
x0 = np.random.randint(0, wd)
y0 = np.random.randint(0, ht)
dx = np.random.randint(50, 100)
dy = np.random.randint(50, 100)
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
return img1, img2
def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
ht, wd = flow.shape[:2]
coords = np.meshgrid(np.arange(wd), np.arange(ht))
coords = np.stack(coords, axis=-1)
coords = coords.reshape(-1, 2).astype(np.float32)
flow = flow.reshape(-1, 2).astype(np.float32)
valid = valid.reshape(-1).astype(np.float32)
coords0 = coords[valid>=1]
flow0 = flow[valid>=1]
ht1 = int(round(ht * fy))
wd1 = int(round(wd * fx))
coords1 = coords0 * [fx, fy]
flow1 = flow0 * [fx, fy]
xx = np.round(coords1[:,0]).astype(np.int32)
yy = np.round(coords1[:,1]).astype(np.int32)
v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
xx = xx[v]
yy = yy[v]
flow1 = flow1[v]
flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
valid_img = np.zeros([ht1, wd1], dtype=np.int32)
flow_img[yy, xx] = flow1
valid_img[yy, xx] = 1
return flow_img, valid_img
def spatial_transform(self, img1, img2, flow, valid):
# randomly sample scale
ht, wd = img1.shape[:2]
min_scale = np.maximum(
(self.crop_size[0] + 1) / float(ht),
(self.crop_size[1] + 1) / float(wd))
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = np.clip(scale, min_scale, None)
scale_y = np.clip(scale, min_scale, None)
if np.random.rand() < self.spatial_aug_prob:
# rescale the images
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
if np.random.rand() < 0.5: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
valid = valid[:, ::-1]
margin_y = 20
margin_x = 50
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
return img1, img2, flow, valid
def __call__(self, img1, img2, flow, valid):
img1, img2 = self.color_transform(img1, img2)
img1, img2 = self.eraser_transform(img1, img2)
img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
img1 = np.ascontiguousarray(img1)
img2 = np.ascontiguousarray(img2)
flow = np.ascontiguousarray(flow)
valid = np.ascontiguousarray(valid)
return img1, img2, flow, valid

275
core/utils/flow_viz.py Normal file
View File

@ -0,0 +1,275 @@
# MIT License
#
# Copyright (c) 2018 Tom Runia
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to conditions.
#
# Author: Tom Runia
# Date Created: 2018-08-03
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def make_colorwheel():
'''
Generates a color wheel for optical flow visualization as presented in:
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
According to the C++ source code of Daniel Scharstein
According to the Matlab source code of Deqing Sun
'''
RY = 15
YG = 6
GC = 4
CB = 11
BM = 13
MR = 6
ncols = RY + YG + GC + CB + BM + MR
colorwheel = np.zeros((ncols, 3))
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
col = col+RY
# YG
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
colorwheel[col:col+YG, 1] = 255
col = col+YG
# GC
colorwheel[col:col+GC, 1] = 255
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
col = col+GC
# CB
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
colorwheel[col:col+CB, 2] = 255
col = col+CB
# BM
colorwheel[col:col+BM, 2] = 255
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
col = col+BM
# MR
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
colorwheel[col:col+MR, 0] = 255
return colorwheel
def flow_compute_color(u, v, convert_to_bgr=False):
'''
Applies the flow color wheel to (possibly clipped) flow components u and v.
According to the C++ source code of Daniel Scharstein
According to the Matlab source code of Deqing Sun
:param u: np.ndarray, input horizontal flow
:param v: np.ndarray, input vertical flow
:param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB
:return:
'''
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
colorwheel = make_colorwheel() # shape [55x3]
ncols = colorwheel.shape[0]
rad = np.sqrt(np.square(u) + np.square(v))
a = np.arctan2(-v, -u)/np.pi
fk = (a+1) / 2*(ncols-1) + 1
k0 = np.floor(fk).astype(np.int32)
k1 = k0 + 1
k1[k1 == ncols] = 1
f = fk - k0
for i in range(colorwheel.shape[1]):
tmp = colorwheel[:,i]
col0 = tmp[k0] / 255.0
col1 = tmp[k1] / 255.0
col = (1-f)*col0 + f*col1
idx = (rad <= 1)
col[idx] = 1 - rad[idx] * (1-col[idx])
col[~idx] = col[~idx] * 0.75 # out of range?
# Note the 2-i => BGR instead of RGB
ch_idx = 2-i if convert_to_bgr else i
flow_image[:,:,ch_idx] = np.floor(255 * col)
return flow_image
def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
'''
Expects a two dimensional flow image of shape [H,W,2]
According to the C++ source code of Daniel Scharstein
According to the Matlab source code of Deqing Sun
:param flow_uv: np.ndarray of shape [H,W,2]
:param clip_flow: float, maximum clipping value for flow
:return:
'''
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
if clip_flow is not None:
flow_uv = np.clip(flow_uv, 0, clip_flow)
u = flow_uv[:,:,0]
v = flow_uv[:,:,1]
rad = np.sqrt(np.square(u) + np.square(v))
rad_max = np.max(rad)
epsilon = 1e-5
u = u / (rad_max + epsilon)
v = v / (rad_max + epsilon)
return flow_compute_color(u, v, convert_to_bgr)
UNKNOWN_FLOW_THRESH = 1e7
SMALLFLOW = 0.0
LARGEFLOW = 1e8
def make_color_wheel():
"""
Generate color wheel according Middlebury color code
:return: Color wheel
"""
RY = 15
YG = 6
GC = 4
CB = 11
BM = 13
MR = 6
ncols = RY + YG + GC + CB + BM + MR
colorwheel = np.zeros([ncols, 3])
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
col += RY
# YG
colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
colorwheel[col:col+YG, 1] = 255
col += YG
# GC
colorwheel[col:col+GC, 1] = 255
colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
col += GC
# CB
colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
colorwheel[col:col+CB, 2] = 255
col += CB
# BM
colorwheel[col:col+BM, 2] = 255
colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
col += + BM
# MR
colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
colorwheel[col:col+MR, 0] = 255
return colorwheel
def compute_color(u, v):
"""
compute optical flow color map
:param u: optical flow horizontal map
:param v: optical flow vertical map
:return: optical flow in color code
"""
[h, w] = u.shape
img = np.zeros([h, w, 3])
nanIdx = np.isnan(u) | np.isnan(v)
u[nanIdx] = 0
v[nanIdx] = 0
colorwheel = make_color_wheel()
ncols = np.size(colorwheel, 0)
rad = np.sqrt(u**2+v**2)
a = np.arctan2(-v, -u) / np.pi
fk = (a+1) / 2 * (ncols - 1) + 1
k0 = np.floor(fk).astype(int)
k1 = k0 + 1
k1[k1 == ncols+1] = 1
f = fk - k0
for i in range(0, np.size(colorwheel,1)):
tmp = colorwheel[:, i]
col0 = tmp[k0-1] / 255
col1 = tmp[k1-1] / 255
col = (1-f) * col0 + f * col1
idx = rad <= 1
col[idx] = 1-rad[idx]*(1-col[idx])
notidx = np.logical_not(idx)
col[notidx] *= 0.75
img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
return img
# from https://github.com/gengshan-y/VCN
def flow_to_image(flow):
"""
Convert flow into middlebury color code image
:param flow: optical flow map
:return: optical flow image in middlebury color
"""
u = flow[:, :, 0]
v = flow[:, :, 1]
maxu = -999.
maxv = -999.
minu = 999.
minv = 999.
idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
u[idxUnknow] = 0
v[idxUnknow] = 0
maxu = max(maxu, np.max(u))
minu = min(minu, np.min(u))
maxv = max(maxv, np.max(v))
minv = min(minv, np.min(v))
rad = np.sqrt(u ** 2 + v ** 2)
maxrad = max(-1, np.max(rad))
u = u/(maxrad + np.finfo(float).eps)
v = v/(maxrad + np.finfo(float).eps)
img = compute_color(u, v)
idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
img[idx] = 0
return np.uint8(img)

124
core/utils/frame_utils.py Normal file
View File

@ -0,0 +1,124 @@
import numpy as np
from PIL import Image
from os.path import *
import re
import cv2
TAG_CHAR = np.array([202021.25], np.float32)
def readFlow(fn):
""" Read .flo file in Middlebury format"""
# Code adapted from:
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
# print 'fn = %s'%(fn)
with open(fn, 'rb') as f:
magic = np.fromfile(f, np.float32, count=1)
if 202021.25 != magic:
print('Magic number incorrect. Invalid .flo file')
return None
else:
w = np.fromfile(f, np.int32, count=1)
h = np.fromfile(f, np.int32, count=1)
# print 'Reading %d x %d flo file\n' % (w, h)
data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
# Reshape data into 3D array (columns, rows, bands)
# The reshape here is for visualization, the original code is (w,h,2)
return np.resize(data, (int(h), int(w), 2))
def readPFM(file):
file = open(file, 'rb')
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().rstrip()
if header == b'PF':
color = True
elif header == b'Pf':
color = False
else:
raise Exception('Not a PFM file.')
dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
if dim_match:
width, height = map(int, dim_match.groups())
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().rstrip())
if scale < 0: # little-endian
endian = '<'
scale = -scale
else:
endian = '>' # big-endian
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data
def writeFlow(filename,uv,v=None):
""" Write optical flow to file.
If v is None, uv is assumed to contain both u and v channels,
stacked in depth.
Original code by Deqing Sun, adapted from Daniel Scharstein.
"""
nBands = 2
if v is None:
assert(uv.ndim == 3)
assert(uv.shape[2] == 2)
u = uv[:,:,0]
v = uv[:,:,1]
else:
u = uv
assert(u.shape == v.shape)
height,width = u.shape
f = open(filename,'wb')
# write the header
f.write(TAG_CHAR)
np.array(width).astype(np.int32).tofile(f)
np.array(height).astype(np.int32).tofile(f)
# arrange into matrix form
tmp = np.zeros((height, width*nBands))
tmp[:,np.arange(width)*2] = u
tmp[:,np.arange(width)*2 + 1] = v
tmp.astype(np.float32).tofile(f)
f.close()
def readFlowKITTI(filename):
flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
flow = flow[:,:,::-1].astype(np.float32)
flow, valid = flow[:, :, :2], flow[:, :, 2]
flow = (flow - 2**15) / 64.0
return flow, valid
def writeFlowKITTI(filename, uv):
uv = 64.0 * uv + 2**15
valid = np.ones([uv.shape[0], uv.shape[1], 1])
uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
cv2.imwrite(filename, uv[..., ::-1])
def read_gen(file_name, pil=False):
ext = splitext(file_name)[-1]
if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
return Image.open(file_name)
elif ext == '.bin' or ext == '.raw':
return np.load(file_name)
elif ext == '.flo':
return readFlow(file_name).astype(np.float32)
elif ext == '.pfm':
flow = readPFM(file_name).astype(np.float32)
return flow[:, :, :-1]
return []

62
core/utils/utils.py Normal file
View File

@ -0,0 +1,62 @@
import torch
import torch.nn.functional as F
import numpy as np
from scipy import interpolate
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def forward_interpolate(flow):
flow = flow.detach().cpu().numpy()
dx, dy = flow[0], flow[1]
ht, wd = dx.shape
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
x1 = x0 + dx
y1 = y0 + dy
x1 = x1.reshape(-1)
y1 = y1.reshape(-1)
dx = dx.reshape(-1)
dy = dy.reshape(-1)
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
x1 = x1[valid]
y1 = y1[valid]
dx = dx[valid]
dy = dy[valid]
flow_x = interpolate.griddata(
(x1, y1), dx, (x0, y0), method='nearest')
flow_y = interpolate.griddata(
(x1, y1), dy, (x0, y0), method='nearest')
flow = np.stack([flow_x, flow_y], axis=0)
return torch.from_numpy(flow).float()
def coords_grid(batch, ht, wd):
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def upflow8(flow, mode='bilinear'):
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)

90
demo.py Normal file
View File

@ -0,0 +1,90 @@
import sys
sys.path.append('core')
import argparse
import os
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import datasets
from utils import flow_viz
from raft import RAFT
DEVICE = 'cuda'
def pad8(img):
"""pad image such that dimensions are divisible by 8"""
ht, wd = img.shape[2:]
pad_ht = (((ht // 8) + 1) * 8 - ht) % 8
pad_wd = (((wd // 8) + 1) * 8 - wd) % 8
pad_ht1 = [pad_ht//2, pad_ht-pad_ht//2]
pad_wd1 = [pad_wd//2, pad_wd-pad_wd//2]
img = F.pad(img, pad_wd1 + pad_ht1, mode='replicate')
return img
def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)[..., :3]
img = torch.from_numpy(img).permute(2, 0, 1).float()
return pad8(img[None]).to(DEVICE)
def display(image1, image2, flow):
image1 = image1.permute(1, 2, 0).cpu().numpy() / 255.0
image2 = image2.permute(1, 2, 0).cpu().numpy() / 255.0
flow = flow.permute(1, 2, 0).cpu().numpy()
flow_image = flow_viz.flow_to_image(flow)
flow_image = cv2.resize(flow_image, (image1.shape[1], image1.shape[0]))
cv2.imshow('image1', image1[..., ::-1])
cv2.imshow('image2', image2[..., ::-1])
cv2.imshow('flow', flow_image[..., ::-1])
cv2.waitKey()
def demo(args):
model = RAFT(args)
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(args.model))
model.to(DEVICE)
model.eval()
with torch.no_grad():
# sintel images
image1 = load_image('images/sintel_0.png')
image2 = load_image('images/sintel_1.png')
flow_predictions = model(image1, image2, iters=args.iters, upsample=False)
display(image1[0], image2[0], flow_predictions[-1][0])
# kitti images
image1 = load_image('images/kitti_0.png')
image2 = load_image('images/kitti_1.png')
flow_predictions = model(image1, image2, iters=16)
display(image1[0], image2[0], flow_predictions[-1][0])
# davis images
image1 = load_image('images/davis_0.jpg')
image2 = load_image('images/davis_1.jpg')
flow_predictions = model(image1, image2, iters=16)
display(image1[0], image2[0], flow_predictions[-1][0])
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', help="restore checkpoint")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--iters', type=int, default=12)
args = parser.parse_args()
demo(args)

100
evaluate.py Normal file
View File

@ -0,0 +1,100 @@
import sys
sys.path.append('core')
from PIL import Image
import cv2
import argparse
import os
import time
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import datasets
from utils import flow_viz
from raft import RAFT
def validate_sintel(args, model, iters=50):
""" Evaluate trained model on Sintel(train) clean + final passes """
model.eval()
pad = 2
for dstype in ['clean', 'final']:
val_dataset = datasets.MpiSintel(args, do_augument=False, dstype=dstype)
epe_list = []
for i in range(len(val_dataset)):
image1, image2, flow_gt, _ = val_dataset[i]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
image1 = F.pad(image1, [0, 0, pad, pad], mode='replicate')
image2 = F.pad(image2, [0, 0, pad, pad], mode='replicate')
with torch.no_grad():
flow_predictions = model.module(image1, image2, iters=iters)
flow_pr = flow_predictions[-1][0,:,pad:-pad]
epe = torch.sum((flow_pr - flow_gt.cuda())**2, dim=0)
epe = torch.sqrt(epe).mean()
epe_list.append(epe.item())
print("Validation (%s) EPE: %f" % (dstype, np.mean(epe_list)))
def validate_kitti(args, model, iters=32):
""" Evaluate trained model on KITTI (train) """
model.eval()
val_dataset = datasets.KITTI(args, do_augument=False, is_val=True, do_pad=True)
with torch.no_grad():
epe_list, out_list = [], []
for i in range(len(val_dataset)):
image1, image2, flow_gt, valid_gt = val_dataset[i]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
flow_gt = flow_gt.cuda()
valid_gt = valid_gt.cuda()
flow_predictions = model.module(image1, image2, iters=iters)
flow_pr = flow_predictions[-1][0]
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
mag = torch.sum(flow_gt**2, dim=0).sqrt()
epe = epe.view(-1)
mag = mag.view(-1)
val = valid_gt.view(-1) >= 0.5
out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
epe_list.append(epe[val].mean().item())
out_list.append(out[val].cpu().numpy())
epe_list = np.array(epe_list)
out_list = np.concatenate(out_list)
print("Validation KITTI: %f, %f" % (np.mean(epe_list), 100*np.mean(out_list)))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', help="restore checkpoint")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--sintel_iters', type=int, default=50)
parser.add_argument('--kitti_iters', type=int, default=32)
args = parser.parse_args()
model = RAFT(args)
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(args.model))
model.to('cuda')
model.eval()
validate_sintel(args, model, args.sintel_iters)
validate_kitti(args, model, args.kitti_iters)

BIN
images/davis_0.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 497 KiB

BIN
images/davis_1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 514 KiB

BIN
images/kitti_0.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 829 KiB

BIN
images/kitti_1.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 822 KiB

BIN
images/sintel_0.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 396 KiB

BIN
images/sintel_1.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 388 KiB

3
scripts/download_models.sh Executable file
View File

@ -0,0 +1,3 @@
#!/bin/bash
wget https://www.dropbox.com/s/a2acvmczgzm6f9n/models.zip
unzip models.zip

211
train.py Executable file
View File

@ -0,0 +1,211 @@
from __future__ import print_function, division
import sys
sys.path.append('core')
import argparse
import os
import cv2
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from raft import RAFT
from evaluate import validate_sintel, validate_kitti
import datasets
# exclude extremly large displacements
MAX_FLOW = 1000
SUM_FREQ = 1000
VAL_FREQ = 5000
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def sequence_loss(flow_preds, flow_gt, valid):
""" Loss function defined over sequence of flow predictions """
n_predictions = len(flow_preds)
flow_loss = 0.0
# exlude invalid pixels and extremely large diplacements
valid = (valid >= 0.5) & (flow_gt.abs().sum(dim=1) < MAX_FLOW)
for i in range(n_predictions):
i_weight = 0.8**(n_predictions - i - 1)
i_loss = (flow_preds[i] - flow_gt).abs()
flow_loss += i_weight * (valid[:, None] * i_loss).mean()
epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
epe = epe.view(-1)[valid.view(-1)]
metrics = {
'epe': epe.mean().item(),
'1px': (epe < 1).float().mean().item(),
'3px': (epe < 3).float().mean().item(),
'5px': (epe < 5).float().mean().item(),
}
return flow_loss, metrics
def fetch_dataloader(args):
""" Create the data loader for the corresponding trainign set """
if args.dataset == 'chairs':
train_dataset = datasets.FlyingChairs(args, image_size=args.image_size)
elif args.dataset == 'things':
clean_dataset = datasets.SceneFlow(args, image_size=args.image_size, dstype='frames_cleanpass')
final_dataset = datasets.SceneFlow(args, image_size=args.image_size, dstype='frames_finalpass')
train_dataset = clean_dataset + final_dataset
elif args.dataset == 'sintel':
clean_dataset = datasets.MpiSintel(args, image_size=args.image_size, dstype='clean')
final_dataset = datasets.MpiSintel(args, image_size=args.image_size, dstype='final')
train_dataset = clean_dataset + final_dataset
elif args.dataset == 'kitti':
train_dataset = datasets.KITTI(args, image_size=args.image_size, is_val=False)
gpuargs = {'num_workers': 4, 'drop_last' : True}
train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=True, shuffle=True, **gpuargs)
print('Training with %d image pairs' % len(train_dataset))
return train_loader
def fetch_optimizer(args, model):
""" Create the optimizer and learning rate scheduler """
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps,
pct_start=0.2, cycle_momentum=False, anneal_strategy='linear', final_div_factor=0.05)
return optimizer, scheduler
class Logger:
def __init__(self, model, scheduler):
self.model = model
self.scheduler = scheduler
self.total_steps = 0
self.running_loss = {}
def _print_training_status(self):
metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())]
training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_lr()[0])
metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
# print the training status
print(training_str + metrics_str)
for key in self.running_loss:
self.running_loss[key] = 0.0
def push(self, metrics):
self.total_steps += 1
for key in metrics:
if key not in self.running_loss:
self.running_loss[key] = 0.0
self.running_loss[key] += metrics[key]
if self.total_steps % SUM_FREQ == SUM_FREQ-1:
self._print_training_status()
self.running_loss = {}
def train(args):
model = RAFT(args)
model = nn.DataParallel(model)
print("Parameter Count: %d" % count_parameters(model))
if args.restore_ckpt is not None:
model.load_state_dict(torch.load(args.restore_ckpt))
model.cuda()
model.train()
if 'chairs' not in args.dataset:
model.module.freeze_bn()
train_loader = fetch_dataloader(args)
optimizer, scheduler = fetch_optimizer(args, model)
total_steps = 0
logger = Logger(model, scheduler)
should_keep_training = True
while should_keep_training:
for i_batch, data_blob in enumerate(train_loader):
image1, image2, flow, valid = [x.cuda() for x in data_blob]
optimizer.zero_grad()
flow_predictions = model(image1, image2, iters=args.iters)
loss, metrics = sequence_loss(flow_predictions, flow, valid)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()
scheduler.step()
total_steps += 1
logger.push(metrics)
if total_steps % VAL_FREQ == VAL_FREQ-1:
PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name)
torch.save(model.state_dict(), PATH)
if total_steps == args.num_steps:
should_keep_training = False
break
PATH = 'checkpoints/%s.pth' % args.name
torch.save(model.state_dict(), PATH)
return PATH
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--name', default='bla', help="name your experiment")
parser.add_argument('--dataset', help="which dataset to use for training")
parser.add_argument('--restore_ckpt', help="restore checkpoint")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--lr', type=float, default=0.00002)
parser.add_argument('--num_steps', type=int, default=100000)
parser.add_argument('--batch_size', type=int, default=6)
parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
parser.add_argument('--iters', type=int, default=12)
parser.add_argument('--wdecay', type=float, default=.00005)
parser.add_argument('--epsilon', type=float, default=1e-8)
parser.add_argument('--clip', type=float, default=1.0)
parser.add_argument('--dropout', type=float, default=0.0)
args = parser.parse_args()
torch.manual_seed(1234)
np.random.seed(1234)
if not os.path.isdir('checkpoints'):
os.mkdir('checkpoints')
# scale learning rate and batch size by number of GPUs
num_gpus = torch.cuda.device_count()
args.batch_size = args.batch_size * num_gpus
args.lr = args.lr * num_gpus
train(args)