demo fixes
This commit is contained in:
parent
e5e09ec34f
commit
53682167fe
Binary file not shown.
Before Width: | Height: | Size: 14 KiB After Width: | Height: | Size: 14 KiB |
@ -55,7 +55,9 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
|
||||
return tracks, visibilities
|
||||
|
||||
def _compute_dense_tracks(self, video, grid_query_frame, grid_size=30, backward_tracking=False):
|
||||
def _compute_dense_tracks(
|
||||
self, video, grid_query_frame, grid_size=150, backward_tracking=False
|
||||
):
|
||||
*_, H, W = video.shape
|
||||
grid_step = W // grid_size
|
||||
grid_width = W // grid_step
|
||||
@ -172,8 +174,9 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
|
||||
inv_tracks = inv_tracks.flip(1)
|
||||
inv_visibilities = inv_visibilities.flip(1)
|
||||
arange = torch.arange(video.shape[1], device=queries.device)[None, :, None]
|
||||
|
||||
mask = tracks == 0
|
||||
mask = (arange < queries[None, :, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
|
||||
|
||||
tracks[mask] = inv_tracks[mask]
|
||||
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
|
||||
|
@ -226,7 +226,7 @@ class Visualizer:
|
||||
|
||||
# draw tracks
|
||||
if self.tracks_leave_trace != 0:
|
||||
for t in range(1, T):
|
||||
for t in range(query_frame + 1, T):
|
||||
first_ind = (
|
||||
max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0
|
||||
)
|
||||
@ -251,7 +251,7 @@ class Visualizer:
|
||||
res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1])
|
||||
|
||||
# draw points
|
||||
for t in range(T):
|
||||
for t in range(query_frame, T):
|
||||
img = Image.fromarray(np.uint8(res_video[t]))
|
||||
for i in range(N):
|
||||
coord = (tracks[t, i, 0], tracks[t, i, 1])
|
||||
|
9
demo.py
9
demo.py
@ -72,7 +72,7 @@ if __name__ == "__main__":
|
||||
model = torch.hub.load("facebookresearch/co-tracker", "cotracker2")
|
||||
model = model.to(DEFAULT_DEVICE)
|
||||
video = video.to(DEFAULT_DEVICE)
|
||||
|
||||
# video = video[:, :20]
|
||||
pred_tracks, pred_visibility = model(
|
||||
video,
|
||||
grid_size=args.grid_size,
|
||||
@ -85,4 +85,9 @@ if __name__ == "__main__":
|
||||
# save a video with predicted tracks
|
||||
seq_name = args.video_path.split("/")[-1]
|
||||
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
|
||||
vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame)
|
||||
vis.visualize(
|
||||
video,
|
||||
pred_tracks,
|
||||
pred_visibility,
|
||||
query_frame=0 if args.backward_tracking else args.grid_query_frame,
|
||||
)
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user