added upsampling module
107
README.md
@ -1,7 +1,4 @@
|
||||
# RAFT
|
||||
|
||||
**7/22/2020: We have updated our method to predict flow at full resolution leading to improved results on public benchmarks. This repository will be updated to reflect these changes within the next few days.**
|
||||
|
||||
This repository contains the source code for our paper:
|
||||
|
||||
[RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)<br/>
|
||||
@ -11,12 +8,15 @@ Zachary Teed and Jia Deng<br/>
|
||||
<img src="RAFT.png">
|
||||
|
||||
## Requirements
|
||||
Our code was tested using PyTorch 1.3.1 and Python 3. The following additional packages need to be installed
|
||||
|
||||
The code has been tested with PyTorch 1.5.1 and PyTorch Nightly. If you want to train with mixed precision, you will have to install the nightly build.
|
||||
```Shell
|
||||
pip install Pillow
|
||||
pip install scipy
|
||||
pip install opencv-python
|
||||
conda create --name raft
|
||||
conda activate raft
|
||||
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch-nightly
|
||||
conda install matplotlib
|
||||
conda install tensorboard
|
||||
conda install scipy
|
||||
conda install opencv
|
||||
```
|
||||
|
||||
## Demos
|
||||
@ -24,77 +24,56 @@ Pretrained models can be downloaded by running
|
||||
```Shell
|
||||
./scripts/download_models.sh
|
||||
```
|
||||
or downloaded from [google drive](https://drive.google.com/file/d/10-BYgHqRNPGvmNUWr8razjb1xHu55pyA/view?usp=sharing)
|
||||
|
||||
You can run the demos using one of the available models.
|
||||
|
||||
You can demo a trained model on a sequence of frames
|
||||
```Shell
|
||||
python demo.py --model=models/chairs+things.pth
|
||||
python demo.py --model=models/raft-things.pth --path=demo-frames
|
||||
```
|
||||
|
||||
or using the small (1M parameter) model
|
||||
|
||||
```Shell
|
||||
python demo.py --model=models/small.pth --small
|
||||
```
|
||||
## Required Data
|
||||
To evaluate/train RAFT, you will need to download the required datasets.
|
||||
* [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
|
||||
* [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
|
||||
* [Sintel](http://sintel.is.tue.mpg.de/)
|
||||
* [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow)
|
||||
* [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional)
|
||||
|
||||
Running the demos will display the two images and a vizualization of the optical flow estimate. After the images display, press any key to continue.
|
||||
|
||||
## Training
|
||||
To train RAFT, you will need to download the required datasets. The first stage of training requires the [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) and [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) datasets. Finetuning and evaluation require the [Sintel](http://sintel.is.tue.mpg.de/) and [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) datasets. We organize the directory structure as follows. By default `datasets.py` will search for the datasets in these locations
|
||||
By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder
|
||||
|
||||
```Shell
|
||||
├── datasets
|
||||
│ ├── Sintel
|
||||
| | ├── test
|
||||
| | ├── training
|
||||
│ ├── KITTI
|
||||
| | ├── testing
|
||||
| | ├── training
|
||||
| | ├── devkit
|
||||
│ ├── FlyingChairs_release
|
||||
| | ├── data
|
||||
│ ├── FlyingThings3D
|
||||
| | ├── frames_cleanpass
|
||||
| | ├── frames_finalpass
|
||||
| | ├── optical_flow
|
||||
├── Sintel
|
||||
├── test
|
||||
├── training
|
||||
├── KITTI
|
||||
├── testing
|
||||
├── training
|
||||
├── devkit
|
||||
├── FlyingChairs_release
|
||||
├── data
|
||||
├── FlyingThings3D
|
||||
├── frames_cleanpass
|
||||
├── frames_finalpass
|
||||
├── optical_flow
|
||||
```
|
||||
|
||||
We used the following training schedule in our paper (note: we use 2 GPUs for training)
|
||||
|
||||
```Shell
|
||||
python train.py --name=chairs --image_size 368 496 --dataset=chairs --num_steps=100000 --lr=0.0002 --batch_size=6
|
||||
```
|
||||
|
||||
Next, finetune on the FlyingThings dataset
|
||||
|
||||
```Shell
|
||||
python train.py --name=things --image_size 368 768 --dataset=things --num_steps=60000 --lr=0.00005 --batch_size=3 --restore_ckpt=checkpoints/chairs.pth
|
||||
```
|
||||
|
||||
You can perform dataset specific finetuning
|
||||
|
||||
### Sintel
|
||||
|
||||
```Shell
|
||||
python train.py --name=sintel_ft --image_size 368 768 --dataset=sintel --num_steps=60000 --lr=0.00005 --batch_size=4 --restore_ckpt=checkpoints/things.pth
|
||||
```
|
||||
|
||||
### KITTI
|
||||
|
||||
```Shell
|
||||
python train.py --name=kitti_ft --image_size 288 896 --dataset=kitti --num_steps=40000 --lr=0.0001 --batch_size=4 --restore_ckpt=checkpoints/things.pth
|
||||
```
|
||||
|
||||
|
||||
## Evaluation
|
||||
You can evaluate a model on Sintel and KITTI by running
|
||||
|
||||
You can evaluate a trained model using `evaluate.py`
|
||||
```Shell
|
||||
python evaluate.py --model=models/chairs+things.pth
|
||||
python evaluate.py --model=models/raft-things.pth --dataset=sintel
|
||||
```
|
||||
|
||||
or the small model by including the `small` flag
|
||||
|
||||
## Training
|
||||
Training code will be made available in the next few days
|
||||
<!-- We used the following training schedule in our paper (note: we use 2 GPUs for training). Training logs will be written to the `runs` which can be visualized using tensorboard
|
||||
```Shell
|
||||
python evaluate.py --model=models/small.pth --small
|
||||
./train_standard.sh
|
||||
```
|
||||
|
||||
If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU)
|
||||
```Shell
|
||||
./train_mixed.sh
|
||||
``` -->
|
||||
|
22872
chairs_split.txt
Normal file
@ -2,6 +2,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from utils.utils import bilinear_sampler, coords_grid
|
||||
|
||||
|
||||
class CorrBlock:
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
self.num_levels = num_levels
|
||||
@ -12,10 +13,10 @@ class CorrBlock:
|
||||
corr = CorrBlock.corr(fmap1, fmap2)
|
||||
|
||||
batch, h1, w1, dim, h2, w2 = corr.shape
|
||||
corr = corr.view(batch*h1*w1, dim, h2, w2)
|
||||
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
|
||||
|
||||
self.corr_pyramid.append(corr)
|
||||
for i in range(self.num_levels):
|
||||
for i in range(self.num_levels-1):
|
||||
corr = F.avg_pool2d(corr, 2, stride=2)
|
||||
self.corr_pyramid.append(corr)
|
||||
|
||||
@ -40,7 +41,8 @@ class CorrBlock:
|
||||
out_pyramid.append(corr)
|
||||
|
||||
out = torch.cat(out_pyramid, dim=-1)
|
||||
return out.permute(0, 3, 1, 2)
|
||||
return out.permute(0, 3, 1, 2).contiguous().float()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def corr(fmap1, fmap2):
|
||||
@ -51,3 +53,4 @@ class CorrBlock:
|
||||
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
|
||||
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||
|
353
core/datasets.py
@ -6,53 +6,42 @@ import torch.utils.data as data
|
||||
import torch.nn.functional as F
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import math
|
||||
import random
|
||||
from glob import glob
|
||||
import os.path as osp
|
||||
|
||||
from utils import frame_utils
|
||||
from utils.augmentor import FlowAugmentor, FlowAugmentorKITTI
|
||||
from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
|
||||
|
||||
|
||||
class CombinedDataset(data.Dataset):
|
||||
def __init__(self, datasets):
|
||||
self.datasets = datasets
|
||||
|
||||
def __len__(self):
|
||||
length = 0
|
||||
for i in range(len(self.datasets)):
|
||||
length += len(self.datsaets[i])
|
||||
return length
|
||||
|
||||
def __getitem__(self, index):
|
||||
i = 0
|
||||
for j in range(len(self.datasets)):
|
||||
if i + len(self.datasets[j]) >= index:
|
||||
yield self.datasets[j][index-i]
|
||||
break
|
||||
i += len(self.datasets[j])
|
||||
|
||||
def __add__(self, other):
|
||||
self.datasets.append(other)
|
||||
return self
|
||||
|
||||
class FlowDataset(data.Dataset):
|
||||
def __init__(self, args, image_size=None, do_augument=False):
|
||||
self.image_size = image_size
|
||||
self.do_augument = do_augument
|
||||
|
||||
if self.do_augument:
|
||||
self.augumentor = FlowAugmentor(self.image_size)
|
||||
def __init__(self, aug_params=None, sparse=False):
|
||||
self.augmentor = None
|
||||
self.sparse = sparse
|
||||
if aug_params is not None:
|
||||
if sparse:
|
||||
self.augmentor = SparseFlowAugmentor(**aug_params)
|
||||
else:
|
||||
self.augmentor = FlowAugmentor(**aug_params)
|
||||
|
||||
self.is_test = False
|
||||
self.init_seed = False
|
||||
self.flow_list = []
|
||||
self.image_list = []
|
||||
|
||||
self.init_seed = False
|
||||
self.extra_info = []
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
if self.is_test:
|
||||
img1 = frame_utils.read_gen(self.image_list[index][0])
|
||||
img2 = frame_utils.read_gen(self.image_list[index][1])
|
||||
img1 = np.array(img1).astype(np.uint8)[..., :3]
|
||||
img2 = np.array(img2).astype(np.uint8)[..., :3]
|
||||
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
||||
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
||||
return img1, img2, self.extra_info[index]
|
||||
|
||||
if not self.init_seed:
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is not None:
|
||||
@ -62,133 +51,96 @@ class FlowDataset(data.Dataset):
|
||||
self.init_seed = True
|
||||
|
||||
index = index % len(self.image_list)
|
||||
valid = None
|
||||
if self.sparse:
|
||||
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
|
||||
else:
|
||||
flow = frame_utils.read_gen(self.flow_list[index])
|
||||
|
||||
img1 = frame_utils.read_gen(self.image_list[index][0])
|
||||
img2 = frame_utils.read_gen(self.image_list[index][1])
|
||||
|
||||
img1 = np.array(img1).astype(np.uint8)[..., :3]
|
||||
img2 = np.array(img2).astype(np.uint8)[..., :3]
|
||||
flow = np.array(flow).astype(np.float32)
|
||||
img1 = np.array(img1).astype(np.uint8)
|
||||
img2 = np.array(img2).astype(np.uint8)
|
||||
|
||||
if self.do_augument:
|
||||
img1, img2, flow = self.augumentor(img1, img2, flow)
|
||||
# grayscale images
|
||||
if len(img1.shape) == 2:
|
||||
img1 = np.tile(img1[...,None], (1, 1, 3))
|
||||
img2 = np.tile(img2[...,None], (1, 1, 3))
|
||||
else:
|
||||
img1 = img1[..., :3]
|
||||
img2 = img2[..., :3]
|
||||
|
||||
if self.augmentor is not None:
|
||||
if self.sparse:
|
||||
img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
|
||||
else:
|
||||
img1, img2, flow = self.augmentor(img1, img2, flow)
|
||||
|
||||
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
||||
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
||||
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
|
||||
valid = torch.ones_like(flow[0])
|
||||
|
||||
return img1, img2, flow, valid
|
||||
if valid is not None:
|
||||
valid = torch.from_numpy(valid)
|
||||
else:
|
||||
valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
|
||||
|
||||
return img1, img2, flow, valid.float()
|
||||
|
||||
|
||||
def __rmul__(self, v):
|
||||
self.flow_list = v * self.flow_list
|
||||
self.image_list = v * self.image_list
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_list)
|
||||
|
||||
def __add(self, other):
|
||||
return CombinedDataset([self, other])
|
||||
|
||||
|
||||
class MpiSintelTest(FlowDataset):
|
||||
def __init__(self, args, root='datasets/Sintel/test', dstype='clean'):
|
||||
super(MpiSintelTest, self).__init__(args, image_size=None, do_augument=False)
|
||||
|
||||
self.root = root
|
||||
self.dstype = dstype
|
||||
|
||||
image_dir = osp.join(self.root, dstype)
|
||||
all_sequences = os.listdir(image_dir)
|
||||
|
||||
self.image_list = []
|
||||
for sequence in all_sequences:
|
||||
frames = sorted(glob(osp.join(image_dir, sequence, '*.png')))
|
||||
for i in range(len(frames)-1):
|
||||
self.image_list += [[frames[i], frames[i+1], sequence, i]]
|
||||
|
||||
def __getitem__(self, index):
|
||||
img1 = frame_utils.read_gen(self.image_list[index][0])
|
||||
img2 = frame_utils.read_gen(self.image_list[index][1])
|
||||
sequence = self.image_list[index][2]
|
||||
frame = self.image_list[index][3]
|
||||
|
||||
img1 = np.array(img1).astype(np.uint8)[..., :3]
|
||||
img2 = np.array(img2).astype(np.uint8)[..., :3]
|
||||
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
||||
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
||||
return img1, img2, sequence, frame
|
||||
|
||||
|
||||
class MpiSintel(FlowDataset):
|
||||
def __init__(self, args, image_size=None, do_augument=True, root='datasets/Sintel/training', dstype='clean'):
|
||||
super(MpiSintel, self).__init__(args, image_size, do_augument)
|
||||
if do_augument:
|
||||
self.augumentor.min_scale = -0.2
|
||||
self.augumentor.max_scale = 0.7
|
||||
def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
|
||||
super(MpiSintel, self).__init__(aug_params)
|
||||
flow_root = osp.join(root, split, 'flow')
|
||||
image_root = osp.join(root, split, dstype)
|
||||
|
||||
self.root = root
|
||||
self.dstype = dstype
|
||||
if split == 'test':
|
||||
self.is_test = True
|
||||
|
||||
flow_root = osp.join(root, 'flow')
|
||||
image_root = osp.join(root, dstype)
|
||||
for scene in os.listdir(image_root):
|
||||
image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
|
||||
for i in range(len(image_list)-1):
|
||||
self.image_list += [ [image_list[i], image_list[i+1]] ]
|
||||
self.extra_info += [ (scene, i) ] # scene and frame_id
|
||||
|
||||
file_list = sorted(glob(osp.join(flow_root, '*/*.flo')))
|
||||
for flo in file_list:
|
||||
fbase = flo[len(flow_root)+1:]
|
||||
fprefix = fbase[:-8]
|
||||
fnum = int(fbase[-8:-4])
|
||||
|
||||
img1 = osp.join(image_root, fprefix + "%04d"%(fnum+0) + '.png')
|
||||
img2 = osp.join(image_root, fprefix + "%04d"%(fnum+1) + '.png')
|
||||
|
||||
if not osp.isfile(img1) or not osp.isfile(img2) or not osp.isfile(flo):
|
||||
continue
|
||||
|
||||
self.image_list.append((img1, img2))
|
||||
self.flow_list.append(flo)
|
||||
if split != 'test':
|
||||
self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
|
||||
|
||||
|
||||
class FlyingChairs(FlowDataset):
|
||||
def __init__(self, args, image_size=None, do_augument=True, root='datasets/FlyingChairs_release/data'):
|
||||
super(FlyingChairs, self).__init__(args, image_size, do_augument)
|
||||
self.root = root
|
||||
self.augumentor.min_scale = -0.2
|
||||
self.augumentor.max_scale = 1.0
|
||||
def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
|
||||
super(FlyingChairs, self).__init__(aug_params)
|
||||
|
||||
images = sorted(glob(osp.join(root, '*.ppm')))
|
||||
self.flow_list = sorted(glob(osp.join(root, '*.flo')))
|
||||
assert (len(images)//2 == len(self.flow_list))
|
||||
flows = sorted(glob(osp.join(root, '*.flo')))
|
||||
assert (len(images)//2 == len(flows))
|
||||
|
||||
self.image_list = []
|
||||
for i in range(len(self.flow_list)):
|
||||
im1 = images[2*i]
|
||||
im2 = images[2*i + 1]
|
||||
self.image_list.append([im1, im2])
|
||||
split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
|
||||
for i in range(len(flows)):
|
||||
xid = split_list[i]
|
||||
if (split=='training' and xid==1) or (split=='validation' and xid==2):
|
||||
self.flow_list += [ flows[i] ]
|
||||
self.image_list += [ [images[2*i], images[2*i+1]] ]
|
||||
|
||||
|
||||
class SceneFlow(FlowDataset):
|
||||
def __init__(self, args, image_size, do_augument=True, root='datasets',
|
||||
dstype='frames_cleanpass', use_flyingthings=True, use_monkaa=False, use_driving=False):
|
||||
|
||||
super(SceneFlow, self).__init__(args, image_size, do_augument)
|
||||
self.root = root
|
||||
self.dstype = dstype
|
||||
|
||||
self.augumentor.min_scale = -0.2
|
||||
self.augumentor.max_scale = 0.8
|
||||
|
||||
if use_flyingthings:
|
||||
self.add_flyingthings()
|
||||
|
||||
if use_monkaa:
|
||||
self.add_monkaa()
|
||||
|
||||
if use_driving:
|
||||
self.add_driving()
|
||||
|
||||
def add_flyingthings(self):
|
||||
root = osp.join(self.root, 'FlyingThings3D')
|
||||
class FlyingThings3D(FlowDataset):
|
||||
def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
|
||||
super(FlyingThings3D, self).__init__(aug_params)
|
||||
|
||||
for cam in ['left']:
|
||||
for direction in ['into_future', 'into_past']:
|
||||
image_dirs = sorted(glob(osp.join(root, self.dstype, 'TRAIN/*/*')))
|
||||
image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
|
||||
image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
|
||||
|
||||
flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
|
||||
@ -205,108 +157,79 @@ class SceneFlow(FlowDataset):
|
||||
self.image_list += [ [images[i+1], images[i]] ]
|
||||
self.flow_list += [ flows[i+1] ]
|
||||
|
||||
def add_monkaa(self):
|
||||
pass # we don't use monkaa
|
||||
|
||||
def add_driving(self):
|
||||
pass # we don't use driving
|
||||
|
||||
|
||||
class KITTI(FlowDataset):
|
||||
def __init__(self, args, image_size=None, do_augument=True, is_test=False, is_val=False, do_pad=False, split=True, root='datasets/KITTI'):
|
||||
super(KITTI, self).__init__(args, image_size, do_augument)
|
||||
self.root = root
|
||||
self.is_test = is_test
|
||||
self.is_val = is_val
|
||||
self.do_pad = do_pad
|
||||
def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
|
||||
super(KITTI, self).__init__(aug_params, sparse=True)
|
||||
if split == 'testing':
|
||||
self.is_test = True
|
||||
|
||||
if self.do_augument:
|
||||
self.augumentor = FlowAugmentorKITTI(self.image_size, min_scale=-0.2, max_scale=0.5)
|
||||
root = osp.join(root, split)
|
||||
images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
|
||||
images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
|
||||
|
||||
if self.is_test:
|
||||
images1 = sorted(glob(os.path.join(root, 'testing', 'image_2/*_10.png')))
|
||||
images2 = sorted(glob(os.path.join(root, 'testing', 'image_2/*_11.png')))
|
||||
for i in range(len(images1)):
|
||||
self.image_list += [[images1[i], images2[i]]]
|
||||
for img1, img2 in zip(images1, images2):
|
||||
frame_id = img1.split('/')[-1]
|
||||
self.extra_info += [ [frame_id] ]
|
||||
self.image_list += [ [img1, img2] ]
|
||||
|
||||
else:
|
||||
flows = sorted(glob(os.path.join(root, 'training', 'flow_occ/*_10.png')))
|
||||
images1 = sorted(glob(os.path.join(root, 'training', 'image_2/*_10.png')))
|
||||
images2 = sorted(glob(os.path.join(root, 'training', 'image_2/*_11.png')))
|
||||
if split == 'training':
|
||||
self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
|
||||
|
||||
for i in range(len(flows)):
|
||||
|
||||
class HD1K(FlowDataset):
|
||||
def __init__(self, aug_params=None, root='datasets/HD1k'):
|
||||
super(HD1K, self).__init__(aug_params, sparse=True)
|
||||
|
||||
seq_ix = 0
|
||||
while 1:
|
||||
flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
|
||||
images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
|
||||
|
||||
if len(flows) == 0:
|
||||
break
|
||||
|
||||
for i in range(len(flows)-1):
|
||||
self.flow_list += [flows[i]]
|
||||
self.image_list += [[images1[i], images2[i]]]
|
||||
self.image_list += [ [images[i], images[i+1]] ]
|
||||
|
||||
seq_ix += 1
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
|
||||
""" Create the data loader for the corresponding trainign set """
|
||||
|
||||
if self.is_test:
|
||||
frame_id = self.image_list[index][0]
|
||||
frame_id = frame_id.split('/')[-1]
|
||||
if args.stage == 'chairs':
|
||||
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 1.0, 'do_flip': True}
|
||||
train_dataset = FlyingChairs(aug_params, split='training')
|
||||
|
||||
img1 = frame_utils.read_gen(self.image_list[index][0])
|
||||
img2 = frame_utils.read_gen(self.image_list[index][1])
|
||||
elif args.stage == 'things':
|
||||
aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
|
||||
clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
|
||||
final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
|
||||
train_dataset = clean_dataset + final_dataset
|
||||
|
||||
img1 = np.array(img1).astype(np.uint8)[..., :3]
|
||||
img2 = np.array(img2).astype(np.uint8)[..., :3]
|
||||
elif args.stage == 'sintel':
|
||||
aug_params = {'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True}
|
||||
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
|
||||
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
|
||||
sintel_final = MpiSintel(aug_params, split='training', dstype='final')
|
||||
|
||||
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
||||
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
||||
return img1, img2, frame_id
|
||||
if TRAIN_DS == 'C+T+K+S+H':
|
||||
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True})
|
||||
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.5, 'do_flip': True})
|
||||
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
|
||||
|
||||
elif TRAIN_DS == 'C+T+K/S':
|
||||
train_dataset = 100*sintel_clean + 100*sintel_final + things
|
||||
|
||||
else:
|
||||
if not self.init_seed:
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is not None:
|
||||
np.random.seed(worker_info.id)
|
||||
random.seed(worker_info.id)
|
||||
self.init_seed = True
|
||||
elif args.stage == 'kitti':
|
||||
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
|
||||
train_dataset = KITTI(args, image_size=args.image_size, is_val=False)
|
||||
|
||||
index = index % len(self.image_list)
|
||||
frame_id = self.image_list[index][0]
|
||||
frame_id = frame_id.split('/')[-1]
|
||||
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
||||
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
|
||||
|
||||
img1 = frame_utils.read_gen(self.image_list[index][0])
|
||||
img2 = frame_utils.read_gen(self.image_list[index][1])
|
||||
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
|
||||
print('Training with %d image pairs' % len(train_dataset))
|
||||
return train_loader
|
||||
|
||||
img1 = np.array(img1).astype(np.uint8)[..., :3]
|
||||
img2 = np.array(img2).astype(np.uint8)[..., :3]
|
||||
|
||||
if self.do_augument:
|
||||
img1, img2, flow, valid = self.augumentor(img1, img2, flow, valid)
|
||||
|
||||
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
||||
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
||||
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
|
||||
valid = torch.from_numpy(valid).float()
|
||||
|
||||
if self.do_pad:
|
||||
ht, wd = img1.shape[1:]
|
||||
pad_ht = (((ht // 8) + 1) * 8 - ht) % 8
|
||||
pad_wd = (((wd // 8) + 1) * 8 - wd) % 8
|
||||
pad_ht1 = [0, pad_ht]
|
||||
pad_wd1 = [pad_wd//2, pad_wd - pad_wd//2]
|
||||
pad = pad_wd1 + pad_ht1
|
||||
|
||||
img1 = img1.view(1, 3, ht, wd)
|
||||
img2 = img2.view(1, 3, ht, wd)
|
||||
flow = flow.view(1, 2, ht, wd)
|
||||
valid = valid.view(1, 1, ht, wd)
|
||||
|
||||
img1 = torch.nn.functional.pad(img1, pad, mode='replicate')
|
||||
img2 = torch.nn.functional.pad(img2, pad, mode='replicate')
|
||||
flow = torch.nn.functional.pad(flow, pad, mode='constant', value=0)
|
||||
valid = torch.nn.functional.pad(valid, pad, mode='replicate', value=0)
|
||||
|
||||
img1 = img1.view(3, ht+pad_ht, wd+pad_wd)
|
||||
img2 = img2.view(3, ht+pad_ht, wd+pad_wd)
|
||||
flow = flow.view(2, ht+pad_ht, wd+pad_wd)
|
||||
valid = valid.view(ht+pad_ht, wd+pad_wd)
|
||||
|
||||
if self.is_test:
|
||||
return img1, img2, flow, valid, frame_id
|
||||
|
||||
return img1, img2, flow, valid
|
||||
|
@ -143,10 +143,9 @@ class BasicEncoder(nn.Module):
|
||||
# output convolution
|
||||
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
else:
|
||||
self.dropout = None
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
@ -184,7 +183,7 @@ class BasicEncoder(nn.Module):
|
||||
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.dropout is not None:
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
if is_list:
|
||||
@ -218,10 +217,9 @@ class SmallEncoder(nn.Module):
|
||||
self.layer2 = self._make_layer(64, stride=2)
|
||||
self.layer3 = self._make_layer(96, stride=2)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
else:
|
||||
self.dropout = None
|
||||
|
||||
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
||||
|
||||
@ -260,8 +258,8 @@ class SmallEncoder(nn.Module):
|
||||
x = self.layer3(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
# if self.dropout is not None:
|
||||
# x = self.dropout(x)
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
if is_list:
|
||||
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
68
core/raft.py
@ -3,11 +3,23 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modules.update import BasicUpdateBlock, SmallUpdateBlock
|
||||
from modules.extractor import BasicEncoder, SmallEncoder
|
||||
from modules.corr import CorrBlock
|
||||
from update import BasicUpdateBlock, SmallUpdateBlock
|
||||
from extractor import BasicEncoder, SmallEncoder
|
||||
from corr import CorrBlock
|
||||
from utils.utils import bilinear_sampler, coords_grid, upflow8
|
||||
|
||||
try:
|
||||
autocast = torch.cuda.amp.autocast
|
||||
except:
|
||||
# dummy autocast for PyTorch < 1.6
|
||||
class autocast:
|
||||
def __init__(self, enabled):
|
||||
pass
|
||||
def __enter__(self):
|
||||
pass
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
class RAFT(nn.Module):
|
||||
def __init__(self, args):
|
||||
@ -26,7 +38,7 @@ class RAFT(nn.Module):
|
||||
args.corr_levels = 4
|
||||
args.corr_radius = 4
|
||||
|
||||
if not hasattr(args, 'dropout'):
|
||||
if 'dropout' not in args._get_kwargs():
|
||||
args.dropout = 0
|
||||
|
||||
# feature network, context network, and update block
|
||||
@ -40,6 +52,7 @@ class RAFT(nn.Module):
|
||||
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
|
||||
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
||||
|
||||
|
||||
def freeze_bn(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
@ -54,46 +67,73 @@ class RAFT(nn.Module):
|
||||
# optical flow computed as difference: flow = coords1 - coords0
|
||||
return coords0, coords1
|
||||
|
||||
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True):
|
||||
def upsample_flow(self, flow, mask):
|
||||
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
|
||||
N, _, H, W = flow.shape
|
||||
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
||||
mask = torch.softmax(mask, dim=2)
|
||||
|
||||
up_flow = F.unfold(8 * flow, [3,3], padding=1)
|
||||
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
||||
|
||||
up_flow = torch.sum(mask * up_flow, dim=2)
|
||||
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
||||
return up_flow.reshape(N, 2, 8*H, 8*W)
|
||||
|
||||
|
||||
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
|
||||
""" Estimate optical flow between pair of frames """
|
||||
|
||||
image1 = 2 * (image1 / 255.0) - 1.0
|
||||
image2 = 2 * (image2 / 255.0) - 1.0
|
||||
|
||||
image1 = image1.contiguous()
|
||||
image2 = image2.contiguous()
|
||||
|
||||
hdim = self.hidden_dim
|
||||
cdim = self.context_dim
|
||||
|
||||
# run the feature network
|
||||
with autocast(enabled=self.args.mixed_precision):
|
||||
fmap1, fmap2 = self.fnet([image1, image2])
|
||||
|
||||
fmap1 = fmap1.float()
|
||||
fmap2 = fmap2.float()
|
||||
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
|
||||
# run the context network
|
||||
with autocast(enabled=self.args.mixed_precision):
|
||||
cnet = self.cnet(image1)
|
||||
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
|
||||
net, inp = torch.tanh(net), torch.relu(inp)
|
||||
net = torch.tanh(net)
|
||||
inp = torch.relu(inp)
|
||||
|
||||
# if dropout is being used reset mask
|
||||
self.update_block.reset_mask(net, inp)
|
||||
coords0, coords1 = self.initialize_flow(image1)
|
||||
|
||||
if flow_init is not None:
|
||||
coords1 = coords1 + flow_init
|
||||
|
||||
flow_predictions = []
|
||||
for itr in range(iters):
|
||||
coords1 = coords1.detach()
|
||||
corr = corr_fn(coords1) # index correlation volume
|
||||
|
||||
flow = coords1 - coords0
|
||||
net, delta_flow = self.update_block(net, inp, corr, flow)
|
||||
with autocast(enabled=self.args.mixed_precision):
|
||||
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
|
||||
|
||||
# F(t+1) = F(t) + \Delta(t)
|
||||
coords1 = coords1 + delta_flow
|
||||
|
||||
if upsample:
|
||||
# upsample predictions
|
||||
if up_mask is None:
|
||||
flow_up = upflow8(coords1 - coords0)
|
||||
else:
|
||||
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
|
||||
|
||||
flow_predictions.append(flow_up)
|
||||
|
||||
else:
|
||||
flow_predictions.append(coords1 - coords0)
|
||||
if test_mode:
|
||||
return coords1 - coords0, flow_up
|
||||
|
||||
return flow_predictions
|
||||
|
||||
|
||||
|
@ -2,34 +2,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# VariationalHidDropout from https://github.com/locuslab/trellisnet/tree/master/TrellisNet
|
||||
class VariationalHidDropout(nn.Module):
|
||||
def __init__(self, dropout=0.0):
|
||||
"""
|
||||
Hidden-to-hidden (VD-based) dropout that applies the same mask at every time step and every layer of TrellisNet
|
||||
:param dropout: The dropout rate (0 means no dropout is applied)
|
||||
"""
|
||||
super(VariationalHidDropout, self).__init__()
|
||||
self.dropout = dropout
|
||||
self.mask = None
|
||||
|
||||
def reset_mask(self, x):
|
||||
dropout = self.dropout
|
||||
|
||||
# Dimension (N, C, L)
|
||||
n, c, h, w = x.shape
|
||||
m = x.data.new(n, c, 1, 1).bernoulli_(1 - dropout)
|
||||
with torch.no_grad():
|
||||
mask = m / (1 - dropout)
|
||||
self.mask = mask
|
||||
return mask
|
||||
|
||||
def forward(self, x):
|
||||
if not self.training or self.dropout == 0:
|
||||
return x
|
||||
assert self.mask is not None, "You need to reset mask before using VariationalHidDropout"
|
||||
return self.mask * x
|
||||
|
||||
|
||||
class FlowHead(nn.Module):
|
||||
def __init__(self, input_dim=128, hidden_dim=256):
|
||||
@ -41,7 +13,6 @@ class FlowHead(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.conv2(self.relu(self.conv1(x)))
|
||||
|
||||
|
||||
class ConvGRU(nn.Module):
|
||||
def __init__(self, hidden_dim=128, input_dim=192+128):
|
||||
super(ConvGRU, self).__init__()
|
||||
@ -59,7 +30,6 @@ class ConvGRU(nn.Module):
|
||||
h = (1-z) * h + z * q
|
||||
return h
|
||||
|
||||
|
||||
class SepConvGRU(nn.Module):
|
||||
def __init__(self, hidden_dim=128, input_dim=192+128):
|
||||
super(SepConvGRU, self).__init__()
|
||||
@ -133,49 +103,37 @@ class SmallUpdateBlock(nn.Module):
|
||||
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
|
||||
|
||||
self.drop_inp = VariationalHidDropout(dropout=args.dropout)
|
||||
self.drop_net = VariationalHidDropout(dropout=args.dropout)
|
||||
|
||||
def reset_mask(self, net, inp):
|
||||
self.drop_inp.reset_mask(inp)
|
||||
self.drop_net.reset_mask(net)
|
||||
|
||||
def forward(self, net, inp, corr, flow):
|
||||
motion_features = self.encoder(flow, corr)
|
||||
|
||||
if self.training:
|
||||
net = self.drop_net(net)
|
||||
inp = self.drop_inp(inp)
|
||||
|
||||
inp = torch.cat([inp, motion_features], dim=1)
|
||||
net = self.gru(net, inp)
|
||||
delta_flow = self.flow_head(net)
|
||||
|
||||
return net, delta_flow
|
||||
return net, None, delta_flow
|
||||
|
||||
class BasicUpdateBlock(nn.Module):
|
||||
def __init__(self, args, hidden_dim=128, input_dim=128):
|
||||
super(BasicUpdateBlock, self).__init__()
|
||||
self.args = args
|
||||
self.encoder = BasicMotionEncoder(args)
|
||||
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||
|
||||
self.drop_inp = VariationalHidDropout(dropout=args.dropout)
|
||||
self.drop_net = VariationalHidDropout(dropout=args.dropout)
|
||||
self.mask = nn.Sequential(
|
||||
nn.Conv2d(128, 256, 3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 64*9, 1, padding=0))
|
||||
|
||||
def reset_mask(self, net, inp):
|
||||
self.drop_inp.reset_mask(inp)
|
||||
self.drop_net.reset_mask(net)
|
||||
|
||||
def forward(self, net, inp, corr, flow):
|
||||
def forward(self, net, inp, corr, flow, upsample=True):
|
||||
motion_features = self.encoder(flow, corr)
|
||||
|
||||
if self.training:
|
||||
net = self.drop_net(net)
|
||||
inp = self.drop_inp(inp)
|
||||
|
||||
inp = torch.cat([inp, motion_features], dim=1)
|
||||
|
||||
net = self.gru(net, inp)
|
||||
delta_flow = self.flow_head(net)
|
||||
|
||||
return net, delta_flow
|
||||
# scale mask to balence gradients
|
||||
mask = .25 * self.mask(net)
|
||||
return net, mask, delta_flow
|
||||
|
||||
|
||||
|
@ -1,46 +1,55 @@
|
||||
import numpy as np
|
||||
import random
|
||||
import math
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from torchvision.transforms import ColorJitter
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FlowAugmentor:
|
||||
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5):
|
||||
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
|
||||
|
||||
# spatial augmentation params
|
||||
self.crop_size = crop_size
|
||||
self.augcolor = torchvision.transforms.ColorJitter(
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.4,
|
||||
hue=0.5/3.14)
|
||||
|
||||
self.asymmetric_color_aug_prob = 0.2
|
||||
self.spatial_aug_prob = 0.8
|
||||
self.eraser_aug_prob = 0.5
|
||||
|
||||
self.min_scale = min_scale
|
||||
self.max_scale = max_scale
|
||||
self.max_stretch = 0.2
|
||||
self.spatial_aug_prob = 0.8
|
||||
self.stretch_prob = 0.8
|
||||
self.margin = 20
|
||||
self.max_stretch = 0.2
|
||||
|
||||
# flip augmentation params
|
||||
self.do_flip = do_flip
|
||||
self.h_flip_prob = 0.5
|
||||
self.v_flip_prob = 0.1
|
||||
|
||||
# photometric augmentation params
|
||||
self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
|
||||
self.asymmetric_color_aug_prob = 0.2
|
||||
self.eraser_aug_prob = 0.5
|
||||
|
||||
def color_transform(self, img1, img2):
|
||||
""" Photometric augmentation """
|
||||
|
||||
# asymmetric
|
||||
if np.random.rand() < self.asymmetric_color_aug_prob:
|
||||
img1 = np.array(self.augcolor(Image.fromarray(img1)), dtype=np.uint8)
|
||||
img2 = np.array(self.augcolor(Image.fromarray(img2)), dtype=np.uint8)
|
||||
img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
|
||||
img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
|
||||
|
||||
# symmetric
|
||||
else:
|
||||
image_stack = np.concatenate([img1, img2], axis=0)
|
||||
image_stack = np.array(self.augcolor(Image.fromarray(image_stack)), dtype=np.uint8)
|
||||
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
||||
img1, img2 = np.split(image_stack, 2, axis=0)
|
||||
|
||||
return img1, img2
|
||||
|
||||
def eraser_transform(self, img1, img2, bounds=[50, 100]):
|
||||
""" Occlusion augmentation """
|
||||
|
||||
ht, wd = img1.shape[:2]
|
||||
if np.random.rand() < self.eraser_aug_prob:
|
||||
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
||||
@ -55,14 +64,10 @@ class FlowAugmentor:
|
||||
|
||||
def spatial_transform(self, img1, img2, flow):
|
||||
# randomly sample scale
|
||||
|
||||
ht, wd = img1.shape[:2]
|
||||
min_scale = np.maximum(
|
||||
(self.crop_size[0] + 1) / float(ht),
|
||||
(self.crop_size[1] + 1) / float(wd))
|
||||
|
||||
max_scale = self.max_scale
|
||||
min_scale = max(min_scale, self.min_scale)
|
||||
(self.crop_size[0] + 8) / float(ht),
|
||||
(self.crop_size[1] + 8) / float(wd))
|
||||
|
||||
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
||||
scale_x = scale
|
||||
@ -81,21 +86,19 @@ class FlowAugmentor:
|
||||
flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||
flow = flow * [scale_x, scale_y]
|
||||
|
||||
if np.random.rand() < 0.5: # h-flip
|
||||
if self.do_flip:
|
||||
if np.random.rand() < self.h_flip_prob: # h-flip
|
||||
img1 = img1[:, ::-1]
|
||||
img2 = img2[:, ::-1]
|
||||
flow = flow[:, ::-1] * [-1.0, 1.0]
|
||||
|
||||
if np.random.rand() < 0.1: # v-flip
|
||||
if np.random.rand() < self.v_flip_prob: # v-flip
|
||||
img1 = img1[::-1, :]
|
||||
img2 = img2[::-1, :]
|
||||
flow = flow[::-1, :] * [1.0, -1.0]
|
||||
|
||||
y0 = np.random.randint(-self.margin, img1.shape[0] - self.crop_size[0] + self.margin)
|
||||
x0 = np.random.randint(-self.margin, img1.shape[1] - self.crop_size[1] + self.margin)
|
||||
|
||||
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
|
||||
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
|
||||
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
|
||||
x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
|
||||
|
||||
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
@ -114,22 +117,29 @@ class FlowAugmentor:
|
||||
|
||||
return img1, img2, flow
|
||||
|
||||
|
||||
class FlowAugmentorKITTI:
|
||||
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5):
|
||||
class SparseFlowAugmentor:
|
||||
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
|
||||
# spatial augmentation params
|
||||
self.crop_size = crop_size
|
||||
self.augcolor = torchvision.transforms.ColorJitter(
|
||||
brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
|
||||
|
||||
self.max_scale = max_scale
|
||||
self.min_scale = min_scale
|
||||
|
||||
self.max_scale = max_scale
|
||||
self.spatial_aug_prob = 0.8
|
||||
self.stretch_prob = 0.8
|
||||
self.max_stretch = 0.2
|
||||
|
||||
# flip augmentation params
|
||||
self.do_flip = do_flip
|
||||
self.h_flip_prob = 0.5
|
||||
self.v_flip_prob = 0.1
|
||||
|
||||
# photometric augmentation params
|
||||
self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
|
||||
self.asymmetric_color_aug_prob = 0.2
|
||||
self.eraser_aug_prob = 0.5
|
||||
|
||||
def color_transform(self, img1, img2):
|
||||
image_stack = np.concatenate([img1, img2], axis=0)
|
||||
image_stack = np.array(self.augcolor(Image.fromarray(image_stack)), dtype=np.uint8)
|
||||
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
||||
img1, img2 = np.split(image_stack, 2, axis=0)
|
||||
return img1, img2
|
||||
|
||||
@ -198,6 +208,7 @@ class FlowAugmentorKITTI:
|
||||
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||
flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
|
||||
|
||||
if self.do_flip:
|
||||
if np.random.rand() < 0.5: # h-flip
|
||||
img1 = img1[:, ::-1]
|
||||
img2 = img2[:, ::-1]
|
||||
|
@ -103,6 +103,13 @@ def readFlowKITTI(filename):
|
||||
flow = (flow - 2**15) / 64.0
|
||||
return flow, valid
|
||||
|
||||
def readDispKITTI(filename):
|
||||
disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
|
||||
valid = disp > 0.0
|
||||
flow = np.stack([-disp, np.zeros_like(disp)], -1)
|
||||
return flow, valid
|
||||
|
||||
|
||||
def writeFlowKITTI(filename, uv):
|
||||
uv = 64.0 * uv + 2**15
|
||||
valid = np.ones([uv.shape[0], uv.shape[1], 1])
|
||||
@ -120,5 +127,8 @@ def read_gen(file_name, pil=False):
|
||||
return readFlow(file_name).astype(np.float32)
|
||||
elif ext == '.pfm':
|
||||
flow = readPFM(file_name).astype(np.float32)
|
||||
if len(flow.shape) == 2:
|
||||
return flow
|
||||
else:
|
||||
return flow[:, :, :-1]
|
||||
return []
|
@ -4,21 +4,21 @@ import numpy as np
|
||||
from scipy import interpolate
|
||||
|
||||
|
||||
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
||||
""" Wrapper for grid_sample, uses pixel coordinates """
|
||||
H, W = img.shape[-2:]
|
||||
xgrid, ygrid = coords.split([1,1], dim=-1)
|
||||
xgrid = 2*xgrid/(W-1) - 1
|
||||
ygrid = 2*ygrid/(H-1) - 1
|
||||
class InputPadder:
|
||||
""" Pads images such that dimensions are divisible by 8 """
|
||||
def __init__(self, dims):
|
||||
self.ht, self.wd = dims[-2:]
|
||||
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
||||
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
||||
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
||||
|
||||
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||
img = F.grid_sample(img, grid, align_corners=True)
|
||||
def pad(self, *inputs):
|
||||
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
||||
|
||||
if mask:
|
||||
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||
return img, mask.float()
|
||||
|
||||
return img
|
||||
def unpad(self,x):
|
||||
ht, wd = x.shape[-2:]
|
||||
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
||||
return x[..., c[0]:c[1], c[2]:c[3]]
|
||||
|
||||
def forward_interpolate(flow):
|
||||
flow = flow.detach().cpu().numpy()
|
||||
@ -42,15 +42,33 @@ def forward_interpolate(flow):
|
||||
dy = dy[valid]
|
||||
|
||||
flow_x = interpolate.griddata(
|
||||
(x1, y1), dx, (x0, y0), method='nearest')
|
||||
(x1, y1), dx, (x0, y0), method='cubic', fill_value=0)
|
||||
|
||||
flow_y = interpolate.griddata(
|
||||
(x1, y1), dy, (x0, y0), method='nearest')
|
||||
(x1, y1), dy, (x0, y0), method='cubic', fill_value=0)
|
||||
|
||||
flow = np.stack([flow_x, flow_y], axis=0)
|
||||
return torch.from_numpy(flow).float()
|
||||
|
||||
|
||||
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
||||
""" Wrapper for grid_sample, uses pixel coordinates """
|
||||
H, W = img.shape[-2:]
|
||||
xgrid, ygrid = coords.split([1,1], dim=-1)
|
||||
xgrid = 2*xgrid/(W-1) - 1
|
||||
ygrid = 2*ygrid/(H-1) - 1
|
||||
|
||||
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||
img = F.grid_sample(img, grid, align_corners=True)
|
||||
|
||||
if mask:
|
||||
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||
return img, mask.float()
|
||||
|
||||
return img
|
||||
|
||||
|
||||
|
||||
def coords_grid(batch, ht, wd):
|
||||
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
||||
coords = torch.stack(coords[::-1], dim=0).float()
|
||||
|
BIN
demo-frames/frame_0016.png
Executable file
After Width: | Height: | Size: 652 KiB |
BIN
demo-frames/frame_0017.png
Executable file
After Width: | Height: | Size: 652 KiB |
BIN
demo-frames/frame_0018.png
Executable file
After Width: | Height: | Size: 652 KiB |
BIN
demo-frames/frame_0019.png
Executable file
After Width: | Height: | Size: 653 KiB |
BIN
demo-frames/frame_0020.png
Executable file
After Width: | Height: | Size: 655 KiB |
BIN
demo-frames/frame_0021.png
Executable file
After Width: | Height: | Size: 657 KiB |
BIN
demo-frames/frame_0022.png
Executable file
After Width: | Height: | Size: 658 KiB |
BIN
demo-frames/frame_0023.png
Executable file
After Width: | Height: | Size: 659 KiB |
BIN
demo-frames/frame_0024.png
Executable file
After Width: | Height: | Size: 660 KiB |
BIN
demo-frames/frame_0025.png
Executable file
After Width: | Height: | Size: 660 KiB |
87
demo.py
@ -4,87 +4,76 @@ sys.path.append('core')
|
||||
import argparse
|
||||
import os
|
||||
import cv2
|
||||
import glob
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
|
||||
import datasets
|
||||
from utils import flow_viz
|
||||
from raft import RAFT
|
||||
from utils import flow_viz
|
||||
from utils.utils import InputPadder
|
||||
|
||||
|
||||
|
||||
DEVICE = 'cuda'
|
||||
|
||||
def pad8(img):
|
||||
"""pad image such that dimensions are divisible by 8"""
|
||||
ht, wd = img.shape[2:]
|
||||
pad_ht = (((ht // 8) + 1) * 8 - ht) % 8
|
||||
pad_wd = (((wd // 8) + 1) * 8 - wd) % 8
|
||||
pad_ht1 = [pad_ht//2, pad_ht-pad_ht//2]
|
||||
pad_wd1 = [pad_wd//2, pad_wd-pad_wd//2]
|
||||
|
||||
img = F.pad(img, pad_wd1 + pad_ht1, mode='replicate')
|
||||
def load_image(imfile):
|
||||
img = np.array(Image.open(imfile)).astype(np.uint8)
|
||||
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
||||
return img
|
||||
|
||||
def load_image(imfile):
|
||||
img = np.array(Image.open(imfile)).astype(np.uint8)[..., :3]
|
||||
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
||||
return pad8(img[None]).to(DEVICE)
|
||||
|
||||
def load_image_list(image_files):
|
||||
images = []
|
||||
for imfile in sorted(image_files):
|
||||
images.append(load_image(imfile))
|
||||
|
||||
images = torch.stack(images, dim=0)
|
||||
images = images.to(DEVICE)
|
||||
|
||||
padder = InputPadder(images.shape)
|
||||
return padder.pad(images)[0]
|
||||
|
||||
|
||||
def display(image1, image2, flow):
|
||||
image1 = image1.permute(1, 2, 0).cpu().numpy() / 255.0
|
||||
image2 = image2.permute(1, 2, 0).cpu().numpy() / 255.0
|
||||
def viz(img, flo):
|
||||
img = img[0].permute(1,2,0).cpu().numpy()
|
||||
flo = flo[0].permute(1,2,0).cpu().numpy()
|
||||
|
||||
flow = flow.permute(1, 2, 0).cpu().numpy()
|
||||
flow_image = flow_viz.flow_to_image(flow)
|
||||
flow_image = cv2.resize(flow_image, (image1.shape[1], image1.shape[0]))
|
||||
# map flow to rgb image
|
||||
flo = flow_viz.flow_to_image(flo)
|
||||
img_flo = np.concatenate([img, flo], axis=0)
|
||||
|
||||
|
||||
cv2.imshow('image1', image1[..., ::-1])
|
||||
cv2.imshow('image2', image2[..., ::-1])
|
||||
cv2.imshow('flow', flow_image[..., ::-1])
|
||||
cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
|
||||
cv2.waitKey()
|
||||
|
||||
|
||||
def demo(args):
|
||||
model = RAFT(args)
|
||||
model = torch.nn.DataParallel(model)
|
||||
model = torch.nn.DataParallel(RAFT(args))
|
||||
model.load_state_dict(torch.load(args.model))
|
||||
|
||||
model = model.module
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
images = glob.glob(os.path.join(args.path, '*.png')) + \
|
||||
glob.glob(os.path.join(args.path, '*.jpg'))
|
||||
|
||||
# sintel images
|
||||
image1 = load_image('images/sintel_0.png')
|
||||
image2 = load_image('images/sintel_1.png')
|
||||
images = load_image_list(images)
|
||||
for i in range(images.shape[0]-1):
|
||||
image1 = images[i,None]
|
||||
image2 = images[i+1,None]
|
||||
|
||||
flow_predictions = model(image1, image2, iters=args.iters, upsample=False)
|
||||
display(image1[0], image2[0], flow_predictions[-1][0])
|
||||
|
||||
# kitti images
|
||||
image1 = load_image('images/kitti_0.png')
|
||||
image2 = load_image('images/kitti_1.png')
|
||||
|
||||
flow_predictions = model(image1, image2, iters=16)
|
||||
display(image1[0], image2[0], flow_predictions[-1][0])
|
||||
|
||||
# davis images
|
||||
image1 = load_image('images/davis_0.jpg')
|
||||
image2 = load_image('images/davis_1.jpg')
|
||||
|
||||
flow_predictions = model(image1, image2, iters=16)
|
||||
display(image1[0], image2[0], flow_predictions[-1][0])
|
||||
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
|
||||
viz(image1, flow_up)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', help="restore checkpoint")
|
||||
parser.add_argument('--path', help="dataset for evaluation")
|
||||
parser.add_argument('--small', action='store_true', help='use small model')
|
||||
parser.add_argument('--iters', type=int, default=12)
|
||||
|
||||
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
||||
args = parser.parse_args()
|
||||
|
||||
demo(args)
|
3
download_models.sh
Executable file
@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
wget https://www.dropbox.com/s/npt24nvhoojdr0n/models.zip
|
||||
unzip models.zip
|
182
evaluate.py
@ -2,7 +2,6 @@ import sys
|
||||
sys.path.append('core')
|
||||
|
||||
from PIL import Image
|
||||
import cv2
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
@ -13,56 +12,140 @@ import matplotlib.pyplot as plt
|
||||
|
||||
import datasets
|
||||
from utils import flow_viz
|
||||
from utils import frame_utils
|
||||
|
||||
from raft import RAFT
|
||||
from utils.utils import InputPadder, forward_interpolate
|
||||
|
||||
|
||||
|
||||
def validate_sintel(args, model, iters=50):
|
||||
""" Evaluate trained model on Sintel(train) clean + final passes """
|
||||
@torch.no_grad()
|
||||
def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'):
|
||||
""" Create submission for the Sintel leaderboard """
|
||||
model.eval()
|
||||
pad = 2
|
||||
|
||||
for dstype in ['clean', 'final']:
|
||||
val_dataset = datasets.MpiSintel(args, do_augument=False, dstype=dstype)
|
||||
test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype)
|
||||
|
||||
epe_list = []
|
||||
for i in range(len(val_dataset)):
|
||||
image1, image2, flow_gt, _ = val_dataset[i]
|
||||
image1 = image1[None].cuda()
|
||||
image2 = image2[None].cuda()
|
||||
image1 = F.pad(image1, [0, 0, pad, pad], mode='replicate')
|
||||
image2 = F.pad(image2, [0, 0, pad, pad], mode='replicate')
|
||||
flow_prev, sequence_prev = None, None
|
||||
for test_id in range(len(test_dataset)):
|
||||
image1, image2, (sequence, frame) = test_dataset[test_id]
|
||||
if sequence != sequence_prev:
|
||||
flow_prev = None
|
||||
|
||||
with torch.no_grad():
|
||||
flow_predictions = model.module(image1, image2, iters=iters)
|
||||
flow_pr = flow_predictions[-1][0,:,pad:-pad]
|
||||
padder = InputPadder(image1.shape)
|
||||
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
|
||||
|
||||
epe = torch.sum((flow_pr - flow_gt.cuda())**2, dim=0)
|
||||
epe = torch.sqrt(epe).mean()
|
||||
epe_list.append(epe.item())
|
||||
flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True)
|
||||
flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
|
||||
|
||||
print("Validation (%s) EPE: %f" % (dstype, np.mean(epe_list)))
|
||||
if warm_start:
|
||||
flow_prev = forward_interpolate(flow_low[0])[None].cuda()
|
||||
|
||||
output_dir = os.path.join(output_path, dstype, sequence)
|
||||
output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1))
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
frame_utils.writeFlow(output_file, flow)
|
||||
sequence_prev = sequence
|
||||
|
||||
|
||||
def validate_kitti(args, model, iters=32):
|
||||
""" Evaluate trained model on KITTI (train) """
|
||||
|
||||
@torch.no_grad()
|
||||
def create_kitti_submission(model, iters=24, output_path='kitti_submission'):
|
||||
""" Create submission for the Sintel leaderboard """
|
||||
model.eval()
|
||||
val_dataset = datasets.KITTI(args, do_augument=False, is_val=True, do_pad=True)
|
||||
test_dataset = datasets.KITTI(split='testing', aug_params=None)
|
||||
|
||||
with torch.no_grad():
|
||||
epe_list, out_list = [], []
|
||||
for i in range(len(val_dataset)):
|
||||
image1, image2, flow_gt, valid_gt = val_dataset[i]
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
|
||||
for test_id in range(len(test_dataset)):
|
||||
image1, image2, (frame_id, ) = test_dataset[test_id]
|
||||
padder = InputPadder(image1.shape)
|
||||
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
|
||||
|
||||
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||
flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
|
||||
|
||||
output_filename = os.path.join(output_path, frame_id)
|
||||
frame_utils.writeFlowKITTI(output_filename, flow)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_chairs(model, iters=24):
|
||||
""" Perform evaluation on the FlyingChairs (test) split """
|
||||
model.eval()
|
||||
epe_list = []
|
||||
|
||||
val_dataset = datasets.FlyingChairs(split='validation')
|
||||
for val_id in range(len(val_dataset)):
|
||||
image1, image2, flow_gt, _ = val_dataset[val_id]
|
||||
image1 = image1[None].cuda()
|
||||
image2 = image2[None].cuda()
|
||||
flow_gt = flow_gt.cuda()
|
||||
valid_gt = valid_gt.cuda()
|
||||
|
||||
flow_predictions = model.module(image1, image2, iters=iters)
|
||||
flow_pr = flow_predictions[-1][0]
|
||||
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||
epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt()
|
||||
epe_list.append(epe.view(-1).numpy())
|
||||
|
||||
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
|
||||
epe = np.mean(np.concatenate(epe_list))
|
||||
print("Validation Chairs EPE: %f" % epe)
|
||||
return {'chairs': epe}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_sintel(model, iters=32):
|
||||
""" Peform validation using the Sintel (train) split """
|
||||
model.eval()
|
||||
results = {}
|
||||
for dstype in ['clean', 'final']:
|
||||
val_dataset = datasets.MpiSintel(split='training', dstype=dstype)
|
||||
epe_list = []
|
||||
|
||||
for val_id in range(len(val_dataset)):
|
||||
image1, image2, flow_gt, _ = val_dataset[val_id]
|
||||
image1 = image1[None].cuda()
|
||||
image2 = image2[None].cuda()
|
||||
|
||||
padder = InputPadder(image1.shape)
|
||||
image1, image2 = padder.pad(image1, image2)
|
||||
|
||||
flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||
flow = padder.unpad(flow_pr[0]).cpu()
|
||||
|
||||
epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
|
||||
epe_list.append(epe.view(-1).numpy())
|
||||
|
||||
epe_all = np.concatenate(epe_list)
|
||||
epe = np.mean(epe_all)
|
||||
px1 = np.mean(epe_all<1)
|
||||
px3 = np.mean(epe_all<3)
|
||||
px5 = np.mean(epe_all<5)
|
||||
|
||||
print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
|
||||
results[dstype] = np.mean(epe_list)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_kitti(model, iters=24):
|
||||
""" Peform validation using the KITTI-2015 (train) split """
|
||||
model.eval()
|
||||
val_dataset = datasets.KITTI(split='training')
|
||||
|
||||
out_list, epe_list = [], []
|
||||
for val_id in range(len(val_dataset)):
|
||||
image1, image2, flow_gt, valid_gt = val_dataset[val_id]
|
||||
image1 = image1[None].cuda()
|
||||
image2 = image2[None].cuda()
|
||||
|
||||
padder = InputPadder(image1.shape)
|
||||
image1, image2 = padder.pad(image1, image2)
|
||||
|
||||
flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||
flow = padder.unpad(flow_pr[0]).cpu()
|
||||
|
||||
epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
|
||||
mag = torch.sum(flow_gt**2, dim=0).sqrt()
|
||||
|
||||
epe = epe.view(-1)
|
||||
@ -76,25 +159,38 @@ def validate_kitti(args, model, iters=32):
|
||||
epe_list = np.array(epe_list)
|
||||
out_list = np.concatenate(out_list)
|
||||
|
||||
epe = np.mean(epe_list)
|
||||
f1 = 100 * np.mean(out_list)
|
||||
|
||||
print("Validation KITTI: %f, %f" % (np.mean(epe_list), 100*np.mean(out_list)))
|
||||
print("Validation KITTI: %f, %f" % (epe, f1))
|
||||
return {'kitti-epe': epe, 'kitti-f1': f1}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', help="restore checkpoint")
|
||||
parser.add_argument('--dataset', help="dataset for evaluation")
|
||||
parser.add_argument('--small', action='store_true', help='use small model')
|
||||
parser.add_argument('--sintel_iters', type=int, default=50)
|
||||
parser.add_argument('--kitti_iters', type=int, default=32)
|
||||
|
||||
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
||||
args = parser.parse_args()
|
||||
|
||||
model = RAFT(args)
|
||||
model = torch.nn.DataParallel(model)
|
||||
model = torch.nn.DataParallel(RAFT(args))
|
||||
model.load_state_dict(torch.load(args.model))
|
||||
|
||||
model.to('cuda')
|
||||
model.cuda()
|
||||
model.eval()
|
||||
|
||||
validate_sintel(args, model, args.sintel_iters)
|
||||
validate_kitti(args, model, args.kitti_iters)
|
||||
# create_sintel_submission(model.module, warm_start=True)
|
||||
# create_kitti_submission(model.module)
|
||||
|
||||
with torch.no_grad():
|
||||
if args.dataset == 'chairs':
|
||||
validate_chairs(model.module)
|
||||
|
||||
elif args.dataset == 'sintel':
|
||||
validate_sintel(model.module)
|
||||
|
||||
elif args.dataset == 'kitti':
|
||||
validate_kitti(model.module)
|
||||
|
||||
|
||||
|
Before Width: | Height: | Size: 497 KiB |
Before Width: | Height: | Size: 514 KiB |
Before Width: | Height: | Size: 829 KiB |
Before Width: | Height: | Size: 822 KiB |
Before Width: | Height: | Size: 396 KiB |
Before Width: | Height: | Size: 388 KiB |
@ -1,3 +0,0 @@
|
||||
#!/bin/bash
|
||||
wget https://www.dropbox.com/s/a2acvmczgzm6f9n/models.zip
|
||||
unzip models.zip
|
157
train.py
Executable file → Normal file
@ -16,26 +16,43 @@ import torch.nn.functional as F
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from raft import RAFT
|
||||
from evaluate import validate_sintel, validate_kitti
|
||||
import evaluate
|
||||
import datasets
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
try:
|
||||
from torch.cuda.amp import GradScaler
|
||||
except:
|
||||
# dummy GradScaler for PyTorch < 1.6
|
||||
class GradScaler:
|
||||
def __init__(self):
|
||||
pass
|
||||
def scale(self, loss):
|
||||
return loss
|
||||
def unscale_(self, optimizer):
|
||||
pass
|
||||
def step(self, optimizer):
|
||||
optimizer.step()
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
|
||||
# exclude extremly large displacements
|
||||
MAX_FLOW = 1000
|
||||
SUM_FREQ = 200
|
||||
MAX_FLOW = 500
|
||||
SUM_FREQ = 100
|
||||
VAL_FREQ = 5000
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
def sequence_loss(flow_preds, flow_gt, valid):
|
||||
def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
|
||||
""" Loss function defined over sequence of flow predictions """
|
||||
|
||||
n_predictions = len(flow_preds)
|
||||
flow_loss = 0.0
|
||||
|
||||
# exlude invalid pixels and extremely large diplacements
|
||||
valid = (valid >= 0.5) & (flow_gt.abs().sum(dim=1) < MAX_FLOW)
|
||||
mag = torch.sum(flow_gt**2, dim=1).sqrt()
|
||||
valid = (valid >= 0.5) & (mag < max_flow)
|
||||
|
||||
for i in range(n_predictions):
|
||||
i_weight = 0.8**(n_predictions - i - 1)
|
||||
@ -54,39 +71,22 @@ def sequence_loss(flow_preds, flow_gt, valid):
|
||||
|
||||
return flow_loss, metrics
|
||||
|
||||
def show_image(img):
|
||||
img = img.permute(1,2,0).cpu().numpy()
|
||||
plt.imshow(img/255.0)
|
||||
plt.show()
|
||||
# cv2.imshow('image', img/255.0)
|
||||
# cv2.waitKey()
|
||||
|
||||
def fetch_dataloader(args):
|
||||
""" Create the data loader for the corresponding training set """
|
||||
|
||||
if args.dataset == 'chairs':
|
||||
train_dataset = datasets.FlyingChairs(args, image_size=args.image_size)
|
||||
|
||||
elif args.dataset == 'things':
|
||||
clean_dataset = datasets.SceneFlow(args, image_size=args.image_size, dstype='frames_cleanpass')
|
||||
final_dataset = datasets.SceneFlow(args, image_size=args.image_size, dstype='frames_finalpass')
|
||||
train_dataset = clean_dataset + final_dataset
|
||||
|
||||
elif args.dataset == 'sintel':
|
||||
clean_dataset = datasets.MpiSintel(args, image_size=args.image_size, dstype='clean')
|
||||
final_dataset = datasets.MpiSintel(args, image_size=args.image_size, dstype='final')
|
||||
train_dataset = clean_dataset + final_dataset
|
||||
|
||||
elif args.dataset == 'kitti':
|
||||
train_dataset = datasets.KITTI(args, image_size=args.image_size, is_val=False)
|
||||
|
||||
gpuargs = {'num_workers': 4, 'drop_last' : True}
|
||||
train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
|
||||
pin_memory=True, shuffle=True, **gpuargs)
|
||||
|
||||
print('Training with %d image pairs' % len(train_dataset))
|
||||
return train_loader
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
def fetch_optimizer(args, model):
|
||||
""" Create the optimizer and learning rate scheduler """
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
|
||||
|
||||
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps,
|
||||
pct_start=0.2, cycle_momentum=False, anneal_strategy='linear')
|
||||
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
|
||||
pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')
|
||||
|
||||
return optimizer, scheduler
|
||||
|
||||
@ -97,17 +97,22 @@ class Logger:
|
||||
self.scheduler = scheduler
|
||||
self.total_steps = 0
|
||||
self.running_loss = {}
|
||||
self.writer = None
|
||||
|
||||
def _print_training_status(self):
|
||||
metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())]
|
||||
training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_lr()[0])
|
||||
training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0])
|
||||
metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
|
||||
|
||||
# print the training status
|
||||
print(training_str + metrics_str)
|
||||
|
||||
for key in self.running_loss:
|
||||
self.running_loss[key] = 0.0
|
||||
if self.writer is None:
|
||||
self.writer = SummaryWriter()
|
||||
|
||||
for k in self.running_loss:
|
||||
self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps)
|
||||
self.running_loss[k] = 0.0
|
||||
|
||||
def push(self, metrics):
|
||||
self.total_steps += 1
|
||||
@ -122,56 +127,95 @@ class Logger:
|
||||
self._print_training_status()
|
||||
self.running_loss = {}
|
||||
|
||||
def write_dict(self, results):
|
||||
if self.writer is None:
|
||||
self.writer = SummaryWriter()
|
||||
|
||||
for key in results:
|
||||
self.writer.add_scalar(key, results[key], self.total_steps)
|
||||
|
||||
def close(self):
|
||||
self.writer.close()
|
||||
|
||||
|
||||
def train(args):
|
||||
|
||||
model = RAFT(args)
|
||||
model = nn.DataParallel(model)
|
||||
model = nn.DataParallel(RAFT(args), device_ids=args.gpus)
|
||||
print("Parameter Count: %d" % count_parameters(model))
|
||||
|
||||
if args.restore_ckpt is not None:
|
||||
model.load_state_dict(torch.load(args.restore_ckpt))
|
||||
model.load_state_dict(torch.load(args.restore_ckpt), strict=False)
|
||||
|
||||
model.cuda()
|
||||
model.train()
|
||||
|
||||
if 'chairs' not in args.dataset:
|
||||
if args.stage != 'chairs':
|
||||
model.module.freeze_bn()
|
||||
|
||||
train_loader = fetch_dataloader(args)
|
||||
train_loader = datasets.fetch_dataloader(args)
|
||||
optimizer, scheduler = fetch_optimizer(args, model)
|
||||
|
||||
total_steps = 0
|
||||
scaler = GradScaler(enabled=args.mixed_precision)
|
||||
logger = Logger(model, scheduler)
|
||||
|
||||
VAL_FREQ = 5000
|
||||
add_noise = True
|
||||
|
||||
should_keep_training = True
|
||||
while should_keep_training:
|
||||
|
||||
for i_batch, data_blob in enumerate(train_loader):
|
||||
optimizer.zero_grad()
|
||||
image1, image2, flow, valid = [x.cuda() for x in data_blob]
|
||||
|
||||
optimizer.zero_grad()
|
||||
# show_image(image1[0])
|
||||
# show_image(image2[0])
|
||||
|
||||
if args.add_noise:
|
||||
stdv = np.random.uniform(0.0, 5.0)
|
||||
image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
|
||||
image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)
|
||||
|
||||
flow_predictions = model(image1, image2, iters=args.iters)
|
||||
|
||||
loss, metrics = sequence_loss(flow_predictions, flow, valid)
|
||||
loss.backward()
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
total_steps += 1
|
||||
|
||||
scaler.step(optimizer)
|
||||
scheduler.step()
|
||||
scaler.update()
|
||||
logger.push(metrics)
|
||||
|
||||
if total_steps % VAL_FREQ == VAL_FREQ - 1:
|
||||
PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name)
|
||||
torch.save(model.state_dict(), PATH)
|
||||
|
||||
if total_steps == args.num_steps:
|
||||
results = {}
|
||||
for val_dataset in args.validation:
|
||||
if val_dataset == 'chairs':
|
||||
results.update(evaluate.validate_chairs(model.module))
|
||||
elif val_dataset == 'sintel':
|
||||
results.update(evaluate.validate_sintel(model.module))
|
||||
elif val_dataset == 'kitti':
|
||||
results.update(evaluate.validate_kitti(model.module))
|
||||
|
||||
logger.write_dict(results)
|
||||
|
||||
model.train()
|
||||
if args.stage != 'chairs':
|
||||
model.module.freeze_bn()
|
||||
|
||||
total_steps += 1
|
||||
|
||||
if total_steps > args.num_steps:
|
||||
should_keep_training = False
|
||||
break
|
||||
|
||||
|
||||
logger.close()
|
||||
PATH = 'checkpoints/%s.pth' % args.name
|
||||
torch.save(model.state_dict(), PATH)
|
||||
|
||||
@ -180,21 +224,25 @@ def train(args):
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--name', default='bla', help="name your experiment")
|
||||
parser.add_argument('--dataset', help="which dataset to use for training")
|
||||
parser.add_argument('--name', default='raft', help="name your experiment")
|
||||
parser.add_argument('--stage', help="determines which dataset to use for training")
|
||||
parser.add_argument('--restore_ckpt', help="restore checkpoint")
|
||||
parser.add_argument('--small', action='store_true', help='use small model')
|
||||
parser.add_argument('--validation', type=str, nargs='+')
|
||||
|
||||
parser.add_argument('--lr', type=float, default=0.00002)
|
||||
parser.add_argument('--num_steps', type=int, default=100000)
|
||||
parser.add_argument('--batch_size', type=int, default=6)
|
||||
parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
|
||||
parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])
|
||||
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
||||
|
||||
parser.add_argument('--iters', type=int, default=12)
|
||||
parser.add_argument('--wdecay', type=float, default=.00005)
|
||||
parser.add_argument('--epsilon', type=float, default=1e-8)
|
||||
parser.add_argument('--clip', type=float, default=1.0)
|
||||
parser.add_argument('--dropout', type=float, default=0.0)
|
||||
parser.add_argument('--add_noise', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
torch.manual_seed(1234)
|
||||
@ -203,9 +251,4 @@ if __name__ == '__main__':
|
||||
if not os.path.isdir('checkpoints'):
|
||||
os.mkdir('checkpoints')
|
||||
|
||||
# scale learning rate and batch size by number of GPUs
|
||||
num_gpus = torch.cuda.device_count()
|
||||
args.batch_size = args.batch_size * num_gpus
|
||||
args.lr = args.lr * num_gpus
|
||||
|
||||
train(args)
|