cotracker/demo1.py

51 lines
1.4 KiB
Python

import os
import torch
from base64 import b64encode
from cotracker.utils.visualizer import Visualizer, read_video_from_path
video = read_video_from_path('./assets/F1_shorts.mp4')
video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
from cotracker.predictor import CoTrackerPredictor
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
model = CoTrackerPredictor(
checkpoint=os.path.join(
'./checkpoints/cotracker2.pth'
)
)
# pred_tracks, pred_visibility = model(video, grid_size=30)
# vis = Visualizer(save_dir='./videos', pad_value=100)
# vis.visualize(video=video, tracks=pred_tracks, visibility=pred_visibility, filename='teaser');
grid_query_frame=20
import torch.nn.functional as F
# video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None].to(device)
import time
start_time = time.time()
# pred_tracks, pred_visibility = model(video_interp,
pred_tracks, pred_visibility = model(video,
grid_query_frame=grid_query_frame, backward_tracking=True)
end_time = time.time()
print("Time taken: {:.2f} seconds".format(end_time - start_time))
vis = Visualizer(
save_dir='./videos',
pad_value=20,
linewidth=1,
mode='optical_flow'
)
vis.visualize(
# video=video_interp,
video=video,
tracks=pred_tracks,
visibility=pred_visibility,
filename='dense');