diff --git a/RAFT.png b/RAFT.png new file mode 100644 index 0000000..a387fe2 Binary files /dev/null and b/RAFT.png differ diff --git a/README.md b/README.md index 4dc4038..a7c85af 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ This repository contains the source code for our paper: [RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)
Zachary Teed and Jia Deng
+ + ## Requirements Our code was tested using PyTorch 1.3.1 and Python 3. The following additional packages need to be installed @@ -84,11 +86,11 @@ python train.py --name=kitti_ft --image_size 288 896 --dataset=kitti --num_steps You can evaluate a model on Sintel and KITTI by running ```Shell -python evaluate.py --model=checkpoints/chairs+things.pth +python evaluate.py --model=models/chairs+things.pth ``` or the small model by including the `small` flag ```Shell -python evaluate.py --model=checkpoints/small.pth --small +python evaluate.py --model=models/small.pth --small ``` diff --git a/core/modules/update.py b/core/modules/update.py index d9133dd..a1f362c 100644 --- a/core/modules/update.py +++ b/core/modules/update.py @@ -133,8 +133,20 @@ class SmallUpdateBlock(nn.Module): self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) self.flow_head = FlowHead(hidden_dim, hidden_dim=128) + self.drop_inp = VariationalHidDropout(dropout=args.dropout) + self.drop_net = VariationalHidDropout(dropout=args.dropout) + + def reset_mask(self, net, inp): + self.drop_inp.reset_mask(inp) + self.drop_net.reset_mask(net) + def forward(self, net, inp, corr, flow): motion_features = self.encoder(flow, corr) + + if self.training: + net = self.drop_net(net) + inp = self.drop_inp(inp) + inp = torch.cat([inp, motion_features], dim=1) net = self.gru(net, inp) delta_flow = self.flow_head(net) @@ -157,12 +169,12 @@ class BasicUpdateBlock(nn.Module): def forward(self, net, inp, corr, flow): motion_features = self.encoder(flow, corr) - inp = torch.cat([inp, motion_features], dim=1) if self.training: net = self.drop_net(net) inp = self.drop_inp(inp) - + + inp = torch.cat([inp, motion_features], dim=1) net = self.gru(net, inp) delta_flow = self.flow_head(net) diff --git a/core/raft.py b/core/raft.py index 22a587d..e14a54a 100644 --- a/core/raft.py +++ b/core/raft.py @@ -26,7 +26,7 @@ class RAFT(nn.Module): args.corr_levels = 4 args.corr_radius = 4 - if 'dropout' not in args._get_kwargs(): + if not hasattr(args, 'dropout'): args.dropout = 0 # feature network, context network, and update block diff --git a/train.py b/train.py index 2767acf..a6d75ad 100755 --- a/train.py +++ b/train.py @@ -21,7 +21,7 @@ import datasets # exclude extremly large displacements MAX_FLOW = 1000 -SUM_FREQ = 100 +SUM_FREQ = 200 VAL_FREQ = 5000 @@ -56,7 +56,7 @@ def sequence_loss(flow_preds, flow_gt, valid): def fetch_dataloader(args): - """ Create the data loader for the corresponding trainign set """ + """ Create the data loader for the corresponding training set """ if args.dataset == 'chairs': train_dataset = datasets.FlyingChairs(args, image_size=args.image_size) @@ -86,7 +86,7 @@ def fetch_optimizer(args, model): optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon) scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps, - pct_start=0.2, cycle_momentum=False, anneal_strategy='linear', final_div_factor=1.0) + pct_start=0.2, cycle_momentum=False, anneal_strategy='linear') return optimizer, scheduler