diff --git a/README.md b/README.md index aad688b..94afe44 100644 --- a/README.md +++ b/README.md @@ -24,21 +24,13 @@ Pretrained models can be downloaded by running ```Shell ./download_models.sh ``` -or downloaded from [google drive](https://drive.google.com/file/d/10-BYgHqRNPGvmNUWr8razjb1xHu55pyA/view?usp=sharing) +or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing) You can demo a trained model on a sequence of frames ```Shell python demo.py --model=models/raft-things.pth --path=demo-frames ``` -## (Optional) Efficent Implementation -You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension -```Shell -cd alt_cuda_corr && python setup.py install && cd .. -``` -and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag.Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass. - - ## Required Data To evaluate/train RAFT, you will need to download the required datasets. * [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) @@ -83,3 +75,10 @@ If you have a RTX GPU, training can be accelerated using mixed precision. You ca ```Shell ./train_mixed.sh ``` + +## (Optional) Efficent Implementation +You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension +```Shell +cd alt_cuda_corr && python setup.py install && cd .. +``` +and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass. diff --git a/download_models.sh b/download_models.sh index 52a5eba..7b6ed7e 100755 --- a/download_models.sh +++ b/download_models.sh @@ -1,3 +1,3 @@ #!/bin/bash -wget https://www.dropbox.com/s/npt24nvhoojdr0n/models.zip +wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip unzip models.zip diff --git a/train.py b/train.py index 1314141..3075730 100644 --- a/train.py +++ b/train.py @@ -44,7 +44,7 @@ SUM_FREQ = 100 VAL_FREQ = 5000 -def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW): +def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW): """ Loss function defined over sequence of flow predictions """ n_predictions = len(flow_preds) @@ -55,7 +55,7 @@ def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW): valid = (valid >= 0.5) & (mag < max_flow) for i in range(n_predictions): - i_weight = 0.8**(n_predictions - i - 1) + i_weight = gamma**(n_predictions - i - 1) i_loss = (flow_preds[i] - flow_gt).abs() flow_loss += i_weight * (valid[:, None] * i_loss).mean() @@ -71,16 +71,11 @@ def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW): return flow_loss, metrics -def show_image(img): - img = img.permute(1,2,0).cpu().numpy() - plt.imshow(img/255.0) - plt.show() - # cv2.imshow('image', img/255.0) - # cv2.waitKey() def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) + def fetch_optimizer(args, model): """ Create the optimizer and learning rate scheduler """ optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon) @@ -169,9 +164,6 @@ def train(args): optimizer.zero_grad() image1, image2, flow, valid = [x.cuda() for x in data_blob] - # show_image(image1[0]) - # show_image(image2[0]) - if args.add_noise: stdv = np.random.uniform(0.0, 5.0) image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0) @@ -179,7 +171,7 @@ def train(args): flow_predictions = model(image1, image2, iters=args.iters) - loss, metrics = sequence_loss(flow_predictions, flow, valid) + loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) @@ -188,7 +180,6 @@ def train(args): scheduler.step() scaler.update() - logger.push(metrics) if total_steps % VAL_FREQ == VAL_FREQ - 1: @@ -243,6 +234,7 @@ if __name__ == '__main__': parser.add_argument('--epsilon', type=float, default=1e-8) parser.add_argument('--clip', type=float, default=1.0) parser.add_argument('--dropout', type=float, default=0.0) + parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting') parser.add_argument('--add_noise', action='store_true') args = parser.parse_args() diff --git a/train_mixed.sh b/train_mixed.sh index ae92aac..d9b979f 100755 --- a/train_mixed.sh +++ b/train_mixed.sh @@ -2,5 +2,5 @@ mkdir -p checkpoints python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision -python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --mixed_precision -python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --mixed_precision +python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision +python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision diff --git a/train_standard.sh b/train_standard.sh index 19b5809..b487c6d 100755 --- a/train_standard.sh +++ b/train_standard.sh @@ -2,5 +2,5 @@ mkdir -p checkpoints python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 -python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 -python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 \ No newline at end of file +python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 +python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 \ No newline at end of file