added upsampling module

This commit is contained in:
Zach Teed 2020-07-25 17:36:17 -06:00
parent dc1220825d
commit a2408eab78
32 changed files with 23559 additions and 619 deletions

111
README.md
View File

@ -1,7 +1,4 @@
# RAFT # RAFT
**7/22/2020: We have updated our method to predict flow at full resolution leading to improved results on public benchmarks. This repository will be updated to reflect these changes within the next few days.**
This repository contains the source code for our paper: This repository contains the source code for our paper:
[RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)<br/> [RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)<br/>
@ -11,90 +8,72 @@ Zachary Teed and Jia Deng<br/>
<img src="RAFT.png"> <img src="RAFT.png">
## Requirements ## Requirements
Our code was tested using PyTorch 1.3.1 and Python 3. The following additional packages need to be installed The code has been tested with PyTorch 1.5.1 and PyTorch Nightly. If you want to train with mixed precision, you will have to install the nightly build.
```Shell
```Shell conda create --name raft
pip install Pillow conda activate raft
pip install scipy conda install pytorch torchvision cudatoolkit=10.1 -c pytorch-nightly
pip install opencv-python conda install matplotlib
``` conda install tensorboard
conda install scipy
conda install opencv
```
## Demos ## Demos
Pretrained models can be downloaded by running Pretrained models can be downloaded by running
```Shell ```Shell
./scripts/download_models.sh ./scripts/download_models.sh
``` ```
or downloaded from [google drive](https://drive.google.com/file/d/10-BYgHqRNPGvmNUWr8razjb1xHu55pyA/view?usp=sharing)
You can run the demos using one of the available models. You can demo a trained model on a sequence of frames
```Shell ```Shell
python demo.py --model=models/chairs+things.pth python demo.py --model=models/raft-things.pth --path=demo-frames
``` ```
or using the small (1M parameter) model
```Shell ## Required Data
python demo.py --model=models/small.pth --small To evaluate/train RAFT, you will need to download the required datasets.
``` * [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
* [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
* [Sintel](http://sintel.is.tue.mpg.de/)
* [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow)
* [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional)
Running the demos will display the two images and a vizualization of the optical flow estimate. After the images display, press any key to continue.
## Training By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder
To train RAFT, you will need to download the required datasets. The first stage of training requires the [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) and [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) datasets. Finetuning and evaluation require the [Sintel](http://sintel.is.tue.mpg.de/) and [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) datasets. We organize the directory structure as follows. By default `datasets.py` will search for the datasets in these locations
```Shell ```Shell
├── datasets ├── datasets
├── Sintel ├── Sintel
| | ├── test ├── test
| | ├── training ├── training
├── KITTI ├── KITTI
| | ├── testing ├── testing
| | ├── training ├── training
| | ├── devkit ├── devkit
├── FlyingChairs_release ├── FlyingChairs_release
| | ├── data ├── data
├── FlyingThings3D ├── FlyingThings3D
| | ├── frames_cleanpass ├── frames_cleanpass
| | ├── frames_finalpass ├── frames_finalpass
| | ├── optical_flow ├── optical_flow
``` ```
We used the following training schedule in our paper (note: we use 2 GPUs for training)
```Shell
python train.py --name=chairs --image_size 368 496 --dataset=chairs --num_steps=100000 --lr=0.0002 --batch_size=6
```
Next, finetune on the FlyingThings dataset
```Shell
python train.py --name=things --image_size 368 768 --dataset=things --num_steps=60000 --lr=0.00005 --batch_size=3 --restore_ckpt=checkpoints/chairs.pth
```
You can perform dataset specific finetuning
### Sintel
```Shell
python train.py --name=sintel_ft --image_size 368 768 --dataset=sintel --num_steps=60000 --lr=0.00005 --batch_size=4 --restore_ckpt=checkpoints/things.pth
```
### KITTI
```Shell
python train.py --name=kitti_ft --image_size 288 896 --dataset=kitti --num_steps=40000 --lr=0.0001 --batch_size=4 --restore_ckpt=checkpoints/things.pth
```
## Evaluation ## Evaluation
You can evaluate a model on Sintel and KITTI by running You can evaluate a trained model using `evaluate.py`
```Shell ```Shell
python evaluate.py --model=models/chairs+things.pth python evaluate.py --model=models/raft-things.pth --dataset=sintel
``` ```
or the small model by including the `small` flag ## Training
Training code will be made available in the next few days
<!-- We used the following training schedule in our paper (note: we use 2 GPUs for training). Training logs will be written to the `runs` which can be visualized using tensorboard
```Shell ```Shell
python evaluate.py --model=models/small.pth --small ./train_standard.sh
``` ```
If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU)
```Shell
./train_mixed.sh
``` -->

22872
chairs_split.txt Normal file

File diff suppressed because it is too large Load Diff

View File

@ -2,6 +2,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from utils.utils import bilinear_sampler, coords_grid from utils.utils import bilinear_sampler, coords_grid
class CorrBlock: class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4): def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels self.num_levels = num_levels
@ -12,10 +13,10 @@ class CorrBlock:
corr = CorrBlock.corr(fmap1, fmap2) corr = CorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.view(batch*h1*w1, dim, h2, w2) corr = corr.reshape(batch*h1*w1, dim, h2, w2)
self.corr_pyramid.append(corr) self.corr_pyramid.append(corr)
for i in range(self.num_levels): for i in range(self.num_levels-1):
corr = F.avg_pool2d(corr, 2, stride=2) corr = F.avg_pool2d(corr, 2, stride=2)
self.corr_pyramid.append(corr) self.corr_pyramid.append(corr)
@ -40,14 +41,16 @@ class CorrBlock:
out_pyramid.append(corr) out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1) out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2) return out.permute(0, 3, 1, 2).contiguous().float()
@staticmethod @staticmethod
def corr(fmap1, fmap2): def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht*wd) fmap1 = fmap1.view(batch, dim, ht*wd)
fmap2 = fmap2.view(batch, dim, ht*wd) fmap2 = fmap2.view(batch, dim, ht*wd)
corr = torch.matmul(fmap1.transpose(1,2), fmap2) corr = torch.matmul(fmap1.transpose(1,2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd) corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float()) return corr / torch.sqrt(torch.tensor(dim).float())

View File

@ -6,53 +6,42 @@ import torch.utils.data as data
import torch.nn.functional as F import torch.nn.functional as F
import os import os
import cv2
import math import math
import random import random
from glob import glob from glob import glob
import os.path as osp import os.path as osp
from utils import frame_utils from utils import frame_utils
from utils.augmentor import FlowAugmentor, FlowAugmentorKITTI from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
class CombinedDataset(data.Dataset):
def __init__(self, datasets):
self.datasets = datasets
def __len__(self):
length = 0
for i in range(len(self.datasets)):
length += len(self.datsaets[i])
return length
def __getitem__(self, index):
i = 0
for j in range(len(self.datasets)):
if i + len(self.datasets[j]) >= index:
yield self.datasets[j][index-i]
break
i += len(self.datasets[j])
def __add__(self, other):
self.datasets.append(other)
return self
class FlowDataset(data.Dataset): class FlowDataset(data.Dataset):
def __init__(self, args, image_size=None, do_augument=False): def __init__(self, aug_params=None, sparse=False):
self.image_size = image_size self.augmentor = None
self.do_augument = do_augument self.sparse = sparse
if aug_params is not None:
if self.do_augument: if sparse:
self.augumentor = FlowAugmentor(self.image_size) self.augmentor = SparseFlowAugmentor(**aug_params)
else:
self.augmentor = FlowAugmentor(**aug_params)
self.is_test = False
self.init_seed = False
self.flow_list = [] self.flow_list = []
self.image_list = [] self.image_list = []
self.extra_info = []
self.init_seed = False
def __getitem__(self, index): def __getitem__(self, index):
if self.is_test:
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
return img1, img2, self.extra_info[index]
if not self.init_seed: if not self.init_seed:
worker_info = torch.utils.data.get_worker_info() worker_info = torch.utils.data.get_worker_info()
if worker_info is not None: if worker_info is not None:
@ -62,133 +51,96 @@ class FlowDataset(data.Dataset):
self.init_seed = True self.init_seed = True
index = index % len(self.image_list) index = index % len(self.image_list)
flow = frame_utils.read_gen(self.flow_list[index]) valid = None
if self.sparse:
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
else:
flow = frame_utils.read_gen(self.flow_list[index])
img1 = frame_utils.read_gen(self.image_list[index][0]) img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1]) img2 = frame_utils.read_gen(self.image_list[index][1])
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
flow = np.array(flow).astype(np.float32) flow = np.array(flow).astype(np.float32)
img1 = np.array(img1).astype(np.uint8)
img2 = np.array(img2).astype(np.uint8)
if self.do_augument: # grayscale images
img1, img2, flow = self.augumentor(img1, img2, flow) if len(img1.shape) == 2:
img1 = np.tile(img1[...,None], (1, 1, 3))
img2 = np.tile(img2[...,None], (1, 1, 3))
else:
img1 = img1[..., :3]
img2 = img2[..., :3]
if self.augmentor is not None:
if self.sparse:
img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
else:
img1, img2, flow = self.augmentor(img1, img2, flow)
img1 = torch.from_numpy(img1).permute(2, 0, 1).float() img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float() img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
flow = torch.from_numpy(flow).permute(2, 0, 1).float() flow = torch.from_numpy(flow).permute(2, 0, 1).float()
valid = torch.ones_like(flow[0])
return img1, img2, flow, valid if valid is not None:
valid = torch.from_numpy(valid)
else:
valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
return img1, img2, flow, valid.float()
def __rmul__(self, v):
self.flow_list = v * self.flow_list
self.image_list = v * self.image_list
return self
def __len__(self): def __len__(self):
return len(self.image_list) return len(self.image_list)
def __add(self, other):
return CombinedDataset([self, other])
class MpiSintelTest(FlowDataset):
def __init__(self, args, root='datasets/Sintel/test', dstype='clean'):
super(MpiSintelTest, self).__init__(args, image_size=None, do_augument=False)
self.root = root
self.dstype = dstype
image_dir = osp.join(self.root, dstype)
all_sequences = os.listdir(image_dir)
self.image_list = []
for sequence in all_sequences:
frames = sorted(glob(osp.join(image_dir, sequence, '*.png')))
for i in range(len(frames)-1):
self.image_list += [[frames[i], frames[i+1], sequence, i]]
def __getitem__(self, index):
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
sequence = self.image_list[index][2]
frame = self.image_list[index][3]
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
return img1, img2, sequence, frame
class MpiSintel(FlowDataset): class MpiSintel(FlowDataset):
def __init__(self, args, image_size=None, do_augument=True, root='datasets/Sintel/training', dstype='clean'): def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
super(MpiSintel, self).__init__(args, image_size, do_augument) super(MpiSintel, self).__init__(aug_params)
if do_augument: flow_root = osp.join(root, split, 'flow')
self.augumentor.min_scale = -0.2 image_root = osp.join(root, split, dstype)
self.augumentor.max_scale = 0.7
self.root = root if split == 'test':
self.dstype = dstype self.is_test = True
flow_root = osp.join(root, 'flow') for scene in os.listdir(image_root):
image_root = osp.join(root, dstype) image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
for i in range(len(image_list)-1):
self.image_list += [ [image_list[i], image_list[i+1]] ]
self.extra_info += [ (scene, i) ] # scene and frame_id
file_list = sorted(glob(osp.join(flow_root, '*/*.flo'))) if split != 'test':
for flo in file_list: self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
fbase = flo[len(flow_root)+1:]
fprefix = fbase[:-8]
fnum = int(fbase[-8:-4])
img1 = osp.join(image_root, fprefix + "%04d"%(fnum+0) + '.png')
img2 = osp.join(image_root, fprefix + "%04d"%(fnum+1) + '.png')
if not osp.isfile(img1) or not osp.isfile(img2) or not osp.isfile(flo):
continue
self.image_list.append((img1, img2))
self.flow_list.append(flo)
class FlyingChairs(FlowDataset): class FlyingChairs(FlowDataset):
def __init__(self, args, image_size=None, do_augument=True, root='datasets/FlyingChairs_release/data'): def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
super(FlyingChairs, self).__init__(args, image_size, do_augument) super(FlyingChairs, self).__init__(aug_params)
self.root = root
self.augumentor.min_scale = -0.2
self.augumentor.max_scale = 1.0
images = sorted(glob(osp.join(root, '*.ppm'))) images = sorted(glob(osp.join(root, '*.ppm')))
self.flow_list = sorted(glob(osp.join(root, '*.flo'))) flows = sorted(glob(osp.join(root, '*.flo')))
assert (len(images)//2 == len(self.flow_list)) assert (len(images)//2 == len(flows))
self.image_list = [] split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
for i in range(len(self.flow_list)): for i in range(len(flows)):
im1 = images[2*i] xid = split_list[i]
im2 = images[2*i + 1] if (split=='training' and xid==1) or (split=='validation' and xid==2):
self.image_list.append([im1, im2]) self.flow_list += [ flows[i] ]
self.image_list += [ [images[2*i], images[2*i+1]] ]
class SceneFlow(FlowDataset): class FlyingThings3D(FlowDataset):
def __init__(self, args, image_size, do_augument=True, root='datasets', def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
dstype='frames_cleanpass', use_flyingthings=True, use_monkaa=False, use_driving=False): super(FlyingThings3D, self).__init__(aug_params)
super(SceneFlow, self).__init__(args, image_size, do_augument)
self.root = root
self.dstype = dstype
self.augumentor.min_scale = -0.2
self.augumentor.max_scale = 0.8
if use_flyingthings:
self.add_flyingthings()
if use_monkaa:
self.add_monkaa()
if use_driving:
self.add_driving()
def add_flyingthings(self):
root = osp.join(self.root, 'FlyingThings3D')
for cam in ['left']: for cam in ['left']:
for direction in ['into_future', 'into_past']: for direction in ['into_future', 'into_past']:
image_dirs = sorted(glob(osp.join(root, self.dstype, 'TRAIN/*/*'))) image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
@ -199,114 +151,85 @@ class SceneFlow(FlowDataset):
flows = sorted(glob(osp.join(fdir, '*.pfm')) ) flows = sorted(glob(osp.join(fdir, '*.pfm')) )
for i in range(len(flows)-1): for i in range(len(flows)-1):
if direction == 'into_future': if direction == 'into_future':
self.image_list += [[images[i], images[i+1]]] self.image_list += [ [images[i], images[i+1]] ]
self.flow_list += [flows[i]] self.flow_list += [ flows[i] ]
elif direction == 'into_past': elif direction == 'into_past':
self.image_list += [[images[i+1], images[i]]] self.image_list += [ [images[i+1], images[i]] ]
self.flow_list += [flows[i+1]] self.flow_list += [ flows[i+1] ]
def add_monkaa(self):
pass # we don't use monkaa
def add_driving(self):
pass # we don't use driving
class KITTI(FlowDataset): class KITTI(FlowDataset):
def __init__(self, args, image_size=None, do_augument=True, is_test=False, is_val=False, do_pad=False, split=True, root='datasets/KITTI'): def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
super(KITTI, self).__init__(args, image_size, do_augument) super(KITTI, self).__init__(aug_params, sparse=True)
self.root = root if split == 'testing':
self.is_test = is_test self.is_test = True
self.is_val = is_val
self.do_pad = do_pad
if self.do_augument: root = osp.join(root, split)
self.augumentor = FlowAugmentorKITTI(self.image_size, min_scale=-0.2, max_scale=0.5) images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
if self.is_test: for img1, img2 in zip(images1, images2):
images1 = sorted(glob(os.path.join(root, 'testing', 'image_2/*_10.png'))) frame_id = img1.split('/')[-1]
images2 = sorted(glob(os.path.join(root, 'testing', 'image_2/*_11.png'))) self.extra_info += [ [frame_id] ]
for i in range(len(images1)): self.image_list += [ [img1, img2] ]
self.image_list += [[images1[i], images2[i]]]
else: if split == 'training':
flows = sorted(glob(os.path.join(root, 'training', 'flow_occ/*_10.png'))) self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
images1 = sorted(glob(os.path.join(root, 'training', 'image_2/*_10.png')))
images2 = sorted(glob(os.path.join(root, 'training', 'image_2/*_11.png')))
for i in range(len(flows)):
class HD1K(FlowDataset):
def __init__(self, aug_params=None, root='datasets/HD1k'):
super(HD1K, self).__init__(aug_params, sparse=True)
seq_ix = 0
while 1:
flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
if len(flows) == 0:
break
for i in range(len(flows)-1):
self.flow_list += [flows[i]] self.flow_list += [flows[i]]
self.image_list += [[images1[i], images2[i]]] self.image_list += [ [images[i], images[i+1]] ]
seq_ix += 1
def __getitem__(self, index): def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
""" Create the data loader for the corresponding trainign set """
if self.is_test: if args.stage == 'chairs':
frame_id = self.image_list[index][0] aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 1.0, 'do_flip': True}
frame_id = frame_id.split('/')[-1] train_dataset = FlyingChairs(aug_params, split='training')
elif args.stage == 'things':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
train_dataset = clean_dataset + final_dataset
img1 = frame_utils.read_gen(self.image_list[index][0]) elif args.stage == 'sintel':
img2 = frame_utils.read_gen(self.image_list[index][1]) aug_params = {'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True}
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
sintel_final = MpiSintel(aug_params, split='training', dstype='final')
img1 = np.array(img1).astype(np.uint8)[..., :3] if TRAIN_DS == 'C+T+K+S+H':
img2 = np.array(img2).astype(np.uint8)[..., :3] kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True})
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.5, 'do_flip': True})
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
img1 = torch.from_numpy(img1).permute(2, 0, 1).float() elif TRAIN_DS == 'C+T+K/S':
img2 = torch.from_numpy(img2).permute(2, 0, 1).float() train_dataset = 100*sintel_clean + 100*sintel_final + things
return img1, img2, frame_id
elif args.stage == 'kitti':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
train_dataset = KITTI(args, image_size=args.image_size, is_val=False)
else: train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
if not self.init_seed: pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
np.random.seed(worker_info.id)
random.seed(worker_info.id)
self.init_seed = True
index = index % len(self.image_list) print('Training with %d image pairs' % len(train_dataset))
frame_id = self.image_list[index][0] return train_loader
frame_id = frame_id.split('/')[-1]
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
if self.do_augument:
img1, img2, flow, valid = self.augumentor(img1, img2, flow, valid)
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
valid = torch.from_numpy(valid).float()
if self.do_pad:
ht, wd = img1.shape[1:]
pad_ht = (((ht // 8) + 1) * 8 - ht) % 8
pad_wd = (((wd // 8) + 1) * 8 - wd) % 8
pad_ht1 = [0, pad_ht]
pad_wd1 = [pad_wd//2, pad_wd - pad_wd//2]
pad = pad_wd1 + pad_ht1
img1 = img1.view(1, 3, ht, wd)
img2 = img2.view(1, 3, ht, wd)
flow = flow.view(1, 2, ht, wd)
valid = valid.view(1, 1, ht, wd)
img1 = torch.nn.functional.pad(img1, pad, mode='replicate')
img2 = torch.nn.functional.pad(img2, pad, mode='replicate')
flow = torch.nn.functional.pad(flow, pad, mode='constant', value=0)
valid = torch.nn.functional.pad(valid, pad, mode='replicate', value=0)
img1 = img1.view(3, ht+pad_ht, wd+pad_wd)
img2 = img2.view(3, ht+pad_ht, wd+pad_wd)
flow = flow.view(2, ht+pad_ht, wd+pad_wd)
valid = valid.view(ht+pad_ht, wd+pad_wd)
if self.is_test:
return img1, img2, flow, valid, frame_id
return img1, img2, flow, valid

View File

@ -143,10 +143,9 @@ class BasicEncoder(nn.Module):
# output convolution # output convolution
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0: if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout) self.dropout = nn.Dropout2d(p=dropout)
else:
self.dropout = None
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
@ -184,7 +183,7 @@ class BasicEncoder(nn.Module):
x = self.conv2(x) x = self.conv2(x)
if self.dropout is not None: if self.training and self.dropout is not None:
x = self.dropout(x) x = self.dropout(x)
if is_list: if is_list:
@ -218,10 +217,9 @@ class SmallEncoder(nn.Module):
self.layer2 = self._make_layer(64, stride=2) self.layer2 = self._make_layer(64, stride=2)
self.layer3 = self._make_layer(96, stride=2) self.layer3 = self._make_layer(96, stride=2)
self.dropout = None
if dropout > 0: if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout) self.dropout = nn.Dropout2d(p=dropout)
else:
self.dropout = None
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
@ -260,8 +258,8 @@ class SmallEncoder(nn.Module):
x = self.layer3(x) x = self.layer3(x)
x = self.conv2(x) x = self.conv2(x)
# if self.dropout is not None: if self.training and self.dropout is not None:
# x = self.dropout(x) x = self.dropout(x)
if is_list: if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0) x = torch.split(x, [batch_dim, batch_dim], dim=0)

View File

@ -3,11 +3,23 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from modules.update import BasicUpdateBlock, SmallUpdateBlock from update import BasicUpdateBlock, SmallUpdateBlock
from modules.extractor import BasicEncoder, SmallEncoder from extractor import BasicEncoder, SmallEncoder
from modules.corr import CorrBlock from corr import CorrBlock
from utils.utils import bilinear_sampler, coords_grid, upflow8 from utils.utils import bilinear_sampler, coords_grid, upflow8
try:
autocast = torch.cuda.amp.autocast
except:
# dummy autocast for PyTorch < 1.6
class autocast:
def __init__(self, enabled):
pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
class RAFT(nn.Module): class RAFT(nn.Module):
def __init__(self, args): def __init__(self, args):
@ -26,7 +38,7 @@ class RAFT(nn.Module):
args.corr_levels = 4 args.corr_levels = 4
args.corr_radius = 4 args.corr_radius = 4
if not hasattr(args, 'dropout'): if 'dropout' not in args._get_kwargs():
args.dropout = 0 args.dropout = 0
# feature network, context network, and update block # feature network, context network, and update block
@ -40,6 +52,7 @@ class RAFT(nn.Module):
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
def freeze_bn(self): def freeze_bn(self):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.BatchNorm2d): if isinstance(m, nn.BatchNorm2d):
@ -54,46 +67,73 @@ class RAFT(nn.Module):
# optical flow computed as difference: flow = coords1 - coords0 # optical flow computed as difference: flow = coords1 - coords0
return coords0, coords1 return coords0, coords1
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True): def upsample_flow(self, flow, mask):
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
N, _, H, W = flow.shape
mask = mask.view(N, 1, 9, 8, 8, H, W)
mask = torch.softmax(mask, dim=2)
up_flow = F.unfold(8 * flow, [3,3], padding=1)
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
up_flow = torch.sum(mask * up_flow, dim=2)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
return up_flow.reshape(N, 2, 8*H, 8*W)
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
""" Estimate optical flow between pair of frames """ """ Estimate optical flow between pair of frames """
image1 = 2 * (image1 / 255.0) - 1.0 image1 = 2 * (image1 / 255.0) - 1.0
image2 = 2 * (image2 / 255.0) - 1.0 image2 = 2 * (image2 / 255.0) - 1.0
image1 = image1.contiguous()
image2 = image2.contiguous()
hdim = self.hidden_dim hdim = self.hidden_dim
cdim = self.context_dim cdim = self.context_dim
# run the feature network # run the feature network
fmap1, fmap2 = self.fnet([image1, image2]) with autocast(enabled=self.args.mixed_precision):
fmap1, fmap2 = self.fnet([image1, image2])
fmap1 = fmap1.float()
fmap2 = fmap2.float()
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
# run the context network # run the context network
cnet = self.cnet(image1) with autocast(enabled=self.args.mixed_precision):
net, inp = torch.split(cnet, [hdim, cdim], dim=1) cnet = self.cnet(image1)
net, inp = torch.tanh(net), torch.relu(inp) net, inp = torch.split(cnet, [hdim, cdim], dim=1)
net = torch.tanh(net)
inp = torch.relu(inp)
# if dropout is being used reset mask
self.update_block.reset_mask(net, inp)
coords0, coords1 = self.initialize_flow(image1) coords0, coords1 = self.initialize_flow(image1)
if flow_init is not None:
coords1 = coords1 + flow_init
flow_predictions = [] flow_predictions = []
for itr in range(iters): for itr in range(iters):
coords1 = coords1.detach() coords1 = coords1.detach()
corr = corr_fn(coords1) # index correlation volume corr = corr_fn(coords1) # index correlation volume
flow = coords1 - coords0 flow = coords1 - coords0
net, delta_flow = self.update_block(net, inp, corr, flow) with autocast(enabled=self.args.mixed_precision):
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
# F(t+1) = F(t) + \Delta(t) # F(t+1) = F(t) + \Delta(t)
coords1 = coords1 + delta_flow coords1 = coords1 + delta_flow
if upsample: # upsample predictions
if up_mask is None:
flow_up = upflow8(coords1 - coords0) flow_up = upflow8(coords1 - coords0)
flow_predictions.append(flow_up)
else: else:
flow_predictions.append(coords1 - coords0) flow_up = self.upsample_flow(coords1 - coords0, up_mask)
flow_predictions.append(flow_up)
if test_mode:
return coords1 - coords0, flow_up
return flow_predictions return flow_predictions

View File

@ -2,34 +2,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
# VariationalHidDropout from https://github.com/locuslab/trellisnet/tree/master/TrellisNet
class VariationalHidDropout(nn.Module):
def __init__(self, dropout=0.0):
"""
Hidden-to-hidden (VD-based) dropout that applies the same mask at every time step and every layer of TrellisNet
:param dropout: The dropout rate (0 means no dropout is applied)
"""
super(VariationalHidDropout, self).__init__()
self.dropout = dropout
self.mask = None
def reset_mask(self, x):
dropout = self.dropout
# Dimension (N, C, L)
n, c, h, w = x.shape
m = x.data.new(n, c, 1, 1).bernoulli_(1 - dropout)
with torch.no_grad():
mask = m / (1 - dropout)
self.mask = mask
return mask
def forward(self, x):
if not self.training or self.dropout == 0:
return x
assert self.mask is not None, "You need to reset mask before using VariationalHidDropout"
return self.mask * x
class FlowHead(nn.Module): class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256): def __init__(self, input_dim=128, hidden_dim=256):
@ -41,7 +13,6 @@ class FlowHead(nn.Module):
def forward(self, x): def forward(self, x):
return self.conv2(self.relu(self.conv1(x))) return self.conv2(self.relu(self.conv1(x)))
class ConvGRU(nn.Module): class ConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128): def __init__(self, hidden_dim=128, input_dim=192+128):
super(ConvGRU, self).__init__() super(ConvGRU, self).__init__()
@ -59,7 +30,6 @@ class ConvGRU(nn.Module):
h = (1-z) * h + z * q h = (1-z) * h + z * q
return h return h
class SepConvGRU(nn.Module): class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128): def __init__(self, hidden_dim=128, input_dim=192+128):
super(SepConvGRU, self).__init__() super(SepConvGRU, self).__init__()
@ -133,49 +103,37 @@ class SmallUpdateBlock(nn.Module):
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128) self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
self.drop_inp = VariationalHidDropout(dropout=args.dropout)
self.drop_net = VariationalHidDropout(dropout=args.dropout)
def reset_mask(self, net, inp):
self.drop_inp.reset_mask(inp)
self.drop_net.reset_mask(net)
def forward(self, net, inp, corr, flow): def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr) motion_features = self.encoder(flow, corr)
if self.training:
net = self.drop_net(net)
inp = self.drop_inp(inp)
inp = torch.cat([inp, motion_features], dim=1) inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp) net = self.gru(net, inp)
delta_flow = self.flow_head(net) delta_flow = self.flow_head(net)
return net, delta_flow return net, None, delta_flow
class BasicUpdateBlock(nn.Module): class BasicUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128, input_dim=128): def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__() super(BasicUpdateBlock, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args) self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256) self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.drop_inp = VariationalHidDropout(dropout=args.dropout) self.mask = nn.Sequential(
self.drop_net = VariationalHidDropout(dropout=args.dropout) nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64*9, 1, padding=0))
def reset_mask(self, net, inp): def forward(self, net, inp, corr, flow, upsample=True):
self.drop_inp.reset_mask(inp)
self.drop_net.reset_mask(net)
def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr) motion_features = self.encoder(flow, corr)
if self.training:
net = self.drop_net(net)
inp = self.drop_inp(inp)
inp = torch.cat([inp, motion_features], dim=1) inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp) net = self.gru(net, inp)
delta_flow = self.flow_head(net) delta_flow = self.flow_head(net)
return net, delta_flow # scale mask to balence gradients
mask = .25 * self.mask(net)
return net, mask, delta_flow

View File

@ -1,46 +1,55 @@
import numpy as np import numpy as np
import random import random
import math import math
import cv2
from PIL import Image from PIL import Image
import cv2
import torch import torch
import torchvision
from torchvision.transforms import ColorJitter
import torch.nn.functional as F import torch.nn.functional as F
class FlowAugmentor: class FlowAugmentor:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5): def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
# spatial augmentation params
self.crop_size = crop_size self.crop_size = crop_size
self.augcolor = torchvision.transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.5/3.14)
self.asymmetric_color_aug_prob = 0.2
self.spatial_aug_prob = 0.8
self.eraser_aug_prob = 0.5
self.min_scale = min_scale self.min_scale = min_scale
self.max_scale = max_scale self.max_scale = max_scale
self.max_stretch = 0.2 self.spatial_aug_prob = 0.8
self.stretch_prob = 0.8 self.stretch_prob = 0.8
self.margin = 20 self.max_stretch = 0.2
# flip augmentation params
self.do_flip = do_flip
self.h_flip_prob = 0.5
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
def color_transform(self, img1, img2): def color_transform(self, img1, img2):
""" Photometric augmentation """
# asymmetric
if np.random.rand() < self.asymmetric_color_aug_prob: if np.random.rand() < self.asymmetric_color_aug_prob:
img1 = np.array(self.augcolor(Image.fromarray(img1)), dtype=np.uint8) img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
img2 = np.array(self.augcolor(Image.fromarray(img2)), dtype=np.uint8) img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
# symmetric
else: else:
image_stack = np.concatenate([img1, img2], axis=0) image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(self.augcolor(Image.fromarray(image_stack)), dtype=np.uint8) image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
img1, img2 = np.split(image_stack, 2, axis=0) img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2 return img1, img2
def eraser_transform(self, img1, img2, bounds=[50, 100]): def eraser_transform(self, img1, img2, bounds=[50, 100]):
""" Occlusion augmentation """
ht, wd = img1.shape[:2] ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob: if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0) mean_color = np.mean(img2.reshape(-1, 3), axis=0)
@ -55,22 +64,18 @@ class FlowAugmentor:
def spatial_transform(self, img1, img2, flow): def spatial_transform(self, img1, img2, flow):
# randomly sample scale # randomly sample scale
ht, wd = img1.shape[:2] ht, wd = img1.shape[:2]
min_scale = np.maximum( min_scale = np.maximum(
(self.crop_size[0] + 1) / float(ht), (self.crop_size[0] + 8) / float(ht),
(self.crop_size[1] + 1) / float(wd)) (self.crop_size[1] + 8) / float(wd))
max_scale = self.max_scale
min_scale = max(min_scale, self.min_scale)
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = scale scale_x = scale
scale_y = scale scale_y = scale
if np.random.rand() < self.stretch_prob: if np.random.rand() < self.stretch_prob:
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_x = np.clip(scale_x, min_scale, None) scale_x = np.clip(scale_x, min_scale, None)
scale_y = np.clip(scale_y, min_scale, None) scale_y = np.clip(scale_y, min_scale, None)
@ -81,22 +86,20 @@ class FlowAugmentor:
flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow = flow * [scale_x, scale_y] flow = flow * [scale_x, scale_y]
if np.random.rand() < 0.5: # h-flip if self.do_flip:
img1 = img1[:, ::-1] if np.random.rand() < self.h_flip_prob: # h-flip
img2 = img2[:, ::-1] img1 = img1[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0] img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
if np.random.rand() < 0.1: # v-flip if np.random.rand() < self.v_flip_prob: # v-flip
img1 = img1[::-1, :] img1 = img1[::-1, :]
img2 = img2[::-1, :] img2 = img2[::-1, :]
flow = flow[::-1, :] * [1.0, -1.0] flow = flow[::-1, :] * [1.0, -1.0]
y0 = np.random.randint(-self.margin, img1.shape[0] - self.crop_size[0] + self.margin) y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
x0 = np.random.randint(-self.margin, img1.shape[1] - self.crop_size[1] + self.margin) x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
@ -114,22 +117,29 @@ class FlowAugmentor:
return img1, img2, flow return img1, img2, flow
class SparseFlowAugmentor:
class FlowAugmentorKITTI: def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5): # spatial augmentation params
self.crop_size = crop_size self.crop_size = crop_size
self.augcolor = torchvision.transforms.ColorJitter(
brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
self.max_scale = max_scale
self.min_scale = min_scale self.min_scale = min_scale
self.max_scale = max_scale
self.spatial_aug_prob = 0.8 self.spatial_aug_prob = 0.8
self.stretch_prob = 0.8
self.max_stretch = 0.2
# flip augmentation params
self.do_flip = do_flip
self.h_flip_prob = 0.5
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5 self.eraser_aug_prob = 0.5
def color_transform(self, img1, img2): def color_transform(self, img1, img2):
image_stack = np.concatenate([img1, img2], axis=0) image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(self.augcolor(Image.fromarray(image_stack)), dtype=np.uint8) image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
img1, img2 = np.split(image_stack, 2, axis=0) img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2 return img1, img2
@ -198,11 +208,12 @@ class FlowAugmentorKITTI:
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
if np.random.rand() < 0.5: # h-flip if self.do_flip:
img1 = img1[:, ::-1] if np.random.rand() < 0.5: # h-flip
img2 = img2[:, ::-1] img1 = img1[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0] img2 = img2[:, ::-1]
valid = valid[:, ::-1] flow = flow[:, ::-1] * [-1.0, 1.0]
valid = valid[:, ::-1]
margin_y = 20 margin_y = 20
margin_x = 50 margin_x = 50

View File

@ -103,6 +103,13 @@ def readFlowKITTI(filename):
flow = (flow - 2**15) / 64.0 flow = (flow - 2**15) / 64.0
return flow, valid return flow, valid
def readDispKITTI(filename):
disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
valid = disp > 0.0
flow = np.stack([-disp, np.zeros_like(disp)], -1)
return flow, valid
def writeFlowKITTI(filename, uv): def writeFlowKITTI(filename, uv):
uv = 64.0 * uv + 2**15 uv = 64.0 * uv + 2**15
valid = np.ones([uv.shape[0], uv.shape[1], 1]) valid = np.ones([uv.shape[0], uv.shape[1], 1])
@ -120,5 +127,8 @@ def read_gen(file_name, pil=False):
return readFlow(file_name).astype(np.float32) return readFlow(file_name).astype(np.float32)
elif ext == '.pfm': elif ext == '.pfm':
flow = readPFM(file_name).astype(np.float32) flow = readPFM(file_name).astype(np.float32)
return flow[:, :, :-1] if len(flow.shape) == 2:
return flow
else:
return flow[:, :, :-1]
return [] return []

View File

@ -4,21 +4,21 @@ import numpy as np
from scipy import interpolate from scipy import interpolate
def bilinear_sampler(img, coords, mode='bilinear', mask=False): class InputPadder:
""" Wrapper for grid_sample, uses pixel coordinates """ """ Pads images such that dimensions are divisible by 8 """
H, W = img.shape[-2:] def __init__(self, dims):
xgrid, ygrid = coords.split([1,1], dim=-1) self.ht, self.wd = dims[-2:]
xgrid = 2*xgrid/(W-1) - 1 pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
ygrid = 2*ygrid/(H-1) - 1 pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
grid = torch.cat([xgrid, ygrid], dim=-1) def pad(self, *inputs):
img = F.grid_sample(img, grid, align_corners=True) return [F.pad(x, self._pad, mode='replicate') for x in inputs]
if mask: def unpad(self,x):
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) ht, wd = x.shape[-2:]
return img, mask.float() c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
return x[..., c[0]:c[1], c[2]:c[3]]
return img
def forward_interpolate(flow): def forward_interpolate(flow):
flow = flow.detach().cpu().numpy() flow = flow.detach().cpu().numpy()
@ -42,15 +42,33 @@ def forward_interpolate(flow):
dy = dy[valid] dy = dy[valid]
flow_x = interpolate.griddata( flow_x = interpolate.griddata(
(x1, y1), dx, (x0, y0), method='nearest') (x1, y1), dx, (x0, y0), method='cubic', fill_value=0)
flow_y = interpolate.griddata( flow_y = interpolate.griddata(
(x1, y1), dy, (x0, y0), method='nearest') (x1, y1), dy, (x0, y0), method='cubic', fill_value=0)
flow = np.stack([flow_x, flow_y], axis=0) flow = np.stack([flow_x, flow_y], axis=0)
return torch.from_numpy(flow).float() return torch.from_numpy(flow).float()
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def coords_grid(batch, ht, wd): def coords_grid(batch, ht, wd):
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
coords = torch.stack(coords[::-1], dim=0).float() coords = torch.stack(coords[::-1], dim=0).float()

BIN
demo-frames/frame_0016.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 652 KiB

BIN
demo-frames/frame_0017.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 652 KiB

BIN
demo-frames/frame_0018.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 652 KiB

BIN
demo-frames/frame_0019.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 653 KiB

BIN
demo-frames/frame_0020.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 655 KiB

BIN
demo-frames/frame_0021.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 657 KiB

BIN
demo-frames/frame_0022.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 658 KiB

BIN
demo-frames/frame_0023.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 659 KiB

BIN
demo-frames/frame_0024.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 660 KiB

BIN
demo-frames/frame_0025.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 660 KiB

89
demo.py
View File

@ -4,87 +4,76 @@ sys.path.append('core')
import argparse import argparse
import os import os
import cv2 import cv2
import glob
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from PIL import Image from PIL import Image
import datasets
from utils import flow_viz
from raft import RAFT from raft import RAFT
from utils import flow_viz
from utils.utils import InputPadder
DEVICE = 'cuda' DEVICE = 'cuda'
def pad8(img): def load_image(imfile):
"""pad image such that dimensions are divisible by 8""" img = np.array(Image.open(imfile)).astype(np.uint8)
ht, wd = img.shape[2:] img = torch.from_numpy(img).permute(2, 0, 1).float()
pad_ht = (((ht // 8) + 1) * 8 - ht) % 8
pad_wd = (((wd // 8) + 1) * 8 - wd) % 8
pad_ht1 = [pad_ht//2, pad_ht-pad_ht//2]
pad_wd1 = [pad_wd//2, pad_wd-pad_wd//2]
img = F.pad(img, pad_wd1 + pad_ht1, mode='replicate')
return img return img
def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)[..., :3]
img = torch.from_numpy(img).permute(2, 0, 1).float()
return pad8(img[None]).to(DEVICE)
def load_image_list(image_files):
images = []
for imfile in sorted(image_files):
images.append(load_image(imfile))
images = torch.stack(images, dim=0)
images = images.to(DEVICE)
def display(image1, image2, flow): padder = InputPadder(images.shape)
image1 = image1.permute(1, 2, 0).cpu().numpy() / 255.0 return padder.pad(images)[0]
image2 = image2.permute(1, 2, 0).cpu().numpy() / 255.0
flow = flow.permute(1, 2, 0).cpu().numpy() def viz(img, flo):
flow_image = flow_viz.flow_to_image(flow) img = img[0].permute(1,2,0).cpu().numpy()
flow_image = cv2.resize(flow_image, (image1.shape[1], image1.shape[0])) flo = flo[0].permute(1,2,0).cpu().numpy()
# map flow to rgb image
flo = flow_viz.flow_to_image(flo)
img_flo = np.concatenate([img, flo], axis=0)
cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
cv2.imshow('image1', image1[..., ::-1])
cv2.imshow('image2', image2[..., ::-1])
cv2.imshow('flow', flow_image[..., ::-1])
cv2.waitKey() cv2.waitKey()
def demo(args): def demo(args):
model = RAFT(args) model = torch.nn.DataParallel(RAFT(args))
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(args.model)) model.load_state_dict(torch.load(args.model))
model = model.module
model.to(DEVICE) model.to(DEVICE)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
images = glob.glob(os.path.join(args.path, '*.png')) + \
glob.glob(os.path.join(args.path, '*.jpg'))
# sintel images images = load_image_list(images)
image1 = load_image('images/sintel_0.png') for i in range(images.shape[0]-1):
image2 = load_image('images/sintel_1.png') image1 = images[i,None]
image2 = images[i+1,None]
flow_predictions = model(image1, image2, iters=args.iters, upsample=False) flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
display(image1[0], image2[0], flow_predictions[-1][0]) viz(image1, flow_up)
# kitti images
image1 = load_image('images/kitti_0.png')
image2 = load_image('images/kitti_1.png')
flow_predictions = model(image1, image2, iters=16)
display(image1[0], image2[0], flow_predictions[-1][0])
# davis images
image1 = load_image('images/davis_0.jpg')
image2 = load_image('images/davis_1.jpg')
flow_predictions = model(image1, image2, iters=16)
display(image1[0], image2[0], flow_predictions[-1][0])
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', help="restore checkpoint") parser.add_argument('--model', help="restore checkpoint")
parser.add_argument('--path', help="dataset for evaluation")
parser.add_argument('--small', action='store_true', help='use small model') parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--iters', type=int, default=12) parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
args = parser.parse_args() args = parser.parse_args()
demo(args)
demo(args)

3
download_models.sh Executable file
View File

@ -0,0 +1,3 @@
#!/bin/bash
wget https://www.dropbox.com/s/npt24nvhoojdr0n/models.zip
unzip models.zip

View File

@ -2,7 +2,6 @@ import sys
sys.path.append('core') sys.path.append('core')
from PIL import Image from PIL import Image
import cv2
import argparse import argparse
import os import os
import time import time
@ -13,88 +12,185 @@ import matplotlib.pyplot as plt
import datasets import datasets
from utils import flow_viz from utils import flow_viz
from utils import frame_utils
from raft import RAFT from raft import RAFT
from utils.utils import InputPadder, forward_interpolate
def validate_sintel(args, model, iters=50): @torch.no_grad()
""" Evaluate trained model on Sintel(train) clean + final passes """ def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'):
""" Create submission for the Sintel leaderboard """
model.eval() model.eval()
pad = 2
for dstype in ['clean', 'final']: for dstype in ['clean', 'final']:
val_dataset = datasets.MpiSintel(args, do_augument=False, dstype=dstype) test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype)
epe_list = [] flow_prev, sequence_prev = None, None
for i in range(len(val_dataset)): for test_id in range(len(test_dataset)):
image1, image2, flow_gt, _ = val_dataset[i] image1, image2, (sequence, frame) = test_dataset[test_id]
image1 = image1[None].cuda() if sequence != sequence_prev:
image2 = image2[None].cuda() flow_prev = None
image1 = F.pad(image1, [0, 0, pad, pad], mode='replicate')
image2 = F.pad(image2, [0, 0, pad, pad], mode='replicate')
with torch.no_grad():
flow_predictions = model.module(image1, image2, iters=iters)
flow_pr = flow_predictions[-1][0,:,pad:-pad]
epe = torch.sum((flow_pr - flow_gt.cuda())**2, dim=0)
epe = torch.sqrt(epe).mean()
epe_list.append(epe.item())
print("Validation (%s) EPE: %f" % (dstype, np.mean(epe_list)))
def validate_kitti(args, model, iters=32):
""" Evaluate trained model on KITTI (train) """
model.eval()
val_dataset = datasets.KITTI(args, do_augument=False, is_val=True, do_pad=True)
with torch.no_grad():
epe_list, out_list = [], []
for i in range(len(val_dataset)):
image1, image2, flow_gt, valid_gt = val_dataset[i]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
flow_gt = flow_gt.cuda()
valid_gt = valid_gt.cuda()
flow_predictions = model.module(image1, image2, iters=iters)
flow_pr = flow_predictions[-1][0]
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
mag = torch.sum(flow_gt**2, dim=0).sqrt()
epe = epe.view(-1) padder = InputPadder(image1.shape)
mag = mag.view(-1) image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
val = valid_gt.view(-1) >= 0.5
out = ((epe > 3.0) & ((epe/mag) > 0.05)).float() flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True)
epe_list.append(epe[val].mean().item()) flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
out_list.append(out[val].cpu().numpy())
if warm_start:
flow_prev = forward_interpolate(flow_low[0])[None].cuda()
output_dir = os.path.join(output_path, dstype, sequence)
output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
frame_utils.writeFlow(output_file, flow)
sequence_prev = sequence
@torch.no_grad()
def create_kitti_submission(model, iters=24, output_path='kitti_submission'):
""" Create submission for the Sintel leaderboard """
model.eval()
test_dataset = datasets.KITTI(split='testing', aug_params=None)
if not os.path.exists(output_path):
os.makedirs(output_path)
for test_id in range(len(test_dataset)):
image1, image2, (frame_id, ) = test_dataset[test_id]
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
output_filename = os.path.join(output_path, frame_id)
frame_utils.writeFlowKITTI(output_filename, flow)
@torch.no_grad()
def validate_chairs(model, iters=24):
""" Perform evaluation on the FlyingChairs (test) split """
model.eval()
epe_list = []
val_dataset = datasets.FlyingChairs(split='validation')
for val_id in range(len(val_dataset)):
image1, image2, flow_gt, _ = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt()
epe_list.append(epe.view(-1).numpy())
epe = np.mean(np.concatenate(epe_list))
print("Validation Chairs EPE: %f" % epe)
return {'chairs': epe}
@torch.no_grad()
def validate_sintel(model, iters=32):
""" Peform validation using the Sintel (train) split """
model.eval()
results = {}
for dstype in ['clean', 'final']:
val_dataset = datasets.MpiSintel(split='training', dstype=dstype)
epe_list = []
for val_id in range(len(val_dataset)):
image1, image2, flow_gt, _ = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
flow = padder.unpad(flow_pr[0]).cpu()
epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
epe_list.append(epe.view(-1).numpy())
epe_all = np.concatenate(epe_list)
epe = np.mean(epe_all)
px1 = np.mean(epe_all<1)
px3 = np.mean(epe_all<3)
px5 = np.mean(epe_all<5)
print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
results[dstype] = np.mean(epe_list)
return results
@torch.no_grad()
def validate_kitti(model, iters=24):
""" Peform validation using the KITTI-2015 (train) split """
model.eval()
val_dataset = datasets.KITTI(split='training')
out_list, epe_list = [], []
for val_id in range(len(val_dataset)):
image1, image2, flow_gt, valid_gt = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
flow = padder.unpad(flow_pr[0]).cpu()
epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
mag = torch.sum(flow_gt**2, dim=0).sqrt()
epe = epe.view(-1)
mag = mag.view(-1)
val = valid_gt.view(-1) >= 0.5
out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
epe_list.append(epe[val].mean().item())
out_list.append(out[val].cpu().numpy())
epe_list = np.array(epe_list) epe_list = np.array(epe_list)
out_list = np.concatenate(out_list) out_list = np.concatenate(out_list)
epe = np.mean(epe_list)
f1 = 100 * np.mean(out_list)
print("Validation KITTI: %f, %f" % (np.mean(epe_list), 100*np.mean(out_list))) print("Validation KITTI: %f, %f" % (epe, f1))
return {'kitti-epe': epe, 'kitti-f1': f1}
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', help="restore checkpoint") parser.add_argument('--model', help="restore checkpoint")
parser.add_argument('--dataset', help="dataset for evaluation")
parser.add_argument('--small', action='store_true', help='use small model') parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--sintel_iters', type=int, default=50) parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--kitti_iters', type=int, default=32)
args = parser.parse_args() args = parser.parse_args()
model = RAFT(args) model = torch.nn.DataParallel(RAFT(args))
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(args.model)) model.load_state_dict(torch.load(args.model))
model.to('cuda') model.cuda()
model.eval() model.eval()
validate_sintel(args, model, args.sintel_iters) # create_sintel_submission(model.module, warm_start=True)
validate_kitti(args, model, args.kitti_iters) # create_kitti_submission(model.module)
with torch.no_grad():
if args.dataset == 'chairs':
validate_chairs(model.module)
elif args.dataset == 'sintel':
validate_sintel(model.module)
elif args.dataset == 'kitti':
validate_kitti(model.module)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 497 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 514 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 829 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 822 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 396 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 388 KiB

View File

@ -1,3 +0,0 @@
#!/bin/bash
wget https://www.dropbox.com/s/a2acvmczgzm6f9n/models.zip
unzip models.zip

167
train.py Executable file → Normal file
View File

@ -16,26 +16,43 @@ import torch.nn.functional as F
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from raft import RAFT from raft import RAFT
from evaluate import validate_sintel, validate_kitti import evaluate
import datasets import datasets
from torch.utils.tensorboard import SummaryWriter
try:
from torch.cuda.amp import GradScaler
except:
# dummy GradScaler for PyTorch < 1.6
class GradScaler:
def __init__(self):
pass
def scale(self, loss):
return loss
def unscale_(self, optimizer):
pass
def step(self, optimizer):
optimizer.step()
def update(self):
pass
# exclude extremly large displacements # exclude extremly large displacements
MAX_FLOW = 1000 MAX_FLOW = 500
SUM_FREQ = 200 SUM_FREQ = 100
VAL_FREQ = 5000 VAL_FREQ = 5000
def count_parameters(model): def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def sequence_loss(flow_preds, flow_gt, valid):
""" Loss function defined over sequence of flow predictions """ """ Loss function defined over sequence of flow predictions """
n_predictions = len(flow_preds) n_predictions = len(flow_preds)
flow_loss = 0.0 flow_loss = 0.0
# exlude invalid pixels and extremely large diplacements # exlude invalid pixels and extremely large diplacements
valid = (valid >= 0.5) & (flow_gt.abs().sum(dim=1) < MAX_FLOW) mag = torch.sum(flow_gt**2, dim=1).sqrt()
valid = (valid >= 0.5) & (mag < max_flow)
for i in range(n_predictions): for i in range(n_predictions):
i_weight = 0.8**(n_predictions - i - 1) i_weight = 0.8**(n_predictions - i - 1)
@ -54,39 +71,22 @@ def sequence_loss(flow_preds, flow_gt, valid):
return flow_loss, metrics return flow_loss, metrics
def show_image(img):
img = img.permute(1,2,0).cpu().numpy()
plt.imshow(img/255.0)
plt.show()
# cv2.imshow('image', img/255.0)
# cv2.waitKey()
def fetch_dataloader(args): def count_parameters(model):
""" Create the data loader for the corresponding training set """ return sum(p.numel() for p in model.parameters() if p.requires_grad)
if args.dataset == 'chairs':
train_dataset = datasets.FlyingChairs(args, image_size=args.image_size)
elif args.dataset == 'things':
clean_dataset = datasets.SceneFlow(args, image_size=args.image_size, dstype='frames_cleanpass')
final_dataset = datasets.SceneFlow(args, image_size=args.image_size, dstype='frames_finalpass')
train_dataset = clean_dataset + final_dataset
elif args.dataset == 'sintel':
clean_dataset = datasets.MpiSintel(args, image_size=args.image_size, dstype='clean')
final_dataset = datasets.MpiSintel(args, image_size=args.image_size, dstype='final')
train_dataset = clean_dataset + final_dataset
elif args.dataset == 'kitti':
train_dataset = datasets.KITTI(args, image_size=args.image_size, is_val=False)
gpuargs = {'num_workers': 4, 'drop_last' : True}
train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=True, shuffle=True, **gpuargs)
print('Training with %d image pairs' % len(train_dataset))
return train_loader
def fetch_optimizer(args, model): def fetch_optimizer(args, model):
""" Create the optimizer and learning rate scheduler """ """ Create the optimizer and learning rate scheduler """
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon) optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps, scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
pct_start=0.2, cycle_momentum=False, anneal_strategy='linear') pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')
return optimizer, scheduler return optimizer, scheduler
@ -97,17 +97,22 @@ class Logger:
self.scheduler = scheduler self.scheduler = scheduler
self.total_steps = 0 self.total_steps = 0
self.running_loss = {} self.running_loss = {}
self.writer = None
def _print_training_status(self): def _print_training_status(self):
metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())] metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())]
training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_lr()[0]) training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0])
metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data) metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
# print the training status # print the training status
print(training_str + metrics_str) print(training_str + metrics_str)
for key in self.running_loss: if self.writer is None:
self.running_loss[key] = 0.0 self.writer = SummaryWriter()
for k in self.running_loss:
self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps)
self.running_loss[k] = 0.0
def push(self, metrics): def push(self, metrics):
self.total_steps += 1 self.total_steps += 1
@ -122,56 +127,95 @@ class Logger:
self._print_training_status() self._print_training_status()
self.running_loss = {} self.running_loss = {}
def write_dict(self, results):
if self.writer is None:
self.writer = SummaryWriter()
for key in results:
self.writer.add_scalar(key, results[key], self.total_steps)
def close(self):
self.writer.close()
def train(args): def train(args):
model = RAFT(args) model = nn.DataParallel(RAFT(args), device_ids=args.gpus)
model = nn.DataParallel(model)
print("Parameter Count: %d" % count_parameters(model)) print("Parameter Count: %d" % count_parameters(model))
if args.restore_ckpt is not None: if args.restore_ckpt is not None:
model.load_state_dict(torch.load(args.restore_ckpt)) model.load_state_dict(torch.load(args.restore_ckpt), strict=False)
model.cuda() model.cuda()
model.train() model.train()
if 'chairs' not in args.dataset: if args.stage != 'chairs':
model.module.freeze_bn() model.module.freeze_bn()
train_loader = fetch_dataloader(args) train_loader = datasets.fetch_dataloader(args)
optimizer, scheduler = fetch_optimizer(args, model) optimizer, scheduler = fetch_optimizer(args, model)
total_steps = 0 total_steps = 0
scaler = GradScaler(enabled=args.mixed_precision)
logger = Logger(model, scheduler) logger = Logger(model, scheduler)
VAL_FREQ = 5000
add_noise = True
should_keep_training = True should_keep_training = True
while should_keep_training: while should_keep_training:
for i_batch, data_blob in enumerate(train_loader): for i_batch, data_blob in enumerate(train_loader):
optimizer.zero_grad()
image1, image2, flow, valid = [x.cuda() for x in data_blob] image1, image2, flow, valid = [x.cuda() for x in data_blob]
optimizer.zero_grad() # show_image(image1[0])
flow_predictions = model(image1, image2, iters=args.iters) # show_image(image2[0])
if args.add_noise:
stdv = np.random.uniform(0.0, 5.0)
image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)
flow_predictions = model(image1, image2, iters=args.iters)
loss, metrics = sequence_loss(flow_predictions, flow, valid) loss, metrics = sequence_loss(flow_predictions, flow, valid)
loss.backward() scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()
scaler.step(optimizer)
scheduler.step() scheduler.step()
total_steps += 1 scaler.update()
logger.push(metrics) logger.push(metrics)
if total_steps % VAL_FREQ == VAL_FREQ-1: if total_steps % VAL_FREQ == VAL_FREQ - 1:
PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name) PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name)
torch.save(model.state_dict(), PATH) torch.save(model.state_dict(), PATH)
if total_steps == args.num_steps: results = {}
for val_dataset in args.validation:
if val_dataset == 'chairs':
results.update(evaluate.validate_chairs(model.module))
elif val_dataset == 'sintel':
results.update(evaluate.validate_sintel(model.module))
elif val_dataset == 'kitti':
results.update(evaluate.validate_kitti(model.module))
logger.write_dict(results)
model.train()
if args.stage != 'chairs':
model.module.freeze_bn()
total_steps += 1
if total_steps > args.num_steps:
should_keep_training = False should_keep_training = False
break break
logger.close()
PATH = 'checkpoints/%s.pth' % args.name PATH = 'checkpoints/%s.pth' % args.name
torch.save(model.state_dict(), PATH) torch.save(model.state_dict(), PATH)
@ -180,21 +224,25 @@ def train(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--name', default='bla', help="name your experiment") parser.add_argument('--name', default='raft', help="name your experiment")
parser.add_argument('--dataset', help="which dataset to use for training") parser.add_argument('--stage', help="determines which dataset to use for training")
parser.add_argument('--restore_ckpt', help="restore checkpoint") parser.add_argument('--restore_ckpt', help="restore checkpoint")
parser.add_argument('--small', action='store_true', help='use small model') parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--validation', type=str, nargs='+')
parser.add_argument('--lr', type=float, default=0.00002) parser.add_argument('--lr', type=float, default=0.00002)
parser.add_argument('--num_steps', type=int, default=100000) parser.add_argument('--num_steps', type=int, default=100000)
parser.add_argument('--batch_size', type=int, default=6) parser.add_argument('--batch_size', type=int, default=6)
parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512]) parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--iters', type=int, default=12) parser.add_argument('--iters', type=int, default=12)
parser.add_argument('--wdecay', type=float, default=.00005) parser.add_argument('--wdecay', type=float, default=.00005)
parser.add_argument('--epsilon', type=float, default=1e-8) parser.add_argument('--epsilon', type=float, default=1e-8)
parser.add_argument('--clip', type=float, default=1.0) parser.add_argument('--clip', type=float, default=1.0)
parser.add_argument('--dropout', type=float, default=0.0) parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--add_noise', action='store_true')
args = parser.parse_args() args = parser.parse_args()
torch.manual_seed(1234) torch.manual_seed(1234)
@ -203,9 +251,4 @@ if __name__ == '__main__':
if not os.path.isdir('checkpoints'): if not os.path.isdir('checkpoints'):
os.mkdir('checkpoints') os.mkdir('checkpoints')
# scale learning rate and batch size by number of GPUs train(args)
num_gpus = torch.cuda.device_count()
args.batch_size = args.batch_size * num_gpus
args.lr = args.lr * num_gpus
train(args)