fixed problems with variational dropout
This commit is contained in:
parent
dd91321527
commit
3fac6470f4
@ -4,6 +4,8 @@ This repository contains the source code for our paper:
|
|||||||
[RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)<br/>
|
[RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)<br/>
|
||||||
Zachary Teed and Jia Deng<br/>
|
Zachary Teed and Jia Deng<br/>
|
||||||
|
|
||||||
|
<img src="RAFT.png">
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
Our code was tested using PyTorch 1.3.1 and Python 3. The following additional packages need to be installed
|
Our code was tested using PyTorch 1.3.1 and Python 3. The following additional packages need to be installed
|
||||||
|
|
||||||
@ -84,11 +86,11 @@ python train.py --name=kitti_ft --image_size 288 896 --dataset=kitti --num_steps
|
|||||||
You can evaluate a model on Sintel and KITTI by running
|
You can evaluate a model on Sintel and KITTI by running
|
||||||
|
|
||||||
```Shell
|
```Shell
|
||||||
python evaluate.py --model=checkpoints/chairs+things.pth
|
python evaluate.py --model=models/chairs+things.pth
|
||||||
```
|
```
|
||||||
|
|
||||||
or the small model by including the `small` flag
|
or the small model by including the `small` flag
|
||||||
|
|
||||||
```Shell
|
```Shell
|
||||||
python evaluate.py --model=checkpoints/small.pth --small
|
python evaluate.py --model=models/small.pth --small
|
||||||
```
|
```
|
||||||
|
@ -133,8 +133,20 @@ class SmallUpdateBlock(nn.Module):
|
|||||||
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
|
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
|
||||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
|
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
|
||||||
|
|
||||||
|
self.drop_inp = VariationalHidDropout(dropout=args.dropout)
|
||||||
|
self.drop_net = VariationalHidDropout(dropout=args.dropout)
|
||||||
|
|
||||||
|
def reset_mask(self, net, inp):
|
||||||
|
self.drop_inp.reset_mask(inp)
|
||||||
|
self.drop_net.reset_mask(net)
|
||||||
|
|
||||||
def forward(self, net, inp, corr, flow):
|
def forward(self, net, inp, corr, flow):
|
||||||
motion_features = self.encoder(flow, corr)
|
motion_features = self.encoder(flow, corr)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
net = self.drop_net(net)
|
||||||
|
inp = self.drop_inp(inp)
|
||||||
|
|
||||||
inp = torch.cat([inp, motion_features], dim=1)
|
inp = torch.cat([inp, motion_features], dim=1)
|
||||||
net = self.gru(net, inp)
|
net = self.gru(net, inp)
|
||||||
delta_flow = self.flow_head(net)
|
delta_flow = self.flow_head(net)
|
||||||
@ -157,12 +169,12 @@ class BasicUpdateBlock(nn.Module):
|
|||||||
|
|
||||||
def forward(self, net, inp, corr, flow):
|
def forward(self, net, inp, corr, flow):
|
||||||
motion_features = self.encoder(flow, corr)
|
motion_features = self.encoder(flow, corr)
|
||||||
inp = torch.cat([inp, motion_features], dim=1)
|
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
net = self.drop_net(net)
|
net = self.drop_net(net)
|
||||||
inp = self.drop_inp(inp)
|
inp = self.drop_inp(inp)
|
||||||
|
|
||||||
|
inp = torch.cat([inp, motion_features], dim=1)
|
||||||
net = self.gru(net, inp)
|
net = self.gru(net, inp)
|
||||||
delta_flow = self.flow_head(net)
|
delta_flow = self.flow_head(net)
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ class RAFT(nn.Module):
|
|||||||
args.corr_levels = 4
|
args.corr_levels = 4
|
||||||
args.corr_radius = 4
|
args.corr_radius = 4
|
||||||
|
|
||||||
if 'dropout' not in args._get_kwargs():
|
if not hasattr(args, 'dropout'):
|
||||||
args.dropout = 0
|
args.dropout = 0
|
||||||
|
|
||||||
# feature network, context network, and update block
|
# feature network, context network, and update block
|
||||||
|
6
train.py
6
train.py
@ -21,7 +21,7 @@ import datasets
|
|||||||
|
|
||||||
# exclude extremly large displacements
|
# exclude extremly large displacements
|
||||||
MAX_FLOW = 1000
|
MAX_FLOW = 1000
|
||||||
SUM_FREQ = 100
|
SUM_FREQ = 200
|
||||||
VAL_FREQ = 5000
|
VAL_FREQ = 5000
|
||||||
|
|
||||||
|
|
||||||
@ -56,7 +56,7 @@ def sequence_loss(flow_preds, flow_gt, valid):
|
|||||||
|
|
||||||
|
|
||||||
def fetch_dataloader(args):
|
def fetch_dataloader(args):
|
||||||
""" Create the data loader for the corresponding trainign set """
|
""" Create the data loader for the corresponding training set """
|
||||||
|
|
||||||
if args.dataset == 'chairs':
|
if args.dataset == 'chairs':
|
||||||
train_dataset = datasets.FlyingChairs(args, image_size=args.image_size)
|
train_dataset = datasets.FlyingChairs(args, image_size=args.image_size)
|
||||||
@ -86,7 +86,7 @@ def fetch_optimizer(args, model):
|
|||||||
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
|
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
|
||||||
|
|
||||||
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps,
|
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps,
|
||||||
pct_start=0.2, cycle_momentum=False, anneal_strategy='linear', final_div_factor=1.0)
|
pct_start=0.2, cycle_momentum=False, anneal_strategy='linear')
|
||||||
|
|
||||||
return optimizer, scheduler
|
return optimizer, scheduler
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user