Merge branch 'main' of github.com:JunkyByte/co-tracker
This commit is contained in:
commit
03f3c41e07
41
README.md
41
README.md
@ -1,6 +1,6 @@
|
|||||||
# CoTracker: It is Better to Track Together
|
# CoTracker: It is Better to Track Together
|
||||||
|
|
||||||
**[Meta AI Research, FAIR](https://ai.facebook.com/research/)**; **[University of Oxford, VGG](https://www.robots.ox.ac.uk/~vgg/)**
|
**[Meta AI Research, GenAI](https://ai.facebook.com/research/)**; **[University of Oxford, VGG](https://www.robots.ox.ac.uk/~vgg/)**
|
||||||
|
|
||||||
[Nikita Karaev](https://nikitakaraevv.github.io/), [Ignacio Rocco](https://www.irocco.info/), [Benjamin Graham](https://ai.facebook.com/people/benjamin-graham/), [Natalia Neverova](https://nneverova.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/)
|
[Nikita Karaev](https://nikitakaraevv.github.io/), [Ignacio Rocco](https://www.irocco.info/), [Benjamin Graham](https://ai.facebook.com/people/benjamin-graham/), [Natalia Neverova](https://nneverova.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/)
|
||||||
|
|
||||||
@ -15,7 +15,7 @@
|
|||||||
**CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow.
|
**CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow.
|
||||||
|
|
||||||
CoTracker can track:
|
CoTracker can track:
|
||||||
- **Every pixel** within a video
|
- **Every pixel** in a video
|
||||||
- Points sampled on a regular grid on any video frame
|
- Points sampled on a regular grid on any video frame
|
||||||
- Manually selected points
|
- Manually selected points
|
||||||
|
|
||||||
@ -26,16 +26,30 @@ Try these tracking modes for yourself with our [Colab demo](https://colab.resear
|
|||||||
## Installation Instructions
|
## Installation Instructions
|
||||||
Ensure you have both PyTorch and TorchVision installed on your system. Follow the instructions [here](https://pytorch.org/get-started/locally/) for the installation. We strongly recommend installing both PyTorch and TorchVision with CUDA support.
|
Ensure you have both PyTorch and TorchVision installed on your system. Follow the instructions [here](https://pytorch.org/get-started/locally/) for the installation. We strongly recommend installing both PyTorch and TorchVision with CUDA support.
|
||||||
|
|
||||||
## Steps to Install CoTracker and its dependencies:
|
### Pretrained models via PyTorch Hub
|
||||||
|
The easiest way to use CoTracker is to load a pretrained model from torch.hub:
|
||||||
|
```
|
||||||
|
pip install einops timm tqdm
|
||||||
|
```
|
||||||
|
```
|
||||||
|
import torch
|
||||||
|
import timm
|
||||||
|
import einops
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker_w8")
|
||||||
|
```
|
||||||
|
Another option is to install it from this gihub repo. That's the best way if you need to run our demo or evaluate / train CoTracker:
|
||||||
|
### Steps to Install CoTracker and its dependencies:
|
||||||
```
|
```
|
||||||
git clone https://github.com/facebookresearch/co-tracker
|
git clone https://github.com/facebookresearch/co-tracker
|
||||||
cd co-tracker
|
cd co-tracker
|
||||||
pip install -e .
|
pip install -e .
|
||||||
pip install opencv-python einops timm matplotlib moviepy flow_vis
|
pip install opencv-python einops timm matplotlib moviepy flow_vis
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Model Weights Download:
|
### Download Model Weights:
|
||||||
```
|
```
|
||||||
mkdir checkpoints
|
mkdir checkpoints
|
||||||
cd checkpoints
|
cd checkpoints
|
||||||
@ -60,24 +74,26 @@ To reproduce the results presented in the paper, download the following datasets
|
|||||||
|
|
||||||
And install the necessary dependencies:
|
And install the necessary dependencies:
|
||||||
```
|
```
|
||||||
pip install hydra-core==1.1.0 mediapy tensorboard
|
pip install hydra-core==1.1.0 mediapy
|
||||||
```
|
```
|
||||||
Then, execute the following command to evaluate on BADJA:
|
Then, execute the following command to evaluate on BADJA:
|
||||||
```
|
```
|
||||||
python ./cotracker/evaluation/evaluate.py --config-name eval_badja exp_dir=./eval_outputs dataset_root=your/badja/path
|
python ./cotracker/evaluation/evaluate.py --config-name eval_badja exp_dir=./eval_outputs dataset_root=your/badja/path
|
||||||
```
|
```
|
||||||
|
By default, evaluation will be slow since it is done for one target point at a time, which ensures robustness and fairness, as described in the paper.
|
||||||
|
|
||||||
## Training
|
## Training
|
||||||
To train the CoTracker as described in our paper, you first need to generate annotations for [Google Kubric](https://github.com/google-research/kubric) MOVI-f dataset. Instructions for annotation generation can be found [here](https://github.com/deepmind/tapnet).
|
To train the CoTracker as described in our paper, you first need to generate annotations for [Google Kubric](https://github.com/google-research/kubric) MOVI-f dataset. Instructions for annotation generation can be found [here](https://github.com/deepmind/tapnet).
|
||||||
|
|
||||||
Once you have the annotated dataset, you need to make sure you followed the steps for evaluation setup and install the training dependencies:
|
Once you have the annotated dataset, you need to make sure you followed the steps for evaluation setup and install the training dependencies:
|
||||||
```
|
```
|
||||||
pip install pytorch_lightning==1.6.0
|
pip install pytorch_lightning==1.6.0 tensorboard
|
||||||
```
|
```
|
||||||
launch training on Kubric. Our model was trained using 32 GPUs, and you can adjust the parameters to best suit your hardware setup.
|
Now you can launch training on Kubric. Our model was trained for 50000 iterations on 32 GPUs (4 nodes with 8 GPUs).
|
||||||
|
Modify *dataset_root* and *ckpt_path* accordingly before running this command:
|
||||||
```
|
```
|
||||||
python train.py --batch_size 1 --num_workers 28 \
|
python train.py --batch_size 1 --num_workers 28 \
|
||||||
--num_steps 50000 --ckpt_path ./ --model_name cotracker \
|
--num_steps 50000 --ckpt_path ./ --dataset_root ./datasets --model_name cotracker \
|
||||||
--save_freq 200 --sequence_len 24 --eval_datasets tapvid_davis_first badja \
|
--save_freq 200 --sequence_len 24 --eval_datasets tapvid_davis_first badja \
|
||||||
--traj_per_sample 256 --sliding_window_len 8 --updateformer_space_depth 6 --updateformer_time_depth 6 \
|
--traj_per_sample 256 --sliding_window_len 8 --updateformer_space_depth 6 --updateformer_time_depth 6 \
|
||||||
--save_every_n_epoch 10 --evaluate_every_n_epoch 10 --model_stride 4
|
--save_every_n_epoch 10 --evaluate_every_n_epoch 10 --model_stride 4
|
||||||
@ -86,13 +102,16 @@ python train.py --batch_size 1 --num_workers 28 \
|
|||||||
## License
|
## License
|
||||||
The majority of CoTracker is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Particle Video Revisited is licensed under the MIT license, TAP-Vid is licensed under the Apache 2.0 license.
|
The majority of CoTracker is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Particle Video Revisited is licensed under the MIT license, TAP-Vid is licensed under the Apache 2.0 license.
|
||||||
|
|
||||||
|
## Acknowledgments
|
||||||
|
We would like to thank [PIPs](https://github.com/aharley/pips) and [TAP-Vid](https://github.com/deepmind/tapnet) for publicly releasing their code and data. We also want to thank [Luke Melas-Kyriazi](https://lukemelas.github.io/) for proofreading the paper, [Jianyuan Wang](https://jytime.github.io/), [Roman Shapovalov](https://shapovalov.ro/) and [Adam W. Harley](https://adamharley.com/) for the insightful discussions.
|
||||||
|
|
||||||
## Citing CoTracker
|
## Citing CoTracker
|
||||||
If you find our repository useful, please consider giving it a star ⭐ and citing our paper in your work:
|
If you find our repository useful, please consider giving it a star ⭐ and citing our paper in your work:
|
||||||
```
|
```
|
||||||
@article{karaev2023cotracker,
|
@article{karaev2023cotracker,
|
||||||
title={CoTracker: It is Better to Track Together},
|
title={CoTracker: It is Better to Track Together},
|
||||||
author={Nikita Karaev and Ignacio Rocco and Benjamin Graham and Natalia Neverova and Andrea Vedaldi and Christian Rupprecht},
|
author={Nikita Karaev and Ignacio Rocco and Benjamin Graham and Natalia Neverova and Andrea Vedaldi and Christian Rupprecht},
|
||||||
journal={arxiv},
|
journal={arXiv:2307.07635},
|
||||||
year={2023}
|
year={2023}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
@ -185,7 +185,11 @@ class Evaluator:
|
|||||||
if not all(gotit):
|
if not all(gotit):
|
||||||
print("batch is None")
|
print("batch is None")
|
||||||
continue
|
continue
|
||||||
dataclass_to_cuda_(sample)
|
if torch.cuda.is_available():
|
||||||
|
dataclass_to_cuda_(sample)
|
||||||
|
device = torch.device("cuda")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not train_mode
|
not train_mode
|
||||||
@ -205,7 +209,7 @@ class Evaluator:
|
|||||||
queries[:, :, 1],
|
queries[:, :, 1],
|
||||||
],
|
],
|
||||||
dim=2,
|
dim=2,
|
||||||
)
|
).to(device)
|
||||||
else:
|
else:
|
||||||
queries = torch.cat(
|
queries = torch.cat(
|
||||||
[
|
[
|
||||||
@ -213,7 +217,7 @@ class Evaluator:
|
|||||||
sample.trajectory[:, 0],
|
sample.trajectory[:, 0],
|
||||||
],
|
],
|
||||||
dim=2,
|
dim=2,
|
||||||
)
|
).to(device)
|
||||||
|
|
||||||
pred_tracks = model(sample.video, queries)
|
pred_tracks = model(sample.video, queries)
|
||||||
if "strided" in dataset_name:
|
if "strided" in dataset_name:
|
||||||
|
@ -102,6 +102,8 @@ def run_eval(cfg: DefaultConfig):
|
|||||||
single_point=cfg.single_point,
|
single_point=cfg.single_point,
|
||||||
n_iters=cfg.n_iters,
|
n_iters=cfg.n_iters,
|
||||||
)
|
)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
predictor.model = predictor.model.cuda()
|
||||||
|
|
||||||
# Setting the random seeds
|
# Setting the random seeds
|
||||||
torch.manual_seed(cfg.seed)
|
torch.manual_seed(cfg.seed)
|
||||||
|
@ -12,6 +12,8 @@ from cotracker.models.core.cotracker.cotracker import CoTracker
|
|||||||
def build_cotracker(
|
def build_cotracker(
|
||||||
checkpoint: str,
|
checkpoint: str,
|
||||||
):
|
):
|
||||||
|
if checkpoint is None:
|
||||||
|
return build_cotracker_stride_4_wind_8()
|
||||||
model_name = checkpoint.split("/")[-1].split(".")[0]
|
model_name = checkpoint.split("/")[-1].split(".")[0]
|
||||||
if model_name == "cotracker_stride_4_wind_8":
|
if model_name == "cotracker_stride_4_wind_8":
|
||||||
return build_cotracker_stride_4_wind_8(checkpoint=checkpoint)
|
return build_cotracker_stride_4_wind_8(checkpoint=checkpoint)
|
||||||
|
@ -25,11 +25,11 @@ from cotracker.models.core.embeddings import (
|
|||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
|
||||||
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device='cuda'):
|
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device="cuda"):
|
||||||
if grid_size == 1:
|
if grid_size == 1:
|
||||||
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[
|
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2], device=device)[
|
||||||
None, None
|
None, None
|
||||||
].to(device)
|
]
|
||||||
|
|
||||||
grid_y, grid_x = meshgrid2d(
|
grid_y, grid_x = meshgrid2d(
|
||||||
1, grid_size, grid_size, stack=False, norm=False, device=device
|
1, grid_size, grid_size, stack=False, norm=False, device=device
|
||||||
|
@ -29,11 +29,10 @@ class EvaluationPredictor(torch.nn.Module):
|
|||||||
self.n_iters = n_iters
|
self.n_iters = n_iters
|
||||||
|
|
||||||
self.model = cotracker_model
|
self.model = cotracker_model
|
||||||
self.model.to("cuda")
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
def forward(self, video, queries):
|
def forward(self, video, queries):
|
||||||
queries = queries.clone().cuda()
|
queries = queries.clone()
|
||||||
B, T, C, H, W = video.shape
|
B, T, C, H, W = video.shape
|
||||||
B, N, D = queries.shape
|
B, N, D = queries.shape
|
||||||
|
|
||||||
@ -42,14 +41,16 @@ class EvaluationPredictor(torch.nn.Module):
|
|||||||
|
|
||||||
rgbs = video.reshape(B * T, C, H, W)
|
rgbs = video.reshape(B * T, C, H, W)
|
||||||
rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear")
|
rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear")
|
||||||
rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]).cuda()
|
rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||||
|
|
||||||
|
device = rgbs.device
|
||||||
|
|
||||||
queries[:, :, 1] *= self.interp_shape[1] / W
|
queries[:, :, 1] *= self.interp_shape[1] / W
|
||||||
queries[:, :, 2] *= self.interp_shape[0] / H
|
queries[:, :, 2] *= self.interp_shape[0] / H
|
||||||
|
|
||||||
if self.single_point:
|
if self.single_point:
|
||||||
traj_e = torch.zeros((B, T, N, 2)).cuda()
|
traj_e = torch.zeros((B, T, N, 2), device=device)
|
||||||
vis_e = torch.zeros((B, T, N)).cuda()
|
vis_e = torch.zeros((B, T, N), device=device)
|
||||||
for pind in range((N)):
|
for pind in range((N)):
|
||||||
query = queries[:, pind : pind + 1]
|
query = queries[:, pind : pind + 1]
|
||||||
|
|
||||||
@ -60,8 +61,10 @@ class EvaluationPredictor(torch.nn.Module):
|
|||||||
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
|
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
|
||||||
else:
|
else:
|
||||||
if self.grid_size > 0:
|
if self.grid_size > 0:
|
||||||
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
|
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device)
|
||||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() #
|
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(
|
||||||
|
device
|
||||||
|
) #
|
||||||
queries = torch.cat([queries, xy], dim=1) #
|
queries = torch.cat([queries, xy], dim=1) #
|
||||||
|
|
||||||
traj_e, __, vis_e, __ = self.model(
|
traj_e, __, vis_e, __ = self.model(
|
||||||
@ -91,8 +94,8 @@ class EvaluationPredictor(torch.nn.Module):
|
|||||||
query = torch.cat([query, xy_target], dim=1).to(device) #
|
query = torch.cat([query, xy_target], dim=1).to(device) #
|
||||||
|
|
||||||
if self.grid_size > 0:
|
if self.grid_size > 0:
|
||||||
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
|
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device)
|
||||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() #
|
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
||||||
query = torch.cat([query, xy], dim=1).to(device) #
|
query = torch.cat([query, xy], dim=1).to(device) #
|
||||||
# crop the video to start from the queried frame
|
# crop the video to start from the queried frame
|
||||||
query[0, 0, 0] = 0
|
query[0, 0, 0] = 0
|
||||||
|
@ -25,8 +25,6 @@ class CoTrackerPredictor(torch.nn.Module):
|
|||||||
model = build_cotracker(checkpoint)
|
model = build_cotracker(checkpoint)
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.device = device or 'cuda'
|
|
||||||
self.model.to(self.device)
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -73,7 +71,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
|||||||
grid_width = W // grid_step
|
grid_width = W // grid_step
|
||||||
grid_height = H // grid_step
|
grid_height = H // grid_step
|
||||||
tracks = visibilities = None
|
tracks = visibilities = None
|
||||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(self.device)
|
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
|
||||||
grid_pts[0, :, 0] = grid_query_frame
|
grid_pts[0, :, 0] = grid_query_frame
|
||||||
for offset in tqdm(range(grid_step * grid_step)):
|
for offset in tqdm(range(grid_step * grid_step)):
|
||||||
ox = offset % grid_step
|
ox = offset % grid_step
|
||||||
@ -108,10 +106,8 @@ class CoTrackerPredictor(torch.nn.Module):
|
|||||||
assert B == 1
|
assert B == 1
|
||||||
|
|
||||||
video = video.reshape(B * T, C, H, W)
|
video = video.reshape(B * T, C, H, W)
|
||||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").to(self.device)
|
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear")
|
||||||
video = video.reshape(
|
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||||
B, T, 3, self.interp_shape[0], self.interp_shape[1]
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
if queries is not None:
|
if queries is not None:
|
||||||
queries = queries.clone()
|
queries = queries.clone()
|
||||||
@ -120,7 +116,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
|||||||
queries[:, :, 1] *= self.interp_shape[1] / W
|
queries[:, :, 1] *= self.interp_shape[1] / W
|
||||||
queries[:, :, 2] *= self.interp_shape[0] / H
|
queries[:, :, 2] *= self.interp_shape[0] / H
|
||||||
elif grid_size > 0:
|
elif grid_size > 0:
|
||||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=self.device)
|
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
|
||||||
if segm_mask is not None:
|
if segm_mask is not None:
|
||||||
segm_mask = F.interpolate(
|
segm_mask = F.interpolate(
|
||||||
segm_mask, tuple(self.interp_shape), mode="nearest"
|
segm_mask, tuple(self.interp_shape), mode="nearest"
|
||||||
|
@ -14,7 +14,6 @@ from matplotlib import cm
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from moviepy.editor import ImageSequenceClip
|
from moviepy.editor import ImageSequenceClip
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
@ -67,7 +66,7 @@ class Visualizer:
|
|||||||
gt_tracks: torch.Tensor = None, # (B,T,N,2)
|
gt_tracks: torch.Tensor = None, # (B,T,N,2)
|
||||||
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
||||||
filename: str = "video",
|
filename: str = "video",
|
||||||
writer: SummaryWriter = None,
|
writer=None, # tensorboard Summary Writer, used for visualization during training
|
||||||
step: int = 0,
|
step: int = 0,
|
||||||
query_frame: int = 0,
|
query_frame: int = 0,
|
||||||
save_video: bool = True,
|
save_video: bool = True,
|
||||||
|
12
demo.py
12
demo.py
@ -32,11 +32,6 @@ if __name__ == "__main__":
|
|||||||
default="./checkpoints/cotracker_stride_4_wind_8.pth",
|
default="./checkpoints/cotracker_stride_4_wind_8.pth",
|
||||||
help="cotracker model",
|
help="cotracker model",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--device",
|
|
||||||
default="cuda",
|
|
||||||
help="Device to use for inference",
|
|
||||||
)
|
|
||||||
parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size")
|
parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--grid_query_frame",
|
"--grid_query_frame",
|
||||||
@ -59,7 +54,12 @@ if __name__ == "__main__":
|
|||||||
segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
|
segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
|
||||||
segm_mask = torch.from_numpy(segm_mask)[None, None]
|
segm_mask = torch.from_numpy(segm_mask)[None, None]
|
||||||
|
|
||||||
model = CoTrackerPredictor(checkpoint=args.checkpoint, device=args.device)
|
model = CoTrackerPredictor(checkpoint=args.checkpoint)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
model = model.cuda()
|
||||||
|
video = video.cuda()
|
||||||
|
else:
|
||||||
|
print("CUDA is not available!")
|
||||||
|
|
||||||
pred_tracks, pred_visibility = model(
|
pred_tracks, pred_visibility = model(
|
||||||
video,
|
video,
|
||||||
|
32
hubconf.py
Normal file
32
hubconf.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
dependencies = ["torch", "einops", "timm", "tqdm"]
|
||||||
|
|
||||||
|
_COTRACKER_URL = (
|
||||||
|
"https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_cotracker_predictor(*, pretrained: bool = True, **kwargs):
|
||||||
|
from cotracker.predictor import CoTrackerPredictor
|
||||||
|
|
||||||
|
predictor = CoTrackerPredictor(checkpoint=None)
|
||||||
|
if pretrained:
|
||||||
|
state_dict = torch.hub.load_state_dict_from_url(
|
||||||
|
_COTRACKER_URL, map_location="cpu"
|
||||||
|
)
|
||||||
|
predictor.model.load_state_dict(state_dict)
|
||||||
|
return predictor
|
||||||
|
|
||||||
|
|
||||||
|
def cotracker_w8(*, pretrained: bool = True, **kwargs):
|
||||||
|
"""
|
||||||
|
CoTracker model with stride 4 and window length 8. (The main model from the paper)
|
||||||
|
"""
|
||||||
|
return _make_cotracker_predictor(pretrained=pretrained, **kwargs)
|
File diff suppressed because one or more lines are too long
111
train.py
111
train.py
@ -36,21 +36,6 @@ from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_
|
|||||||
from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss
|
from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss
|
||||||
|
|
||||||
|
|
||||||
# define the handler function
|
|
||||||
# for training on a slurm cluster
|
|
||||||
def sig_handler(signum, frame):
|
|
||||||
print("caught signal", signum)
|
|
||||||
print(socket.gethostname(), "USR1 signal caught.")
|
|
||||||
# do other stuff to cleanup here
|
|
||||||
print("requeuing job " + os.environ["SLURM_JOB_ID"])
|
|
||||||
os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
|
|
||||||
def term_handler(signum, frame):
|
|
||||||
print("bypassing sigterm", flush=True)
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_optimizer(args, model):
|
def fetch_optimizer(args, model):
|
||||||
"""Create the optimizer and learning rate scheduler"""
|
"""Create the optimizer and learning rate scheduler"""
|
||||||
optimizer = optim.AdamW(
|
optimizer = optim.AdamW(
|
||||||
@ -153,6 +138,8 @@ def run_test_eval(evaluator, model, dataloaders, writer, step):
|
|||||||
single_point=False,
|
single_point=False,
|
||||||
n_iters=6,
|
n_iters=6,
|
||||||
)
|
)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
predictor.model = predictor.model.cuda()
|
||||||
|
|
||||||
metrics = evaluator.evaluate_sequence(
|
metrics = evaluator.evaluate_sequence(
|
||||||
model=predictor,
|
model=predictor,
|
||||||
@ -302,9 +289,7 @@ class Lite(LightningLite):
|
|||||||
eval_dataloaders.append(("fastcapture", eval_dataloader_fastcapture))
|
eval_dataloaders.append(("fastcapture", eval_dataloader_fastcapture))
|
||||||
|
|
||||||
if "tapvid_davis_first" in args.eval_datasets:
|
if "tapvid_davis_first" in args.eval_datasets:
|
||||||
data_root = os.path.join(
|
data_root = os.path.join(args.dataset_root, "tapvid_davis/tapvid_davis.pkl")
|
||||||
args.dataset_root, "/tapvid_davis/tapvid_davis.pkl"
|
|
||||||
)
|
|
||||||
eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root)
|
eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root)
|
||||||
eval_dataloader_tapvid_davis = torch.utils.data.DataLoader(
|
eval_dataloader_tapvid_davis = torch.utils.data.DataLoader(
|
||||||
eval_dataset,
|
eval_dataset,
|
||||||
@ -551,17 +536,15 @@ class Lite(LightningLite):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
signal.signal(signal.SIGUSR1, sig_handler)
|
|
||||||
signal.signal(signal.SIGTERM, term_handler)
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--model_name", default="cotracker", help="model name")
|
parser.add_argument("--model_name", default="cotracker", help="model name")
|
||||||
parser.add_argument("--restore_ckpt", help="restore checkpoint")
|
parser.add_argument("--restore_ckpt", help="path to restore a checkpoint")
|
||||||
parser.add_argument("--ckpt_path", help="restore checkpoint")
|
parser.add_argument("--ckpt_path", help="path to save checkpoints")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--batch_size", type=int, default=4, help="batch size used during training."
|
"--batch_size", type=int, default=4, help="batch size used during training."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num_workers", type=int, default=6, help="left right consistency loss"
|
"--num_workers", type=int, default=6, help="number of dataloader workers"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -578,20 +561,34 @@ if __name__ == "__main__":
|
|||||||
"--evaluate_every_n_epoch",
|
"--evaluate_every_n_epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="number of flow-field updates during validation forward pass",
|
help="evaluate during training after every n epochs, after every epoch by default",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_every_n_epoch",
|
"--save_every_n_epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="number of flow-field updates during validation forward pass",
|
help="save checkpoints during training after every n epochs, after every epoch by default",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--validate_at_start", action="store_true", help="use mixed precision"
|
"--validate_at_start",
|
||||||
|
action="store_true",
|
||||||
|
help="whether to run evaluation before training starts",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_freq",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="frequency of trajectory visualization during training",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--traj_per_sample",
|
||||||
|
type=int,
|
||||||
|
default=768,
|
||||||
|
help="the number of trajectories to sample for training",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_root", type=str, help="path lo all the datasets (train and eval)"
|
||||||
)
|
)
|
||||||
parser.add_argument("--save_freq", type=int, default=100, help="save_freq")
|
|
||||||
parser.add_argument("--traj_per_sample", type=int, default=768, help="save_freq")
|
|
||||||
parser.add_argument("--dataset_root", type=str, help="path lo all the datasets")
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train_iters",
|
"--train_iters",
|
||||||
@ -605,49 +602,75 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--eval_datasets",
|
"--eval_datasets",
|
||||||
nargs="+",
|
nargs="+",
|
||||||
default=["things", "badja", "fastcapture"],
|
default=["things", "badja"],
|
||||||
help="eval datasets.",
|
help="what datasets to use for evaluation",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--remove_space_attn", action="store_true", help="use mixed precision"
|
"--remove_space_attn",
|
||||||
|
action="store_true",
|
||||||
|
help="remove space attention from CoTracker",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dont_use_augs", action="store_true", help="use mixed precision"
|
"--dont_use_augs",
|
||||||
|
action="store_true",
|
||||||
|
help="don't apply augmentations during training",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sample_vis_1st_frame", action="store_true", help="use mixed precision"
|
"--sample_vis_1st_frame",
|
||||||
|
action="store_true",
|
||||||
|
help="only sample trajectories with points visible on the first frame",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sliding_window_len", type=int, default=8, help="use mixed precision"
|
"--sliding_window_len",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="length of the CoTracker sliding window",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--updateformer_hidden_size", type=int, default=384, help="use mixed precision"
|
"--updateformer_hidden_size",
|
||||||
|
type=int,
|
||||||
|
default=384,
|
||||||
|
help="hidden dimension of the CoTracker transformer model",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--updateformer_num_heads", type=int, default=8, help="use mixed precision"
|
"--updateformer_num_heads",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="number of heads of the CoTracker transformer model",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--updateformer_space_depth", type=int, default=12, help="use mixed precision"
|
"--updateformer_space_depth",
|
||||||
|
type=int,
|
||||||
|
default=12,
|
||||||
|
help="number of group attention layers in the CoTracker transformer model",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--updateformer_time_depth", type=int, default=12, help="use mixed precision"
|
"--updateformer_time_depth",
|
||||||
|
type=int,
|
||||||
|
default=12,
|
||||||
|
help="number of time attention layers in the CoTracker transformer model",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_stride", type=int, default=8, help="use mixed precision"
|
"--model_stride",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="stride of the CoTracker feature network",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--crop_size",
|
"--crop_size",
|
||||||
type=int,
|
type=int,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
default=[384, 512],
|
default=[384, 512],
|
||||||
help="use mixed precision",
|
help="crop videos to this resolution during training",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--eval_max_seq_len", type=int, default=1000, help="use mixed precision"
|
"--eval_max_seq_len",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="maximum length of evaluation videos",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
|
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
|
||||||
@ -661,5 +684,5 @@ if __name__ == "__main__":
|
|||||||
devices="auto",
|
devices="auto",
|
||||||
accelerator="gpu",
|
accelerator="gpu",
|
||||||
precision=32,
|
precision=32,
|
||||||
num_nodes=4,
|
# num_nodes=4,
|
||||||
).run(args)
|
).run(args)
|
||||||
|
Loading…
Reference in New Issue
Block a user