Merge branch 'main' of github.com:JunkyByte/co-tracker

This commit is contained in:
JunkyByte 2023-07-25 16:23:37 +02:00
commit 03f3c41e07
12 changed files with 236 additions and 138 deletions

View File

@ -1,6 +1,6 @@
# CoTracker: It is Better to Track Together
**[Meta AI Research, FAIR](https://ai.facebook.com/research/)**; **[University of Oxford, VGG](https://www.robots.ox.ac.uk/~vgg/)**
**[Meta AI Research, GenAI](https://ai.facebook.com/research/)**; **[University of Oxford, VGG](https://www.robots.ox.ac.uk/~vgg/)**
[Nikita Karaev](https://nikitakaraevv.github.io/), [Ignacio Rocco](https://www.irocco.info/), [Benjamin Graham](https://ai.facebook.com/people/benjamin-graham/), [Natalia Neverova](https://nneverova.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/)
@ -15,7 +15,7 @@
**CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow.
CoTracker can track:
- **Every pixel** within a video
- **Every pixel** in a video
- Points sampled on a regular grid on any video frame
- Manually selected points
@ -26,16 +26,30 @@ Try these tracking modes for yourself with our [Colab demo](https://colab.resear
## Installation Instructions
Ensure you have both PyTorch and TorchVision installed on your system. Follow the instructions [here](https://pytorch.org/get-started/locally/) for the installation. We strongly recommend installing both PyTorch and TorchVision with CUDA support.
## Steps to Install CoTracker and its dependencies:
### Pretrained models via PyTorch Hub
The easiest way to use CoTracker is to load a pretrained model from torch.hub:
```
pip install einops timm tqdm
```
```
import torch
import timm
import einops
import tqdm
cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker_w8")
```
Another option is to install it from this gihub repo. That's the best way if you need to run our demo or evaluate / train CoTracker:
### Steps to Install CoTracker and its dependencies:
```
git clone https://github.com/facebookresearch/co-tracker
cd co-tracker
pip install -e .
pip install opencv-python einops timm matplotlib moviepy flow_vis
pip install opencv-python einops timm matplotlib moviepy flow_vis
```
## Model Weights Download:
### Download Model Weights:
```
mkdir checkpoints
cd checkpoints
@ -60,24 +74,26 @@ To reproduce the results presented in the paper, download the following datasets
And install the necessary dependencies:
```
pip install hydra-core==1.1.0 mediapy tensorboard
pip install hydra-core==1.1.0 mediapy
```
Then, execute the following command to evaluate on BADJA:
```
python ./cotracker/evaluation/evaluate.py --config-name eval_badja exp_dir=./eval_outputs dataset_root=your/badja/path
```
By default, evaluation will be slow since it is done for one target point at a time, which ensures robustness and fairness, as described in the paper.
## Training
To train the CoTracker as described in our paper, you first need to generate annotations for [Google Kubric](https://github.com/google-research/kubric) MOVI-f dataset. Instructions for annotation generation can be found [here](https://github.com/deepmind/tapnet).
Once you have the annotated dataset, you need to make sure you followed the steps for evaluation setup and install the training dependencies:
```
pip install pytorch_lightning==1.6.0
pip install pytorch_lightning==1.6.0 tensorboard
```
launch training on Kubric. Our model was trained using 32 GPUs, and you can adjust the parameters to best suit your hardware setup.
Now you can launch training on Kubric. Our model was trained for 50000 iterations on 32 GPUs (4 nodes with 8 GPUs).
Modify *dataset_root* and *ckpt_path* accordingly before running this command:
```
python train.py --batch_size 1 --num_workers 28 \
--num_steps 50000 --ckpt_path ./ --model_name cotracker \
--num_steps 50000 --ckpt_path ./ --dataset_root ./datasets --model_name cotracker \
--save_freq 200 --sequence_len 24 --eval_datasets tapvid_davis_first badja \
--traj_per_sample 256 --sliding_window_len 8 --updateformer_space_depth 6 --updateformer_time_depth 6 \
--save_every_n_epoch 10 --evaluate_every_n_epoch 10 --model_stride 4
@ -86,13 +102,16 @@ python train.py --batch_size 1 --num_workers 28 \
## License
The majority of CoTracker is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Particle Video Revisited is licensed under the MIT license, TAP-Vid is licensed under the Apache 2.0 license.
## Acknowledgments
We would like to thank [PIPs](https://github.com/aharley/pips) and [TAP-Vid](https://github.com/deepmind/tapnet) for publicly releasing their code and data. We also want to thank [Luke Melas-Kyriazi](https://lukemelas.github.io/) for proofreading the paper, [Jianyuan Wang](https://jytime.github.io/), [Roman Shapovalov](https://shapovalov.ro/) and [Adam W. Harley](https://adamharley.com/) for the insightful discussions.
## Citing CoTracker
If you find our repository useful, please consider giving it a star ⭐ and citing our paper in your work:
```
@article{karaev2023cotracker,
title={CoTracker: It is Better to Track Together},
author={Nikita Karaev and Ignacio Rocco and Benjamin Graham and Natalia Neverova and Andrea Vedaldi and Christian Rupprecht},
journal={arxiv},
journal={arXiv:2307.07635},
year={2023}
}
```
```

View File

@ -185,7 +185,11 @@ class Evaluator:
if not all(gotit):
print("batch is None")
continue
dataclass_to_cuda_(sample)
if torch.cuda.is_available():
dataclass_to_cuda_(sample)
device = torch.device("cuda")
else:
device = torch.device("cpu")
if (
not train_mode
@ -205,7 +209,7 @@ class Evaluator:
queries[:, :, 1],
],
dim=2,
)
).to(device)
else:
queries = torch.cat(
[
@ -213,7 +217,7 @@ class Evaluator:
sample.trajectory[:, 0],
],
dim=2,
)
).to(device)
pred_tracks = model(sample.video, queries)
if "strided" in dataset_name:

View File

@ -102,6 +102,8 @@ def run_eval(cfg: DefaultConfig):
single_point=cfg.single_point,
n_iters=cfg.n_iters,
)
if torch.cuda.is_available():
predictor.model = predictor.model.cuda()
# Setting the random seeds
torch.manual_seed(cfg.seed)

View File

@ -12,6 +12,8 @@ from cotracker.models.core.cotracker.cotracker import CoTracker
def build_cotracker(
checkpoint: str,
):
if checkpoint is None:
return build_cotracker_stride_4_wind_8()
model_name = checkpoint.split("/")[-1].split(".")[0]
if model_name == "cotracker_stride_4_wind_8":
return build_cotracker_stride_4_wind_8(checkpoint=checkpoint)

View File

@ -25,11 +25,11 @@ from cotracker.models.core.embeddings import (
torch.manual_seed(0)
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device='cuda'):
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device="cuda"):
if grid_size == 1:
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2], device=device)[
None, None
].to(device)
]
grid_y, grid_x = meshgrid2d(
1, grid_size, grid_size, stack=False, norm=False, device=device

View File

@ -29,11 +29,10 @@ class EvaluationPredictor(torch.nn.Module):
self.n_iters = n_iters
self.model = cotracker_model
self.model.to("cuda")
self.model.eval()
def forward(self, video, queries):
queries = queries.clone().cuda()
queries = queries.clone()
B, T, C, H, W = video.shape
B, N, D = queries.shape
@ -42,14 +41,16 @@ class EvaluationPredictor(torch.nn.Module):
rgbs = video.reshape(B * T, C, H, W)
rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear")
rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]).cuda()
rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
device = rgbs.device
queries[:, :, 1] *= self.interp_shape[1] / W
queries[:, :, 2] *= self.interp_shape[0] / H
if self.single_point:
traj_e = torch.zeros((B, T, N, 2)).cuda()
vis_e = torch.zeros((B, T, N)).cuda()
traj_e = torch.zeros((B, T, N, 2), device=device)
vis_e = torch.zeros((B, T, N), device=device)
for pind in range((N)):
query = queries[:, pind : pind + 1]
@ -60,8 +61,10 @@ class EvaluationPredictor(torch.nn.Module):
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
else:
if self.grid_size > 0:
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() #
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device)
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(
device
) #
queries = torch.cat([queries, xy], dim=1) #
traj_e, __, vis_e, __ = self.model(
@ -91,8 +94,8 @@ class EvaluationPredictor(torch.nn.Module):
query = torch.cat([query, xy_target], dim=1).to(device) #
if self.grid_size > 0:
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() #
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device)
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
query = torch.cat([query, xy], dim=1).to(device) #
# crop the video to start from the queried frame
query[0, 0, 0] = 0

View File

@ -25,8 +25,6 @@ class CoTrackerPredictor(torch.nn.Module):
model = build_cotracker(checkpoint)
self.model = model
self.device = device or 'cuda'
self.model.to(self.device)
self.model.eval()
@torch.no_grad()
@ -73,7 +71,7 @@ class CoTrackerPredictor(torch.nn.Module):
grid_width = W // grid_step
grid_height = H // grid_step
tracks = visibilities = None
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(self.device)
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
grid_pts[0, :, 0] = grid_query_frame
for offset in tqdm(range(grid_step * grid_step)):
ox = offset % grid_step
@ -108,10 +106,8 @@ class CoTrackerPredictor(torch.nn.Module):
assert B == 1
video = video.reshape(B * T, C, H, W)
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").to(self.device)
video = video.reshape(
B, T, 3, self.interp_shape[0], self.interp_shape[1]
).to(self.device)
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear")
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
if queries is not None:
queries = queries.clone()
@ -120,7 +116,7 @@ class CoTrackerPredictor(torch.nn.Module):
queries[:, :, 1] *= self.interp_shape[1] / W
queries[:, :, 2] *= self.interp_shape[0] / H
elif grid_size > 0:
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=self.device)
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
if segm_mask is not None:
segm_mask = F.interpolate(
segm_mask, tuple(self.interp_shape), mode="nearest"

View File

@ -14,7 +14,6 @@ from matplotlib import cm
import torch.nn.functional as F
import torchvision.transforms as transforms
from moviepy.editor import ImageSequenceClip
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
@ -67,7 +66,7 @@ class Visualizer:
gt_tracks: torch.Tensor = None, # (B,T,N,2)
segm_mask: torch.Tensor = None, # (B,1,H,W)
filename: str = "video",
writer: SummaryWriter = None,
writer=None, # tensorboard Summary Writer, used for visualization during training
step: int = 0,
query_frame: int = 0,
save_video: bool = True,

12
demo.py
View File

@ -32,11 +32,6 @@ if __name__ == "__main__":
default="./checkpoints/cotracker_stride_4_wind_8.pth",
help="cotracker model",
)
parser.add_argument(
"--device",
default="cuda",
help="Device to use for inference",
)
parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size")
parser.add_argument(
"--grid_query_frame",
@ -59,7 +54,12 @@ if __name__ == "__main__":
segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
segm_mask = torch.from_numpy(segm_mask)[None, None]
model = CoTrackerPredictor(checkpoint=args.checkpoint, device=args.device)
model = CoTrackerPredictor(checkpoint=args.checkpoint)
if torch.cuda.is_available():
model = model.cuda()
video = video.cuda()
else:
print("CUDA is not available!")
pred_tracks, pred_visibility = model(
video,

32
hubconf.py Normal file
View File

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
dependencies = ["torch", "einops", "timm", "tqdm"]
_COTRACKER_URL = (
"https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth"
)
def _make_cotracker_predictor(*, pretrained: bool = True, **kwargs):
from cotracker.predictor import CoTrackerPredictor
predictor = CoTrackerPredictor(checkpoint=None)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
_COTRACKER_URL, map_location="cpu"
)
predictor.model.load_state_dict(state_dict)
return predictor
def cotracker_w8(*, pretrained: bool = True, **kwargs):
"""
CoTracker model with stride 4 and window length 8. (The main model from the paper)
"""
return _make_cotracker_predictor(pretrained=pretrained, **kwargs)

File diff suppressed because one or more lines are too long

111
train.py
View File

@ -36,21 +36,6 @@ from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_
from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss
# define the handler function
# for training on a slurm cluster
def sig_handler(signum, frame):
print("caught signal", signum)
print(socket.gethostname(), "USR1 signal caught.")
# do other stuff to cleanup here
print("requeuing job " + os.environ["SLURM_JOB_ID"])
os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
sys.exit(-1)
def term_handler(signum, frame):
print("bypassing sigterm", flush=True)
def fetch_optimizer(args, model):
"""Create the optimizer and learning rate scheduler"""
optimizer = optim.AdamW(
@ -153,6 +138,8 @@ def run_test_eval(evaluator, model, dataloaders, writer, step):
single_point=False,
n_iters=6,
)
if torch.cuda.is_available():
predictor.model = predictor.model.cuda()
metrics = evaluator.evaluate_sequence(
model=predictor,
@ -302,9 +289,7 @@ class Lite(LightningLite):
eval_dataloaders.append(("fastcapture", eval_dataloader_fastcapture))
if "tapvid_davis_first" in args.eval_datasets:
data_root = os.path.join(
args.dataset_root, "/tapvid_davis/tapvid_davis.pkl"
)
data_root = os.path.join(args.dataset_root, "tapvid_davis/tapvid_davis.pkl")
eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root)
eval_dataloader_tapvid_davis = torch.utils.data.DataLoader(
eval_dataset,
@ -551,17 +536,15 @@ class Lite(LightningLite):
if __name__ == "__main__":
signal.signal(signal.SIGUSR1, sig_handler)
signal.signal(signal.SIGTERM, term_handler)
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="cotracker", help="model name")
parser.add_argument("--restore_ckpt", help="restore checkpoint")
parser.add_argument("--ckpt_path", help="restore checkpoint")
parser.add_argument("--restore_ckpt", help="path to restore a checkpoint")
parser.add_argument("--ckpt_path", help="path to save checkpoints")
parser.add_argument(
"--batch_size", type=int, default=4, help="batch size used during training."
)
parser.add_argument(
"--num_workers", type=int, default=6, help="left right consistency loss"
"--num_workers", type=int, default=6, help="number of dataloader workers"
)
parser.add_argument(
@ -578,20 +561,34 @@ if __name__ == "__main__":
"--evaluate_every_n_epoch",
type=int,
default=1,
help="number of flow-field updates during validation forward pass",
help="evaluate during training after every n epochs, after every epoch by default",
)
parser.add_argument(
"--save_every_n_epoch",
type=int,
default=1,
help="number of flow-field updates during validation forward pass",
help="save checkpoints during training after every n epochs, after every epoch by default",
)
parser.add_argument(
"--validate_at_start", action="store_true", help="use mixed precision"
"--validate_at_start",
action="store_true",
help="whether to run evaluation before training starts",
)
parser.add_argument(
"--save_freq",
type=int,
default=100,
help="frequency of trajectory visualization during training",
)
parser.add_argument(
"--traj_per_sample",
type=int,
default=768,
help="the number of trajectories to sample for training",
)
parser.add_argument(
"--dataset_root", type=str, help="path lo all the datasets (train and eval)"
)
parser.add_argument("--save_freq", type=int, default=100, help="save_freq")
parser.add_argument("--traj_per_sample", type=int, default=768, help="save_freq")
parser.add_argument("--dataset_root", type=str, help="path lo all the datasets")
parser.add_argument(
"--train_iters",
@ -605,49 +602,75 @@ if __name__ == "__main__":
parser.add_argument(
"--eval_datasets",
nargs="+",
default=["things", "badja", "fastcapture"],
help="eval datasets.",
default=["things", "badja"],
help="what datasets to use for evaluation",
)
parser.add_argument(
"--remove_space_attn", action="store_true", help="use mixed precision"
"--remove_space_attn",
action="store_true",
help="remove space attention from CoTracker",
)
parser.add_argument(
"--dont_use_augs", action="store_true", help="use mixed precision"
"--dont_use_augs",
action="store_true",
help="don't apply augmentations during training",
)
parser.add_argument(
"--sample_vis_1st_frame", action="store_true", help="use mixed precision"
"--sample_vis_1st_frame",
action="store_true",
help="only sample trajectories with points visible on the first frame",
)
parser.add_argument(
"--sliding_window_len", type=int, default=8, help="use mixed precision"
"--sliding_window_len",
type=int,
default=8,
help="length of the CoTracker sliding window",
)
parser.add_argument(
"--updateformer_hidden_size", type=int, default=384, help="use mixed precision"
"--updateformer_hidden_size",
type=int,
default=384,
help="hidden dimension of the CoTracker transformer model",
)
parser.add_argument(
"--updateformer_num_heads", type=int, default=8, help="use mixed precision"
"--updateformer_num_heads",
type=int,
default=8,
help="number of heads of the CoTracker transformer model",
)
parser.add_argument(
"--updateformer_space_depth", type=int, default=12, help="use mixed precision"
"--updateformer_space_depth",
type=int,
default=12,
help="number of group attention layers in the CoTracker transformer model",
)
parser.add_argument(
"--updateformer_time_depth", type=int, default=12, help="use mixed precision"
"--updateformer_time_depth",
type=int,
default=12,
help="number of time attention layers in the CoTracker transformer model",
)
parser.add_argument(
"--model_stride", type=int, default=8, help="use mixed precision"
"--model_stride",
type=int,
default=8,
help="stride of the CoTracker feature network",
)
parser.add_argument(
"--crop_size",
type=int,
nargs="+",
default=[384, 512],
help="use mixed precision",
help="crop videos to this resolution during training",
)
parser.add_argument(
"--eval_max_seq_len", type=int, default=1000, help="use mixed precision"
"--eval_max_seq_len",
type=int,
default=1000,
help="maximum length of evaluation videos",
)
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
@ -661,5 +684,5 @@ if __name__ == "__main__":
devices="auto",
accelerator="gpu",
precision=32,
num_nodes=4,
# num_nodes=4,
).run(args)