diff --git a/README.md b/README.md index 87dd170..330b256 100644 --- a/README.md +++ b/README.md @@ -8,11 +8,11 @@ Zachary Teed and Jia Deng
## Requirements -The code has been tested with PyTorch 1.5.1 and PyTorch Nightly. If you want to train with mixed precision, you will have to install the nightly build. +The code has been tested with PyTorch 1.6 and Cuda 10.1. ```Shell conda create --name raft conda activate raft -conda install pytorch torchvision cudatoolkit=10.1 -c pytorch-nightly +conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 -c pytorch conda install matplotlib conda install tensorboard conda install scipy @@ -67,8 +67,7 @@ python evaluate.py --model=models/raft-things.pth --dataset=sintel ``` ## Training -Training code will be made available in the next few days - +``` diff --git a/core/datasets.py b/core/datasets.py index c5f0a36..3411fda 100644 --- a/core/datasets.py +++ b/core/datasets.py @@ -200,7 +200,7 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): """ Create the data loader for the corresponding trainign set """ if args.stage == 'chairs': - aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 1.0, 'do_flip': True} + aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} train_dataset = FlyingChairs(aug_params, split='training') elif args.stage == 'things': @@ -210,14 +210,14 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): train_dataset = clean_dataset + final_dataset elif args.stage == 'sintel': - aug_params = {'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True} + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} things = FlyingThings3D(aug_params, dstype='frames_cleanpass') sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') sintel_final = MpiSintel(aug_params, split='training', dstype='final') if TRAIN_DS == 'C+T+K+S+H': - kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True}) - hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.5, 'do_flip': True}) + kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) + hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things elif TRAIN_DS == 'C+T+K/S': @@ -225,7 +225,7 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): elif args.stage == 'kitti': aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} - train_dataset = KITTI(args, image_size=args.image_size, is_val=False) + train_dataset = KITTI(aug_params, split='training') train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=False, shuffle=True, num_workers=4, drop_last=True) diff --git a/train.py b/train.py index c2c58c5..1314141 100644 --- a/train.py +++ b/train.py @@ -39,7 +39,7 @@ except: # exclude extremly large displacements -MAX_FLOW = 500 +MAX_FLOW = 400 SUM_FREQ = 100 VAL_FREQ = 5000 @@ -181,13 +181,14 @@ def train(args): loss, metrics = sequence_loss(flow_predictions, flow, valid) scaler.scale(loss).backward() - - scaler.unscale_(optimizer) + scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) scaler.step(optimizer) scheduler.step() scaler.update() + + logger.push(metrics) if total_steps % VAL_FREQ == VAL_FREQ - 1: diff --git a/train_mixed.sh b/train_mixed.sh new file mode 100755 index 0000000..ae92aac --- /dev/null +++ b/train_mixed.sh @@ -0,0 +1,6 @@ +#!/bin/bash +mkdir -p checkpoints +python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision +python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision +python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --mixed_precision +python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --mixed_precision diff --git a/train_standard.sh b/train_standard.sh new file mode 100755 index 0000000..19b5809 --- /dev/null +++ b/train_standard.sh @@ -0,0 +1,6 @@ +#!/bin/bash +mkdir -p checkpoints +python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 +python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 +python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 +python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 \ No newline at end of file