# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import argparse import imageio.v3 as iio import numpy as np from cotracker.utils.visualizer import Visualizer from cotracker.predictor import CoTrackerOnlinePredictor # Unfortunately MPS acceleration does not support all the features we require, # but we may be able to enable it in the future DEFAULT_DEVICE = ( # "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" "cuda" if torch.cuda.is_available() else "cpu" ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--video_path", default="./assets/apple.mp4", help="path to a video", ) parser.add_argument( "--checkpoint", default=None, help="CoTracker model parameters", ) parser.add_argument("--grid_size", type=int, default=10, help="Regular grid size") parser.add_argument( "--grid_query_frame", type=int, default=0, help="Compute dense and grid tracks starting from this frame", ) args = parser.parse_args() if args.checkpoint is not None: model = CoTrackerOnlinePredictor(checkpoint=args.checkpoint) else: model = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online") model = model.to(DEFAULT_DEVICE) window_frames = [] def _process_step(window_frames, is_first_step, grid_size): video_chunk = ( torch.tensor(np.stack(window_frames[-model.step * 2 :]), device=DEFAULT_DEVICE) .float() .permute(0, 3, 1, 2)[None] ) # (1, T, 3, H, W) return model(video_chunk, is_first_step=is_first_step, grid_size=grid_size) # Iterating over video frames, processing one window at a time: is_first_step = True for i, frame in enumerate( iio.imiter( "https://github.com/facebookresearch/co-tracker/blob/main/assets/apple.mp4", plugin="FFMPEG", ) ): if i % model.step == 0 and i != 0: pred_tracks, pred_visibility = _process_step( window_frames, is_first_step, grid_size=args.grid_size ) is_first_step = False window_frames.append(frame) # Processing the final video frames in case video length is not a multiple of model.step pred_tracks, pred_visibility = _process_step( window_frames[-(i % model.step) - model.step - 1 :], is_first_step, grid_size=args.grid_size, ) print("Tracks are computed") # save a video with predicted tracks seq_name = args.video_path.split("/")[-1] video = torch.tensor(np.stack(window_frames), device=DEFAULT_DEVICE).permute(0, 3, 1, 2)[None] vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame)