added small 1M paramter model
This commit is contained in:
parent
c86b3dc8f3
commit
01ad964d94
17
README.md
17
README.md
@ -24,21 +24,13 @@ Pretrained models can be downloaded by running
|
||||
```Shell
|
||||
./download_models.sh
|
||||
```
|
||||
or downloaded from [google drive](https://drive.google.com/file/d/10-BYgHqRNPGvmNUWr8razjb1xHu55pyA/view?usp=sharing)
|
||||
or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing)
|
||||
|
||||
You can demo a trained model on a sequence of frames
|
||||
```Shell
|
||||
python demo.py --model=models/raft-things.pth --path=demo-frames
|
||||
```
|
||||
|
||||
## (Optional) Efficent Implementation
|
||||
You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
|
||||
```Shell
|
||||
cd alt_cuda_corr && python setup.py install && cd ..
|
||||
```
|
||||
and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag.Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.
|
||||
|
||||
|
||||
## Required Data
|
||||
To evaluate/train RAFT, you will need to download the required datasets.
|
||||
* [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
|
||||
@ -83,3 +75,10 @@ If you have a RTX GPU, training can be accelerated using mixed precision. You ca
|
||||
```Shell
|
||||
./train_mixed.sh
|
||||
```
|
||||
|
||||
## (Optional) Efficent Implementation
|
||||
You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
|
||||
```Shell
|
||||
cd alt_cuda_corr && python setup.py install && cd ..
|
||||
```
|
||||
and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.
|
||||
|
@ -1,3 +1,3 @@
|
||||
#!/bin/bash
|
||||
wget https://www.dropbox.com/s/npt24nvhoojdr0n/models.zip
|
||||
wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip
|
||||
unzip models.zip
|
||||
|
18
train.py
18
train.py
@ -44,7 +44,7 @@ SUM_FREQ = 100
|
||||
VAL_FREQ = 5000
|
||||
|
||||
|
||||
def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
|
||||
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
|
||||
""" Loss function defined over sequence of flow predictions """
|
||||
|
||||
n_predictions = len(flow_preds)
|
||||
@ -55,7 +55,7 @@ def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
|
||||
valid = (valid >= 0.5) & (mag < max_flow)
|
||||
|
||||
for i in range(n_predictions):
|
||||
i_weight = 0.8**(n_predictions - i - 1)
|
||||
i_weight = gamma**(n_predictions - i - 1)
|
||||
i_loss = (flow_preds[i] - flow_gt).abs()
|
||||
flow_loss += i_weight * (valid[:, None] * i_loss).mean()
|
||||
|
||||
@ -71,16 +71,11 @@ def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
|
||||
|
||||
return flow_loss, metrics
|
||||
|
||||
def show_image(img):
|
||||
img = img.permute(1,2,0).cpu().numpy()
|
||||
plt.imshow(img/255.0)
|
||||
plt.show()
|
||||
# cv2.imshow('image', img/255.0)
|
||||
# cv2.waitKey()
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
def fetch_optimizer(args, model):
|
||||
""" Create the optimizer and learning rate scheduler """
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
|
||||
@ -169,9 +164,6 @@ def train(args):
|
||||
optimizer.zero_grad()
|
||||
image1, image2, flow, valid = [x.cuda() for x in data_blob]
|
||||
|
||||
# show_image(image1[0])
|
||||
# show_image(image2[0])
|
||||
|
||||
if args.add_noise:
|
||||
stdv = np.random.uniform(0.0, 5.0)
|
||||
image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
|
||||
@ -179,7 +171,7 @@ def train(args):
|
||||
|
||||
flow_predictions = model(image1, image2, iters=args.iters)
|
||||
|
||||
loss, metrics = sequence_loss(flow_predictions, flow, valid)
|
||||
loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
|
||||
@ -188,7 +180,6 @@ def train(args):
|
||||
scheduler.step()
|
||||
scaler.update()
|
||||
|
||||
|
||||
logger.push(metrics)
|
||||
|
||||
if total_steps % VAL_FREQ == VAL_FREQ - 1:
|
||||
@ -243,6 +234,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--epsilon', type=float, default=1e-8)
|
||||
parser.add_argument('--clip', type=float, default=1.0)
|
||||
parser.add_argument('--dropout', type=float, default=0.0)
|
||||
parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
|
||||
parser.add_argument('--add_noise', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -2,5 +2,5 @@
|
||||
mkdir -p checkpoints
|
||||
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision
|
||||
python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision
|
||||
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --mixed_precision
|
||||
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --mixed_precision
|
||||
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision
|
||||
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision
|
||||
|
@ -2,5 +2,5 @@
|
||||
mkdir -p checkpoints
|
||||
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001
|
||||
python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001
|
||||
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001
|
||||
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001
|
||||
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma=0.85
|
||||
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85
|
Loading…
Reference in New Issue
Block a user