added upsampling module

This commit is contained in:
Zach Teed 2020-07-25 17:36:17 -06:00
parent dc1220825d
commit a2408eab78
32 changed files with 23559 additions and 619 deletions

111
README.md
View File

@ -1,7 +1,4 @@
# RAFT
**7/22/2020: We have updated our method to predict flow at full resolution leading to improved results on public benchmarks. This repository will be updated to reflect these changes within the next few days.**
This repository contains the source code for our paper:
[RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)<br/>
@ -11,90 +8,72 @@ Zachary Teed and Jia Deng<br/>
<img src="RAFT.png">
## Requirements
Our code was tested using PyTorch 1.3.1 and Python 3. The following additional packages need to be installed
```Shell
pip install Pillow
pip install scipy
pip install opencv-python
```
The code has been tested with PyTorch 1.5.1 and PyTorch Nightly. If you want to train with mixed precision, you will have to install the nightly build.
```Shell
conda create --name raft
conda activate raft
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch-nightly
conda install matplotlib
conda install tensorboard
conda install scipy
conda install opencv
```
## Demos
Pretrained models can be downloaded by running
```Shell
./scripts/download_models.sh
```
or downloaded from [google drive](https://drive.google.com/file/d/10-BYgHqRNPGvmNUWr8razjb1xHu55pyA/view?usp=sharing)
You can run the demos using one of the available models.
You can demo a trained model on a sequence of frames
```Shell
python demo.py --model=models/chairs+things.pth
python demo.py --model=models/raft-things.pth --path=demo-frames
```
or using the small (1M parameter) model
```Shell
python demo.py --model=models/small.pth --small
```
## Required Data
To evaluate/train RAFT, you will need to download the required datasets.
* [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
* [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
* [Sintel](http://sintel.is.tue.mpg.de/)
* [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow)
* [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional)
Running the demos will display the two images and a vizualization of the optical flow estimate. After the images display, press any key to continue.
## Training
To train RAFT, you will need to download the required datasets. The first stage of training requires the [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) and [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) datasets. Finetuning and evaluation require the [Sintel](http://sintel.is.tue.mpg.de/) and [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) datasets. We organize the directory structure as follows. By default `datasets.py` will search for the datasets in these locations
By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder
```Shell
├── datasets
├── Sintel
| | ├── test
| | ├── training
├── KITTI
| | ├── testing
| | ├── training
| | ├── devkit
├── FlyingChairs_release
| | ├── data
├── FlyingThings3D
| | ├── frames_cleanpass
| | ├── frames_finalpass
| | ├── optical_flow
├── Sintel
├── test
├── training
├── KITTI
├── testing
├── training
├── devkit
├── FlyingChairs_release
├── data
├── FlyingThings3D
├── frames_cleanpass
├── frames_finalpass
├── optical_flow
```
We used the following training schedule in our paper (note: we use 2 GPUs for training)
```Shell
python train.py --name=chairs --image_size 368 496 --dataset=chairs --num_steps=100000 --lr=0.0002 --batch_size=6
```
Next, finetune on the FlyingThings dataset
```Shell
python train.py --name=things --image_size 368 768 --dataset=things --num_steps=60000 --lr=0.00005 --batch_size=3 --restore_ckpt=checkpoints/chairs.pth
```
You can perform dataset specific finetuning
### Sintel
```Shell
python train.py --name=sintel_ft --image_size 368 768 --dataset=sintel --num_steps=60000 --lr=0.00005 --batch_size=4 --restore_ckpt=checkpoints/things.pth
```
### KITTI
```Shell
python train.py --name=kitti_ft --image_size 288 896 --dataset=kitti --num_steps=40000 --lr=0.0001 --batch_size=4 --restore_ckpt=checkpoints/things.pth
```
## Evaluation
You can evaluate a model on Sintel and KITTI by running
You can evaluate a trained model using `evaluate.py`
```Shell
python evaluate.py --model=models/chairs+things.pth
python evaluate.py --model=models/raft-things.pth --dataset=sintel
```
or the small model by including the `small` flag
## Training
Training code will be made available in the next few days
<!-- We used the following training schedule in our paper (note: we use 2 GPUs for training). Training logs will be written to the `runs` which can be visualized using tensorboard
```Shell
python evaluate.py --model=models/small.pth --small
./train_standard.sh
```
If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU)
```Shell
./train_mixed.sh
``` -->

22872
chairs_split.txt Normal file

File diff suppressed because it is too large Load Diff

View File

@ -2,6 +2,7 @@ import torch
import torch.nn.functional as F
from utils.utils import bilinear_sampler, coords_grid
class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
@ -12,10 +13,10 @@ class CorrBlock:
corr = CorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.view(batch*h1*w1, dim, h2, w2)
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
self.corr_pyramid.append(corr)
for i in range(self.num_levels):
for i in range(self.num_levels-1):
corr = F.avg_pool2d(corr, 2, stride=2)
self.corr_pyramid.append(corr)
@ -40,7 +41,8 @@ class CorrBlock:
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2)
return out.permute(0, 3, 1, 2).contiguous().float()
@staticmethod
def corr(fmap1, fmap2):
@ -50,4 +52,5 @@ class CorrBlock:
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())
return corr / torch.sqrt(torch.tensor(dim).float())

View File

@ -6,53 +6,42 @@ import torch.utils.data as data
import torch.nn.functional as F
import os
import cv2
import math
import random
from glob import glob
import os.path as osp
from utils import frame_utils
from utils.augmentor import FlowAugmentor, FlowAugmentorKITTI
from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
class CombinedDataset(data.Dataset):
def __init__(self, datasets):
self.datasets = datasets
def __len__(self):
length = 0
for i in range(len(self.datasets)):
length += len(self.datsaets[i])
return length
def __getitem__(self, index):
i = 0
for j in range(len(self.datasets)):
if i + len(self.datasets[j]) >= index:
yield self.datasets[j][index-i]
break
i += len(self.datasets[j])
def __add__(self, other):
self.datasets.append(other)
return self
class FlowDataset(data.Dataset):
def __init__(self, args, image_size=None, do_augument=False):
self.image_size = image_size
self.do_augument = do_augument
if self.do_augument:
self.augumentor = FlowAugmentor(self.image_size)
def __init__(self, aug_params=None, sparse=False):
self.augmentor = None
self.sparse = sparse
if aug_params is not None:
if sparse:
self.augmentor = SparseFlowAugmentor(**aug_params)
else:
self.augmentor = FlowAugmentor(**aug_params)
self.is_test = False
self.init_seed = False
self.flow_list = []
self.image_list = []
self.init_seed = False
self.extra_info = []
def __getitem__(self, index):
if self.is_test:
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
return img1, img2, self.extra_info[index]
if not self.init_seed:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
@ -62,133 +51,96 @@ class FlowDataset(data.Dataset):
self.init_seed = True
index = index % len(self.image_list)
flow = frame_utils.read_gen(self.flow_list[index])
valid = None
if self.sparse:
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
else:
flow = frame_utils.read_gen(self.flow_list[index])
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
flow = np.array(flow).astype(np.float32)
img1 = np.array(img1).astype(np.uint8)
img2 = np.array(img2).astype(np.uint8)
if self.do_augument:
img1, img2, flow = self.augumentor(img1, img2, flow)
# grayscale images
if len(img1.shape) == 2:
img1 = np.tile(img1[...,None], (1, 1, 3))
img2 = np.tile(img2[...,None], (1, 1, 3))
else:
img1 = img1[..., :3]
img2 = img2[..., :3]
if self.augmentor is not None:
if self.sparse:
img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
else:
img1, img2, flow = self.augmentor(img1, img2, flow)
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
valid = torch.ones_like(flow[0])
return img1, img2, flow, valid
if valid is not None:
valid = torch.from_numpy(valid)
else:
valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
return img1, img2, flow, valid.float()
def __rmul__(self, v):
self.flow_list = v * self.flow_list
self.image_list = v * self.image_list
return self
def __len__(self):
return len(self.image_list)
def __add(self, other):
return CombinedDataset([self, other])
class MpiSintelTest(FlowDataset):
def __init__(self, args, root='datasets/Sintel/test', dstype='clean'):
super(MpiSintelTest, self).__init__(args, image_size=None, do_augument=False)
self.root = root
self.dstype = dstype
image_dir = osp.join(self.root, dstype)
all_sequences = os.listdir(image_dir)
self.image_list = []
for sequence in all_sequences:
frames = sorted(glob(osp.join(image_dir, sequence, '*.png')))
for i in range(len(frames)-1):
self.image_list += [[frames[i], frames[i+1], sequence, i]]
def __getitem__(self, index):
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
sequence = self.image_list[index][2]
frame = self.image_list[index][3]
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
return img1, img2, sequence, frame
class MpiSintel(FlowDataset):
def __init__(self, args, image_size=None, do_augument=True, root='datasets/Sintel/training', dstype='clean'):
super(MpiSintel, self).__init__(args, image_size, do_augument)
if do_augument:
self.augumentor.min_scale = -0.2
self.augumentor.max_scale = 0.7
def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
super(MpiSintel, self).__init__(aug_params)
flow_root = osp.join(root, split, 'flow')
image_root = osp.join(root, split, dstype)
self.root = root
self.dstype = dstype
if split == 'test':
self.is_test = True
flow_root = osp.join(root, 'flow')
image_root = osp.join(root, dstype)
for scene in os.listdir(image_root):
image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
for i in range(len(image_list)-1):
self.image_list += [ [image_list[i], image_list[i+1]] ]
self.extra_info += [ (scene, i) ] # scene and frame_id
file_list = sorted(glob(osp.join(flow_root, '*/*.flo')))
for flo in file_list:
fbase = flo[len(flow_root)+1:]
fprefix = fbase[:-8]
fnum = int(fbase[-8:-4])
img1 = osp.join(image_root, fprefix + "%04d"%(fnum+0) + '.png')
img2 = osp.join(image_root, fprefix + "%04d"%(fnum+1) + '.png')
if not osp.isfile(img1) or not osp.isfile(img2) or not osp.isfile(flo):
continue
self.image_list.append((img1, img2))
self.flow_list.append(flo)
if split != 'test':
self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
class FlyingChairs(FlowDataset):
def __init__(self, args, image_size=None, do_augument=True, root='datasets/FlyingChairs_release/data'):
super(FlyingChairs, self).__init__(args, image_size, do_augument)
self.root = root
self.augumentor.min_scale = -0.2
self.augumentor.max_scale = 1.0
def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
super(FlyingChairs, self).__init__(aug_params)
images = sorted(glob(osp.join(root, '*.ppm')))
self.flow_list = sorted(glob(osp.join(root, '*.flo')))
assert (len(images)//2 == len(self.flow_list))
flows = sorted(glob(osp.join(root, '*.flo')))
assert (len(images)//2 == len(flows))
self.image_list = []
for i in range(len(self.flow_list)):
im1 = images[2*i]
im2 = images[2*i + 1]
self.image_list.append([im1, im2])
split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
for i in range(len(flows)):
xid = split_list[i]
if (split=='training' and xid==1) or (split=='validation' and xid==2):
self.flow_list += [ flows[i] ]
self.image_list += [ [images[2*i], images[2*i+1]] ]
class SceneFlow(FlowDataset):
def __init__(self, args, image_size, do_augument=True, root='datasets',
dstype='frames_cleanpass', use_flyingthings=True, use_monkaa=False, use_driving=False):
super(SceneFlow, self).__init__(args, image_size, do_augument)
self.root = root
self.dstype = dstype
self.augumentor.min_scale = -0.2
self.augumentor.max_scale = 0.8
if use_flyingthings:
self.add_flyingthings()
if use_monkaa:
self.add_monkaa()
if use_driving:
self.add_driving()
def add_flyingthings(self):
root = osp.join(self.root, 'FlyingThings3D')
class FlyingThings3D(FlowDataset):
def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
super(FlyingThings3D, self).__init__(aug_params)
for cam in ['left']:
for direction in ['into_future', 'into_past']:
image_dirs = sorted(glob(osp.join(root, self.dstype, 'TRAIN/*/*')))
image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
@ -199,114 +151,85 @@ class SceneFlow(FlowDataset):
flows = sorted(glob(osp.join(fdir, '*.pfm')) )
for i in range(len(flows)-1):
if direction == 'into_future':
self.image_list += [[images[i], images[i+1]]]
self.flow_list += [flows[i]]
self.image_list += [ [images[i], images[i+1]] ]
self.flow_list += [ flows[i] ]
elif direction == 'into_past':
self.image_list += [[images[i+1], images[i]]]
self.flow_list += [flows[i+1]]
def add_monkaa(self):
pass # we don't use monkaa
def add_driving(self):
pass # we don't use driving
self.image_list += [ [images[i+1], images[i]] ]
self.flow_list += [ flows[i+1] ]
class KITTI(FlowDataset):
def __init__(self, args, image_size=None, do_augument=True, is_test=False, is_val=False, do_pad=False, split=True, root='datasets/KITTI'):
super(KITTI, self).__init__(args, image_size, do_augument)
self.root = root
self.is_test = is_test
self.is_val = is_val
self.do_pad = do_pad
def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
super(KITTI, self).__init__(aug_params, sparse=True)
if split == 'testing':
self.is_test = True
if self.do_augument:
self.augumentor = FlowAugmentorKITTI(self.image_size, min_scale=-0.2, max_scale=0.5)
root = osp.join(root, split)
images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
if self.is_test:
images1 = sorted(glob(os.path.join(root, 'testing', 'image_2/*_10.png')))
images2 = sorted(glob(os.path.join(root, 'testing', 'image_2/*_11.png')))
for i in range(len(images1)):
self.image_list += [[images1[i], images2[i]]]
for img1, img2 in zip(images1, images2):
frame_id = img1.split('/')[-1]
self.extra_info += [ [frame_id] ]
self.image_list += [ [img1, img2] ]
else:
flows = sorted(glob(os.path.join(root, 'training', 'flow_occ/*_10.png')))
images1 = sorted(glob(os.path.join(root, 'training', 'image_2/*_10.png')))
images2 = sorted(glob(os.path.join(root, 'training', 'image_2/*_11.png')))
if split == 'training':
self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
for i in range(len(flows)):
class HD1K(FlowDataset):
def __init__(self, aug_params=None, root='datasets/HD1k'):
super(HD1K, self).__init__(aug_params, sparse=True)
seq_ix = 0
while 1:
flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
if len(flows) == 0:
break
for i in range(len(flows)-1):
self.flow_list += [flows[i]]
self.image_list += [[images1[i], images2[i]]]
self.image_list += [ [images[i], images[i+1]] ]
seq_ix += 1
def __getitem__(self, index):
def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
""" Create the data loader for the corresponding trainign set """
if self.is_test:
frame_id = self.image_list[index][0]
frame_id = frame_id.split('/')[-1]
if args.stage == 'chairs':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 1.0, 'do_flip': True}
train_dataset = FlyingChairs(aug_params, split='training')
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
elif args.stage == 'things':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
train_dataset = clean_dataset + final_dataset
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
elif args.stage == 'sintel':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True}
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
sintel_final = MpiSintel(aug_params, split='training', dstype='final')
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
return img1, img2, frame_id
if TRAIN_DS == 'C+T+K+S+H':
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True})
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.5, 'do_flip': True})
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
elif TRAIN_DS == 'C+T+K/S':
train_dataset = 100*sintel_clean + 100*sintel_final + things
else:
if not self.init_seed:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
np.random.seed(worker_info.id)
random.seed(worker_info.id)
self.init_seed = True
elif args.stage == 'kitti':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
train_dataset = KITTI(args, image_size=args.image_size, is_val=False)
index = index % len(self.image_list)
frame_id = self.image_list[index][0]
frame_id = frame_id.split('/')[-1]
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
print('Training with %d image pairs' % len(train_dataset))
return train_loader
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
if self.do_augument:
img1, img2, flow, valid = self.augumentor(img1, img2, flow, valid)
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
valid = torch.from_numpy(valid).float()
if self.do_pad:
ht, wd = img1.shape[1:]
pad_ht = (((ht // 8) + 1) * 8 - ht) % 8
pad_wd = (((wd // 8) + 1) * 8 - wd) % 8
pad_ht1 = [0, pad_ht]
pad_wd1 = [pad_wd//2, pad_wd - pad_wd//2]
pad = pad_wd1 + pad_ht1
img1 = img1.view(1, 3, ht, wd)
img2 = img2.view(1, 3, ht, wd)
flow = flow.view(1, 2, ht, wd)
valid = valid.view(1, 1, ht, wd)
img1 = torch.nn.functional.pad(img1, pad, mode='replicate')
img2 = torch.nn.functional.pad(img2, pad, mode='replicate')
flow = torch.nn.functional.pad(flow, pad, mode='constant', value=0)
valid = torch.nn.functional.pad(valid, pad, mode='replicate', value=0)
img1 = img1.view(3, ht+pad_ht, wd+pad_wd)
img2 = img2.view(3, ht+pad_ht, wd+pad_wd)
flow = flow.view(2, ht+pad_ht, wd+pad_wd)
valid = valid.view(ht+pad_ht, wd+pad_wd)
if self.is_test:
return img1, img2, flow, valid, frame_id
return img1, img2, flow, valid

View File

@ -143,10 +143,9 @@ class BasicEncoder(nn.Module):
# output convolution
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
else:
self.dropout = None
for m in self.modules():
if isinstance(m, nn.Conv2d):
@ -184,7 +183,7 @@ class BasicEncoder(nn.Module):
x = self.conv2(x)
if self.dropout is not None:
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
@ -218,10 +217,9 @@ class SmallEncoder(nn.Module):
self.layer2 = self._make_layer(64, stride=2)
self.layer3 = self._make_layer(96, stride=2)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
else:
self.dropout = None
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
@ -260,8 +258,8 @@ class SmallEncoder(nn.Module):
x = self.layer3(x)
x = self.conv2(x)
# if self.dropout is not None:
# x = self.dropout(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)

View File

@ -3,11 +3,23 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from modules.update import BasicUpdateBlock, SmallUpdateBlock
from modules.extractor import BasicEncoder, SmallEncoder
from modules.corr import CorrBlock
from update import BasicUpdateBlock, SmallUpdateBlock
from extractor import BasicEncoder, SmallEncoder
from corr import CorrBlock
from utils.utils import bilinear_sampler, coords_grid, upflow8
try:
autocast = torch.cuda.amp.autocast
except:
# dummy autocast for PyTorch < 1.6
class autocast:
def __init__(self, enabled):
pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
class RAFT(nn.Module):
def __init__(self, args):
@ -26,7 +38,7 @@ class RAFT(nn.Module):
args.corr_levels = 4
args.corr_radius = 4
if not hasattr(args, 'dropout'):
if 'dropout' not in args._get_kwargs():
args.dropout = 0
# feature network, context network, and update block
@ -40,6 +52,7 @@ class RAFT(nn.Module):
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
@ -54,46 +67,73 @@ class RAFT(nn.Module):
# optical flow computed as difference: flow = coords1 - coords0
return coords0, coords1
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True):
def upsample_flow(self, flow, mask):
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
N, _, H, W = flow.shape
mask = mask.view(N, 1, 9, 8, 8, H, W)
mask = torch.softmax(mask, dim=2)
up_flow = F.unfold(8 * flow, [3,3], padding=1)
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
up_flow = torch.sum(mask * up_flow, dim=2)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
return up_flow.reshape(N, 2, 8*H, 8*W)
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
""" Estimate optical flow between pair of frames """
image1 = 2 * (image1 / 255.0) - 1.0
image2 = 2 * (image2 / 255.0) - 1.0
image1 = image1.contiguous()
image2 = image2.contiguous()
hdim = self.hidden_dim
cdim = self.context_dim
# run the feature network
fmap1, fmap2 = self.fnet([image1, image2])
with autocast(enabled=self.args.mixed_precision):
fmap1, fmap2 = self.fnet([image1, image2])
fmap1 = fmap1.float()
fmap2 = fmap2.float()
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
# run the context network
cnet = self.cnet(image1)
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
net, inp = torch.tanh(net), torch.relu(inp)
with autocast(enabled=self.args.mixed_precision):
cnet = self.cnet(image1)
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
net = torch.tanh(net)
inp = torch.relu(inp)
# if dropout is being used reset mask
self.update_block.reset_mask(net, inp)
coords0, coords1 = self.initialize_flow(image1)
if flow_init is not None:
coords1 = coords1 + flow_init
flow_predictions = []
for itr in range(iters):
coords1 = coords1.detach()
corr = corr_fn(coords1) # index correlation volume
flow = coords1 - coords0
net, delta_flow = self.update_block(net, inp, corr, flow)
with autocast(enabled=self.args.mixed_precision):
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
# F(t+1) = F(t) + \Delta(t)
coords1 = coords1 + delta_flow
if upsample:
# upsample predictions
if up_mask is None:
flow_up = upflow8(coords1 - coords0)
flow_predictions.append(flow_up)
else:
flow_predictions.append(coords1 - coords0)
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
flow_predictions.append(flow_up)
if test_mode:
return coords1 - coords0, flow_up
return flow_predictions

View File

@ -2,34 +2,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
# VariationalHidDropout from https://github.com/locuslab/trellisnet/tree/master/TrellisNet
class VariationalHidDropout(nn.Module):
def __init__(self, dropout=0.0):
"""
Hidden-to-hidden (VD-based) dropout that applies the same mask at every time step and every layer of TrellisNet
:param dropout: The dropout rate (0 means no dropout is applied)
"""
super(VariationalHidDropout, self).__init__()
self.dropout = dropout
self.mask = None
def reset_mask(self, x):
dropout = self.dropout
# Dimension (N, C, L)
n, c, h, w = x.shape
m = x.data.new(n, c, 1, 1).bernoulli_(1 - dropout)
with torch.no_grad():
mask = m / (1 - dropout)
self.mask = mask
return mask
def forward(self, x):
if not self.training or self.dropout == 0:
return x
assert self.mask is not None, "You need to reset mask before using VariationalHidDropout"
return self.mask * x
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
@ -41,7 +13,6 @@ class FlowHead(nn.Module):
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class ConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(ConvGRU, self).__init__()
@ -59,7 +30,6 @@ class ConvGRU(nn.Module):
h = (1-z) * h + z * q
return h
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(SepConvGRU, self).__init__()
@ -133,49 +103,37 @@ class SmallUpdateBlock(nn.Module):
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
self.drop_inp = VariationalHidDropout(dropout=args.dropout)
self.drop_net = VariationalHidDropout(dropout=args.dropout)
def reset_mask(self, net, inp):
self.drop_inp.reset_mask(inp)
self.drop_net.reset_mask(net)
def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)
if self.training:
net = self.drop_net(net)
inp = self.drop_inp(inp)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
return net, delta_flow
return net, None, delta_flow
class BasicUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.drop_inp = VariationalHidDropout(dropout=args.dropout)
self.drop_net = VariationalHidDropout(dropout=args.dropout)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64*9, 1, padding=0))
def reset_mask(self, net, inp):
self.drop_inp.reset_mask(inp)
self.drop_net.reset_mask(net)
def forward(self, net, inp, corr, flow):
def forward(self, net, inp, corr, flow, upsample=True):
motion_features = self.encoder(flow, corr)
if self.training:
net = self.drop_net(net)
inp = self.drop_inp(inp)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
return net, delta_flow
# scale mask to balence gradients
mask = .25 * self.mask(net)
return net, mask, delta_flow

View File

@ -1,46 +1,55 @@
import numpy as np
import random
import math
import cv2
from PIL import Image
import cv2
import torch
import torchvision
from torchvision.transforms import ColorJitter
import torch.nn.functional as F
class FlowAugmentor:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5):
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
# spatial augmentation params
self.crop_size = crop_size
self.augcolor = torchvision.transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.5/3.14)
self.asymmetric_color_aug_prob = 0.2
self.spatial_aug_prob = 0.8
self.eraser_aug_prob = 0.5
self.min_scale = min_scale
self.max_scale = max_scale
self.max_stretch = 0.2
self.spatial_aug_prob = 0.8
self.stretch_prob = 0.8
self.margin = 20
self.max_stretch = 0.2
# flip augmentation params
self.do_flip = do_flip
self.h_flip_prob = 0.5
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
def color_transform(self, img1, img2):
""" Photometric augmentation """
# asymmetric
if np.random.rand() < self.asymmetric_color_aug_prob:
img1 = np.array(self.augcolor(Image.fromarray(img1)), dtype=np.uint8)
img2 = np.array(self.augcolor(Image.fromarray(img2)), dtype=np.uint8)
img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
# symmetric
else:
image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(self.augcolor(Image.fromarray(image_stack)), dtype=np.uint8)
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2
def eraser_transform(self, img1, img2, bounds=[50, 100]):
""" Occlusion augmentation """
ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
@ -55,14 +64,10 @@ class FlowAugmentor:
def spatial_transform(self, img1, img2, flow):
# randomly sample scale
ht, wd = img1.shape[:2]
min_scale = np.maximum(
(self.crop_size[0] + 1) / float(ht),
(self.crop_size[1] + 1) / float(wd))
max_scale = self.max_scale
min_scale = max(min_scale, self.min_scale)
(self.crop_size[0] + 8) / float(ht),
(self.crop_size[1] + 8) / float(wd))
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = scale
@ -81,21 +86,19 @@ class FlowAugmentor:
flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow = flow * [scale_x, scale_y]
if np.random.rand() < 0.5: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
if self.do_flip:
if np.random.rand() < self.h_flip_prob: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
if np.random.rand() < 0.1: # v-flip
img1 = img1[::-1, :]
img2 = img2[::-1, :]
flow = flow[::-1, :] * [1.0, -1.0]
if np.random.rand() < self.v_flip_prob: # v-flip
img1 = img1[::-1, :]
img2 = img2[::-1, :]
flow = flow[::-1, :] * [1.0, -1.0]
y0 = np.random.randint(-self.margin, img1.shape[0] - self.crop_size[0] + self.margin)
x0 = np.random.randint(-self.margin, img1.shape[1] - self.crop_size[1] + self.margin)
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
@ -114,22 +117,29 @@ class FlowAugmentor:
return img1, img2, flow
class FlowAugmentorKITTI:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5):
class SparseFlowAugmentor:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
# spatial augmentation params
self.crop_size = crop_size
self.augcolor = torchvision.transforms.ColorJitter(
brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
self.max_scale = max_scale
self.min_scale = min_scale
self.max_scale = max_scale
self.spatial_aug_prob = 0.8
self.stretch_prob = 0.8
self.max_stretch = 0.2
# flip augmentation params
self.do_flip = do_flip
self.h_flip_prob = 0.5
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
def color_transform(self, img1, img2):
image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(self.augcolor(Image.fromarray(image_stack)), dtype=np.uint8)
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2
@ -198,11 +208,12 @@ class FlowAugmentorKITTI:
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
if np.random.rand() < 0.5: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
valid = valid[:, ::-1]
if self.do_flip:
if np.random.rand() < 0.5: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
valid = valid[:, ::-1]
margin_y = 20
margin_x = 50

View File

@ -103,6 +103,13 @@ def readFlowKITTI(filename):
flow = (flow - 2**15) / 64.0
return flow, valid
def readDispKITTI(filename):
disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
valid = disp > 0.0
flow = np.stack([-disp, np.zeros_like(disp)], -1)
return flow, valid
def writeFlowKITTI(filename, uv):
uv = 64.0 * uv + 2**15
valid = np.ones([uv.shape[0], uv.shape[1], 1])
@ -120,5 +127,8 @@ def read_gen(file_name, pil=False):
return readFlow(file_name).astype(np.float32)
elif ext == '.pfm':
flow = readPFM(file_name).astype(np.float32)
return flow[:, :, :-1]
if len(flow.shape) == 2:
return flow
else:
return flow[:, :, :-1]
return []

View File

@ -4,21 +4,21 @@ import numpy as np
from scipy import interpolate
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
class InputPadder:
""" Pads images such that dimensions are divisible by 8 """
def __init__(self, dims):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
def pad(self, *inputs):
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def unpad(self,x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
return x[..., c[0]:c[1], c[2]:c[3]]
def forward_interpolate(flow):
flow = flow.detach().cpu().numpy()
@ -42,15 +42,33 @@ def forward_interpolate(flow):
dy = dy[valid]
flow_x = interpolate.griddata(
(x1, y1), dx, (x0, y0), method='nearest')
(x1, y1), dx, (x0, y0), method='cubic', fill_value=0)
flow_y = interpolate.griddata(
(x1, y1), dy, (x0, y0), method='nearest')
(x1, y1), dy, (x0, y0), method='cubic', fill_value=0)
flow = np.stack([flow_x, flow_y], axis=0)
return torch.from_numpy(flow).float()
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def coords_grid(batch, ht, wd):
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
coords = torch.stack(coords[::-1], dim=0).float()

BIN
demo-frames/frame_0016.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 652 KiB

BIN
demo-frames/frame_0017.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 652 KiB

BIN
demo-frames/frame_0018.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 652 KiB

BIN
demo-frames/frame_0019.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 653 KiB

BIN
demo-frames/frame_0020.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 655 KiB

BIN
demo-frames/frame_0021.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 657 KiB

BIN
demo-frames/frame_0022.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 658 KiB

BIN
demo-frames/frame_0023.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 659 KiB

BIN
demo-frames/frame_0024.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 660 KiB

BIN
demo-frames/frame_0025.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 660 KiB

87
demo.py
View File

@ -4,87 +4,76 @@ sys.path.append('core')
import argparse
import os
import cv2
import glob
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import datasets
from utils import flow_viz
from raft import RAFT
from utils import flow_viz
from utils.utils import InputPadder
DEVICE = 'cuda'
def pad8(img):
"""pad image such that dimensions are divisible by 8"""
ht, wd = img.shape[2:]
pad_ht = (((ht // 8) + 1) * 8 - ht) % 8
pad_wd = (((wd // 8) + 1) * 8 - wd) % 8
pad_ht1 = [pad_ht//2, pad_ht-pad_ht//2]
pad_wd1 = [pad_wd//2, pad_wd-pad_wd//2]
img = F.pad(img, pad_wd1 + pad_ht1, mode='replicate')
def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)
img = torch.from_numpy(img).permute(2, 0, 1).float()
return img
def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)[..., :3]
img = torch.from_numpy(img).permute(2, 0, 1).float()
return pad8(img[None]).to(DEVICE)
def load_image_list(image_files):
images = []
for imfile in sorted(image_files):
images.append(load_image(imfile))
images = torch.stack(images, dim=0)
images = images.to(DEVICE)
padder = InputPadder(images.shape)
return padder.pad(images)[0]
def display(image1, image2, flow):
image1 = image1.permute(1, 2, 0).cpu().numpy() / 255.0
image2 = image2.permute(1, 2, 0).cpu().numpy() / 255.0
def viz(img, flo):
img = img[0].permute(1,2,0).cpu().numpy()
flo = flo[0].permute(1,2,0).cpu().numpy()
flow = flow.permute(1, 2, 0).cpu().numpy()
flow_image = flow_viz.flow_to_image(flow)
flow_image = cv2.resize(flow_image, (image1.shape[1], image1.shape[0]))
# map flow to rgb image
flo = flow_viz.flow_to_image(flo)
img_flo = np.concatenate([img, flo], axis=0)
cv2.imshow('image1', image1[..., ::-1])
cv2.imshow('image2', image2[..., ::-1])
cv2.imshow('flow', flow_image[..., ::-1])
cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
cv2.waitKey()
def demo(args):
model = RAFT(args)
model = torch.nn.DataParallel(model)
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))
model = model.module
model.to(DEVICE)
model.eval()
with torch.no_grad():
images = glob.glob(os.path.join(args.path, '*.png')) + \
glob.glob(os.path.join(args.path, '*.jpg'))
# sintel images
image1 = load_image('images/sintel_0.png')
image2 = load_image('images/sintel_1.png')
images = load_image_list(images)
for i in range(images.shape[0]-1):
image1 = images[i,None]
image2 = images[i+1,None]
flow_predictions = model(image1, image2, iters=args.iters, upsample=False)
display(image1[0], image2[0], flow_predictions[-1][0])
# kitti images
image1 = load_image('images/kitti_0.png')
image2 = load_image('images/kitti_1.png')
flow_predictions = model(image1, image2, iters=16)
display(image1[0], image2[0], flow_predictions[-1][0])
# davis images
image1 = load_image('images/davis_0.jpg')
image2 = load_image('images/davis_1.jpg')
flow_predictions = model(image1, image2, iters=16)
display(image1[0], image2[0], flow_predictions[-1][0])
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
viz(image1, flow_up)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', help="restore checkpoint")
parser.add_argument('--path', help="dataset for evaluation")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--iters', type=int, default=12)
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
args = parser.parse_args()
demo(args)

3
download_models.sh Executable file
View File

@ -0,0 +1,3 @@
#!/bin/bash
wget https://www.dropbox.com/s/npt24nvhoojdr0n/models.zip
unzip models.zip

View File

@ -2,7 +2,6 @@ import sys
sys.path.append('core')
from PIL import Image
import cv2
import argparse
import os
import time
@ -13,88 +12,185 @@ import matplotlib.pyplot as plt
import datasets
from utils import flow_viz
from utils import frame_utils
from raft import RAFT
from utils.utils import InputPadder, forward_interpolate
def validate_sintel(args, model, iters=50):
""" Evaluate trained model on Sintel(train) clean + final passes """
@torch.no_grad()
def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'):
""" Create submission for the Sintel leaderboard """
model.eval()
pad = 2
for dstype in ['clean', 'final']:
val_dataset = datasets.MpiSintel(args, do_augument=False, dstype=dstype)
test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype)
epe_list = []
for i in range(len(val_dataset)):
image1, image2, flow_gt, _ = val_dataset[i]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
image1 = F.pad(image1, [0, 0, pad, pad], mode='replicate')
image2 = F.pad(image2, [0, 0, pad, pad], mode='replicate')
flow_prev, sequence_prev = None, None
for test_id in range(len(test_dataset)):
image1, image2, (sequence, frame) = test_dataset[test_id]
if sequence != sequence_prev:
flow_prev = None
with torch.no_grad():
flow_predictions = model.module(image1, image2, iters=iters)
flow_pr = flow_predictions[-1][0,:,pad:-pad]
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
epe = torch.sum((flow_pr - flow_gt.cuda())**2, dim=0)
epe = torch.sqrt(epe).mean()
epe_list.append(epe.item())
flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True)
flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
print("Validation (%s) EPE: %f" % (dstype, np.mean(epe_list)))
if warm_start:
flow_prev = forward_interpolate(flow_low[0])[None].cuda()
output_dir = os.path.join(output_path, dstype, sequence)
output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
frame_utils.writeFlow(output_file, flow)
sequence_prev = sequence
def validate_kitti(args, model, iters=32):
""" Evaluate trained model on KITTI (train) """
@torch.no_grad()
def create_kitti_submission(model, iters=24, output_path='kitti_submission'):
""" Create submission for the Sintel leaderboard """
model.eval()
val_dataset = datasets.KITTI(args, do_augument=False, is_val=True, do_pad=True)
test_dataset = datasets.KITTI(split='testing', aug_params=None)
with torch.no_grad():
epe_list, out_list = [], []
for i in range(len(val_dataset)):
image1, image2, flow_gt, valid_gt = val_dataset[i]
if not os.path.exists(output_path):
os.makedirs(output_path)
for test_id in range(len(test_dataset)):
image1, image2, (frame_id, ) = test_dataset[test_id]
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
output_filename = os.path.join(output_path, frame_id)
frame_utils.writeFlowKITTI(output_filename, flow)
@torch.no_grad()
def validate_chairs(model, iters=24):
""" Perform evaluation on the FlyingChairs (test) split """
model.eval()
epe_list = []
val_dataset = datasets.FlyingChairs(split='validation')
for val_id in range(len(val_dataset)):
image1, image2, flow_gt, _ = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt()
epe_list.append(epe.view(-1).numpy())
epe = np.mean(np.concatenate(epe_list))
print("Validation Chairs EPE: %f" % epe)
return {'chairs': epe}
@torch.no_grad()
def validate_sintel(model, iters=32):
""" Peform validation using the Sintel (train) split """
model.eval()
results = {}
for dstype in ['clean', 'final']:
val_dataset = datasets.MpiSintel(split='training', dstype=dstype)
epe_list = []
for val_id in range(len(val_dataset)):
image1, image2, flow_gt, _ = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
flow_gt = flow_gt.cuda()
valid_gt = valid_gt.cuda()
flow_predictions = model.module(image1, image2, iters=iters)
flow_pr = flow_predictions[-1][0]
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
mag = torch.sum(flow_gt**2, dim=0).sqrt()
flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
flow = padder.unpad(flow_pr[0]).cpu()
epe = epe.view(-1)
mag = mag.view(-1)
val = valid_gt.view(-1) >= 0.5
epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
epe_list.append(epe.view(-1).numpy())
out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
epe_list.append(epe[val].mean().item())
out_list.append(out[val].cpu().numpy())
epe_all = np.concatenate(epe_list)
epe = np.mean(epe_all)
px1 = np.mean(epe_all<1)
px3 = np.mean(epe_all<3)
px5 = np.mean(epe_all<5)
print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
results[dstype] = np.mean(epe_list)
return results
@torch.no_grad()
def validate_kitti(model, iters=24):
""" Peform validation using the KITTI-2015 (train) split """
model.eval()
val_dataset = datasets.KITTI(split='training')
out_list, epe_list = [], []
for val_id in range(len(val_dataset)):
image1, image2, flow_gt, valid_gt = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
flow = padder.unpad(flow_pr[0]).cpu()
epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
mag = torch.sum(flow_gt**2, dim=0).sqrt()
epe = epe.view(-1)
mag = mag.view(-1)
val = valid_gt.view(-1) >= 0.5
out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
epe_list.append(epe[val].mean().item())
out_list.append(out[val].cpu().numpy())
epe_list = np.array(epe_list)
out_list = np.concatenate(out_list)
epe = np.mean(epe_list)
f1 = 100 * np.mean(out_list)
print("Validation KITTI: %f, %f" % (np.mean(epe_list), 100*np.mean(out_list)))
print("Validation KITTI: %f, %f" % (epe, f1))
return {'kitti-epe': epe, 'kitti-f1': f1}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', help="restore checkpoint")
parser.add_argument('--dataset', help="dataset for evaluation")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--sintel_iters', type=int, default=50)
parser.add_argument('--kitti_iters', type=int, default=32)
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
args = parser.parse_args()
model = RAFT(args)
model = torch.nn.DataParallel(model)
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))
model.to('cuda')
model.cuda()
model.eval()
validate_sintel(args, model, args.sintel_iters)
validate_kitti(args, model, args.kitti_iters)
# create_sintel_submission(model.module, warm_start=True)
# create_kitti_submission(model.module)
with torch.no_grad():
if args.dataset == 'chairs':
validate_chairs(model.module)
elif args.dataset == 'sintel':
validate_sintel(model.module)
elif args.dataset == 'kitti':
validate_kitti(model.module)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 497 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 514 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 829 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 822 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 396 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 388 KiB

View File

@ -1,3 +0,0 @@
#!/bin/bash
wget https://www.dropbox.com/s/a2acvmczgzm6f9n/models.zip
unzip models.zip

159
train.py Executable file → Normal file
View File

@ -16,26 +16,43 @@ import torch.nn.functional as F
from torch.utils.data import DataLoader
from raft import RAFT
from evaluate import validate_sintel, validate_kitti
import evaluate
import datasets
from torch.utils.tensorboard import SummaryWriter
try:
from torch.cuda.amp import GradScaler
except:
# dummy GradScaler for PyTorch < 1.6
class GradScaler:
def __init__(self):
pass
def scale(self, loss):
return loss
def unscale_(self, optimizer):
pass
def step(self, optimizer):
optimizer.step()
def update(self):
pass
# exclude extremly large displacements
MAX_FLOW = 1000
SUM_FREQ = 200
MAX_FLOW = 500
SUM_FREQ = 100
VAL_FREQ = 5000
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def sequence_loss(flow_preds, flow_gt, valid):
def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
""" Loss function defined over sequence of flow predictions """
n_predictions = len(flow_preds)
flow_loss = 0.0
# exlude invalid pixels and extremely large diplacements
valid = (valid >= 0.5) & (flow_gt.abs().sum(dim=1) < MAX_FLOW)
mag = torch.sum(flow_gt**2, dim=1).sqrt()
valid = (valid >= 0.5) & (mag < max_flow)
for i in range(n_predictions):
i_weight = 0.8**(n_predictions - i - 1)
@ -54,39 +71,22 @@ def sequence_loss(flow_preds, flow_gt, valid):
return flow_loss, metrics
def show_image(img):
img = img.permute(1,2,0).cpu().numpy()
plt.imshow(img/255.0)
plt.show()
# cv2.imshow('image', img/255.0)
# cv2.waitKey()
def fetch_dataloader(args):
""" Create the data loader for the corresponding training set """
if args.dataset == 'chairs':
train_dataset = datasets.FlyingChairs(args, image_size=args.image_size)
elif args.dataset == 'things':
clean_dataset = datasets.SceneFlow(args, image_size=args.image_size, dstype='frames_cleanpass')
final_dataset = datasets.SceneFlow(args, image_size=args.image_size, dstype='frames_finalpass')
train_dataset = clean_dataset + final_dataset
elif args.dataset == 'sintel':
clean_dataset = datasets.MpiSintel(args, image_size=args.image_size, dstype='clean')
final_dataset = datasets.MpiSintel(args, image_size=args.image_size, dstype='final')
train_dataset = clean_dataset + final_dataset
elif args.dataset == 'kitti':
train_dataset = datasets.KITTI(args, image_size=args.image_size, is_val=False)
gpuargs = {'num_workers': 4, 'drop_last' : True}
train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=True, shuffle=True, **gpuargs)
print('Training with %d image pairs' % len(train_dataset))
return train_loader
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def fetch_optimizer(args, model):
""" Create the optimizer and learning rate scheduler """
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps,
pct_start=0.2, cycle_momentum=False, anneal_strategy='linear')
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')
return optimizer, scheduler
@ -97,17 +97,22 @@ class Logger:
self.scheduler = scheduler
self.total_steps = 0
self.running_loss = {}
self.writer = None
def _print_training_status(self):
metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())]
training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_lr()[0])
training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0])
metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
# print the training status
print(training_str + metrics_str)
for key in self.running_loss:
self.running_loss[key] = 0.0
if self.writer is None:
self.writer = SummaryWriter()
for k in self.running_loss:
self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps)
self.running_loss[k] = 0.0
def push(self, metrics):
self.total_steps += 1
@ -122,56 +127,95 @@ class Logger:
self._print_training_status()
self.running_loss = {}
def write_dict(self, results):
if self.writer is None:
self.writer = SummaryWriter()
for key in results:
self.writer.add_scalar(key, results[key], self.total_steps)
def close(self):
self.writer.close()
def train(args):
model = RAFT(args)
model = nn.DataParallel(model)
model = nn.DataParallel(RAFT(args), device_ids=args.gpus)
print("Parameter Count: %d" % count_parameters(model))
if args.restore_ckpt is not None:
model.load_state_dict(torch.load(args.restore_ckpt))
model.load_state_dict(torch.load(args.restore_ckpt), strict=False)
model.cuda()
model.train()
if 'chairs' not in args.dataset:
if args.stage != 'chairs':
model.module.freeze_bn()
train_loader = fetch_dataloader(args)
train_loader = datasets.fetch_dataloader(args)
optimizer, scheduler = fetch_optimizer(args, model)
total_steps = 0
scaler = GradScaler(enabled=args.mixed_precision)
logger = Logger(model, scheduler)
VAL_FREQ = 5000
add_noise = True
should_keep_training = True
while should_keep_training:
for i_batch, data_blob in enumerate(train_loader):
optimizer.zero_grad()
image1, image2, flow, valid = [x.cuda() for x in data_blob]
optimizer.zero_grad()
# show_image(image1[0])
# show_image(image2[0])
if args.add_noise:
stdv = np.random.uniform(0.0, 5.0)
image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)
flow_predictions = model(image1, image2, iters=args.iters)
loss, metrics = sequence_loss(flow_predictions, flow, valid)
loss.backward()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()
scheduler.step()
total_steps += 1
scaler.step(optimizer)
scheduler.step()
scaler.update()
logger.push(metrics)
if total_steps % VAL_FREQ == VAL_FREQ-1:
if total_steps % VAL_FREQ == VAL_FREQ - 1:
PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name)
torch.save(model.state_dict(), PATH)
if total_steps == args.num_steps:
results = {}
for val_dataset in args.validation:
if val_dataset == 'chairs':
results.update(evaluate.validate_chairs(model.module))
elif val_dataset == 'sintel':
results.update(evaluate.validate_sintel(model.module))
elif val_dataset == 'kitti':
results.update(evaluate.validate_kitti(model.module))
logger.write_dict(results)
model.train()
if args.stage != 'chairs':
model.module.freeze_bn()
total_steps += 1
if total_steps > args.num_steps:
should_keep_training = False
break
logger.close()
PATH = 'checkpoints/%s.pth' % args.name
torch.save(model.state_dict(), PATH)
@ -180,21 +224,25 @@ def train(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--name', default='bla', help="name your experiment")
parser.add_argument('--dataset', help="which dataset to use for training")
parser.add_argument('--name', default='raft', help="name your experiment")
parser.add_argument('--stage', help="determines which dataset to use for training")
parser.add_argument('--restore_ckpt', help="restore checkpoint")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--validation', type=str, nargs='+')
parser.add_argument('--lr', type=float, default=0.00002)
parser.add_argument('--num_steps', type=int, default=100000)
parser.add_argument('--batch_size', type=int, default=6)
parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--iters', type=int, default=12)
parser.add_argument('--wdecay', type=float, default=.00005)
parser.add_argument('--epsilon', type=float, default=1e-8)
parser.add_argument('--clip', type=float, default=1.0)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--add_noise', action='store_true')
args = parser.parse_args()
torch.manual_seed(1234)
@ -203,9 +251,4 @@ if __name__ == '__main__':
if not os.path.isdir('checkpoints'):
os.mkdir('checkpoints')
# scale learning rate and batch size by number of GPUs
num_gpus = torch.cuda.device_count()
args.batch_size = args.batch_size * num_gpus
args.lr = args.lr * num_gpus
train(args)