Compare commits
10 Commits
c86b3dc8f3
...
3fa0bb0a9c
Author | SHA1 | Date | |
---|---|---|---|
|
3fa0bb0a9c | ||
|
aac9dd5472 | ||
|
0d123fda29 | ||
|
e6e53c4e23 | ||
|
224320502d | ||
|
071a9c063c | ||
|
d3f3840186 | ||
|
25eb2ac723 | ||
|
13198c355d | ||
|
01ad964d94 |
23
README.md
23
README.md
@ -12,11 +12,7 @@ The code has been tested with PyTorch 1.6 and Cuda 10.1.
|
||||
```Shell
|
||||
conda create --name raft
|
||||
conda activate raft
|
||||
conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 -c pytorch
|
||||
conda install matplotlib
|
||||
conda install tensorboard
|
||||
conda install scipy
|
||||
conda install opencv
|
||||
conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch
|
||||
```
|
||||
|
||||
## Demos
|
||||
@ -24,21 +20,13 @@ Pretrained models can be downloaded by running
|
||||
```Shell
|
||||
./download_models.sh
|
||||
```
|
||||
or downloaded from [google drive](https://drive.google.com/file/d/10-BYgHqRNPGvmNUWr8razjb1xHu55pyA/view?usp=sharing)
|
||||
or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing)
|
||||
|
||||
You can demo a trained model on a sequence of frames
|
||||
```Shell
|
||||
python demo.py --model=models/raft-things.pth --path=demo-frames
|
||||
```
|
||||
|
||||
## (Optional) Efficent Implementation
|
||||
You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
|
||||
```Shell
|
||||
cd alt_cuda_corr && python setup.py install && cd ..
|
||||
```
|
||||
and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag.Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.
|
||||
|
||||
|
||||
## Required Data
|
||||
To evaluate/train RAFT, you will need to download the required datasets.
|
||||
* [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
|
||||
@ -83,3 +71,10 @@ If you have a RTX GPU, training can be accelerated using mixed precision. You ca
|
||||
```Shell
|
||||
./train_mixed.sh
|
||||
```
|
||||
|
||||
## (Optional) Efficent Implementation
|
||||
You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
|
||||
```Shell
|
||||
cd alt_cuda_corr && python setup.py install && cd ..
|
||||
```
|
||||
and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.
|
||||
|
36
core/corr.py
36
core/corr.py
@ -34,9 +34,9 @@ class CorrBlock:
|
||||
out_pyramid = []
|
||||
for i in range(self.num_levels):
|
||||
corr = self.corr_pyramid[i]
|
||||
dx = torch.linspace(-r, r, 2*r+1)
|
||||
dy = torch.linspace(-r, r, 2*r+1)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
|
||||
dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
|
||||
dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
|
||||
|
||||
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
||||
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
|
||||
@ -60,26 +60,6 @@ class CorrBlock:
|
||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||
|
||||
|
||||
class CorrLayer(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, fmap1, fmap2, coords, r):
|
||||
fmap1 = fmap1.contiguous()
|
||||
fmap2 = fmap2.contiguous()
|
||||
coords = coords.contiguous()
|
||||
ctx.save_for_backward(fmap1, fmap2, coords)
|
||||
ctx.r = r
|
||||
corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
|
||||
return corr
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_corr):
|
||||
fmap1, fmap2, coords = ctx.saved_tensors
|
||||
grad_corr = grad_corr.contiguous()
|
||||
fmap1_grad, fmap2_grad, coords_grad = \
|
||||
correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
|
||||
return fmap1_grad, fmap2_grad, coords_grad, None
|
||||
|
||||
|
||||
class AlternateCorrBlock:
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
self.num_levels = num_levels
|
||||
@ -92,20 +72,20 @@ class AlternateCorrBlock:
|
||||
self.pyramid.append((fmap1, fmap2))
|
||||
|
||||
def __call__(self, coords):
|
||||
|
||||
coords = coords.permute(0, 2, 3, 1)
|
||||
B, H, W, _ = coords.shape
|
||||
dim = self.pyramid[0][0].shape[1]
|
||||
|
||||
corr_list = []
|
||||
for i in range(self.num_levels):
|
||||
r = self.radius
|
||||
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
|
||||
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
|
||||
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
|
||||
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
|
||||
corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
|
||||
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
|
||||
corr_list.append(corr.squeeze(1))
|
||||
|
||||
corr = torch.stack(corr_list, dim=1)
|
||||
corr = corr.reshape(B, -1, H, W)
|
||||
return corr / 16.0
|
||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||
|
15
core/raft.py
15
core/raft.py
@ -38,11 +38,11 @@ class RAFT(nn.Module):
|
||||
args.corr_levels = 4
|
||||
args.corr_radius = 4
|
||||
|
||||
if 'dropout' not in args._get_kwargs():
|
||||
args.dropout = 0
|
||||
if 'dropout' not in self.args:
|
||||
self.args.dropout = 0
|
||||
|
||||
if 'alternate_corr' not in args._get_kwargs():
|
||||
args.alternate_corr = False
|
||||
if 'alternate_corr' not in self.args:
|
||||
self.args.alternate_corr = False
|
||||
|
||||
# feature network, context network, and update block
|
||||
if args.small:
|
||||
@ -55,7 +55,6 @@ class RAFT(nn.Module):
|
||||
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
|
||||
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
||||
|
||||
|
||||
def freeze_bn(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
@ -64,8 +63,8 @@ class RAFT(nn.Module):
|
||||
def initialize_flow(self, img):
|
||||
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
|
||||
N, C, H, W = img.shape
|
||||
coords0 = coords_grid(N, H//8, W//8).to(img.device)
|
||||
coords1 = coords_grid(N, H//8, W//8).to(img.device)
|
||||
coords0 = coords_grid(N, H//8, W//8, device=img.device)
|
||||
coords1 = coords_grid(N, H//8, W//8, device=img.device)
|
||||
|
||||
# optical flow computed as difference: flow = coords1 - coords0
|
||||
return coords0, coords1
|
||||
@ -103,7 +102,7 @@ class RAFT(nn.Module):
|
||||
fmap1 = fmap1.float()
|
||||
fmap2 = fmap2.float()
|
||||
if self.args.alternate_corr:
|
||||
corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
else:
|
||||
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
|
||||
|
@ -71,8 +71,8 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
||||
return img
|
||||
|
||||
|
||||
def coords_grid(batch, ht, wd):
|
||||
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
||||
def coords_grid(batch, ht, wd, device):
|
||||
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
|
||||
coords = torch.stack(coords[::-1], dim=0).float()
|
||||
return coords[None].repeat(batch, 1, 1, 1)
|
||||
|
||||
|
29
demo.py
29
demo.py
@ -20,21 +20,9 @@ DEVICE = 'cuda'
|
||||
def load_image(imfile):
|
||||
img = np.array(Image.open(imfile)).astype(np.uint8)
|
||||
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
||||
return img
|
||||
return img[None].to(DEVICE)
|
||||
|
||||
|
||||
def load_image_list(image_files):
|
||||
images = []
|
||||
for imfile in sorted(image_files):
|
||||
images.append(load_image(imfile))
|
||||
|
||||
images = torch.stack(images, dim=0)
|
||||
images = images.to(DEVICE)
|
||||
|
||||
padder = InputPadder(images.shape)
|
||||
return padder.pad(images)[0]
|
||||
|
||||
|
||||
def viz(img, flo):
|
||||
img = img[0].permute(1,2,0).cpu().numpy()
|
||||
flo = flo[0].permute(1,2,0).cpu().numpy()
|
||||
@ -43,6 +31,10 @@ def viz(img, flo):
|
||||
flo = flow_viz.flow_to_image(flo)
|
||||
img_flo = np.concatenate([img, flo], axis=0)
|
||||
|
||||
# import matplotlib.pyplot as plt
|
||||
# plt.imshow(img_flo / 255.0)
|
||||
# plt.show()
|
||||
|
||||
cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
|
||||
cv2.waitKey()
|
||||
|
||||
@ -58,11 +50,14 @@ def demo(args):
|
||||
with torch.no_grad():
|
||||
images = glob.glob(os.path.join(args.path, '*.png')) + \
|
||||
glob.glob(os.path.join(args.path, '*.jpg'))
|
||||
|
||||
images = sorted(images)
|
||||
for imfile1, imfile2 in zip(images[:-1], images[1:]):
|
||||
image1 = load_image(imfile1)
|
||||
image2 = load_image(imfile2)
|
||||
|
||||
images = load_image_list(images)
|
||||
for i in range(images.shape[0]-1):
|
||||
image1 = images[i,None]
|
||||
image2 = images[i+1,None]
|
||||
padder = InputPadder(image1.shape)
|
||||
image1, image2 = padder.pad(image1, image2)
|
||||
|
||||
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
|
||||
viz(image1, flow_up)
|
||||
|
@ -1,3 +1,3 @@
|
||||
#!/bin/bash
|
||||
wget https://www.dropbox.com/s/npt24nvhoojdr0n/models.zip
|
||||
wget https://dl.dropboxusercontent.com/s/4j4z58wuv8o0mfz/models.zip
|
||||
unzip models.zip
|
||||
|
18
train.py
18
train.py
@ -44,7 +44,7 @@ SUM_FREQ = 100
|
||||
VAL_FREQ = 5000
|
||||
|
||||
|
||||
def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
|
||||
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
|
||||
""" Loss function defined over sequence of flow predictions """
|
||||
|
||||
n_predictions = len(flow_preds)
|
||||
@ -55,7 +55,7 @@ def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
|
||||
valid = (valid >= 0.5) & (mag < max_flow)
|
||||
|
||||
for i in range(n_predictions):
|
||||
i_weight = 0.8**(n_predictions - i - 1)
|
||||
i_weight = gamma**(n_predictions - i - 1)
|
||||
i_loss = (flow_preds[i] - flow_gt).abs()
|
||||
flow_loss += i_weight * (valid[:, None] * i_loss).mean()
|
||||
|
||||
@ -71,16 +71,11 @@ def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
|
||||
|
||||
return flow_loss, metrics
|
||||
|
||||
def show_image(img):
|
||||
img = img.permute(1,2,0).cpu().numpy()
|
||||
plt.imshow(img/255.0)
|
||||
plt.show()
|
||||
# cv2.imshow('image', img/255.0)
|
||||
# cv2.waitKey()
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
def fetch_optimizer(args, model):
|
||||
""" Create the optimizer and learning rate scheduler """
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
|
||||
@ -169,9 +164,6 @@ def train(args):
|
||||
optimizer.zero_grad()
|
||||
image1, image2, flow, valid = [x.cuda() for x in data_blob]
|
||||
|
||||
# show_image(image1[0])
|
||||
# show_image(image2[0])
|
||||
|
||||
if args.add_noise:
|
||||
stdv = np.random.uniform(0.0, 5.0)
|
||||
image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
|
||||
@ -179,7 +171,7 @@ def train(args):
|
||||
|
||||
flow_predictions = model(image1, image2, iters=args.iters)
|
||||
|
||||
loss, metrics = sequence_loss(flow_predictions, flow, valid)
|
||||
loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
|
||||
@ -188,7 +180,6 @@ def train(args):
|
||||
scheduler.step()
|
||||
scaler.update()
|
||||
|
||||
|
||||
logger.push(metrics)
|
||||
|
||||
if total_steps % VAL_FREQ == VAL_FREQ - 1:
|
||||
@ -243,6 +234,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--epsilon', type=float, default=1e-8)
|
||||
parser.add_argument('--clip', type=float, default=1.0)
|
||||
parser.add_argument('--dropout', type=float, default=0.0)
|
||||
parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
|
||||
parser.add_argument('--add_noise', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -2,5 +2,5 @@
|
||||
mkdir -p checkpoints
|
||||
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision
|
||||
python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision
|
||||
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --mixed_precision
|
||||
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --mixed_precision
|
||||
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision
|
||||
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
mkdir -p checkpoints
|
||||
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001
|
||||
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 10 --lr 0.0004 --image_size 368 496 --wdecay 0.0001
|
||||
python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001
|
||||
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001
|
||||
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001
|
||||
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma=0.85
|
||||
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85
|
||||
|
Loading…
Reference in New Issue
Block a user