try to generate the tracking f1 video but too many points
This commit is contained in:
parent
f208a962b9
commit
40e628ac73
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
__pycache__/
|
||||||
|
.vscode/
|
||||||
|
cotracker/__pycache__/
|
@ -21,6 +21,9 @@ class CoTrackerPredictor(torch.nn.Module):
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
|
self.device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.model.to(self.device)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
8
demo1.py
8
demo1.py
@ -10,6 +10,8 @@ video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
|
|||||||
|
|
||||||
from cotracker.predictor import CoTrackerPredictor
|
from cotracker.predictor import CoTrackerPredictor
|
||||||
|
|
||||||
|
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
model = CoTrackerPredictor(
|
model = CoTrackerPredictor(
|
||||||
checkpoint=os.path.join(
|
checkpoint=os.path.join(
|
||||||
'./checkpoints/cotracker2.pth'
|
'./checkpoints/cotracker2.pth'
|
||||||
@ -24,11 +26,12 @@ model = CoTrackerPredictor(
|
|||||||
grid_query_frame=20
|
grid_query_frame=20
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None]
|
# video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None].to(device)
|
||||||
|
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
pred_tracks, pred_visibility = model(video_interp,
|
# pred_tracks, pred_visibility = model(video_interp,
|
||||||
|
pred_tracks, pred_visibility = model(video,
|
||||||
grid_query_frame=grid_query_frame, backward_tracking=True)
|
grid_query_frame=grid_query_frame, backward_tracking=True)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
@ -41,6 +44,7 @@ vis = Visualizer(
|
|||||||
mode='optical_flow'
|
mode='optical_flow'
|
||||||
)
|
)
|
||||||
vis.visualize(
|
vis.visualize(
|
||||||
|
# video=video_interp,
|
||||||
video=video,
|
video=video,
|
||||||
tracks=pred_tracks,
|
tracks=pred_tracks,
|
||||||
visibility=pred_visibility,
|
visibility=pred_visibility,
|
||||||
|
Loading…
Reference in New Issue
Block a user