Merge branch 'main' of https://github.com/facebookresearch/co-tracker into main
This commit is contained in:
commit
8d36403197
@ -25,7 +25,7 @@ from cotracker.models.core.embeddings import (
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device="cuda"):
|
||||
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device="cpu"):
|
||||
if grid_size == 1:
|
||||
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2], device=device)[
|
||||
None, None
|
||||
|
@ -34,7 +34,7 @@ def normalize(d):
|
||||
return out
|
||||
|
||||
|
||||
def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"):
|
||||
def meshgrid2d(B, Y, X, stack=False, norm=False, device="cpu"):
|
||||
# returns a meshgrid sized B x Y x X
|
||||
|
||||
grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))
|
||||
|
Loading…
Reference in New Issue
Block a user