This commit is contained in:
JunkyByte 2023-07-25 16:31:44 +02:00
parent 51175e006a
commit 4a9286e17f

View File

@ -133,7 +133,7 @@ class CoTrackerPredictor(torch.nn.Module):
) )
if add_support_grid: if add_support_grid:
grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=self.device) grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=video.device)
grid_pts = torch.cat( grid_pts = torch.cat(
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
) )