demo fixes

This commit is contained in:
Nikita Karaev 2023-12-28 17:05:45 +00:00
parent e5e09ec34f
commit 53682167fe
5 changed files with 178 additions and 83 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

After

Width:  |  Height:  |  Size: 14 KiB

View File

@ -55,7 +55,9 @@ class CoTrackerPredictor(torch.nn.Module):
return tracks, visibilities
def _compute_dense_tracks(self, video, grid_query_frame, grid_size=30, backward_tracking=False):
def _compute_dense_tracks(
self, video, grid_query_frame, grid_size=150, backward_tracking=False
):
*_, H, W = video.shape
grid_step = W // grid_size
grid_width = W // grid_step
@ -172,8 +174,9 @@ class CoTrackerPredictor(torch.nn.Module):
inv_tracks = inv_tracks.flip(1)
inv_visibilities = inv_visibilities.flip(1)
arange = torch.arange(video.shape[1], device=queries.device)[None, :, None]
mask = tracks == 0
mask = (arange < queries[None, :, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
tracks[mask] = inv_tracks[mask]
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]

View File

@ -226,7 +226,7 @@ class Visualizer:
# draw tracks
if self.tracks_leave_trace != 0:
for t in range(1, T):
for t in range(query_frame + 1, T):
first_ind = (
max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0
)
@ -251,7 +251,7 @@ class Visualizer:
res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1])
# draw points
for t in range(T):
for t in range(query_frame, T):
img = Image.fromarray(np.uint8(res_video[t]))
for i in range(N):
coord = (tracks[t, i, 0], tracks[t, i, 1])

View File

@ -72,7 +72,7 @@ if __name__ == "__main__":
model = torch.hub.load("facebookresearch/co-tracker", "cotracker2")
model = model.to(DEFAULT_DEVICE)
video = video.to(DEFAULT_DEVICE)
# video = video[:, :20]
pred_tracks, pred_visibility = model(
video,
grid_size=args.grid_size,
@ -85,4 +85,9 @@ if __name__ == "__main__":
# save a video with predicted tracks
seq_name = args.video_path.split("/")[-1]
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame)
vis.visualize(
video,
pred_tracks,
pred_visibility,
query_frame=0 if args.backward_tracking else args.grid_query_frame,
)

File diff suppressed because one or more lines are too long