mps / cpu support
This commit is contained in:
parent
c6878420f5
commit
5890fbd16d
@ -25,14 +25,14 @@ from cotracker.models.core.embeddings import (
|
|||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
|
||||||
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)):
|
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device='cuda'):
|
||||||
if grid_size == 1:
|
if grid_size == 1:
|
||||||
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[
|
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[
|
||||||
None, None
|
None, None
|
||||||
].cuda()
|
].to(device)
|
||||||
|
|
||||||
grid_y, grid_x = meshgrid2d(
|
grid_y, grid_x = meshgrid2d(
|
||||||
1, grid_size, grid_size, stack=False, norm=False, device="cuda"
|
1, grid_size, grid_size, stack=False, norm=False, device=device
|
||||||
)
|
)
|
||||||
step = interp_shape[1] // 64
|
step = interp_shape[1] // 64
|
||||||
if grid_center[0] != 0 or grid_center[1] != 0:
|
if grid_center[0] != 0 or grid_center[1] != 0:
|
||||||
@ -47,7 +47,7 @@ def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)):
|
|||||||
|
|
||||||
grid_y = grid_y + grid_center[0]
|
grid_y = grid_y + grid_center[0]
|
||||||
grid_x = grid_x + grid_center[1]
|
grid_x = grid_x + grid_center[1]
|
||||||
xy = torch.stack([grid_x, grid_y], dim=-1).cuda()
|
xy = torch.stack([grid_x, grid_y], dim=-1).to(device)
|
||||||
return xy
|
return xy
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ from cotracker.models.build_cotracker import (
|
|||||||
|
|
||||||
class CoTrackerPredictor(torch.nn.Module):
|
class CoTrackerPredictor(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth"
|
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth", device=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.interp_shape = (384, 512)
|
self.interp_shape = (384, 512)
|
||||||
@ -25,7 +25,8 @@ class CoTrackerPredictor(torch.nn.Module):
|
|||||||
model = build_cotracker(checkpoint)
|
model = build_cotracker(checkpoint)
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.model.to("cuda")
|
self.device = device or 'cuda'
|
||||||
|
self.model.to(self.device)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -72,7 +73,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
|||||||
grid_width = W // grid_step
|
grid_width = W // grid_step
|
||||||
grid_height = H // grid_step
|
grid_height = H // grid_step
|
||||||
tracks = visibilities = None
|
tracks = visibilities = None
|
||||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to("cuda")
|
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(self.device)
|
||||||
grid_pts[0, :, 0] = grid_query_frame
|
grid_pts[0, :, 0] = grid_query_frame
|
||||||
for offset in tqdm(range(grid_step * grid_step)):
|
for offset in tqdm(range(grid_step * grid_step)):
|
||||||
ox = offset % grid_step
|
ox = offset % grid_step
|
||||||
@ -107,10 +108,10 @@ class CoTrackerPredictor(torch.nn.Module):
|
|||||||
assert B == 1
|
assert B == 1
|
||||||
|
|
||||||
video = video.reshape(B * T, C, H, W)
|
video = video.reshape(B * T, C, H, W)
|
||||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").cuda()
|
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").to(self.device)
|
||||||
video = video.reshape(
|
video = video.reshape(
|
||||||
B, T, 3, self.interp_shape[0], self.interp_shape[1]
|
B, T, 3, self.interp_shape[0], self.interp_shape[1]
|
||||||
).cuda()
|
).to(self.device)
|
||||||
|
|
||||||
if queries is not None:
|
if queries is not None:
|
||||||
queries = queries.clone()
|
queries = queries.clone()
|
||||||
@ -119,7 +120,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
|||||||
queries[:, :, 1] *= self.interp_shape[1] / W
|
queries[:, :, 1] *= self.interp_shape[1] / W
|
||||||
queries[:, :, 2] *= self.interp_shape[0] / H
|
queries[:, :, 2] *= self.interp_shape[0] / H
|
||||||
elif grid_size > 0:
|
elif grid_size > 0:
|
||||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape)
|
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=self.device)
|
||||||
if segm_mask is not None:
|
if segm_mask is not None:
|
||||||
segm_mask = F.interpolate(
|
segm_mask = F.interpolate(
|
||||||
segm_mask, tuple(self.interp_shape), mode="nearest"
|
segm_mask, tuple(self.interp_shape), mode="nearest"
|
||||||
@ -136,7 +137,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if add_support_grid:
|
if add_support_grid:
|
||||||
grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape)
|
grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=self.device)
|
||||||
grid_pts = torch.cat(
|
grid_pts = torch.cat(
|
||||||
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
|
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
|
||||||
)
|
)
|
||||||
|
@ -63,6 +63,7 @@ class Visualizer:
|
|||||||
self,
|
self,
|
||||||
video: torch.Tensor, # (B,T,C,H,W)
|
video: torch.Tensor, # (B,T,C,H,W)
|
||||||
tracks: torch.Tensor, # (B,T,N,2)
|
tracks: torch.Tensor, # (B,T,N,2)
|
||||||
|
visibility: torch.Tensor, # (B, T, N, 1) bool
|
||||||
gt_tracks: torch.Tensor = None, # (B,T,N,2)
|
gt_tracks: torch.Tensor = None, # (B,T,N,2)
|
||||||
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
||||||
filename: str = "video",
|
filename: str = "video",
|
||||||
@ -94,6 +95,7 @@ class Visualizer:
|
|||||||
res_video = self.draw_tracks_on_video(
|
res_video = self.draw_tracks_on_video(
|
||||||
video=video,
|
video=video,
|
||||||
tracks=tracks,
|
tracks=tracks,
|
||||||
|
visibility=visibility,
|
||||||
segm_mask=segm_mask,
|
segm_mask=segm_mask,
|
||||||
gt_tracks=gt_tracks,
|
gt_tracks=gt_tracks,
|
||||||
query_frame=query_frame,
|
query_frame=query_frame,
|
||||||
@ -127,6 +129,7 @@ class Visualizer:
|
|||||||
self,
|
self,
|
||||||
video: torch.Tensor,
|
video: torch.Tensor,
|
||||||
tracks: torch.Tensor,
|
tracks: torch.Tensor,
|
||||||
|
visibility: torch.Tensor,
|
||||||
segm_mask: torch.Tensor = None,
|
segm_mask: torch.Tensor = None,
|
||||||
gt_tracks=None,
|
gt_tracks=None,
|
||||||
query_frame: int = 0,
|
query_frame: int = 0,
|
||||||
@ -228,11 +231,13 @@ class Visualizer:
|
|||||||
if not compensate_for_camera_motion or (
|
if not compensate_for_camera_motion or (
|
||||||
compensate_for_camera_motion and segm_mask[i] > 0
|
compensate_for_camera_motion and segm_mask[i] > 0
|
||||||
):
|
):
|
||||||
|
|
||||||
cv2.circle(
|
cv2.circle(
|
||||||
res_video[t],
|
res_video[t],
|
||||||
coord,
|
coord,
|
||||||
int(self.linewidth * 2),
|
int(self.linewidth * 2),
|
||||||
vector_colors[t, i].tolist(),
|
vector_colors[t, i].tolist(),
|
||||||
|
thickness=-1 if visibility[0, t, i] else 2
|
||||||
-1,
|
-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
7
demo.py
7
demo.py
@ -32,6 +32,11 @@ if __name__ == "__main__":
|
|||||||
default="./checkpoints/cotracker_stride_4_wind_8.pth",
|
default="./checkpoints/cotracker_stride_4_wind_8.pth",
|
||||||
help="cotracker model",
|
help="cotracker model",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
default="cuda",
|
||||||
|
help="Device to use for inference",
|
||||||
|
)
|
||||||
parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size")
|
parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--grid_query_frame",
|
"--grid_query_frame",
|
||||||
@ -54,7 +59,7 @@ if __name__ == "__main__":
|
|||||||
segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
|
segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
|
||||||
segm_mask = torch.from_numpy(segm_mask)[None, None]
|
segm_mask = torch.from_numpy(segm_mask)[None, None]
|
||||||
|
|
||||||
model = CoTrackerPredictor(checkpoint=args.checkpoint)
|
model = CoTrackerPredictor(checkpoint=args.checkpoint, device=args.device)
|
||||||
|
|
||||||
pred_tracks, pred_visibility = model(
|
pred_tracks, pred_visibility = model(
|
||||||
video,
|
video,
|
||||||
|
Loading…
Reference in New Issue
Block a user