mps / cpu support

This commit is contained in:
JunkyByte 2023-07-25 16:17:29 +02:00
parent c6878420f5
commit 5890fbd16d
4 changed files with 23 additions and 12 deletions

View File

@ -25,14 +25,14 @@ from cotracker.models.core.embeddings import (
torch.manual_seed(0) torch.manual_seed(0)
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)): def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device='cuda'):
if grid_size == 1: if grid_size == 1:
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[ return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[
None, None None, None
].cuda() ].to(device)
grid_y, grid_x = meshgrid2d( grid_y, grid_x = meshgrid2d(
1, grid_size, grid_size, stack=False, norm=False, device="cuda" 1, grid_size, grid_size, stack=False, norm=False, device=device
) )
step = interp_shape[1] // 64 step = interp_shape[1] // 64
if grid_center[0] != 0 or grid_center[1] != 0: if grid_center[0] != 0 or grid_center[1] != 0:
@ -47,7 +47,7 @@ def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)):
grid_y = grid_y + grid_center[0] grid_y = grid_y + grid_center[0]
grid_x = grid_x + grid_center[1] grid_x = grid_x + grid_center[1]
xy = torch.stack([grid_x, grid_y], dim=-1).cuda() xy = torch.stack([grid_x, grid_y], dim=-1).to(device)
return xy return xy

View File

@ -17,7 +17,7 @@ from cotracker.models.build_cotracker import (
class CoTrackerPredictor(torch.nn.Module): class CoTrackerPredictor(torch.nn.Module):
def __init__( def __init__(
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth" self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth", device=None
): ):
super().__init__() super().__init__()
self.interp_shape = (384, 512) self.interp_shape = (384, 512)
@ -25,7 +25,8 @@ class CoTrackerPredictor(torch.nn.Module):
model = build_cotracker(checkpoint) model = build_cotracker(checkpoint)
self.model = model self.model = model
self.model.to("cuda") self.device = device or 'cuda'
self.model.to(self.device)
self.model.eval() self.model.eval()
@torch.no_grad() @torch.no_grad()
@ -72,7 +73,7 @@ class CoTrackerPredictor(torch.nn.Module):
grid_width = W // grid_step grid_width = W // grid_step
grid_height = H // grid_step grid_height = H // grid_step
tracks = visibilities = None tracks = visibilities = None
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to("cuda") grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(self.device)
grid_pts[0, :, 0] = grid_query_frame grid_pts[0, :, 0] = grid_query_frame
for offset in tqdm(range(grid_step * grid_step)): for offset in tqdm(range(grid_step * grid_step)):
ox = offset % grid_step ox = offset % grid_step
@ -107,10 +108,10 @@ class CoTrackerPredictor(torch.nn.Module):
assert B == 1 assert B == 1
video = video.reshape(B * T, C, H, W) video = video.reshape(B * T, C, H, W)
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").cuda() video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").to(self.device)
video = video.reshape( video = video.reshape(
B, T, 3, self.interp_shape[0], self.interp_shape[1] B, T, 3, self.interp_shape[0], self.interp_shape[1]
).cuda() ).to(self.device)
if queries is not None: if queries is not None:
queries = queries.clone() queries = queries.clone()
@ -119,7 +120,7 @@ class CoTrackerPredictor(torch.nn.Module):
queries[:, :, 1] *= self.interp_shape[1] / W queries[:, :, 1] *= self.interp_shape[1] / W
queries[:, :, 2] *= self.interp_shape[0] / H queries[:, :, 2] *= self.interp_shape[0] / H
elif grid_size > 0: elif grid_size > 0:
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape) grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=self.device)
if segm_mask is not None: if segm_mask is not None:
segm_mask = F.interpolate( segm_mask = F.interpolate(
segm_mask, tuple(self.interp_shape), mode="nearest" segm_mask, tuple(self.interp_shape), mode="nearest"
@ -136,7 +137,7 @@ class CoTrackerPredictor(torch.nn.Module):
) )
if add_support_grid: if add_support_grid:
grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape) grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=self.device)
grid_pts = torch.cat( grid_pts = torch.cat(
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
) )

View File

@ -63,6 +63,7 @@ class Visualizer:
self, self,
video: torch.Tensor, # (B,T,C,H,W) video: torch.Tensor, # (B,T,C,H,W)
tracks: torch.Tensor, # (B,T,N,2) tracks: torch.Tensor, # (B,T,N,2)
visibility: torch.Tensor, # (B, T, N, 1) bool
gt_tracks: torch.Tensor = None, # (B,T,N,2) gt_tracks: torch.Tensor = None, # (B,T,N,2)
segm_mask: torch.Tensor = None, # (B,1,H,W) segm_mask: torch.Tensor = None, # (B,1,H,W)
filename: str = "video", filename: str = "video",
@ -94,6 +95,7 @@ class Visualizer:
res_video = self.draw_tracks_on_video( res_video = self.draw_tracks_on_video(
video=video, video=video,
tracks=tracks, tracks=tracks,
visibility=visibility,
segm_mask=segm_mask, segm_mask=segm_mask,
gt_tracks=gt_tracks, gt_tracks=gt_tracks,
query_frame=query_frame, query_frame=query_frame,
@ -127,6 +129,7 @@ class Visualizer:
self, self,
video: torch.Tensor, video: torch.Tensor,
tracks: torch.Tensor, tracks: torch.Tensor,
visibility: torch.Tensor,
segm_mask: torch.Tensor = None, segm_mask: torch.Tensor = None,
gt_tracks=None, gt_tracks=None,
query_frame: int = 0, query_frame: int = 0,
@ -228,11 +231,13 @@ class Visualizer:
if not compensate_for_camera_motion or ( if not compensate_for_camera_motion or (
compensate_for_camera_motion and segm_mask[i] > 0 compensate_for_camera_motion and segm_mask[i] > 0
): ):
cv2.circle( cv2.circle(
res_video[t], res_video[t],
coord, coord,
int(self.linewidth * 2), int(self.linewidth * 2),
vector_colors[t, i].tolist(), vector_colors[t, i].tolist(),
thickness=-1 if visibility[0, t, i] else 2
-1, -1,
) )

View File

@ -32,6 +32,11 @@ if __name__ == "__main__":
default="./checkpoints/cotracker_stride_4_wind_8.pth", default="./checkpoints/cotracker_stride_4_wind_8.pth",
help="cotracker model", help="cotracker model",
) )
parser.add_argument(
"--device",
default="cuda",
help="Device to use for inference",
)
parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size") parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size")
parser.add_argument( parser.add_argument(
"--grid_query_frame", "--grid_query_frame",
@ -54,7 +59,7 @@ if __name__ == "__main__":
segm_mask = np.array(Image.open(os.path.join(args.mask_path))) segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
segm_mask = torch.from_numpy(segm_mask)[None, None] segm_mask = torch.from_numpy(segm_mask)[None, None]
model = CoTrackerPredictor(checkpoint=args.checkpoint) model = CoTrackerPredictor(checkpoint=args.checkpoint, device=args.device)
pred_tracks, pred_visibility = model( pred_tracks, pred_visibility = model(
video, video,