This commit is contained in:
Nikita Karaev 2023-10-30 11:36:31 +00:00
commit 8d36403197
2 changed files with 2 additions and 2 deletions

View File

@ -25,7 +25,7 @@ from cotracker.models.core.embeddings import (
torch.manual_seed(0) torch.manual_seed(0)
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device="cuda"): def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device="cpu"):
if grid_size == 1: if grid_size == 1:
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2], device=device)[ return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2], device=device)[
None, None None, None

View File

@ -34,7 +34,7 @@ def normalize(d):
return out return out
def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"): def meshgrid2d(B, Y, X, stack=False, norm=False, device="cpu"):
# returns a meshgrid sized B x Y x X # returns a meshgrid sized B x Y x X
grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device)) grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))