77 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			77 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| 
 | |
| # This source code is licensed under the license found in the
 | |
| # LICENSE file in the root directory of this source tree.
 | |
| 
 | |
| import os
 | |
| import cv2
 | |
| import torch
 | |
| import argparse
 | |
| import numpy as np
 | |
| 
 | |
| from PIL import Image
 | |
| from cotracker.utils.visualizer import Visualizer, read_video_from_path
 | |
| from cotracker.predictor import CoTrackerPredictor
 | |
| 
 | |
| DEFAULT_DEVICE = ('cuda' if torch.cuda.is_available() else
 | |
|                   'mps' if torch.backends.mps.is_available() else
 | |
|                   'cpu')
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser()
 | |
|     parser.add_argument(
 | |
|         "--video_path",
 | |
|         default="./assets/apple.mp4",
 | |
|         help="path to a video",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--mask_path",
 | |
|         default="./assets/apple_mask.png",
 | |
|         help="path to a segmentation mask",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--checkpoint",
 | |
|         default="./checkpoints/cotracker_stride_4_wind_8.pth",
 | |
|         help="cotracker model",
 | |
|     )
 | |
|     parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size")
 | |
|     parser.add_argument(
 | |
|         "--grid_query_frame",
 | |
|         type=int,
 | |
|         default=0,
 | |
|         help="Compute dense and grid tracks starting from this frame ",
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         "--backward_tracking",
 | |
|         action="store_true",
 | |
|         help="Compute tracks in both directions, not only forward",
 | |
|     )
 | |
| 
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     # load the input video frame by frame
 | |
|     video = read_video_from_path(args.video_path)
 | |
|     video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
 | |
|     segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
 | |
|     segm_mask = torch.from_numpy(segm_mask)[None, None]
 | |
| 
 | |
|     model = CoTrackerPredictor(checkpoint=args.checkpoint)
 | |
|     model = model.to(DEFAULT_DEVICE)
 | |
|     video = video.to(DEFAULT_DEVICE)
 | |
| 
 | |
|     pred_tracks, pred_visibility = model(
 | |
|         video,
 | |
|         grid_size=args.grid_size,
 | |
|         grid_query_frame=args.grid_query_frame,
 | |
|         backward_tracking=args.backward_tracking,
 | |
|         # segm_mask=segm_mask
 | |
|     )
 | |
|     print("computed")
 | |
| 
 | |
|     # save a video with predicted tracks
 | |
|     seq_name = args.video_path.split("/")[-1]
 | |
|     vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
 | |
|     vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame)
 |