cotracker/online_demo.py
2023-12-27 12:54:02 +00:00

91 lines
3.0 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import argparse
import imageio.v3 as iio
import numpy as np
from cotracker.utils.visualizer import Visualizer
from cotracker.predictor import CoTrackerOnlinePredictor
# Unfortunately MPS acceleration does not support all the features we require,
# but we may be able to enable it in the future
DEFAULT_DEVICE = (
# "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
"cuda"
if torch.cuda.is_available()
else "cpu"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--video_path",
default="./assets/apple.mp4",
help="path to a video",
)
parser.add_argument(
"--checkpoint",
default=None,
help="CoTracker model parameters",
)
parser.add_argument("--grid_size", type=int, default=10, help="Regular grid size")
parser.add_argument(
"--grid_query_frame",
type=int,
default=0,
help="Compute dense and grid tracks starting from this frame",
)
args = parser.parse_args()
if args.checkpoint is not None:
model = CoTrackerOnlinePredictor(checkpoint=args.checkpoint)
else:
model = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online")
model = model.to(DEFAULT_DEVICE)
window_frames = []
def _process_step(window_frames, is_first_step, grid_size):
video_chunk = (
torch.tensor(np.stack(window_frames[-model.step * 2 :]), device=DEFAULT_DEVICE)
.float()
.permute(0, 3, 1, 2)[None]
) # (1, T, 3, H, W)
return model(video_chunk, is_first_step=is_first_step, grid_size=grid_size)
# Iterating over video frames, processing one window at a time:
is_first_step = True
for i, frame in enumerate(
iio.imiter(
"https://github.com/facebookresearch/co-tracker/blob/main/assets/apple.mp4",
plugin="FFMPEG",
)
):
if i % model.step == 0 and i != 0:
pred_tracks, pred_visibility = _process_step(
window_frames, is_first_step, grid_size=args.grid_size
)
is_first_step = False
window_frames.append(frame)
# Processing the final video frames in case video length is not a multiple of model.step
pred_tracks, pred_visibility = _process_step(
window_frames[-(i % model.step) - model.step - 1 :],
is_first_step,
grid_size=args.grid_size,
)
print("Tracks are computed")
# save a video with predicted tracks
seq_name = args.video_path.split("/")[-1]
video = torch.tensor(np.stack(window_frames), device=DEFAULT_DEVICE).permute(0, 3, 1, 2)[None]
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame)