91 lines
3.0 KiB
Python
91 lines
3.0 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
import argparse
|
|
import imageio.v3 as iio
|
|
import numpy as np
|
|
|
|
from cotracker.utils.visualizer import Visualizer
|
|
from cotracker.predictor import CoTrackerOnlinePredictor
|
|
|
|
# Unfortunately MPS acceleration does not support all the features we require,
|
|
# but we may be able to enable it in the future
|
|
|
|
DEFAULT_DEVICE = (
|
|
# "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
|
"cuda"
|
|
if torch.cuda.is_available()
|
|
else "cpu"
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--video_path",
|
|
default="./assets/apple.mp4",
|
|
help="path to a video",
|
|
)
|
|
parser.add_argument(
|
|
"--checkpoint",
|
|
default=None,
|
|
help="CoTracker model parameters",
|
|
)
|
|
parser.add_argument("--grid_size", type=int, default=10, help="Regular grid size")
|
|
parser.add_argument(
|
|
"--grid_query_frame",
|
|
type=int,
|
|
default=0,
|
|
help="Compute dense and grid tracks starting from this frame",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.checkpoint is not None:
|
|
model = CoTrackerOnlinePredictor(checkpoint=args.checkpoint)
|
|
else:
|
|
model = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online")
|
|
model = model.to(DEFAULT_DEVICE)
|
|
|
|
window_frames = []
|
|
|
|
def _process_step(window_frames, is_first_step, grid_size):
|
|
video_chunk = (
|
|
torch.tensor(np.stack(window_frames[-model.step * 2 :]), device=DEFAULT_DEVICE)
|
|
.float()
|
|
.permute(0, 3, 1, 2)[None]
|
|
) # (1, T, 3, H, W)
|
|
return model(video_chunk, is_first_step=is_first_step, grid_size=grid_size)
|
|
|
|
# Iterating over video frames, processing one window at a time:
|
|
is_first_step = True
|
|
for i, frame in enumerate(
|
|
iio.imiter(
|
|
"https://github.com/facebookresearch/co-tracker/blob/main/assets/apple.mp4",
|
|
plugin="FFMPEG",
|
|
)
|
|
):
|
|
if i % model.step == 0 and i != 0:
|
|
pred_tracks, pred_visibility = _process_step(
|
|
window_frames, is_first_step, grid_size=args.grid_size
|
|
)
|
|
is_first_step = False
|
|
window_frames.append(frame)
|
|
# Processing the final video frames in case video length is not a multiple of model.step
|
|
pred_tracks, pred_visibility = _process_step(
|
|
window_frames[-(i % model.step) - model.step - 1 :],
|
|
is_first_step,
|
|
grid_size=args.grid_size,
|
|
)
|
|
|
|
print("Tracks are computed")
|
|
|
|
# save a video with predicted tracks
|
|
seq_name = args.video_path.split("/")[-1]
|
|
video = torch.tensor(np.stack(window_frames), device=DEFAULT_DEVICE).permute(0, 3, 1, 2)[None]
|
|
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
|
|
vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame)
|