added training code

This commit is contained in:
Zach Teed 2020-07-30 21:25:36 -06:00
parent dc370f877b
commit a1d8344039
5 changed files with 25 additions and 13 deletions

View File

@ -8,11 +8,11 @@ Zachary Teed and Jia Deng<br/>
<img src="RAFT.png"> <img src="RAFT.png">
## Requirements ## Requirements
The code has been tested with PyTorch 1.5.1 and PyTorch Nightly. If you want to train with mixed precision, you will have to install the nightly build. The code has been tested with PyTorch 1.6 and Cuda 10.1.
```Shell ```Shell
conda create --name raft conda create --name raft
conda activate raft conda activate raft
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch-nightly conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 -c pytorch
conda install matplotlib conda install matplotlib
conda install tensorboard conda install tensorboard
conda install scipy conda install scipy
@ -67,8 +67,7 @@ python evaluate.py --model=models/raft-things.pth --dataset=sintel
``` ```
## Training ## Training
Training code will be made available in the next few days We used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard
<!-- We used the following training schedule in our paper (note: we use 2 GPUs for training). Training logs will be written to the `runs` which can be visualized using tensorboard
```Shell ```Shell
./train_standard.sh ./train_standard.sh
``` ```
@ -76,4 +75,4 @@ Training code will be made available in the next few days
If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU) If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU)
```Shell ```Shell
./train_mixed.sh ./train_mixed.sh
``` --> ```

View File

@ -200,7 +200,7 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
""" Create the data loader for the corresponding trainign set """ """ Create the data loader for the corresponding trainign set """
if args.stage == 'chairs': if args.stage == 'chairs':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 1.0, 'do_flip': True} aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
train_dataset = FlyingChairs(aug_params, split='training') train_dataset = FlyingChairs(aug_params, split='training')
elif args.stage == 'things': elif args.stage == 'things':
@ -210,14 +210,14 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
train_dataset = clean_dataset + final_dataset train_dataset = clean_dataset + final_dataset
elif args.stage == 'sintel': elif args.stage == 'sintel':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True} aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
things = FlyingThings3D(aug_params, dstype='frames_cleanpass') things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
sintel_final = MpiSintel(aug_params, split='training', dstype='final') sintel_final = MpiSintel(aug_params, split='training', dstype='final')
if TRAIN_DS == 'C+T+K+S+H': if TRAIN_DS == 'C+T+K+S+H':
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True}) kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.5, 'do_flip': True}) hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
elif TRAIN_DS == 'C+T+K/S': elif TRAIN_DS == 'C+T+K/S':
@ -225,7 +225,7 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
elif args.stage == 'kitti': elif args.stage == 'kitti':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
train_dataset = KITTI(args, image_size=args.image_size, is_val=False) train_dataset = KITTI(aug_params, split='training')
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=False, shuffle=True, num_workers=4, drop_last=True) pin_memory=False, shuffle=True, num_workers=4, drop_last=True)

View File

@ -39,7 +39,7 @@ except:
# exclude extremly large displacements # exclude extremly large displacements
MAX_FLOW = 500 MAX_FLOW = 400
SUM_FREQ = 100 SUM_FREQ = 100
VAL_FREQ = 5000 VAL_FREQ = 5000
@ -181,13 +181,14 @@ def train(args):
loss, metrics = sequence_loss(flow_predictions, flow, valid) loss, metrics = sequence_loss(flow_predictions, flow, valid)
scaler.scale(loss).backward() scaler.scale(loss).backward()
scaler.unscale_(optimizer)
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
scaler.step(optimizer) scaler.step(optimizer)
scheduler.step() scheduler.step()
scaler.update() scaler.update()
logger.push(metrics) logger.push(metrics)
if total_steps % VAL_FREQ == VAL_FREQ - 1: if total_steps % VAL_FREQ == VAL_FREQ - 1:

6
train_mixed.sh Executable file
View File

@ -0,0 +1,6 @@
#!/bin/bash
mkdir -p checkpoints
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision
python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --mixed_precision
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --mixed_precision

6
train_standard.sh Executable file
View File

@ -0,0 +1,6 @@
#!/bin/bash
mkdir -p checkpoints
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001
python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001