added upsampling module
111
README.md
@ -1,7 +1,4 @@
|
|||||||
# RAFT
|
# RAFT
|
||||||
|
|
||||||
**7/22/2020: We have updated our method to predict flow at full resolution leading to improved results on public benchmarks. This repository will be updated to reflect these changes within the next few days.**
|
|
||||||
|
|
||||||
This repository contains the source code for our paper:
|
This repository contains the source code for our paper:
|
||||||
|
|
||||||
[RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)<br/>
|
[RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)<br/>
|
||||||
@ -11,90 +8,72 @@ Zachary Teed and Jia Deng<br/>
|
|||||||
<img src="RAFT.png">
|
<img src="RAFT.png">
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
Our code was tested using PyTorch 1.3.1 and Python 3. The following additional packages need to be installed
|
The code has been tested with PyTorch 1.5.1 and PyTorch Nightly. If you want to train with mixed precision, you will have to install the nightly build.
|
||||||
|
```Shell
|
||||||
```Shell
|
conda create --name raft
|
||||||
pip install Pillow
|
conda activate raft
|
||||||
pip install scipy
|
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch-nightly
|
||||||
pip install opencv-python
|
conda install matplotlib
|
||||||
```
|
conda install tensorboard
|
||||||
|
conda install scipy
|
||||||
|
conda install opencv
|
||||||
|
```
|
||||||
|
|
||||||
## Demos
|
## Demos
|
||||||
Pretrained models can be downloaded by running
|
Pretrained models can be downloaded by running
|
||||||
```Shell
|
```Shell
|
||||||
./scripts/download_models.sh
|
./scripts/download_models.sh
|
||||||
```
|
```
|
||||||
|
or downloaded from [google drive](https://drive.google.com/file/d/10-BYgHqRNPGvmNUWr8razjb1xHu55pyA/view?usp=sharing)
|
||||||
|
|
||||||
You can run the demos using one of the available models.
|
You can demo a trained model on a sequence of frames
|
||||||
|
|
||||||
```Shell
|
```Shell
|
||||||
python demo.py --model=models/chairs+things.pth
|
python demo.py --model=models/raft-things.pth --path=demo-frames
|
||||||
```
|
```
|
||||||
|
|
||||||
or using the small (1M parameter) model
|
|
||||||
|
|
||||||
```Shell
|
## Required Data
|
||||||
python demo.py --model=models/small.pth --small
|
To evaluate/train RAFT, you will need to download the required datasets.
|
||||||
```
|
* [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
|
||||||
|
* [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
|
||||||
|
* [Sintel](http://sintel.is.tue.mpg.de/)
|
||||||
|
* [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow)
|
||||||
|
* [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional)
|
||||||
|
|
||||||
Running the demos will display the two images and a vizualization of the optical flow estimate. After the images display, press any key to continue.
|
|
||||||
|
|
||||||
## Training
|
By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder
|
||||||
To train RAFT, you will need to download the required datasets. The first stage of training requires the [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) and [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) datasets. Finetuning and evaluation require the [Sintel](http://sintel.is.tue.mpg.de/) and [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) datasets. We organize the directory structure as follows. By default `datasets.py` will search for the datasets in these locations
|
|
||||||
|
|
||||||
```Shell
|
```Shell
|
||||||
├── datasets
|
├── datasets
|
||||||
│ ├── Sintel
|
├── Sintel
|
||||||
| | ├── test
|
├── test
|
||||||
| | ├── training
|
├── training
|
||||||
│ ├── KITTI
|
├── KITTI
|
||||||
| | ├── testing
|
├── testing
|
||||||
| | ├── training
|
├── training
|
||||||
| | ├── devkit
|
├── devkit
|
||||||
│ ├── FlyingChairs_release
|
├── FlyingChairs_release
|
||||||
| | ├── data
|
├── data
|
||||||
│ ├── FlyingThings3D
|
├── FlyingThings3D
|
||||||
| | ├── frames_cleanpass
|
├── frames_cleanpass
|
||||||
| | ├── frames_finalpass
|
├── frames_finalpass
|
||||||
| | ├── optical_flow
|
├── optical_flow
|
||||||
```
|
```
|
||||||
|
|
||||||
We used the following training schedule in our paper (note: we use 2 GPUs for training)
|
|
||||||
|
|
||||||
```Shell
|
|
||||||
python train.py --name=chairs --image_size 368 496 --dataset=chairs --num_steps=100000 --lr=0.0002 --batch_size=6
|
|
||||||
```
|
|
||||||
|
|
||||||
Next, finetune on the FlyingThings dataset
|
|
||||||
|
|
||||||
```Shell
|
|
||||||
python train.py --name=things --image_size 368 768 --dataset=things --num_steps=60000 --lr=0.00005 --batch_size=3 --restore_ckpt=checkpoints/chairs.pth
|
|
||||||
```
|
|
||||||
|
|
||||||
You can perform dataset specific finetuning
|
|
||||||
|
|
||||||
### Sintel
|
|
||||||
|
|
||||||
```Shell
|
|
||||||
python train.py --name=sintel_ft --image_size 368 768 --dataset=sintel --num_steps=60000 --lr=0.00005 --batch_size=4 --restore_ckpt=checkpoints/things.pth
|
|
||||||
```
|
|
||||||
|
|
||||||
### KITTI
|
|
||||||
|
|
||||||
```Shell
|
|
||||||
python train.py --name=kitti_ft --image_size 288 896 --dataset=kitti --num_steps=40000 --lr=0.0001 --batch_size=4 --restore_ckpt=checkpoints/things.pth
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
## Evaluation
|
## Evaluation
|
||||||
You can evaluate a model on Sintel and KITTI by running
|
You can evaluate a trained model using `evaluate.py`
|
||||||
|
|
||||||
```Shell
|
```Shell
|
||||||
python evaluate.py --model=models/chairs+things.pth
|
python evaluate.py --model=models/raft-things.pth --dataset=sintel
|
||||||
```
|
```
|
||||||
|
|
||||||
or the small model by including the `small` flag
|
## Training
|
||||||
|
Training code will be made available in the next few days
|
||||||
|
<!-- We used the following training schedule in our paper (note: we use 2 GPUs for training). Training logs will be written to the `runs` which can be visualized using tensorboard
|
||||||
```Shell
|
```Shell
|
||||||
python evaluate.py --model=models/small.pth --small
|
./train_standard.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU)
|
||||||
|
```Shell
|
||||||
|
./train_mixed.sh
|
||||||
|
``` -->
|
||||||
|
22872
chairs_split.txt
Normal file
@ -2,6 +2,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from utils.utils import bilinear_sampler, coords_grid
|
from utils.utils import bilinear_sampler, coords_grid
|
||||||
|
|
||||||
|
|
||||||
class CorrBlock:
|
class CorrBlock:
|
||||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||||
self.num_levels = num_levels
|
self.num_levels = num_levels
|
||||||
@ -12,10 +13,10 @@ class CorrBlock:
|
|||||||
corr = CorrBlock.corr(fmap1, fmap2)
|
corr = CorrBlock.corr(fmap1, fmap2)
|
||||||
|
|
||||||
batch, h1, w1, dim, h2, w2 = corr.shape
|
batch, h1, w1, dim, h2, w2 = corr.shape
|
||||||
corr = corr.view(batch*h1*w1, dim, h2, w2)
|
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
|
||||||
|
|
||||||
self.corr_pyramid.append(corr)
|
self.corr_pyramid.append(corr)
|
||||||
for i in range(self.num_levels):
|
for i in range(self.num_levels-1):
|
||||||
corr = F.avg_pool2d(corr, 2, stride=2)
|
corr = F.avg_pool2d(corr, 2, stride=2)
|
||||||
self.corr_pyramid.append(corr)
|
self.corr_pyramid.append(corr)
|
||||||
|
|
||||||
@ -40,14 +41,16 @@ class CorrBlock:
|
|||||||
out_pyramid.append(corr)
|
out_pyramid.append(corr)
|
||||||
|
|
||||||
out = torch.cat(out_pyramid, dim=-1)
|
out = torch.cat(out_pyramid, dim=-1)
|
||||||
return out.permute(0, 3, 1, 2)
|
return out.permute(0, 3, 1, 2).contiguous().float()
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def corr(fmap1, fmap2):
|
def corr(fmap1, fmap2):
|
||||||
batch, dim, ht, wd = fmap1.shape
|
batch, dim, ht, wd = fmap1.shape
|
||||||
fmap1 = fmap1.view(batch, dim, ht*wd)
|
fmap1 = fmap1.view(batch, dim, ht*wd)
|
||||||
fmap2 = fmap2.view(batch, dim, ht*wd)
|
fmap2 = fmap2.view(batch, dim, ht*wd)
|
||||||
|
|
||||||
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
|
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
|
||||||
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
||||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||||
|
|
367
core/datasets.py
@ -6,53 +6,42 @@ import torch.utils.data as data
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import cv2
|
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from glob import glob
|
from glob import glob
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
|
||||||
from utils import frame_utils
|
from utils import frame_utils
|
||||||
from utils.augmentor import FlowAugmentor, FlowAugmentorKITTI
|
from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
|
||||||
|
|
||||||
|
|
||||||
class CombinedDataset(data.Dataset):
|
|
||||||
def __init__(self, datasets):
|
|
||||||
self.datasets = datasets
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
length = 0
|
|
||||||
for i in range(len(self.datasets)):
|
|
||||||
length += len(self.datsaets[i])
|
|
||||||
return length
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
i = 0
|
|
||||||
for j in range(len(self.datasets)):
|
|
||||||
if i + len(self.datasets[j]) >= index:
|
|
||||||
yield self.datasets[j][index-i]
|
|
||||||
break
|
|
||||||
i += len(self.datasets[j])
|
|
||||||
|
|
||||||
def __add__(self, other):
|
|
||||||
self.datasets.append(other)
|
|
||||||
return self
|
|
||||||
|
|
||||||
class FlowDataset(data.Dataset):
|
class FlowDataset(data.Dataset):
|
||||||
def __init__(self, args, image_size=None, do_augument=False):
|
def __init__(self, aug_params=None, sparse=False):
|
||||||
self.image_size = image_size
|
self.augmentor = None
|
||||||
self.do_augument = do_augument
|
self.sparse = sparse
|
||||||
|
if aug_params is not None:
|
||||||
if self.do_augument:
|
if sparse:
|
||||||
self.augumentor = FlowAugmentor(self.image_size)
|
self.augmentor = SparseFlowAugmentor(**aug_params)
|
||||||
|
else:
|
||||||
|
self.augmentor = FlowAugmentor(**aug_params)
|
||||||
|
|
||||||
|
self.is_test = False
|
||||||
|
self.init_seed = False
|
||||||
self.flow_list = []
|
self.flow_list = []
|
||||||
self.image_list = []
|
self.image_list = []
|
||||||
|
self.extra_info = []
|
||||||
self.init_seed = False
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
|
||||||
|
if self.is_test:
|
||||||
|
img1 = frame_utils.read_gen(self.image_list[index][0])
|
||||||
|
img2 = frame_utils.read_gen(self.image_list[index][1])
|
||||||
|
img1 = np.array(img1).astype(np.uint8)[..., :3]
|
||||||
|
img2 = np.array(img2).astype(np.uint8)[..., :3]
|
||||||
|
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
||||||
|
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
||||||
|
return img1, img2, self.extra_info[index]
|
||||||
|
|
||||||
if not self.init_seed:
|
if not self.init_seed:
|
||||||
worker_info = torch.utils.data.get_worker_info()
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
if worker_info is not None:
|
if worker_info is not None:
|
||||||
@ -62,133 +51,96 @@ class FlowDataset(data.Dataset):
|
|||||||
self.init_seed = True
|
self.init_seed = True
|
||||||
|
|
||||||
index = index % len(self.image_list)
|
index = index % len(self.image_list)
|
||||||
flow = frame_utils.read_gen(self.flow_list[index])
|
valid = None
|
||||||
|
if self.sparse:
|
||||||
|
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
|
||||||
|
else:
|
||||||
|
flow = frame_utils.read_gen(self.flow_list[index])
|
||||||
|
|
||||||
img1 = frame_utils.read_gen(self.image_list[index][0])
|
img1 = frame_utils.read_gen(self.image_list[index][0])
|
||||||
img2 = frame_utils.read_gen(self.image_list[index][1])
|
img2 = frame_utils.read_gen(self.image_list[index][1])
|
||||||
|
|
||||||
img1 = np.array(img1).astype(np.uint8)[..., :3]
|
|
||||||
img2 = np.array(img2).astype(np.uint8)[..., :3]
|
|
||||||
flow = np.array(flow).astype(np.float32)
|
flow = np.array(flow).astype(np.float32)
|
||||||
|
img1 = np.array(img1).astype(np.uint8)
|
||||||
|
img2 = np.array(img2).astype(np.uint8)
|
||||||
|
|
||||||
if self.do_augument:
|
# grayscale images
|
||||||
img1, img2, flow = self.augumentor(img1, img2, flow)
|
if len(img1.shape) == 2:
|
||||||
|
img1 = np.tile(img1[...,None], (1, 1, 3))
|
||||||
|
img2 = np.tile(img2[...,None], (1, 1, 3))
|
||||||
|
else:
|
||||||
|
img1 = img1[..., :3]
|
||||||
|
img2 = img2[..., :3]
|
||||||
|
|
||||||
|
if self.augmentor is not None:
|
||||||
|
if self.sparse:
|
||||||
|
img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
|
||||||
|
else:
|
||||||
|
img1, img2, flow = self.augmentor(img1, img2, flow)
|
||||||
|
|
||||||
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
||||||
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
||||||
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
|
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
|
||||||
valid = torch.ones_like(flow[0])
|
|
||||||
|
|
||||||
return img1, img2, flow, valid
|
if valid is not None:
|
||||||
|
valid = torch.from_numpy(valid)
|
||||||
|
else:
|
||||||
|
valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
|
||||||
|
|
||||||
|
return img1, img2, flow, valid.float()
|
||||||
|
|
||||||
|
|
||||||
|
def __rmul__(self, v):
|
||||||
|
self.flow_list = v * self.flow_list
|
||||||
|
self.image_list = v * self.image_list
|
||||||
|
return self
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.image_list)
|
return len(self.image_list)
|
||||||
|
|
||||||
def __add(self, other):
|
|
||||||
return CombinedDataset([self, other])
|
|
||||||
|
|
||||||
|
|
||||||
class MpiSintelTest(FlowDataset):
|
|
||||||
def __init__(self, args, root='datasets/Sintel/test', dstype='clean'):
|
|
||||||
super(MpiSintelTest, self).__init__(args, image_size=None, do_augument=False)
|
|
||||||
|
|
||||||
self.root = root
|
|
||||||
self.dstype = dstype
|
|
||||||
|
|
||||||
image_dir = osp.join(self.root, dstype)
|
|
||||||
all_sequences = os.listdir(image_dir)
|
|
||||||
|
|
||||||
self.image_list = []
|
|
||||||
for sequence in all_sequences:
|
|
||||||
frames = sorted(glob(osp.join(image_dir, sequence, '*.png')))
|
|
||||||
for i in range(len(frames)-1):
|
|
||||||
self.image_list += [[frames[i], frames[i+1], sequence, i]]
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
img1 = frame_utils.read_gen(self.image_list[index][0])
|
|
||||||
img2 = frame_utils.read_gen(self.image_list[index][1])
|
|
||||||
sequence = self.image_list[index][2]
|
|
||||||
frame = self.image_list[index][3]
|
|
||||||
|
|
||||||
img1 = np.array(img1).astype(np.uint8)[..., :3]
|
|
||||||
img2 = np.array(img2).astype(np.uint8)[..., :3]
|
|
||||||
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
|
||||||
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
|
||||||
return img1, img2, sequence, frame
|
|
||||||
|
|
||||||
|
|
||||||
class MpiSintel(FlowDataset):
|
class MpiSintel(FlowDataset):
|
||||||
def __init__(self, args, image_size=None, do_augument=True, root='datasets/Sintel/training', dstype='clean'):
|
def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
|
||||||
super(MpiSintel, self).__init__(args, image_size, do_augument)
|
super(MpiSintel, self).__init__(aug_params)
|
||||||
if do_augument:
|
flow_root = osp.join(root, split, 'flow')
|
||||||
self.augumentor.min_scale = -0.2
|
image_root = osp.join(root, split, dstype)
|
||||||
self.augumentor.max_scale = 0.7
|
|
||||||
|
|
||||||
self.root = root
|
if split == 'test':
|
||||||
self.dstype = dstype
|
self.is_test = True
|
||||||
|
|
||||||
flow_root = osp.join(root, 'flow')
|
for scene in os.listdir(image_root):
|
||||||
image_root = osp.join(root, dstype)
|
image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
|
||||||
|
for i in range(len(image_list)-1):
|
||||||
|
self.image_list += [ [image_list[i], image_list[i+1]] ]
|
||||||
|
self.extra_info += [ (scene, i) ] # scene and frame_id
|
||||||
|
|
||||||
file_list = sorted(glob(osp.join(flow_root, '*/*.flo')))
|
if split != 'test':
|
||||||
for flo in file_list:
|
self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
|
||||||
fbase = flo[len(flow_root)+1:]
|
|
||||||
fprefix = fbase[:-8]
|
|
||||||
fnum = int(fbase[-8:-4])
|
|
||||||
|
|
||||||
img1 = osp.join(image_root, fprefix + "%04d"%(fnum+0) + '.png')
|
|
||||||
img2 = osp.join(image_root, fprefix + "%04d"%(fnum+1) + '.png')
|
|
||||||
|
|
||||||
if not osp.isfile(img1) or not osp.isfile(img2) or not osp.isfile(flo):
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.image_list.append((img1, img2))
|
|
||||||
self.flow_list.append(flo)
|
|
||||||
|
|
||||||
|
|
||||||
class FlyingChairs(FlowDataset):
|
class FlyingChairs(FlowDataset):
|
||||||
def __init__(self, args, image_size=None, do_augument=True, root='datasets/FlyingChairs_release/data'):
|
def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
|
||||||
super(FlyingChairs, self).__init__(args, image_size, do_augument)
|
super(FlyingChairs, self).__init__(aug_params)
|
||||||
self.root = root
|
|
||||||
self.augumentor.min_scale = -0.2
|
|
||||||
self.augumentor.max_scale = 1.0
|
|
||||||
|
|
||||||
images = sorted(glob(osp.join(root, '*.ppm')))
|
images = sorted(glob(osp.join(root, '*.ppm')))
|
||||||
self.flow_list = sorted(glob(osp.join(root, '*.flo')))
|
flows = sorted(glob(osp.join(root, '*.flo')))
|
||||||
assert (len(images)//2 == len(self.flow_list))
|
assert (len(images)//2 == len(flows))
|
||||||
|
|
||||||
self.image_list = []
|
split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
|
||||||
for i in range(len(self.flow_list)):
|
for i in range(len(flows)):
|
||||||
im1 = images[2*i]
|
xid = split_list[i]
|
||||||
im2 = images[2*i + 1]
|
if (split=='training' and xid==1) or (split=='validation' and xid==2):
|
||||||
self.image_list.append([im1, im2])
|
self.flow_list += [ flows[i] ]
|
||||||
|
self.image_list += [ [images[2*i], images[2*i+1]] ]
|
||||||
|
|
||||||
|
|
||||||
class SceneFlow(FlowDataset):
|
class FlyingThings3D(FlowDataset):
|
||||||
def __init__(self, args, image_size, do_augument=True, root='datasets',
|
def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
|
||||||
dstype='frames_cleanpass', use_flyingthings=True, use_monkaa=False, use_driving=False):
|
super(FlyingThings3D, self).__init__(aug_params)
|
||||||
|
|
||||||
super(SceneFlow, self).__init__(args, image_size, do_augument)
|
|
||||||
self.root = root
|
|
||||||
self.dstype = dstype
|
|
||||||
|
|
||||||
self.augumentor.min_scale = -0.2
|
|
||||||
self.augumentor.max_scale = 0.8
|
|
||||||
|
|
||||||
if use_flyingthings:
|
|
||||||
self.add_flyingthings()
|
|
||||||
|
|
||||||
if use_monkaa:
|
|
||||||
self.add_monkaa()
|
|
||||||
|
|
||||||
if use_driving:
|
|
||||||
self.add_driving()
|
|
||||||
|
|
||||||
def add_flyingthings(self):
|
|
||||||
root = osp.join(self.root, 'FlyingThings3D')
|
|
||||||
|
|
||||||
for cam in ['left']:
|
for cam in ['left']:
|
||||||
for direction in ['into_future', 'into_past']:
|
for direction in ['into_future', 'into_past']:
|
||||||
image_dirs = sorted(glob(osp.join(root, self.dstype, 'TRAIN/*/*')))
|
image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
|
||||||
image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
|
image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
|
||||||
|
|
||||||
flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
|
flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
|
||||||
@ -199,114 +151,85 @@ class SceneFlow(FlowDataset):
|
|||||||
flows = sorted(glob(osp.join(fdir, '*.pfm')) )
|
flows = sorted(glob(osp.join(fdir, '*.pfm')) )
|
||||||
for i in range(len(flows)-1):
|
for i in range(len(flows)-1):
|
||||||
if direction == 'into_future':
|
if direction == 'into_future':
|
||||||
self.image_list += [[images[i], images[i+1]]]
|
self.image_list += [ [images[i], images[i+1]] ]
|
||||||
self.flow_list += [flows[i]]
|
self.flow_list += [ flows[i] ]
|
||||||
elif direction == 'into_past':
|
elif direction == 'into_past':
|
||||||
self.image_list += [[images[i+1], images[i]]]
|
self.image_list += [ [images[i+1], images[i]] ]
|
||||||
self.flow_list += [flows[i+1]]
|
self.flow_list += [ flows[i+1] ]
|
||||||
|
|
||||||
def add_monkaa(self):
|
|
||||||
pass # we don't use monkaa
|
|
||||||
|
|
||||||
def add_driving(self):
|
|
||||||
pass # we don't use driving
|
|
||||||
|
|
||||||
|
|
||||||
class KITTI(FlowDataset):
|
class KITTI(FlowDataset):
|
||||||
def __init__(self, args, image_size=None, do_augument=True, is_test=False, is_val=False, do_pad=False, split=True, root='datasets/KITTI'):
|
def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
|
||||||
super(KITTI, self).__init__(args, image_size, do_augument)
|
super(KITTI, self).__init__(aug_params, sparse=True)
|
||||||
self.root = root
|
if split == 'testing':
|
||||||
self.is_test = is_test
|
self.is_test = True
|
||||||
self.is_val = is_val
|
|
||||||
self.do_pad = do_pad
|
|
||||||
|
|
||||||
if self.do_augument:
|
root = osp.join(root, split)
|
||||||
self.augumentor = FlowAugmentorKITTI(self.image_size, min_scale=-0.2, max_scale=0.5)
|
images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
|
||||||
|
images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
|
||||||
|
|
||||||
if self.is_test:
|
for img1, img2 in zip(images1, images2):
|
||||||
images1 = sorted(glob(os.path.join(root, 'testing', 'image_2/*_10.png')))
|
frame_id = img1.split('/')[-1]
|
||||||
images2 = sorted(glob(os.path.join(root, 'testing', 'image_2/*_11.png')))
|
self.extra_info += [ [frame_id] ]
|
||||||
for i in range(len(images1)):
|
self.image_list += [ [img1, img2] ]
|
||||||
self.image_list += [[images1[i], images2[i]]]
|
|
||||||
|
|
||||||
else:
|
if split == 'training':
|
||||||
flows = sorted(glob(os.path.join(root, 'training', 'flow_occ/*_10.png')))
|
self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
|
||||||
images1 = sorted(glob(os.path.join(root, 'training', 'image_2/*_10.png')))
|
|
||||||
images2 = sorted(glob(os.path.join(root, 'training', 'image_2/*_11.png')))
|
|
||||||
|
|
||||||
for i in range(len(flows)):
|
|
||||||
|
class HD1K(FlowDataset):
|
||||||
|
def __init__(self, aug_params=None, root='datasets/HD1k'):
|
||||||
|
super(HD1K, self).__init__(aug_params, sparse=True)
|
||||||
|
|
||||||
|
seq_ix = 0
|
||||||
|
while 1:
|
||||||
|
flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
|
||||||
|
images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
|
||||||
|
|
||||||
|
if len(flows) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
for i in range(len(flows)-1):
|
||||||
self.flow_list += [flows[i]]
|
self.flow_list += [flows[i]]
|
||||||
self.image_list += [[images1[i], images2[i]]]
|
self.image_list += [ [images[i], images[i+1]] ]
|
||||||
|
|
||||||
|
seq_ix += 1
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
|
||||||
|
""" Create the data loader for the corresponding trainign set """
|
||||||
|
|
||||||
if self.is_test:
|
if args.stage == 'chairs':
|
||||||
frame_id = self.image_list[index][0]
|
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 1.0, 'do_flip': True}
|
||||||
frame_id = frame_id.split('/')[-1]
|
train_dataset = FlyingChairs(aug_params, split='training')
|
||||||
|
|
||||||
|
elif args.stage == 'things':
|
||||||
|
aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
|
||||||
|
clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
|
||||||
|
final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
|
||||||
|
train_dataset = clean_dataset + final_dataset
|
||||||
|
|
||||||
img1 = frame_utils.read_gen(self.image_list[index][0])
|
elif args.stage == 'sintel':
|
||||||
img2 = frame_utils.read_gen(self.image_list[index][1])
|
aug_params = {'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True}
|
||||||
|
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
|
||||||
|
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
|
||||||
|
sintel_final = MpiSintel(aug_params, split='training', dstype='final')
|
||||||
|
|
||||||
img1 = np.array(img1).astype(np.uint8)[..., :3]
|
if TRAIN_DS == 'C+T+K+S+H':
|
||||||
img2 = np.array(img2).astype(np.uint8)[..., :3]
|
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True})
|
||||||
|
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.5, 'do_flip': True})
|
||||||
|
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
|
||||||
|
|
||||||
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
elif TRAIN_DS == 'C+T+K/S':
|
||||||
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
train_dataset = 100*sintel_clean + 100*sintel_final + things
|
||||||
return img1, img2, frame_id
|
|
||||||
|
|
||||||
|
elif args.stage == 'kitti':
|
||||||
|
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
|
||||||
|
train_dataset = KITTI(args, image_size=args.image_size, is_val=False)
|
||||||
|
|
||||||
else:
|
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
||||||
if not self.init_seed:
|
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
|
||||||
worker_info = torch.utils.data.get_worker_info()
|
|
||||||
if worker_info is not None:
|
|
||||||
np.random.seed(worker_info.id)
|
|
||||||
random.seed(worker_info.id)
|
|
||||||
self.init_seed = True
|
|
||||||
|
|
||||||
index = index % len(self.image_list)
|
print('Training with %d image pairs' % len(train_dataset))
|
||||||
frame_id = self.image_list[index][0]
|
return train_loader
|
||||||
frame_id = frame_id.split('/')[-1]
|
|
||||||
|
|
||||||
img1 = frame_utils.read_gen(self.image_list[index][0])
|
|
||||||
img2 = frame_utils.read_gen(self.image_list[index][1])
|
|
||||||
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
|
|
||||||
|
|
||||||
img1 = np.array(img1).astype(np.uint8)[..., :3]
|
|
||||||
img2 = np.array(img2).astype(np.uint8)[..., :3]
|
|
||||||
|
|
||||||
if self.do_augument:
|
|
||||||
img1, img2, flow, valid = self.augumentor(img1, img2, flow, valid)
|
|
||||||
|
|
||||||
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
|
||||||
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
|
||||||
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
|
|
||||||
valid = torch.from_numpy(valid).float()
|
|
||||||
|
|
||||||
if self.do_pad:
|
|
||||||
ht, wd = img1.shape[1:]
|
|
||||||
pad_ht = (((ht // 8) + 1) * 8 - ht) % 8
|
|
||||||
pad_wd = (((wd // 8) + 1) * 8 - wd) % 8
|
|
||||||
pad_ht1 = [0, pad_ht]
|
|
||||||
pad_wd1 = [pad_wd//2, pad_wd - pad_wd//2]
|
|
||||||
pad = pad_wd1 + pad_ht1
|
|
||||||
|
|
||||||
img1 = img1.view(1, 3, ht, wd)
|
|
||||||
img2 = img2.view(1, 3, ht, wd)
|
|
||||||
flow = flow.view(1, 2, ht, wd)
|
|
||||||
valid = valid.view(1, 1, ht, wd)
|
|
||||||
|
|
||||||
img1 = torch.nn.functional.pad(img1, pad, mode='replicate')
|
|
||||||
img2 = torch.nn.functional.pad(img2, pad, mode='replicate')
|
|
||||||
flow = torch.nn.functional.pad(flow, pad, mode='constant', value=0)
|
|
||||||
valid = torch.nn.functional.pad(valid, pad, mode='replicate', value=0)
|
|
||||||
|
|
||||||
img1 = img1.view(3, ht+pad_ht, wd+pad_wd)
|
|
||||||
img2 = img2.view(3, ht+pad_ht, wd+pad_wd)
|
|
||||||
flow = flow.view(2, ht+pad_ht, wd+pad_wd)
|
|
||||||
valid = valid.view(ht+pad_ht, wd+pad_wd)
|
|
||||||
|
|
||||||
if self.is_test:
|
|
||||||
return img1, img2, flow, valid, frame_id
|
|
||||||
|
|
||||||
return img1, img2, flow, valid
|
|
||||||
|
@ -143,10 +143,9 @@ class BasicEncoder(nn.Module):
|
|||||||
# output convolution
|
# output convolution
|
||||||
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
||||||
|
|
||||||
|
self.dropout = None
|
||||||
if dropout > 0:
|
if dropout > 0:
|
||||||
self.dropout = nn.Dropout2d(p=dropout)
|
self.dropout = nn.Dropout2d(p=dropout)
|
||||||
else:
|
|
||||||
self.dropout = None
|
|
||||||
|
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
@ -184,7 +183,7 @@ class BasicEncoder(nn.Module):
|
|||||||
|
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
|
|
||||||
if self.dropout is not None:
|
if self.training and self.dropout is not None:
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
|
|
||||||
if is_list:
|
if is_list:
|
||||||
@ -218,10 +217,9 @@ class SmallEncoder(nn.Module):
|
|||||||
self.layer2 = self._make_layer(64, stride=2)
|
self.layer2 = self._make_layer(64, stride=2)
|
||||||
self.layer3 = self._make_layer(96, stride=2)
|
self.layer3 = self._make_layer(96, stride=2)
|
||||||
|
|
||||||
|
self.dropout = None
|
||||||
if dropout > 0:
|
if dropout > 0:
|
||||||
self.dropout = nn.Dropout2d(p=dropout)
|
self.dropout = nn.Dropout2d(p=dropout)
|
||||||
else:
|
|
||||||
self.dropout = None
|
|
||||||
|
|
||||||
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
||||||
|
|
||||||
@ -260,8 +258,8 @@ class SmallEncoder(nn.Module):
|
|||||||
x = self.layer3(x)
|
x = self.layer3(x)
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
|
|
||||||
# if self.dropout is not None:
|
if self.training and self.dropout is not None:
|
||||||
# x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
|
|
||||||
if is_list:
|
if is_list:
|
||||||
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
78
core/raft.py
@ -3,11 +3,23 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from modules.update import BasicUpdateBlock, SmallUpdateBlock
|
from update import BasicUpdateBlock, SmallUpdateBlock
|
||||||
from modules.extractor import BasicEncoder, SmallEncoder
|
from extractor import BasicEncoder, SmallEncoder
|
||||||
from modules.corr import CorrBlock
|
from corr import CorrBlock
|
||||||
from utils.utils import bilinear_sampler, coords_grid, upflow8
|
from utils.utils import bilinear_sampler, coords_grid, upflow8
|
||||||
|
|
||||||
|
try:
|
||||||
|
autocast = torch.cuda.amp.autocast
|
||||||
|
except:
|
||||||
|
# dummy autocast for PyTorch < 1.6
|
||||||
|
class autocast:
|
||||||
|
def __init__(self, enabled):
|
||||||
|
pass
|
||||||
|
def __enter__(self):
|
||||||
|
pass
|
||||||
|
def __exit__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RAFT(nn.Module):
|
class RAFT(nn.Module):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
@ -26,7 +38,7 @@ class RAFT(nn.Module):
|
|||||||
args.corr_levels = 4
|
args.corr_levels = 4
|
||||||
args.corr_radius = 4
|
args.corr_radius = 4
|
||||||
|
|
||||||
if not hasattr(args, 'dropout'):
|
if 'dropout' not in args._get_kwargs():
|
||||||
args.dropout = 0
|
args.dropout = 0
|
||||||
|
|
||||||
# feature network, context network, and update block
|
# feature network, context network, and update block
|
||||||
@ -40,6 +52,7 @@ class RAFT(nn.Module):
|
|||||||
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
|
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
|
||||||
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
||||||
|
|
||||||
|
|
||||||
def freeze_bn(self):
|
def freeze_bn(self):
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.BatchNorm2d):
|
if isinstance(m, nn.BatchNorm2d):
|
||||||
@ -54,46 +67,73 @@ class RAFT(nn.Module):
|
|||||||
# optical flow computed as difference: flow = coords1 - coords0
|
# optical flow computed as difference: flow = coords1 - coords0
|
||||||
return coords0, coords1
|
return coords0, coords1
|
||||||
|
|
||||||
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True):
|
def upsample_flow(self, flow, mask):
|
||||||
|
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
|
||||||
|
N, _, H, W = flow.shape
|
||||||
|
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
||||||
|
mask = torch.softmax(mask, dim=2)
|
||||||
|
|
||||||
|
up_flow = F.unfold(8 * flow, [3,3], padding=1)
|
||||||
|
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
||||||
|
|
||||||
|
up_flow = torch.sum(mask * up_flow, dim=2)
|
||||||
|
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
||||||
|
return up_flow.reshape(N, 2, 8*H, 8*W)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
|
||||||
""" Estimate optical flow between pair of frames """
|
""" Estimate optical flow between pair of frames """
|
||||||
|
|
||||||
image1 = 2 * (image1 / 255.0) - 1.0
|
image1 = 2 * (image1 / 255.0) - 1.0
|
||||||
image2 = 2 * (image2 / 255.0) - 1.0
|
image2 = 2 * (image2 / 255.0) - 1.0
|
||||||
|
|
||||||
|
image1 = image1.contiguous()
|
||||||
|
image2 = image2.contiguous()
|
||||||
|
|
||||||
hdim = self.hidden_dim
|
hdim = self.hidden_dim
|
||||||
cdim = self.context_dim
|
cdim = self.context_dim
|
||||||
|
|
||||||
# run the feature network
|
# run the feature network
|
||||||
fmap1, fmap2 = self.fnet([image1, image2])
|
with autocast(enabled=self.args.mixed_precision):
|
||||||
|
fmap1, fmap2 = self.fnet([image1, image2])
|
||||||
|
|
||||||
|
fmap1 = fmap1.float()
|
||||||
|
fmap2 = fmap2.float()
|
||||||
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||||
|
|
||||||
# run the context network
|
# run the context network
|
||||||
cnet = self.cnet(image1)
|
with autocast(enabled=self.args.mixed_precision):
|
||||||
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
|
cnet = self.cnet(image1)
|
||||||
net, inp = torch.tanh(net), torch.relu(inp)
|
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
|
||||||
|
net = torch.tanh(net)
|
||||||
|
inp = torch.relu(inp)
|
||||||
|
|
||||||
# if dropout is being used reset mask
|
|
||||||
self.update_block.reset_mask(net, inp)
|
|
||||||
coords0, coords1 = self.initialize_flow(image1)
|
coords0, coords1 = self.initialize_flow(image1)
|
||||||
|
|
||||||
|
if flow_init is not None:
|
||||||
|
coords1 = coords1 + flow_init
|
||||||
|
|
||||||
flow_predictions = []
|
flow_predictions = []
|
||||||
for itr in range(iters):
|
for itr in range(iters):
|
||||||
coords1 = coords1.detach()
|
coords1 = coords1.detach()
|
||||||
corr = corr_fn(coords1) # index correlation volume
|
corr = corr_fn(coords1) # index correlation volume
|
||||||
|
|
||||||
flow = coords1 - coords0
|
flow = coords1 - coords0
|
||||||
net, delta_flow = self.update_block(net, inp, corr, flow)
|
with autocast(enabled=self.args.mixed_precision):
|
||||||
|
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
|
||||||
|
|
||||||
# F(t+1) = F(t) + \Delta(t)
|
# F(t+1) = F(t) + \Delta(t)
|
||||||
coords1 = coords1 + delta_flow
|
coords1 = coords1 + delta_flow
|
||||||
|
|
||||||
if upsample:
|
# upsample predictions
|
||||||
|
if up_mask is None:
|
||||||
flow_up = upflow8(coords1 - coords0)
|
flow_up = upflow8(coords1 - coords0)
|
||||||
flow_predictions.append(flow_up)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
flow_predictions.append(coords1 - coords0)
|
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
|
||||||
|
|
||||||
|
flow_predictions.append(flow_up)
|
||||||
|
|
||||||
|
if test_mode:
|
||||||
|
return coords1 - coords0, flow_up
|
||||||
|
|
||||||
return flow_predictions
|
return flow_predictions
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,34 +2,6 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
# VariationalHidDropout from https://github.com/locuslab/trellisnet/tree/master/TrellisNet
|
|
||||||
class VariationalHidDropout(nn.Module):
|
|
||||||
def __init__(self, dropout=0.0):
|
|
||||||
"""
|
|
||||||
Hidden-to-hidden (VD-based) dropout that applies the same mask at every time step and every layer of TrellisNet
|
|
||||||
:param dropout: The dropout rate (0 means no dropout is applied)
|
|
||||||
"""
|
|
||||||
super(VariationalHidDropout, self).__init__()
|
|
||||||
self.dropout = dropout
|
|
||||||
self.mask = None
|
|
||||||
|
|
||||||
def reset_mask(self, x):
|
|
||||||
dropout = self.dropout
|
|
||||||
|
|
||||||
# Dimension (N, C, L)
|
|
||||||
n, c, h, w = x.shape
|
|
||||||
m = x.data.new(n, c, 1, 1).bernoulli_(1 - dropout)
|
|
||||||
with torch.no_grad():
|
|
||||||
mask = m / (1 - dropout)
|
|
||||||
self.mask = mask
|
|
||||||
return mask
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if not self.training or self.dropout == 0:
|
|
||||||
return x
|
|
||||||
assert self.mask is not None, "You need to reset mask before using VariationalHidDropout"
|
|
||||||
return self.mask * x
|
|
||||||
|
|
||||||
|
|
||||||
class FlowHead(nn.Module):
|
class FlowHead(nn.Module):
|
||||||
def __init__(self, input_dim=128, hidden_dim=256):
|
def __init__(self, input_dim=128, hidden_dim=256):
|
||||||
@ -41,7 +13,6 @@ class FlowHead(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.conv2(self.relu(self.conv1(x)))
|
return self.conv2(self.relu(self.conv1(x)))
|
||||||
|
|
||||||
|
|
||||||
class ConvGRU(nn.Module):
|
class ConvGRU(nn.Module):
|
||||||
def __init__(self, hidden_dim=128, input_dim=192+128):
|
def __init__(self, hidden_dim=128, input_dim=192+128):
|
||||||
super(ConvGRU, self).__init__()
|
super(ConvGRU, self).__init__()
|
||||||
@ -59,7 +30,6 @@ class ConvGRU(nn.Module):
|
|||||||
h = (1-z) * h + z * q
|
h = (1-z) * h + z * q
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
class SepConvGRU(nn.Module):
|
class SepConvGRU(nn.Module):
|
||||||
def __init__(self, hidden_dim=128, input_dim=192+128):
|
def __init__(self, hidden_dim=128, input_dim=192+128):
|
||||||
super(SepConvGRU, self).__init__()
|
super(SepConvGRU, self).__init__()
|
||||||
@ -133,49 +103,37 @@ class SmallUpdateBlock(nn.Module):
|
|||||||
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
|
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
|
||||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
|
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
|
||||||
|
|
||||||
self.drop_inp = VariationalHidDropout(dropout=args.dropout)
|
|
||||||
self.drop_net = VariationalHidDropout(dropout=args.dropout)
|
|
||||||
|
|
||||||
def reset_mask(self, net, inp):
|
|
||||||
self.drop_inp.reset_mask(inp)
|
|
||||||
self.drop_net.reset_mask(net)
|
|
||||||
|
|
||||||
def forward(self, net, inp, corr, flow):
|
def forward(self, net, inp, corr, flow):
|
||||||
motion_features = self.encoder(flow, corr)
|
motion_features = self.encoder(flow, corr)
|
||||||
|
|
||||||
if self.training:
|
|
||||||
net = self.drop_net(net)
|
|
||||||
inp = self.drop_inp(inp)
|
|
||||||
|
|
||||||
inp = torch.cat([inp, motion_features], dim=1)
|
inp = torch.cat([inp, motion_features], dim=1)
|
||||||
net = self.gru(net, inp)
|
net = self.gru(net, inp)
|
||||||
delta_flow = self.flow_head(net)
|
delta_flow = self.flow_head(net)
|
||||||
|
|
||||||
return net, delta_flow
|
return net, None, delta_flow
|
||||||
|
|
||||||
class BasicUpdateBlock(nn.Module):
|
class BasicUpdateBlock(nn.Module):
|
||||||
def __init__(self, args, hidden_dim=128, input_dim=128):
|
def __init__(self, args, hidden_dim=128, input_dim=128):
|
||||||
super(BasicUpdateBlock, self).__init__()
|
super(BasicUpdateBlock, self).__init__()
|
||||||
|
self.args = args
|
||||||
self.encoder = BasicMotionEncoder(args)
|
self.encoder = BasicMotionEncoder(args)
|
||||||
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
|
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
|
||||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||||
|
|
||||||
self.drop_inp = VariationalHidDropout(dropout=args.dropout)
|
self.mask = nn.Sequential(
|
||||||
self.drop_net = VariationalHidDropout(dropout=args.dropout)
|
nn.Conv2d(128, 256, 3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(256, 64*9, 1, padding=0))
|
||||||
|
|
||||||
def reset_mask(self, net, inp):
|
def forward(self, net, inp, corr, flow, upsample=True):
|
||||||
self.drop_inp.reset_mask(inp)
|
|
||||||
self.drop_net.reset_mask(net)
|
|
||||||
|
|
||||||
def forward(self, net, inp, corr, flow):
|
|
||||||
motion_features = self.encoder(flow, corr)
|
motion_features = self.encoder(flow, corr)
|
||||||
|
|
||||||
if self.training:
|
|
||||||
net = self.drop_net(net)
|
|
||||||
inp = self.drop_inp(inp)
|
|
||||||
|
|
||||||
inp = torch.cat([inp, motion_features], dim=1)
|
inp = torch.cat([inp, motion_features], dim=1)
|
||||||
|
|
||||||
net = self.gru(net, inp)
|
net = self.gru(net, inp)
|
||||||
delta_flow = self.flow_head(net)
|
delta_flow = self.flow_head(net)
|
||||||
|
|
||||||
return net, delta_flow
|
# scale mask to balence gradients
|
||||||
|
mask = .25 * self.mask(net)
|
||||||
|
return net, mask, delta_flow
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,46 +1,55 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
import math
|
import math
|
||||||
import cv2
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
|
||||||
|
from torchvision.transforms import ColorJitter
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
class FlowAugmentor:
|
class FlowAugmentor:
|
||||||
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5):
|
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
|
||||||
|
|
||||||
|
# spatial augmentation params
|
||||||
self.crop_size = crop_size
|
self.crop_size = crop_size
|
||||||
self.augcolor = torchvision.transforms.ColorJitter(
|
|
||||||
brightness=0.4,
|
|
||||||
contrast=0.4,
|
|
||||||
saturation=0.4,
|
|
||||||
hue=0.5/3.14)
|
|
||||||
|
|
||||||
self.asymmetric_color_aug_prob = 0.2
|
|
||||||
self.spatial_aug_prob = 0.8
|
|
||||||
self.eraser_aug_prob = 0.5
|
|
||||||
|
|
||||||
self.min_scale = min_scale
|
self.min_scale = min_scale
|
||||||
self.max_scale = max_scale
|
self.max_scale = max_scale
|
||||||
self.max_stretch = 0.2
|
self.spatial_aug_prob = 0.8
|
||||||
self.stretch_prob = 0.8
|
self.stretch_prob = 0.8
|
||||||
self.margin = 20
|
self.max_stretch = 0.2
|
||||||
|
|
||||||
|
# flip augmentation params
|
||||||
|
self.do_flip = do_flip
|
||||||
|
self.h_flip_prob = 0.5
|
||||||
|
self.v_flip_prob = 0.1
|
||||||
|
|
||||||
|
# photometric augmentation params
|
||||||
|
self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
|
||||||
|
self.asymmetric_color_aug_prob = 0.2
|
||||||
|
self.eraser_aug_prob = 0.5
|
||||||
|
|
||||||
def color_transform(self, img1, img2):
|
def color_transform(self, img1, img2):
|
||||||
|
""" Photometric augmentation """
|
||||||
|
|
||||||
|
# asymmetric
|
||||||
if np.random.rand() < self.asymmetric_color_aug_prob:
|
if np.random.rand() < self.asymmetric_color_aug_prob:
|
||||||
img1 = np.array(self.augcolor(Image.fromarray(img1)), dtype=np.uint8)
|
img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
|
||||||
img2 = np.array(self.augcolor(Image.fromarray(img2)), dtype=np.uint8)
|
img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
|
||||||
|
|
||||||
|
# symmetric
|
||||||
else:
|
else:
|
||||||
image_stack = np.concatenate([img1, img2], axis=0)
|
image_stack = np.concatenate([img1, img2], axis=0)
|
||||||
image_stack = np.array(self.augcolor(Image.fromarray(image_stack)), dtype=np.uint8)
|
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
||||||
img1, img2 = np.split(image_stack, 2, axis=0)
|
img1, img2 = np.split(image_stack, 2, axis=0)
|
||||||
|
|
||||||
return img1, img2
|
return img1, img2
|
||||||
|
|
||||||
def eraser_transform(self, img1, img2, bounds=[50, 100]):
|
def eraser_transform(self, img1, img2, bounds=[50, 100]):
|
||||||
|
""" Occlusion augmentation """
|
||||||
|
|
||||||
ht, wd = img1.shape[:2]
|
ht, wd = img1.shape[:2]
|
||||||
if np.random.rand() < self.eraser_aug_prob:
|
if np.random.rand() < self.eraser_aug_prob:
|
||||||
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
||||||
@ -55,22 +64,18 @@ class FlowAugmentor:
|
|||||||
|
|
||||||
def spatial_transform(self, img1, img2, flow):
|
def spatial_transform(self, img1, img2, flow):
|
||||||
# randomly sample scale
|
# randomly sample scale
|
||||||
|
|
||||||
ht, wd = img1.shape[:2]
|
ht, wd = img1.shape[:2]
|
||||||
min_scale = np.maximum(
|
min_scale = np.maximum(
|
||||||
(self.crop_size[0] + 1) / float(ht),
|
(self.crop_size[0] + 8) / float(ht),
|
||||||
(self.crop_size[1] + 1) / float(wd))
|
(self.crop_size[1] + 8) / float(wd))
|
||||||
|
|
||||||
max_scale = self.max_scale
|
|
||||||
min_scale = max(min_scale, self.min_scale)
|
|
||||||
|
|
||||||
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
||||||
scale_x = scale
|
scale_x = scale
|
||||||
scale_y = scale
|
scale_y = scale
|
||||||
if np.random.rand() < self.stretch_prob:
|
if np.random.rand() < self.stretch_prob:
|
||||||
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
||||||
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
||||||
|
|
||||||
scale_x = np.clip(scale_x, min_scale, None)
|
scale_x = np.clip(scale_x, min_scale, None)
|
||||||
scale_y = np.clip(scale_y, min_scale, None)
|
scale_y = np.clip(scale_y, min_scale, None)
|
||||||
|
|
||||||
@ -81,22 +86,20 @@ class FlowAugmentor:
|
|||||||
flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||||
flow = flow * [scale_x, scale_y]
|
flow = flow * [scale_x, scale_y]
|
||||||
|
|
||||||
if np.random.rand() < 0.5: # h-flip
|
if self.do_flip:
|
||||||
img1 = img1[:, ::-1]
|
if np.random.rand() < self.h_flip_prob: # h-flip
|
||||||
img2 = img2[:, ::-1]
|
img1 = img1[:, ::-1]
|
||||||
flow = flow[:, ::-1] * [-1.0, 1.0]
|
img2 = img2[:, ::-1]
|
||||||
|
flow = flow[:, ::-1] * [-1.0, 1.0]
|
||||||
|
|
||||||
if np.random.rand() < 0.1: # v-flip
|
if np.random.rand() < self.v_flip_prob: # v-flip
|
||||||
img1 = img1[::-1, :]
|
img1 = img1[::-1, :]
|
||||||
img2 = img2[::-1, :]
|
img2 = img2[::-1, :]
|
||||||
flow = flow[::-1, :] * [1.0, -1.0]
|
flow = flow[::-1, :] * [1.0, -1.0]
|
||||||
|
|
||||||
y0 = np.random.randint(-self.margin, img1.shape[0] - self.crop_size[0] + self.margin)
|
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
|
||||||
x0 = np.random.randint(-self.margin, img1.shape[1] - self.crop_size[1] + self.margin)
|
x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
|
||||||
|
|
||||||
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
|
|
||||||
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
|
|
||||||
|
|
||||||
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||||
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||||
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||||
@ -114,22 +117,29 @@ class FlowAugmentor:
|
|||||||
|
|
||||||
return img1, img2, flow
|
return img1, img2, flow
|
||||||
|
|
||||||
|
class SparseFlowAugmentor:
|
||||||
class FlowAugmentorKITTI:
|
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
|
||||||
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5):
|
# spatial augmentation params
|
||||||
self.crop_size = crop_size
|
self.crop_size = crop_size
|
||||||
self.augcolor = torchvision.transforms.ColorJitter(
|
|
||||||
brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
|
|
||||||
|
|
||||||
self.max_scale = max_scale
|
|
||||||
self.min_scale = min_scale
|
self.min_scale = min_scale
|
||||||
|
self.max_scale = max_scale
|
||||||
self.spatial_aug_prob = 0.8
|
self.spatial_aug_prob = 0.8
|
||||||
|
self.stretch_prob = 0.8
|
||||||
|
self.max_stretch = 0.2
|
||||||
|
|
||||||
|
# flip augmentation params
|
||||||
|
self.do_flip = do_flip
|
||||||
|
self.h_flip_prob = 0.5
|
||||||
|
self.v_flip_prob = 0.1
|
||||||
|
|
||||||
|
# photometric augmentation params
|
||||||
|
self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
|
||||||
|
self.asymmetric_color_aug_prob = 0.2
|
||||||
self.eraser_aug_prob = 0.5
|
self.eraser_aug_prob = 0.5
|
||||||
|
|
||||||
def color_transform(self, img1, img2):
|
def color_transform(self, img1, img2):
|
||||||
image_stack = np.concatenate([img1, img2], axis=0)
|
image_stack = np.concatenate([img1, img2], axis=0)
|
||||||
image_stack = np.array(self.augcolor(Image.fromarray(image_stack)), dtype=np.uint8)
|
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
||||||
img1, img2 = np.split(image_stack, 2, axis=0)
|
img1, img2 = np.split(image_stack, 2, axis=0)
|
||||||
return img1, img2
|
return img1, img2
|
||||||
|
|
||||||
@ -198,11 +208,12 @@ class FlowAugmentorKITTI:
|
|||||||
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||||
flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
|
flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
|
||||||
|
|
||||||
if np.random.rand() < 0.5: # h-flip
|
if self.do_flip:
|
||||||
img1 = img1[:, ::-1]
|
if np.random.rand() < 0.5: # h-flip
|
||||||
img2 = img2[:, ::-1]
|
img1 = img1[:, ::-1]
|
||||||
flow = flow[:, ::-1] * [-1.0, 1.0]
|
img2 = img2[:, ::-1]
|
||||||
valid = valid[:, ::-1]
|
flow = flow[:, ::-1] * [-1.0, 1.0]
|
||||||
|
valid = valid[:, ::-1]
|
||||||
|
|
||||||
margin_y = 20
|
margin_y = 20
|
||||||
margin_x = 50
|
margin_x = 50
|
||||||
|
@ -103,6 +103,13 @@ def readFlowKITTI(filename):
|
|||||||
flow = (flow - 2**15) / 64.0
|
flow = (flow - 2**15) / 64.0
|
||||||
return flow, valid
|
return flow, valid
|
||||||
|
|
||||||
|
def readDispKITTI(filename):
|
||||||
|
disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
|
||||||
|
valid = disp > 0.0
|
||||||
|
flow = np.stack([-disp, np.zeros_like(disp)], -1)
|
||||||
|
return flow, valid
|
||||||
|
|
||||||
|
|
||||||
def writeFlowKITTI(filename, uv):
|
def writeFlowKITTI(filename, uv):
|
||||||
uv = 64.0 * uv + 2**15
|
uv = 64.0 * uv + 2**15
|
||||||
valid = np.ones([uv.shape[0], uv.shape[1], 1])
|
valid = np.ones([uv.shape[0], uv.shape[1], 1])
|
||||||
@ -120,5 +127,8 @@ def read_gen(file_name, pil=False):
|
|||||||
return readFlow(file_name).astype(np.float32)
|
return readFlow(file_name).astype(np.float32)
|
||||||
elif ext == '.pfm':
|
elif ext == '.pfm':
|
||||||
flow = readPFM(file_name).astype(np.float32)
|
flow = readPFM(file_name).astype(np.float32)
|
||||||
return flow[:, :, :-1]
|
if len(flow.shape) == 2:
|
||||||
|
return flow
|
||||||
|
else:
|
||||||
|
return flow[:, :, :-1]
|
||||||
return []
|
return []
|
@ -4,21 +4,21 @@ import numpy as np
|
|||||||
from scipy import interpolate
|
from scipy import interpolate
|
||||||
|
|
||||||
|
|
||||||
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
class InputPadder:
|
||||||
""" Wrapper for grid_sample, uses pixel coordinates """
|
""" Pads images such that dimensions are divisible by 8 """
|
||||||
H, W = img.shape[-2:]
|
def __init__(self, dims):
|
||||||
xgrid, ygrid = coords.split([1,1], dim=-1)
|
self.ht, self.wd = dims[-2:]
|
||||||
xgrid = 2*xgrid/(W-1) - 1
|
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
||||||
ygrid = 2*ygrid/(H-1) - 1
|
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
||||||
|
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
||||||
|
|
||||||
grid = torch.cat([xgrid, ygrid], dim=-1)
|
def pad(self, *inputs):
|
||||||
img = F.grid_sample(img, grid, align_corners=True)
|
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
||||||
|
|
||||||
if mask:
|
def unpad(self,x):
|
||||||
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
ht, wd = x.shape[-2:]
|
||||||
return img, mask.float()
|
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
||||||
|
return x[..., c[0]:c[1], c[2]:c[3]]
|
||||||
return img
|
|
||||||
|
|
||||||
def forward_interpolate(flow):
|
def forward_interpolate(flow):
|
||||||
flow = flow.detach().cpu().numpy()
|
flow = flow.detach().cpu().numpy()
|
||||||
@ -42,15 +42,33 @@ def forward_interpolate(flow):
|
|||||||
dy = dy[valid]
|
dy = dy[valid]
|
||||||
|
|
||||||
flow_x = interpolate.griddata(
|
flow_x = interpolate.griddata(
|
||||||
(x1, y1), dx, (x0, y0), method='nearest')
|
(x1, y1), dx, (x0, y0), method='cubic', fill_value=0)
|
||||||
|
|
||||||
flow_y = interpolate.griddata(
|
flow_y = interpolate.griddata(
|
||||||
(x1, y1), dy, (x0, y0), method='nearest')
|
(x1, y1), dy, (x0, y0), method='cubic', fill_value=0)
|
||||||
|
|
||||||
flow = np.stack([flow_x, flow_y], axis=0)
|
flow = np.stack([flow_x, flow_y], axis=0)
|
||||||
return torch.from_numpy(flow).float()
|
return torch.from_numpy(flow).float()
|
||||||
|
|
||||||
|
|
||||||
|
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
||||||
|
""" Wrapper for grid_sample, uses pixel coordinates """
|
||||||
|
H, W = img.shape[-2:]
|
||||||
|
xgrid, ygrid = coords.split([1,1], dim=-1)
|
||||||
|
xgrid = 2*xgrid/(W-1) - 1
|
||||||
|
ygrid = 2*ygrid/(H-1) - 1
|
||||||
|
|
||||||
|
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||||
|
img = F.grid_sample(img, grid, align_corners=True)
|
||||||
|
|
||||||
|
if mask:
|
||||||
|
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||||
|
return img, mask.float()
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def coords_grid(batch, ht, wd):
|
def coords_grid(batch, ht, wd):
|
||||||
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
||||||
coords = torch.stack(coords[::-1], dim=0).float()
|
coords = torch.stack(coords[::-1], dim=0).float()
|
||||||
|
BIN
demo-frames/frame_0016.png
Executable file
After Width: | Height: | Size: 652 KiB |
BIN
demo-frames/frame_0017.png
Executable file
After Width: | Height: | Size: 652 KiB |
BIN
demo-frames/frame_0018.png
Executable file
After Width: | Height: | Size: 652 KiB |
BIN
demo-frames/frame_0019.png
Executable file
After Width: | Height: | Size: 653 KiB |
BIN
demo-frames/frame_0020.png
Executable file
After Width: | Height: | Size: 655 KiB |
BIN
demo-frames/frame_0021.png
Executable file
After Width: | Height: | Size: 657 KiB |
BIN
demo-frames/frame_0022.png
Executable file
After Width: | Height: | Size: 658 KiB |
BIN
demo-frames/frame_0023.png
Executable file
After Width: | Height: | Size: 659 KiB |
BIN
demo-frames/frame_0024.png
Executable file
After Width: | Height: | Size: 660 KiB |
BIN
demo-frames/frame_0025.png
Executable file
After Width: | Height: | Size: 660 KiB |
89
demo.py
@ -4,87 +4,76 @@ sys.path.append('core')
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
|
import glob
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import datasets
|
|
||||||
from utils import flow_viz
|
|
||||||
from raft import RAFT
|
from raft import RAFT
|
||||||
|
from utils import flow_viz
|
||||||
|
from utils.utils import InputPadder
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DEVICE = 'cuda'
|
DEVICE = 'cuda'
|
||||||
|
|
||||||
def pad8(img):
|
def load_image(imfile):
|
||||||
"""pad image such that dimensions are divisible by 8"""
|
img = np.array(Image.open(imfile)).astype(np.uint8)
|
||||||
ht, wd = img.shape[2:]
|
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
||||||
pad_ht = (((ht // 8) + 1) * 8 - ht) % 8
|
|
||||||
pad_wd = (((wd // 8) + 1) * 8 - wd) % 8
|
|
||||||
pad_ht1 = [pad_ht//2, pad_ht-pad_ht//2]
|
|
||||||
pad_wd1 = [pad_wd//2, pad_wd-pad_wd//2]
|
|
||||||
|
|
||||||
img = F.pad(img, pad_wd1 + pad_ht1, mode='replicate')
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def load_image(imfile):
|
|
||||||
img = np.array(Image.open(imfile)).astype(np.uint8)[..., :3]
|
|
||||||
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
|
||||||
return pad8(img[None]).to(DEVICE)
|
|
||||||
|
|
||||||
|
def load_image_list(image_files):
|
||||||
|
images = []
|
||||||
|
for imfile in sorted(image_files):
|
||||||
|
images.append(load_image(imfile))
|
||||||
|
|
||||||
|
images = torch.stack(images, dim=0)
|
||||||
|
images = images.to(DEVICE)
|
||||||
|
|
||||||
def display(image1, image2, flow):
|
padder = InputPadder(images.shape)
|
||||||
image1 = image1.permute(1, 2, 0).cpu().numpy() / 255.0
|
return padder.pad(images)[0]
|
||||||
image2 = image2.permute(1, 2, 0).cpu().numpy() / 255.0
|
|
||||||
|
|
||||||
flow = flow.permute(1, 2, 0).cpu().numpy()
|
def viz(img, flo):
|
||||||
flow_image = flow_viz.flow_to_image(flow)
|
img = img[0].permute(1,2,0).cpu().numpy()
|
||||||
flow_image = cv2.resize(flow_image, (image1.shape[1], image1.shape[0]))
|
flo = flo[0].permute(1,2,0).cpu().numpy()
|
||||||
|
|
||||||
|
# map flow to rgb image
|
||||||
|
flo = flow_viz.flow_to_image(flo)
|
||||||
|
img_flo = np.concatenate([img, flo], axis=0)
|
||||||
|
|
||||||
|
cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
|
||||||
cv2.imshow('image1', image1[..., ::-1])
|
|
||||||
cv2.imshow('image2', image2[..., ::-1])
|
|
||||||
cv2.imshow('flow', flow_image[..., ::-1])
|
|
||||||
cv2.waitKey()
|
cv2.waitKey()
|
||||||
|
|
||||||
|
|
||||||
def demo(args):
|
def demo(args):
|
||||||
model = RAFT(args)
|
model = torch.nn.DataParallel(RAFT(args))
|
||||||
model = torch.nn.DataParallel(model)
|
|
||||||
model.load_state_dict(torch.load(args.model))
|
model.load_state_dict(torch.load(args.model))
|
||||||
|
|
||||||
|
model = model.module
|
||||||
model.to(DEVICE)
|
model.to(DEVICE)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
images = glob.glob(os.path.join(args.path, '*.png')) + \
|
||||||
|
glob.glob(os.path.join(args.path, '*.jpg'))
|
||||||
|
|
||||||
# sintel images
|
images = load_image_list(images)
|
||||||
image1 = load_image('images/sintel_0.png')
|
for i in range(images.shape[0]-1):
|
||||||
image2 = load_image('images/sintel_1.png')
|
image1 = images[i,None]
|
||||||
|
image2 = images[i+1,None]
|
||||||
|
|
||||||
flow_predictions = model(image1, image2, iters=args.iters, upsample=False)
|
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
|
||||||
display(image1[0], image2[0], flow_predictions[-1][0])
|
viz(image1, flow_up)
|
||||||
|
|
||||||
# kitti images
|
|
||||||
image1 = load_image('images/kitti_0.png')
|
|
||||||
image2 = load_image('images/kitti_1.png')
|
|
||||||
|
|
||||||
flow_predictions = model(image1, image2, iters=16)
|
|
||||||
display(image1[0], image2[0], flow_predictions[-1][0])
|
|
||||||
|
|
||||||
# davis images
|
|
||||||
image1 = load_image('images/davis_0.jpg')
|
|
||||||
image2 = load_image('images/davis_1.jpg')
|
|
||||||
|
|
||||||
flow_predictions = model(image1, image2, iters=16)
|
|
||||||
display(image1[0], image2[0], flow_predictions[-1][0])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model', help="restore checkpoint")
|
parser.add_argument('--model', help="restore checkpoint")
|
||||||
|
parser.add_argument('--path', help="dataset for evaluation")
|
||||||
parser.add_argument('--small', action='store_true', help='use small model')
|
parser.add_argument('--small', action='store_true', help='use small model')
|
||||||
parser.add_argument('--iters', type=int, default=12)
|
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
demo(args)
|
|
||||||
|
demo(args)
|
||||||
|
3
download_models.sh
Executable file
@ -0,0 +1,3 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
wget https://www.dropbox.com/s/npt24nvhoojdr0n/models.zip
|
||||||
|
unzip models.zip
|
218
evaluate.py
@ -2,7 +2,6 @@ import sys
|
|||||||
sys.path.append('core')
|
sys.path.append('core')
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import cv2
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
@ -13,88 +12,185 @@ import matplotlib.pyplot as plt
|
|||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
from utils import flow_viz
|
from utils import flow_viz
|
||||||
|
from utils import frame_utils
|
||||||
|
|
||||||
from raft import RAFT
|
from raft import RAFT
|
||||||
|
from utils.utils import InputPadder, forward_interpolate
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def validate_sintel(args, model, iters=50):
|
@torch.no_grad()
|
||||||
""" Evaluate trained model on Sintel(train) clean + final passes """
|
def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'):
|
||||||
|
""" Create submission for the Sintel leaderboard """
|
||||||
model.eval()
|
model.eval()
|
||||||
pad = 2
|
|
||||||
|
|
||||||
for dstype in ['clean', 'final']:
|
for dstype in ['clean', 'final']:
|
||||||
val_dataset = datasets.MpiSintel(args, do_augument=False, dstype=dstype)
|
test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype)
|
||||||
|
|
||||||
epe_list = []
|
flow_prev, sequence_prev = None, None
|
||||||
for i in range(len(val_dataset)):
|
for test_id in range(len(test_dataset)):
|
||||||
image1, image2, flow_gt, _ = val_dataset[i]
|
image1, image2, (sequence, frame) = test_dataset[test_id]
|
||||||
image1 = image1[None].cuda()
|
if sequence != sequence_prev:
|
||||||
image2 = image2[None].cuda()
|
flow_prev = None
|
||||||
image1 = F.pad(image1, [0, 0, pad, pad], mode='replicate')
|
|
||||||
image2 = F.pad(image2, [0, 0, pad, pad], mode='replicate')
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
flow_predictions = model.module(image1, image2, iters=iters)
|
|
||||||
flow_pr = flow_predictions[-1][0,:,pad:-pad]
|
|
||||||
|
|
||||||
epe = torch.sum((flow_pr - flow_gt.cuda())**2, dim=0)
|
|
||||||
epe = torch.sqrt(epe).mean()
|
|
||||||
epe_list.append(epe.item())
|
|
||||||
|
|
||||||
print("Validation (%s) EPE: %f" % (dstype, np.mean(epe_list)))
|
|
||||||
|
|
||||||
|
|
||||||
def validate_kitti(args, model, iters=32):
|
|
||||||
""" Evaluate trained model on KITTI (train) """
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
val_dataset = datasets.KITTI(args, do_augument=False, is_val=True, do_pad=True)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
epe_list, out_list = [], []
|
|
||||||
for i in range(len(val_dataset)):
|
|
||||||
image1, image2, flow_gt, valid_gt = val_dataset[i]
|
|
||||||
image1 = image1[None].cuda()
|
|
||||||
image2 = image2[None].cuda()
|
|
||||||
flow_gt = flow_gt.cuda()
|
|
||||||
valid_gt = valid_gt.cuda()
|
|
||||||
|
|
||||||
flow_predictions = model.module(image1, image2, iters=iters)
|
|
||||||
flow_pr = flow_predictions[-1][0]
|
|
||||||
|
|
||||||
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
|
|
||||||
mag = torch.sum(flow_gt**2, dim=0).sqrt()
|
|
||||||
|
|
||||||
epe = epe.view(-1)
|
padder = InputPadder(image1.shape)
|
||||||
mag = mag.view(-1)
|
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
|
||||||
val = valid_gt.view(-1) >= 0.5
|
|
||||||
|
|
||||||
out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
|
flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True)
|
||||||
epe_list.append(epe[val].mean().item())
|
flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
|
||||||
out_list.append(out[val].cpu().numpy())
|
|
||||||
|
if warm_start:
|
||||||
|
flow_prev = forward_interpolate(flow_low[0])[None].cuda()
|
||||||
|
|
||||||
|
output_dir = os.path.join(output_path, dstype, sequence)
|
||||||
|
output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1))
|
||||||
|
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
|
frame_utils.writeFlow(output_file, flow)
|
||||||
|
sequence_prev = sequence
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def create_kitti_submission(model, iters=24, output_path='kitti_submission'):
|
||||||
|
""" Create submission for the Sintel leaderboard """
|
||||||
|
model.eval()
|
||||||
|
test_dataset = datasets.KITTI(split='testing', aug_params=None)
|
||||||
|
|
||||||
|
if not os.path.exists(output_path):
|
||||||
|
os.makedirs(output_path)
|
||||||
|
|
||||||
|
for test_id in range(len(test_dataset)):
|
||||||
|
image1, image2, (frame_id, ) = test_dataset[test_id]
|
||||||
|
padder = InputPadder(image1.shape)
|
||||||
|
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
|
||||||
|
|
||||||
|
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||||
|
flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
|
||||||
|
|
||||||
|
output_filename = os.path.join(output_path, frame_id)
|
||||||
|
frame_utils.writeFlowKITTI(output_filename, flow)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validate_chairs(model, iters=24):
|
||||||
|
""" Perform evaluation on the FlyingChairs (test) split """
|
||||||
|
model.eval()
|
||||||
|
epe_list = []
|
||||||
|
|
||||||
|
val_dataset = datasets.FlyingChairs(split='validation')
|
||||||
|
for val_id in range(len(val_dataset)):
|
||||||
|
image1, image2, flow_gt, _ = val_dataset[val_id]
|
||||||
|
image1 = image1[None].cuda()
|
||||||
|
image2 = image2[None].cuda()
|
||||||
|
|
||||||
|
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||||
|
epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt()
|
||||||
|
epe_list.append(epe.view(-1).numpy())
|
||||||
|
|
||||||
|
epe = np.mean(np.concatenate(epe_list))
|
||||||
|
print("Validation Chairs EPE: %f" % epe)
|
||||||
|
return {'chairs': epe}
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validate_sintel(model, iters=32):
|
||||||
|
""" Peform validation using the Sintel (train) split """
|
||||||
|
model.eval()
|
||||||
|
results = {}
|
||||||
|
for dstype in ['clean', 'final']:
|
||||||
|
val_dataset = datasets.MpiSintel(split='training', dstype=dstype)
|
||||||
|
epe_list = []
|
||||||
|
|
||||||
|
for val_id in range(len(val_dataset)):
|
||||||
|
image1, image2, flow_gt, _ = val_dataset[val_id]
|
||||||
|
image1 = image1[None].cuda()
|
||||||
|
image2 = image2[None].cuda()
|
||||||
|
|
||||||
|
padder = InputPadder(image1.shape)
|
||||||
|
image1, image2 = padder.pad(image1, image2)
|
||||||
|
|
||||||
|
flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||||
|
flow = padder.unpad(flow_pr[0]).cpu()
|
||||||
|
|
||||||
|
epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
|
||||||
|
epe_list.append(epe.view(-1).numpy())
|
||||||
|
|
||||||
|
epe_all = np.concatenate(epe_list)
|
||||||
|
epe = np.mean(epe_all)
|
||||||
|
px1 = np.mean(epe_all<1)
|
||||||
|
px3 = np.mean(epe_all<3)
|
||||||
|
px5 = np.mean(epe_all<5)
|
||||||
|
|
||||||
|
print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
|
||||||
|
results[dstype] = np.mean(epe_list)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validate_kitti(model, iters=24):
|
||||||
|
""" Peform validation using the KITTI-2015 (train) split """
|
||||||
|
model.eval()
|
||||||
|
val_dataset = datasets.KITTI(split='training')
|
||||||
|
|
||||||
|
out_list, epe_list = [], []
|
||||||
|
for val_id in range(len(val_dataset)):
|
||||||
|
image1, image2, flow_gt, valid_gt = val_dataset[val_id]
|
||||||
|
image1 = image1[None].cuda()
|
||||||
|
image2 = image2[None].cuda()
|
||||||
|
|
||||||
|
padder = InputPadder(image1.shape)
|
||||||
|
image1, image2 = padder.pad(image1, image2)
|
||||||
|
|
||||||
|
flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||||
|
flow = padder.unpad(flow_pr[0]).cpu()
|
||||||
|
|
||||||
|
epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
|
||||||
|
mag = torch.sum(flow_gt**2, dim=0).sqrt()
|
||||||
|
|
||||||
|
epe = epe.view(-1)
|
||||||
|
mag = mag.view(-1)
|
||||||
|
val = valid_gt.view(-1) >= 0.5
|
||||||
|
|
||||||
|
out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
|
||||||
|
epe_list.append(epe[val].mean().item())
|
||||||
|
out_list.append(out[val].cpu().numpy())
|
||||||
|
|
||||||
epe_list = np.array(epe_list)
|
epe_list = np.array(epe_list)
|
||||||
out_list = np.concatenate(out_list)
|
out_list = np.concatenate(out_list)
|
||||||
|
|
||||||
|
epe = np.mean(epe_list)
|
||||||
|
f1 = 100 * np.mean(out_list)
|
||||||
|
|
||||||
print("Validation KITTI: %f, %f" % (np.mean(epe_list), 100*np.mean(out_list)))
|
print("Validation KITTI: %f, %f" % (epe, f1))
|
||||||
|
return {'kitti-epe': epe, 'kitti-f1': f1}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model', help="restore checkpoint")
|
parser.add_argument('--model', help="restore checkpoint")
|
||||||
|
parser.add_argument('--dataset', help="dataset for evaluation")
|
||||||
parser.add_argument('--small', action='store_true', help='use small model')
|
parser.add_argument('--small', action='store_true', help='use small model')
|
||||||
parser.add_argument('--sintel_iters', type=int, default=50)
|
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
||||||
parser.add_argument('--kitti_iters', type=int, default=32)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model = RAFT(args)
|
model = torch.nn.DataParallel(RAFT(args))
|
||||||
model = torch.nn.DataParallel(model)
|
|
||||||
model.load_state_dict(torch.load(args.model))
|
model.load_state_dict(torch.load(args.model))
|
||||||
|
|
||||||
model.to('cuda')
|
model.cuda()
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
validate_sintel(args, model, args.sintel_iters)
|
# create_sintel_submission(model.module, warm_start=True)
|
||||||
validate_kitti(args, model, args.kitti_iters)
|
# create_kitti_submission(model.module)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if args.dataset == 'chairs':
|
||||||
|
validate_chairs(model.module)
|
||||||
|
|
||||||
|
elif args.dataset == 'sintel':
|
||||||
|
validate_sintel(model.module)
|
||||||
|
|
||||||
|
elif args.dataset == 'kitti':
|
||||||
|
validate_kitti(model.module)
|
||||||
|
|
||||||
|
|
||||||
|
Before Width: | Height: | Size: 497 KiB |
Before Width: | Height: | Size: 514 KiB |
Before Width: | Height: | Size: 829 KiB |
Before Width: | Height: | Size: 822 KiB |
Before Width: | Height: | Size: 396 KiB |
Before Width: | Height: | Size: 388 KiB |
@ -1,3 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
wget https://www.dropbox.com/s/a2acvmczgzm6f9n/models.zip
|
|
||||||
unzip models.zip
|
|
167
train.py
Executable file → Normal file
@ -16,26 +16,43 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from raft import RAFT
|
from raft import RAFT
|
||||||
from evaluate import validate_sintel, validate_kitti
|
import evaluate
|
||||||
import datasets
|
import datasets
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.cuda.amp import GradScaler
|
||||||
|
except:
|
||||||
|
# dummy GradScaler for PyTorch < 1.6
|
||||||
|
class GradScaler:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
def scale(self, loss):
|
||||||
|
return loss
|
||||||
|
def unscale_(self, optimizer):
|
||||||
|
pass
|
||||||
|
def step(self, optimizer):
|
||||||
|
optimizer.step()
|
||||||
|
def update(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# exclude extremly large displacements
|
# exclude extremly large displacements
|
||||||
MAX_FLOW = 1000
|
MAX_FLOW = 500
|
||||||
SUM_FREQ = 200
|
SUM_FREQ = 100
|
||||||
VAL_FREQ = 5000
|
VAL_FREQ = 5000
|
||||||
|
|
||||||
|
|
||||||
def count_parameters(model):
|
def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
|
||||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
|
|
||||||
def sequence_loss(flow_preds, flow_gt, valid):
|
|
||||||
""" Loss function defined over sequence of flow predictions """
|
""" Loss function defined over sequence of flow predictions """
|
||||||
|
|
||||||
n_predictions = len(flow_preds)
|
n_predictions = len(flow_preds)
|
||||||
flow_loss = 0.0
|
flow_loss = 0.0
|
||||||
|
|
||||||
# exlude invalid pixels and extremely large diplacements
|
# exlude invalid pixels and extremely large diplacements
|
||||||
valid = (valid >= 0.5) & (flow_gt.abs().sum(dim=1) < MAX_FLOW)
|
mag = torch.sum(flow_gt**2, dim=1).sqrt()
|
||||||
|
valid = (valid >= 0.5) & (mag < max_flow)
|
||||||
|
|
||||||
for i in range(n_predictions):
|
for i in range(n_predictions):
|
||||||
i_weight = 0.8**(n_predictions - i - 1)
|
i_weight = 0.8**(n_predictions - i - 1)
|
||||||
@ -54,39 +71,22 @@ def sequence_loss(flow_preds, flow_gt, valid):
|
|||||||
|
|
||||||
return flow_loss, metrics
|
return flow_loss, metrics
|
||||||
|
|
||||||
|
def show_image(img):
|
||||||
|
img = img.permute(1,2,0).cpu().numpy()
|
||||||
|
plt.imshow(img/255.0)
|
||||||
|
plt.show()
|
||||||
|
# cv2.imshow('image', img/255.0)
|
||||||
|
# cv2.waitKey()
|
||||||
|
|
||||||
def fetch_dataloader(args):
|
def count_parameters(model):
|
||||||
""" Create the data loader for the corresponding training set """
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
if args.dataset == 'chairs':
|
|
||||||
train_dataset = datasets.FlyingChairs(args, image_size=args.image_size)
|
|
||||||
|
|
||||||
elif args.dataset == 'things':
|
|
||||||
clean_dataset = datasets.SceneFlow(args, image_size=args.image_size, dstype='frames_cleanpass')
|
|
||||||
final_dataset = datasets.SceneFlow(args, image_size=args.image_size, dstype='frames_finalpass')
|
|
||||||
train_dataset = clean_dataset + final_dataset
|
|
||||||
|
|
||||||
elif args.dataset == 'sintel':
|
|
||||||
clean_dataset = datasets.MpiSintel(args, image_size=args.image_size, dstype='clean')
|
|
||||||
final_dataset = datasets.MpiSintel(args, image_size=args.image_size, dstype='final')
|
|
||||||
train_dataset = clean_dataset + final_dataset
|
|
||||||
|
|
||||||
elif args.dataset == 'kitti':
|
|
||||||
train_dataset = datasets.KITTI(args, image_size=args.image_size, is_val=False)
|
|
||||||
|
|
||||||
gpuargs = {'num_workers': 4, 'drop_last' : True}
|
|
||||||
train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
|
|
||||||
pin_memory=True, shuffle=True, **gpuargs)
|
|
||||||
|
|
||||||
print('Training with %d image pairs' % len(train_dataset))
|
|
||||||
return train_loader
|
|
||||||
|
|
||||||
def fetch_optimizer(args, model):
|
def fetch_optimizer(args, model):
|
||||||
""" Create the optimizer and learning rate scheduler """
|
""" Create the optimizer and learning rate scheduler """
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
|
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
|
||||||
|
|
||||||
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps,
|
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
|
||||||
pct_start=0.2, cycle_momentum=False, anneal_strategy='linear')
|
pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')
|
||||||
|
|
||||||
return optimizer, scheduler
|
return optimizer, scheduler
|
||||||
|
|
||||||
@ -97,17 +97,22 @@ class Logger:
|
|||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
self.total_steps = 0
|
self.total_steps = 0
|
||||||
self.running_loss = {}
|
self.running_loss = {}
|
||||||
|
self.writer = None
|
||||||
|
|
||||||
def _print_training_status(self):
|
def _print_training_status(self):
|
||||||
metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())]
|
metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())]
|
||||||
training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_lr()[0])
|
training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0])
|
||||||
metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
|
metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
|
||||||
|
|
||||||
# print the training status
|
# print the training status
|
||||||
print(training_str + metrics_str)
|
print(training_str + metrics_str)
|
||||||
|
|
||||||
for key in self.running_loss:
|
if self.writer is None:
|
||||||
self.running_loss[key] = 0.0
|
self.writer = SummaryWriter()
|
||||||
|
|
||||||
|
for k in self.running_loss:
|
||||||
|
self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps)
|
||||||
|
self.running_loss[k] = 0.0
|
||||||
|
|
||||||
def push(self, metrics):
|
def push(self, metrics):
|
||||||
self.total_steps += 1
|
self.total_steps += 1
|
||||||
@ -122,56 +127,95 @@ class Logger:
|
|||||||
self._print_training_status()
|
self._print_training_status()
|
||||||
self.running_loss = {}
|
self.running_loss = {}
|
||||||
|
|
||||||
|
def write_dict(self, results):
|
||||||
|
if self.writer is None:
|
||||||
|
self.writer = SummaryWriter()
|
||||||
|
|
||||||
|
for key in results:
|
||||||
|
self.writer.add_scalar(key, results[key], self.total_steps)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.writer.close()
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
|
|
||||||
model = RAFT(args)
|
model = nn.DataParallel(RAFT(args), device_ids=args.gpus)
|
||||||
model = nn.DataParallel(model)
|
|
||||||
print("Parameter Count: %d" % count_parameters(model))
|
print("Parameter Count: %d" % count_parameters(model))
|
||||||
|
|
||||||
if args.restore_ckpt is not None:
|
if args.restore_ckpt is not None:
|
||||||
model.load_state_dict(torch.load(args.restore_ckpt))
|
model.load_state_dict(torch.load(args.restore_ckpt), strict=False)
|
||||||
|
|
||||||
model.cuda()
|
model.cuda()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
if 'chairs' not in args.dataset:
|
if args.stage != 'chairs':
|
||||||
model.module.freeze_bn()
|
model.module.freeze_bn()
|
||||||
|
|
||||||
train_loader = fetch_dataloader(args)
|
train_loader = datasets.fetch_dataloader(args)
|
||||||
optimizer, scheduler = fetch_optimizer(args, model)
|
optimizer, scheduler = fetch_optimizer(args, model)
|
||||||
|
|
||||||
total_steps = 0
|
total_steps = 0
|
||||||
|
scaler = GradScaler(enabled=args.mixed_precision)
|
||||||
logger = Logger(model, scheduler)
|
logger = Logger(model, scheduler)
|
||||||
|
|
||||||
|
VAL_FREQ = 5000
|
||||||
|
add_noise = True
|
||||||
|
|
||||||
should_keep_training = True
|
should_keep_training = True
|
||||||
while should_keep_training:
|
while should_keep_training:
|
||||||
|
|
||||||
for i_batch, data_blob in enumerate(train_loader):
|
for i_batch, data_blob in enumerate(train_loader):
|
||||||
|
optimizer.zero_grad()
|
||||||
image1, image2, flow, valid = [x.cuda() for x in data_blob]
|
image1, image2, flow, valid = [x.cuda() for x in data_blob]
|
||||||
|
|
||||||
optimizer.zero_grad()
|
# show_image(image1[0])
|
||||||
flow_predictions = model(image1, image2, iters=args.iters)
|
# show_image(image2[0])
|
||||||
|
|
||||||
|
if args.add_noise:
|
||||||
|
stdv = np.random.uniform(0.0, 5.0)
|
||||||
|
image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
|
||||||
|
image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)
|
||||||
|
|
||||||
|
flow_predictions = model(image1, image2, iters=args.iters)
|
||||||
|
|
||||||
loss, metrics = sequence_loss(flow_predictions, flow, valid)
|
loss, metrics = sequence_loss(flow_predictions, flow, valid)
|
||||||
loss.backward()
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
|
||||||
optimizer.step()
|
|
||||||
|
scaler.step(optimizer)
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
total_steps += 1
|
scaler.update()
|
||||||
|
|
||||||
logger.push(metrics)
|
logger.push(metrics)
|
||||||
|
|
||||||
if total_steps % VAL_FREQ == VAL_FREQ-1:
|
if total_steps % VAL_FREQ == VAL_FREQ - 1:
|
||||||
PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name)
|
PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name)
|
||||||
torch.save(model.state_dict(), PATH)
|
torch.save(model.state_dict(), PATH)
|
||||||
|
|
||||||
if total_steps == args.num_steps:
|
results = {}
|
||||||
|
for val_dataset in args.validation:
|
||||||
|
if val_dataset == 'chairs':
|
||||||
|
results.update(evaluate.validate_chairs(model.module))
|
||||||
|
elif val_dataset == 'sintel':
|
||||||
|
results.update(evaluate.validate_sintel(model.module))
|
||||||
|
elif val_dataset == 'kitti':
|
||||||
|
results.update(evaluate.validate_kitti(model.module))
|
||||||
|
|
||||||
|
logger.write_dict(results)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
if args.stage != 'chairs':
|
||||||
|
model.module.freeze_bn()
|
||||||
|
|
||||||
|
total_steps += 1
|
||||||
|
|
||||||
|
if total_steps > args.num_steps:
|
||||||
should_keep_training = False
|
should_keep_training = False
|
||||||
break
|
break
|
||||||
|
|
||||||
|
logger.close()
|
||||||
PATH = 'checkpoints/%s.pth' % args.name
|
PATH = 'checkpoints/%s.pth' % args.name
|
||||||
torch.save(model.state_dict(), PATH)
|
torch.save(model.state_dict(), PATH)
|
||||||
|
|
||||||
@ -180,21 +224,25 @@ def train(args):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--name', default='bla', help="name your experiment")
|
parser.add_argument('--name', default='raft', help="name your experiment")
|
||||||
parser.add_argument('--dataset', help="which dataset to use for training")
|
parser.add_argument('--stage', help="determines which dataset to use for training")
|
||||||
parser.add_argument('--restore_ckpt', help="restore checkpoint")
|
parser.add_argument('--restore_ckpt', help="restore checkpoint")
|
||||||
parser.add_argument('--small', action='store_true', help='use small model')
|
parser.add_argument('--small', action='store_true', help='use small model')
|
||||||
|
parser.add_argument('--validation', type=str, nargs='+')
|
||||||
|
|
||||||
parser.add_argument('--lr', type=float, default=0.00002)
|
parser.add_argument('--lr', type=float, default=0.00002)
|
||||||
parser.add_argument('--num_steps', type=int, default=100000)
|
parser.add_argument('--num_steps', type=int, default=100000)
|
||||||
parser.add_argument('--batch_size', type=int, default=6)
|
parser.add_argument('--batch_size', type=int, default=6)
|
||||||
parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
|
parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
|
||||||
|
parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])
|
||||||
|
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
||||||
|
|
||||||
parser.add_argument('--iters', type=int, default=12)
|
parser.add_argument('--iters', type=int, default=12)
|
||||||
parser.add_argument('--wdecay', type=float, default=.00005)
|
parser.add_argument('--wdecay', type=float, default=.00005)
|
||||||
parser.add_argument('--epsilon', type=float, default=1e-8)
|
parser.add_argument('--epsilon', type=float, default=1e-8)
|
||||||
parser.add_argument('--clip', type=float, default=1.0)
|
parser.add_argument('--clip', type=float, default=1.0)
|
||||||
parser.add_argument('--dropout', type=float, default=0.0)
|
parser.add_argument('--dropout', type=float, default=0.0)
|
||||||
|
parser.add_argument('--add_noise', action='store_true')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
torch.manual_seed(1234)
|
torch.manual_seed(1234)
|
||||||
@ -203,9 +251,4 @@ if __name__ == '__main__':
|
|||||||
if not os.path.isdir('checkpoints'):
|
if not os.path.isdir('checkpoints'):
|
||||||
os.mkdir('checkpoints')
|
os.mkdir('checkpoints')
|
||||||
|
|
||||||
# scale learning rate and batch size by number of GPUs
|
train(args)
|
||||||
num_gpus = torch.cuda.device_count()
|
|
||||||
args.batch_size = args.batch_size * num_gpus
|
|
||||||
args.lr = args.lr * num_gpus
|
|
||||||
|
|
||||||
train(args)
|
|