Copyright (c) Meta Platforms, Inc. and affiliates.


 "Open


# CoTracker: It is Better to Track Together
This is a demo for CoTracker, a model that can track any point in a video.

"Logo"

Don't forget to turn on GPU support if you're running this demo in Colab. 

**Runtime** -> **Change runtime type** -> **Hardware accelerator** -> **GPU**

Let's install dependencies for Colab:

In [None]:
# !git clone https://github.com/facebookresearch/co-tracker
# %cd co-tracker
# !pip install -e .
# !pip install opencv-python einops timm matplotlib moviepy flow_vis
# !mkdir checkpoints
# %cd checkpoints
# !wget https://huggingface.co/facebook/cotracker/resolve/main/cotracker2.pth

In [2]:
%cd ..
import os
import torch

from base64 import b64encode
from cotracker.utils.visualizer import Visualizer, read_video_from_path
from IPython.display import HTML

/mnt/d/cotracker


In [3]:
if torch.cuda.is_available():
 print('CUDA available')

CUDA available


Read a video from CO3D:

In [4]:
video = read_video_from_path('./assets/F1_shorts.mp4')
video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()

In [5]:
def show_video(video_path):
 video_file = open(video_path, "r+b").read()
 video_url = f"data:video/mp4;base64,{b64encode(video_file).decode()}"
 return HTML(f"""""")
 
show_video("./assets/F1_shorts.mp4")

Import CoTrackerPredictor and create an instance of it. We'll use this object to estimate tracks:

In [7]:
from cotracker.predictor import CoTrackerPredictor

model = CoTrackerPredictor(
 checkpoint=os.path.join(
 './checkpoints/cotracker2.pth'
 )
)

 state_dict = torch.load(f, map_location="cpu")


(384, 512)


Track points sampled on a regular grid of size 30\*30 on the first frame:

In [8]:
pred_tracks, pred_visibility = model(video, grid_size=30)

: 

Visualize and save the result: 

In [7]:
vis = Visualizer(save_dir='./videos', pad_value=100)
vis.visualize(video=video, tracks=pred_tracks, visibility=pred_visibility, filename='teaser');

: 

In [1]:
show_video("./videos/teaser.mp4")

NameError: name 'show_video' is not defined

## Tracking manually selected points

We will start by tracking points queried manually.
We define a queried point as: [time, x coord, y coord] 

So, the code below defines points with different x and y coordinates sampled on frames 0, 10, 20, and 30:

In [None]:
queries = torch.tensor([
 [0., 400., 350.], # point tracked from the first frame
 [10., 600., 500.], # frame number 10
 [20., 750., 600.], # ...
 [30., 900., 200.]
])
if torch.cuda.is_available():
 queries = queries.cuda()

That's what our queried points look like:

In [None]:
import matplotlib.pyplot as plt
# Create a list of frame numbers corresponding to each point
frame_numbers = queries[:,0].int().tolist()

fig, axs = plt.subplots(2, 2)
axs = axs.flatten()

for i, (query, frame_number) in enumerate(zip(queries, frame_numbers)):
 ax = axs[i]
 ax.plot(query[1].item(), query[2].item(), 'ro') 
 
 ax.set_title("Frame {}".format(frame_number))
 ax.set_xlim(0, video.shape[4])
 ax.set_ylim(0, video.shape[3])
 ax.invert_yaxis()
 
plt.tight_layout()
plt.show()

We pass these points as input to the model and track them:

In [None]:
pred_tracks, pred_visibility = model(video, queries=queries[None])

Finally, we visualize the results with tracks leaving traces from the frame where the tracking starts.
Color encodes time:

In [None]:
vis = Visualizer(
 save_dir='./videos',
 linewidth=6,
 mode='cool',
 tracks_leave_trace=-1
)
vis.visualize(
 video=video,
 tracks=pred_tracks,
 visibility=pred_visibility,
 filename='queries');

In [None]:
show_video("./videos/queries.mp4")

Notice that points queried at frames 10, 20, and 30 are tracked **incorrectly** before the query frame. This is because CoTracker is an online algorithm and only tracks points in one direction. However, we can also run it backward from the queried point to track in both directions. Let's correct this:

In [None]:
pred_tracks, pred_visibility = model(video, queries=queries[None], backward_tracking=True)
vis.visualize(
 video=video,
 tracks=pred_tracks,
 visibility=pred_visibility,
 filename='queries_backward');

In [None]:
show_video("./videos/queries_backward.mp4")

## Points on a regular grid

### Tracking forward from the frame number x

Let's now sample points on a regular grid and start tracking from the frame number 20 with a grid of 30\*30. 

In [None]:
grid_size = 30
grid_query_frame = 20

In [None]:
pred_tracks, pred_visibility = model(video, grid_size=grid_size, grid_query_frame=grid_query_frame)

In [None]:
vis = Visualizer(save_dir='./videos', pad_value=100)
vis.visualize(
 video=video,
 tracks=pred_tracks,
 visibility=pred_visibility,
 filename='grid_query_20',
 query_frame=grid_query_frame);

Note that tracking starts only from points sampled on a frame in the middle of the video. This is different from the grid in the first example:

In [None]:
show_video("./videos/grid_query_20.mp4")

### Tracking forward **and backward** from the frame number x

In [None]:
grid_size = 30
grid_query_frame = 20

Let's activate backward tracking:

In [None]:
pred_tracks, pred_visibility = model(video, grid_size=grid_size, grid_query_frame=grid_query_frame, backward_tracking=True)
vis.visualize(
 video=video,
 tracks=pred_tracks,
 visibility=pred_visibility,
 filename='grid_query_20_backward');

As you can see, we are now tracking points queried in the middle from the first frame:

In [None]:
show_video("./videos/grid_query_20_backward.mp4")

## Regular grid + Segmentation mask

Let's now sample points on a grid and filter them with a segmentation mask.
This allows us to track points sampled densely on an object because we consume less GPU memory.

In [None]:
import numpy as np
from PIL import Image
grid_size = 100

In [None]:
input_mask = './assets/apple_mask.png'
segm_mask = np.array(Image.open(input_mask))

That's a segmentation mask for the first frame:

In [None]:
plt.imshow((segm_mask[...,None]/255.*video[0,0].permute(1,2,0).cpu().numpy()/255.))

In [None]:
pred_tracks, pred_visibility = model(video, grid_size=grid_size, segm_mask=torch.from_numpy(segm_mask)[None, None])
vis = Visualizer(
 save_dir='./videos',
 pad_value=100,
 linewidth=2,
)
vis.visualize(
 video=video,
 tracks=pred_tracks,
 visibility=pred_visibility,
 filename='segm_grid');

We are now only tracking points on the object (and around):

In [None]:
show_video("./videos/segm_grid.mp4")

## Dense Tracks

### Tracking forward **and backward** from the frame number x

CoTracker also has a mode to track **every pixel** in a video in a **dense** manner but it is much slower than in previous examples. Let's downsample the video in order to make it faster: 

In [None]:
video.shape

In [None]:
import torch.nn.functional as F
video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None]

The video now has a much lower resolution:

In [None]:
video_interp.shape

In [None]:
grid_query_frame=20

Again, let's track points in both directions. This will only take a couple of minutes:

In [None]:
import time
start_time = time.time()
pred_tracks, pred_visibility = model(video_interp, grid_query_frame=grid_query_frame, backward_tracking=True)
end_time = time.time() 
print("Time taken: {:.2f} seconds".format(end_time - start_time))

Visualization with an optical flow color encoding:

In [None]:
vis = Visualizer(
 save_dir='./videos',
 pad_value=20,
 linewidth=1,
 mode='optical_flow'
)
vis.visualize(
 video=video_interp,
 tracks=pred_tracks,
 visibility=pred_visibility,
 filename='dense');

In [None]:
show_video("./videos/dense.mp4")

That's all, now you can use CoTracker in your projects!