try to generate the tracking f1 video but too many points

This commit is contained in:
mhz 2024-08-11 13:42:03 +02:00
parent f208a962b9
commit 40e628ac73
3 changed files with 12 additions and 2 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
__pycache__/
.vscode/
cotracker/__pycache__/

View File

@ -21,6 +21,9 @@ class CoTrackerPredictor(torch.nn.Module):
self.model = model
self.model.eval()
self.device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
@torch.no_grad()
def forward(
self,

View File

@ -10,6 +10,8 @@ video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
from cotracker.predictor import CoTrackerPredictor
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
model = CoTrackerPredictor(
checkpoint=os.path.join(
'./checkpoints/cotracker2.pth'
@ -24,11 +26,12 @@ model = CoTrackerPredictor(
grid_query_frame=20
import torch.nn.functional as F
video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None]
# video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None].to(device)
import time
start_time = time.time()
pred_tracks, pred_visibility = model(video_interp,
# pred_tracks, pred_visibility = model(video_interp,
pred_tracks, pred_visibility = model(video,
grid_query_frame=grid_query_frame, backward_tracking=True)
end_time = time.time()
@ -41,6 +44,7 @@ vis = Visualizer(
mode='optical_flow'
)
vis.visualize(
# video=video_interp,
video=video,
tracks=pred_tracks,
visibility=pred_visibility,