add cpu-only mode

This commit is contained in:
nikitakaraevv 2023-07-21 13:41:52 -07:00
parent 32aedaf9b6
commit ab0ce3c977
7 changed files with 39 additions and 42 deletions

View File

@ -185,7 +185,11 @@ class Evaluator:
if not all(gotit): if not all(gotit):
print("batch is None") print("batch is None")
continue continue
dataclass_to_cuda_(sample) if torch.cuda.is_available():
dataclass_to_cuda_(sample)
device = torch.device("cuda")
else:
device = torch.device("cpu")
if ( if (
not train_mode not train_mode
@ -205,7 +209,7 @@ class Evaluator:
queries[:, :, 1], queries[:, :, 1],
], ],
dim=2, dim=2,
) ).to(device)
else: else:
queries = torch.cat( queries = torch.cat(
[ [
@ -213,7 +217,7 @@ class Evaluator:
sample.trajectory[:, 0], sample.trajectory[:, 0],
], ],
dim=2, dim=2,
) ).to(device)
pred_tracks = model(sample.video, queries) pred_tracks = model(sample.video, queries)
if "strided" in dataset_name: if "strided" in dataset_name:

View File

@ -102,6 +102,8 @@ def run_eval(cfg: DefaultConfig):
single_point=cfg.single_point, single_point=cfg.single_point,
n_iters=cfg.n_iters, n_iters=cfg.n_iters,
) )
if torch.cuda.is_available():
predictor.model = predictor.model.cuda()
# Setting the random seeds # Setting the random seeds
torch.manual_seed(cfg.seed) torch.manual_seed(cfg.seed)

View File

@ -25,14 +25,14 @@ from cotracker.models.core.embeddings import (
torch.manual_seed(0) torch.manual_seed(0)
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)): def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device="cuda"):
if grid_size == 1: if grid_size == 1:
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[ return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2], device=device)[
None, None None, None
].cuda() ]
grid_y, grid_x = meshgrid2d( grid_y, grid_x = meshgrid2d(
1, grid_size, grid_size, stack=False, norm=False, device="cuda" 1, grid_size, grid_size, stack=False, norm=False, device=device
) )
step = interp_shape[1] // 64 step = interp_shape[1] // 64
if grid_center[0] != 0 or grid_center[1] != 0: if grid_center[0] != 0 or grid_center[1] != 0:
@ -47,7 +47,7 @@ def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)):
grid_y = grid_y + grid_center[0] grid_y = grid_y + grid_center[0]
grid_x = grid_x + grid_center[1] grid_x = grid_x + grid_center[1]
xy = torch.stack([grid_x, grid_y], dim=-1).cuda() xy = torch.stack([grid_x, grid_y], dim=-1).to(device)
return xy return xy

View File

@ -29,11 +29,10 @@ class EvaluationPredictor(torch.nn.Module):
self.n_iters = n_iters self.n_iters = n_iters
self.model = cotracker_model self.model = cotracker_model
self.model.to("cuda")
self.model.eval() self.model.eval()
def forward(self, video, queries): def forward(self, video, queries):
queries = queries.clone().cuda() queries = queries.clone()
B, T, C, H, W = video.shape B, T, C, H, W = video.shape
B, N, D = queries.shape B, N, D = queries.shape
@ -42,14 +41,16 @@ class EvaluationPredictor(torch.nn.Module):
rgbs = video.reshape(B * T, C, H, W) rgbs = video.reshape(B * T, C, H, W)
rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear") rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear")
rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]).cuda() rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
device = rgbs.device
queries[:, :, 1] *= self.interp_shape[1] / W queries[:, :, 1] *= self.interp_shape[1] / W
queries[:, :, 2] *= self.interp_shape[0] / H queries[:, :, 2] *= self.interp_shape[0] / H
if self.single_point: if self.single_point:
traj_e = torch.zeros((B, T, N, 2)).cuda() traj_e = torch.zeros((B, T, N, 2), device=device)
vis_e = torch.zeros((B, T, N)).cuda() vis_e = torch.zeros((B, T, N), device=device)
for pind in range((N)): for pind in range((N)):
query = queries[:, pind : pind + 1] query = queries[:, pind : pind + 1]
@ -60,8 +61,10 @@ class EvaluationPredictor(torch.nn.Module):
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1] vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
else: else:
if self.grid_size > 0: if self.grid_size > 0:
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:]) xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device)
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() # xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(
device
) #
queries = torch.cat([queries, xy], dim=1) # queries = torch.cat([queries, xy], dim=1) #
traj_e, __, vis_e, __ = self.model( traj_e, __, vis_e, __ = self.model(
@ -91,8 +94,8 @@ class EvaluationPredictor(torch.nn.Module):
query = torch.cat([query, xy_target], dim=1).to(device) # query = torch.cat([query, xy_target], dim=1).to(device) #
if self.grid_size > 0: if self.grid_size > 0:
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:]) xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device)
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() # xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
query = torch.cat([query, xy], dim=1).to(device) # query = torch.cat([query, xy], dim=1).to(device) #
# crop the video to start from the queried frame # crop the video to start from the queried frame
query[0, 0, 0] = 0 query[0, 0, 0] = 0

View File

@ -116,7 +116,7 @@ class CoTrackerPredictor(torch.nn.Module):
queries[:, :, 1] *= self.interp_shape[1] / W queries[:, :, 1] *= self.interp_shape[1] / W
queries[:, :, 2] *= self.interp_shape[0] / H queries[:, :, 2] *= self.interp_shape[0] / H
elif grid_size > 0: elif grid_size > 0:
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape) grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
if segm_mask is not None: if segm_mask is not None:
segm_mask = F.interpolate( segm_mask = F.interpolate(
segm_mask, tuple(self.interp_shape), mode="nearest" segm_mask, tuple(self.interp_shape), mode="nearest"

View File

@ -65,26 +65,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 2,
"id": "1745a859-71d4-4ec3-8ef3-027cabe786d4", "id": "1745a859-71d4-4ec3-8ef3-027cabe786d4",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"/private/home/nikitakaraev/dev/neurips_2023/co-tracker\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/private/home/nikitakaraev/.conda/envs/stereoformer/lib/python3.8/site-packages/requests/__init__.py:109: RequestsDependencyWarning: urllib3 (1.26.14) or chardet (None)/charset_normalizer (3.2.0) doesn't match a supported version!\n",
" warnings.warn(\n"
]
}
],
"source": [ "source": [
"%cd ..\n", "%cd ..\n",
"import os\n", "import os\n",
@ -105,7 +89,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 3,
"id": "f1f9ca4d-951e-49d2-8844-91f7bcadfecd", "id": "f1f9ca4d-951e-49d2-8844-91f7bcadfecd",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -116,7 +100,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 4,
"id": "fb4c2e9d-0e85-4c10-81a2-827d0759bf87", "id": "fb4c2e9d-0e85-4c10-81a2-827d0759bf87",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -129,7 +113,7 @@
"<IPython.core.display.HTML object>" "<IPython.core.display.HTML object>"
] ]
}, },
"execution_count": 3, "execution_count": 4,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -175,8 +159,8 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"if torch.cuda.is_available():\n", "if torch.cuda.is_available():\n",
" model=model.cuda()\n", " model = model.cuda()\n",
" video=video.cuda()" " video = video.cuda()"
] ]
}, },
{ {
@ -282,7 +266,9 @@
" [10., 600., 500.], # frame number 10\n", " [10., 600., 500.], # frame number 10\n",
" [20., 750., 600.], # ...\n", " [20., 750., 600.], # ...\n",
" [30., 900., 200.]\n", " [30., 900., 200.]\n",
"]).cuda()" "])\n",
"if torch.cuda.is_available():\n",
" queries = queries.cuda()"
] ]
}, },
{ {

View File

@ -138,6 +138,8 @@ def run_test_eval(evaluator, model, dataloaders, writer, step):
single_point=False, single_point=False,
n_iters=6, n_iters=6,
) )
if torch.cuda.is_available():
predictor.model = predictor.model.cuda()
metrics = evaluator.evaluate_sequence( metrics = evaluator.evaluate_sequence(
model=predictor, model=predictor,