demo fixes
This commit is contained in:
parent
e5e09ec34f
commit
53682167fe
Binary file not shown.
Before Width: | Height: | Size: 14 KiB After Width: | Height: | Size: 14 KiB |
@ -55,7 +55,9 @@ class CoTrackerPredictor(torch.nn.Module):
|
|||||||
|
|
||||||
return tracks, visibilities
|
return tracks, visibilities
|
||||||
|
|
||||||
def _compute_dense_tracks(self, video, grid_query_frame, grid_size=30, backward_tracking=False):
|
def _compute_dense_tracks(
|
||||||
|
self, video, grid_query_frame, grid_size=150, backward_tracking=False
|
||||||
|
):
|
||||||
*_, H, W = video.shape
|
*_, H, W = video.shape
|
||||||
grid_step = W // grid_size
|
grid_step = W // grid_size
|
||||||
grid_width = W // grid_step
|
grid_width = W // grid_step
|
||||||
@ -172,8 +174,9 @@ class CoTrackerPredictor(torch.nn.Module):
|
|||||||
|
|
||||||
inv_tracks = inv_tracks.flip(1)
|
inv_tracks = inv_tracks.flip(1)
|
||||||
inv_visibilities = inv_visibilities.flip(1)
|
inv_visibilities = inv_visibilities.flip(1)
|
||||||
|
arange = torch.arange(video.shape[1], device=queries.device)[None, :, None]
|
||||||
|
|
||||||
mask = tracks == 0
|
mask = (arange < queries[None, :, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
|
||||||
|
|
||||||
tracks[mask] = inv_tracks[mask]
|
tracks[mask] = inv_tracks[mask]
|
||||||
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
|
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
|
||||||
|
@ -226,7 +226,7 @@ class Visualizer:
|
|||||||
|
|
||||||
# draw tracks
|
# draw tracks
|
||||||
if self.tracks_leave_trace != 0:
|
if self.tracks_leave_trace != 0:
|
||||||
for t in range(1, T):
|
for t in range(query_frame + 1, T):
|
||||||
first_ind = (
|
first_ind = (
|
||||||
max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0
|
max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0
|
||||||
)
|
)
|
||||||
@ -251,7 +251,7 @@ class Visualizer:
|
|||||||
res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1])
|
res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1])
|
||||||
|
|
||||||
# draw points
|
# draw points
|
||||||
for t in range(T):
|
for t in range(query_frame, T):
|
||||||
img = Image.fromarray(np.uint8(res_video[t]))
|
img = Image.fromarray(np.uint8(res_video[t]))
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
coord = (tracks[t, i, 0], tracks[t, i, 1])
|
coord = (tracks[t, i, 0], tracks[t, i, 1])
|
||||||
|
9
demo.py
9
demo.py
@ -72,7 +72,7 @@ if __name__ == "__main__":
|
|||||||
model = torch.hub.load("facebookresearch/co-tracker", "cotracker2")
|
model = torch.hub.load("facebookresearch/co-tracker", "cotracker2")
|
||||||
model = model.to(DEFAULT_DEVICE)
|
model = model.to(DEFAULT_DEVICE)
|
||||||
video = video.to(DEFAULT_DEVICE)
|
video = video.to(DEFAULT_DEVICE)
|
||||||
|
# video = video[:, :20]
|
||||||
pred_tracks, pred_visibility = model(
|
pred_tracks, pred_visibility = model(
|
||||||
video,
|
video,
|
||||||
grid_size=args.grid_size,
|
grid_size=args.grid_size,
|
||||||
@ -85,4 +85,9 @@ if __name__ == "__main__":
|
|||||||
# save a video with predicted tracks
|
# save a video with predicted tracks
|
||||||
seq_name = args.video_path.split("/")[-1]
|
seq_name = args.video_path.split("/")[-1]
|
||||||
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
|
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
|
||||||
vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame)
|
vis.visualize(
|
||||||
|
video,
|
||||||
|
pred_tracks,
|
||||||
|
pred_visibility,
|
||||||
|
query_frame=0 if args.backward_tracking else args.grid_query_frame,
|
||||||
|
)
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user