fixed problems with variational dropout

This commit is contained in:
Zach Teed 2020-05-25 14:30:45 -04:00
parent dd91321527
commit 3fac6470f4
5 changed files with 22 additions and 8 deletions

BIN
RAFT.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 199 KiB

View File

@ -4,6 +4,8 @@ This repository contains the source code for our paper:
[RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)<br/> [RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)<br/>
Zachary Teed and Jia Deng<br/> Zachary Teed and Jia Deng<br/>
<img src="RAFT.png">
## Requirements ## Requirements
Our code was tested using PyTorch 1.3.1 and Python 3. The following additional packages need to be installed Our code was tested using PyTorch 1.3.1 and Python 3. The following additional packages need to be installed
@ -84,11 +86,11 @@ python train.py --name=kitti_ft --image_size 288 896 --dataset=kitti --num_steps
You can evaluate a model on Sintel and KITTI by running You can evaluate a model on Sintel and KITTI by running
```Shell ```Shell
python evaluate.py --model=checkpoints/chairs+things.pth python evaluate.py --model=models/chairs+things.pth
``` ```
or the small model by including the `small` flag or the small model by including the `small` flag
```Shell ```Shell
python evaluate.py --model=checkpoints/small.pth --small python evaluate.py --model=models/small.pth --small
``` ```

View File

@ -133,8 +133,20 @@ class SmallUpdateBlock(nn.Module):
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128) self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
self.drop_inp = VariationalHidDropout(dropout=args.dropout)
self.drop_net = VariationalHidDropout(dropout=args.dropout)
def reset_mask(self, net, inp):
self.drop_inp.reset_mask(inp)
self.drop_net.reset_mask(net)
def forward(self, net, inp, corr, flow): def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr) motion_features = self.encoder(flow, corr)
if self.training:
net = self.drop_net(net)
inp = self.drop_inp(inp)
inp = torch.cat([inp, motion_features], dim=1) inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp) net = self.gru(net, inp)
delta_flow = self.flow_head(net) delta_flow = self.flow_head(net)
@ -157,12 +169,12 @@ class BasicUpdateBlock(nn.Module):
def forward(self, net, inp, corr, flow): def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr) motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
if self.training: if self.training:
net = self.drop_net(net) net = self.drop_net(net)
inp = self.drop_inp(inp) inp = self.drop_inp(inp)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp) net = self.gru(net, inp)
delta_flow = self.flow_head(net) delta_flow = self.flow_head(net)

View File

@ -26,7 +26,7 @@ class RAFT(nn.Module):
args.corr_levels = 4 args.corr_levels = 4
args.corr_radius = 4 args.corr_radius = 4
if 'dropout' not in args._get_kwargs(): if not hasattr(args, 'dropout'):
args.dropout = 0 args.dropout = 0
# feature network, context network, and update block # feature network, context network, and update block

View File

@ -21,7 +21,7 @@ import datasets
# exclude extremly large displacements # exclude extremly large displacements
MAX_FLOW = 1000 MAX_FLOW = 1000
SUM_FREQ = 100 SUM_FREQ = 200
VAL_FREQ = 5000 VAL_FREQ = 5000
@ -56,7 +56,7 @@ def sequence_loss(flow_preds, flow_gt, valid):
def fetch_dataloader(args): def fetch_dataloader(args):
""" Create the data loader for the corresponding trainign set """ """ Create the data loader for the corresponding training set """
if args.dataset == 'chairs': if args.dataset == 'chairs':
train_dataset = datasets.FlyingChairs(args, image_size=args.image_size) train_dataset = datasets.FlyingChairs(args, image_size=args.image_size)
@ -86,7 +86,7 @@ def fetch_optimizer(args, model):
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon) optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps, scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps,
pct_start=0.2, cycle_momentum=False, anneal_strategy='linear', final_div_factor=1.0) pct_start=0.2, cycle_momentum=False, anneal_strategy='linear')
return optimizer, scheduler return optimizer, scheduler