Merge pull request #14 from JunkyByte/main
minor fixes / mps default device when available / occlusion visualization
This commit is contained in:
commit
e84ca71ba5
@ -133,7 +133,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if add_support_grid:
|
if add_support_grid:
|
||||||
grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape)
|
grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=video.device)
|
||||||
grid_pts = torch.cat(
|
grid_pts = torch.cat(
|
||||||
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
|
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
|
||||||
)
|
)
|
||||||
|
@ -62,6 +62,7 @@ class Visualizer:
|
|||||||
self,
|
self,
|
||||||
video: torch.Tensor, # (B,T,C,H,W)
|
video: torch.Tensor, # (B,T,C,H,W)
|
||||||
tracks: torch.Tensor, # (B,T,N,2)
|
tracks: torch.Tensor, # (B,T,N,2)
|
||||||
|
visibility: torch.Tensor = None, # (B, T, N, 1) bool
|
||||||
gt_tracks: torch.Tensor = None, # (B,T,N,2)
|
gt_tracks: torch.Tensor = None, # (B,T,N,2)
|
||||||
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
||||||
filename: str = "video",
|
filename: str = "video",
|
||||||
@ -93,6 +94,7 @@ class Visualizer:
|
|||||||
res_video = self.draw_tracks_on_video(
|
res_video = self.draw_tracks_on_video(
|
||||||
video=video,
|
video=video,
|
||||||
tracks=tracks,
|
tracks=tracks,
|
||||||
|
visibility=visibility,
|
||||||
segm_mask=segm_mask,
|
segm_mask=segm_mask,
|
||||||
gt_tracks=gt_tracks,
|
gt_tracks=gt_tracks,
|
||||||
query_frame=query_frame,
|
query_frame=query_frame,
|
||||||
@ -126,6 +128,7 @@ class Visualizer:
|
|||||||
self,
|
self,
|
||||||
video: torch.Tensor,
|
video: torch.Tensor,
|
||||||
tracks: torch.Tensor,
|
tracks: torch.Tensor,
|
||||||
|
visibility: torch.Tensor = None,
|
||||||
segm_mask: torch.Tensor = None,
|
segm_mask: torch.Tensor = None,
|
||||||
gt_tracks=None,
|
gt_tracks=None,
|
||||||
query_frame: int = 0,
|
query_frame: int = 0,
|
||||||
@ -227,11 +230,13 @@ class Visualizer:
|
|||||||
if not compensate_for_camera_motion or (
|
if not compensate_for_camera_motion or (
|
||||||
compensate_for_camera_motion and segm_mask[i] > 0
|
compensate_for_camera_motion and segm_mask[i] > 0
|
||||||
):
|
):
|
||||||
|
|
||||||
cv2.circle(
|
cv2.circle(
|
||||||
res_video[t],
|
res_video[t],
|
||||||
coord,
|
coord,
|
||||||
int(self.linewidth * 2),
|
int(self.linewidth * 2),
|
||||||
vector_colors[t, i].tolist(),
|
vector_colors[t, i].tolist(),
|
||||||
|
thickness=-1 if visibility[0, t, i] else 2
|
||||||
-1,
|
-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
12
demo.py
12
demo.py
@ -14,6 +14,9 @@ from PIL import Image
|
|||||||
from cotracker.utils.visualizer import Visualizer, read_video_from_path
|
from cotracker.utils.visualizer import Visualizer, read_video_from_path
|
||||||
from cotracker.predictor import CoTrackerPredictor
|
from cotracker.predictor import CoTrackerPredictor
|
||||||
|
|
||||||
|
DEFAULT_DEVICE = ('cuda' if torch.cuda.is_available() else
|
||||||
|
'mps' if torch.backends.mps.is_available() else
|
||||||
|
'cpu')
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -55,11 +58,8 @@ if __name__ == "__main__":
|
|||||||
segm_mask = torch.from_numpy(segm_mask)[None, None]
|
segm_mask = torch.from_numpy(segm_mask)[None, None]
|
||||||
|
|
||||||
model = CoTrackerPredictor(checkpoint=args.checkpoint)
|
model = CoTrackerPredictor(checkpoint=args.checkpoint)
|
||||||
if torch.cuda.is_available():
|
model = model.to(DEFAULT_DEVICE)
|
||||||
model = model.cuda()
|
video = video.to(DEFAULT_DEVICE)
|
||||||
video = video.cuda()
|
|
||||||
else:
|
|
||||||
print("CUDA is not available!")
|
|
||||||
|
|
||||||
pred_tracks, pred_visibility = model(
|
pred_tracks, pred_visibility = model(
|
||||||
video,
|
video,
|
||||||
@ -73,4 +73,4 @@ if __name__ == "__main__":
|
|||||||
# save a video with predicted tracks
|
# save a video with predicted tracks
|
||||||
seq_name = args.video_path.split("/")[-1]
|
seq_name = args.video_path.split("/")[-1]
|
||||||
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
|
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
|
||||||
vis.visualize(video, pred_tracks, query_frame=args.grid_query_frame)
|
vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame)
|
||||||
|
Loading…
Reference in New Issue
Block a user