cotracker/demo.py
2023-07-17 17:49:06 -07:00

72 lines
2.1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
import argparse
import numpy as np
from torchvision.io import read_video
from PIL import Image
from cotracker.utils.visualizer import Visualizer
from cotracker.predictor import CoTrackerPredictor
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--video_path",
default="./assets/apple.mp4",
help="path to a video",
)
parser.add_argument(
"--mask_path",
default="./assets/apple_mask.png",
help="path to a segmentation mask",
)
parser.add_argument(
"--checkpoint",
default="./checkpoints/cotracker_stride_4_wind_8.pth",
help="cotracker model",
)
parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size")
parser.add_argument(
"--grid_query_frame",
type=int,
default=0,
help="Compute dense and grid tracks starting from this frame ",
)
parser.add_argument(
"--backward_tracking",
action="store_true",
help="Compute tracks in both directions, not only forward",
)
args = parser.parse_args()
# load the input video frame by frame
video = read_video(args.video_path)
video = video[0].permute(0, 3, 1, 2)[None].float()
segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
segm_mask = torch.from_numpy(segm_mask)[None, None]
model = CoTrackerPredictor(checkpoint=args.checkpoint)
pred_tracks, pred_visibility = model(
video,
grid_size=args.grid_size,
grid_query_frame=args.grid_query_frame,
backward_tracking=args.backward_tracking,
# segm_mask=segm_mask
)
print("computed")
# save a video with predicted tracks
seq_name = args.video_path.split("/")[-1]
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
vis.visualize(video, pred_tracks, query_frame=args.grid_query_frame)