added training code
This commit is contained in:
parent
dc370f877b
commit
a1d8344039
@ -8,11 +8,11 @@ Zachary Teed and Jia Deng<br/>
|
|||||||
<img src="RAFT.png">
|
<img src="RAFT.png">
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
The code has been tested with PyTorch 1.5.1 and PyTorch Nightly. If you want to train with mixed precision, you will have to install the nightly build.
|
The code has been tested with PyTorch 1.6 and Cuda 10.1.
|
||||||
```Shell
|
```Shell
|
||||||
conda create --name raft
|
conda create --name raft
|
||||||
conda activate raft
|
conda activate raft
|
||||||
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch-nightly
|
conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 -c pytorch
|
||||||
conda install matplotlib
|
conda install matplotlib
|
||||||
conda install tensorboard
|
conda install tensorboard
|
||||||
conda install scipy
|
conda install scipy
|
||||||
@ -67,8 +67,7 @@ python evaluate.py --model=models/raft-things.pth --dataset=sintel
|
|||||||
```
|
```
|
||||||
|
|
||||||
## Training
|
## Training
|
||||||
Training code will be made available in the next few days
|
We used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard
|
||||||
<!-- We used the following training schedule in our paper (note: we use 2 GPUs for training). Training logs will be written to the `runs` which can be visualized using tensorboard
|
|
||||||
```Shell
|
```Shell
|
||||||
./train_standard.sh
|
./train_standard.sh
|
||||||
```
|
```
|
||||||
@ -76,4 +75,4 @@ Training code will be made available in the next few days
|
|||||||
If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU)
|
If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU)
|
||||||
```Shell
|
```Shell
|
||||||
./train_mixed.sh
|
./train_mixed.sh
|
||||||
``` -->
|
```
|
||||||
|
@ -200,7 +200,7 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
|
|||||||
""" Create the data loader for the corresponding trainign set """
|
""" Create the data loader for the corresponding trainign set """
|
||||||
|
|
||||||
if args.stage == 'chairs':
|
if args.stage == 'chairs':
|
||||||
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 1.0, 'do_flip': True}
|
aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
|
||||||
train_dataset = FlyingChairs(aug_params, split='training')
|
train_dataset = FlyingChairs(aug_params, split='training')
|
||||||
|
|
||||||
elif args.stage == 'things':
|
elif args.stage == 'things':
|
||||||
@ -210,14 +210,14 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
|
|||||||
train_dataset = clean_dataset + final_dataset
|
train_dataset = clean_dataset + final_dataset
|
||||||
|
|
||||||
elif args.stage == 'sintel':
|
elif args.stage == 'sintel':
|
||||||
aug_params = {'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True}
|
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
|
||||||
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
|
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
|
||||||
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
|
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
|
||||||
sintel_final = MpiSintel(aug_params, split='training', dstype='final')
|
sintel_final = MpiSintel(aug_params, split='training', dstype='final')
|
||||||
|
|
||||||
if TRAIN_DS == 'C+T+K+S+H':
|
if TRAIN_DS == 'C+T+K+S+H':
|
||||||
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True})
|
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
|
||||||
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.5, 'do_flip': True})
|
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
|
||||||
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
|
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
|
||||||
|
|
||||||
elif TRAIN_DS == 'C+T+K/S':
|
elif TRAIN_DS == 'C+T+K/S':
|
||||||
@ -225,7 +225,7 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
|
|||||||
|
|
||||||
elif args.stage == 'kitti':
|
elif args.stage == 'kitti':
|
||||||
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
|
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
|
||||||
train_dataset = KITTI(args, image_size=args.image_size, is_val=False)
|
train_dataset = KITTI(aug_params, split='training')
|
||||||
|
|
||||||
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
||||||
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
|
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
|
||||||
|
7
train.py
7
train.py
@ -39,7 +39,7 @@ except:
|
|||||||
|
|
||||||
|
|
||||||
# exclude extremly large displacements
|
# exclude extremly large displacements
|
||||||
MAX_FLOW = 500
|
MAX_FLOW = 400
|
||||||
SUM_FREQ = 100
|
SUM_FREQ = 100
|
||||||
VAL_FREQ = 5000
|
VAL_FREQ = 5000
|
||||||
|
|
||||||
@ -181,13 +181,14 @@ def train(args):
|
|||||||
|
|
||||||
loss, metrics = sequence_loss(flow_predictions, flow, valid)
|
loss, metrics = sequence_loss(flow_predictions, flow, valid)
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
scaler.unscale_(optimizer)
|
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
|
||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
scaler.update()
|
scaler.update()
|
||||||
|
|
||||||
|
|
||||||
logger.push(metrics)
|
logger.push(metrics)
|
||||||
|
|
||||||
if total_steps % VAL_FREQ == VAL_FREQ - 1:
|
if total_steps % VAL_FREQ == VAL_FREQ - 1:
|
||||||
|
6
train_mixed.sh
Executable file
6
train_mixed.sh
Executable file
@ -0,0 +1,6 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
mkdir -p checkpoints
|
||||||
|
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision
|
||||||
|
python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision
|
||||||
|
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --mixed_precision
|
||||||
|
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --mixed_precision
|
6
train_standard.sh
Executable file
6
train_standard.sh
Executable file
@ -0,0 +1,6 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
mkdir -p checkpoints
|
||||||
|
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001
|
||||||
|
python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001
|
||||||
|
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001
|
||||||
|
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001
|
Loading…
Reference in New Issue
Block a user