add cpu-only mode
This commit is contained in:
parent
32aedaf9b6
commit
ab0ce3c977
@ -185,7 +185,11 @@ class Evaluator:
|
|||||||
if not all(gotit):
|
if not all(gotit):
|
||||||
print("batch is None")
|
print("batch is None")
|
||||||
continue
|
continue
|
||||||
dataclass_to_cuda_(sample)
|
if torch.cuda.is_available():
|
||||||
|
dataclass_to_cuda_(sample)
|
||||||
|
device = torch.device("cuda")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not train_mode
|
not train_mode
|
||||||
@ -205,7 +209,7 @@ class Evaluator:
|
|||||||
queries[:, :, 1],
|
queries[:, :, 1],
|
||||||
],
|
],
|
||||||
dim=2,
|
dim=2,
|
||||||
)
|
).to(device)
|
||||||
else:
|
else:
|
||||||
queries = torch.cat(
|
queries = torch.cat(
|
||||||
[
|
[
|
||||||
@ -213,7 +217,7 @@ class Evaluator:
|
|||||||
sample.trajectory[:, 0],
|
sample.trajectory[:, 0],
|
||||||
],
|
],
|
||||||
dim=2,
|
dim=2,
|
||||||
)
|
).to(device)
|
||||||
|
|
||||||
pred_tracks = model(sample.video, queries)
|
pred_tracks = model(sample.video, queries)
|
||||||
if "strided" in dataset_name:
|
if "strided" in dataset_name:
|
||||||
|
@ -102,6 +102,8 @@ def run_eval(cfg: DefaultConfig):
|
|||||||
single_point=cfg.single_point,
|
single_point=cfg.single_point,
|
||||||
n_iters=cfg.n_iters,
|
n_iters=cfg.n_iters,
|
||||||
)
|
)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
predictor.model = predictor.model.cuda()
|
||||||
|
|
||||||
# Setting the random seeds
|
# Setting the random seeds
|
||||||
torch.manual_seed(cfg.seed)
|
torch.manual_seed(cfg.seed)
|
||||||
|
@ -25,14 +25,14 @@ from cotracker.models.core.embeddings import (
|
|||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
|
||||||
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)):
|
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device="cuda"):
|
||||||
if grid_size == 1:
|
if grid_size == 1:
|
||||||
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[
|
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2], device=device)[
|
||||||
None, None
|
None, None
|
||||||
].cuda()
|
]
|
||||||
|
|
||||||
grid_y, grid_x = meshgrid2d(
|
grid_y, grid_x = meshgrid2d(
|
||||||
1, grid_size, grid_size, stack=False, norm=False, device="cuda"
|
1, grid_size, grid_size, stack=False, norm=False, device=device
|
||||||
)
|
)
|
||||||
step = interp_shape[1] // 64
|
step = interp_shape[1] // 64
|
||||||
if grid_center[0] != 0 or grid_center[1] != 0:
|
if grid_center[0] != 0 or grid_center[1] != 0:
|
||||||
@ -47,7 +47,7 @@ def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)):
|
|||||||
|
|
||||||
grid_y = grid_y + grid_center[0]
|
grid_y = grid_y + grid_center[0]
|
||||||
grid_x = grid_x + grid_center[1]
|
grid_x = grid_x + grid_center[1]
|
||||||
xy = torch.stack([grid_x, grid_y], dim=-1).cuda()
|
xy = torch.stack([grid_x, grid_y], dim=-1).to(device)
|
||||||
return xy
|
return xy
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,11 +29,10 @@ class EvaluationPredictor(torch.nn.Module):
|
|||||||
self.n_iters = n_iters
|
self.n_iters = n_iters
|
||||||
|
|
||||||
self.model = cotracker_model
|
self.model = cotracker_model
|
||||||
self.model.to("cuda")
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
def forward(self, video, queries):
|
def forward(self, video, queries):
|
||||||
queries = queries.clone().cuda()
|
queries = queries.clone()
|
||||||
B, T, C, H, W = video.shape
|
B, T, C, H, W = video.shape
|
||||||
B, N, D = queries.shape
|
B, N, D = queries.shape
|
||||||
|
|
||||||
@ -42,14 +41,16 @@ class EvaluationPredictor(torch.nn.Module):
|
|||||||
|
|
||||||
rgbs = video.reshape(B * T, C, H, W)
|
rgbs = video.reshape(B * T, C, H, W)
|
||||||
rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear")
|
rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear")
|
||||||
rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]).cuda()
|
rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||||
|
|
||||||
|
device = rgbs.device
|
||||||
|
|
||||||
queries[:, :, 1] *= self.interp_shape[1] / W
|
queries[:, :, 1] *= self.interp_shape[1] / W
|
||||||
queries[:, :, 2] *= self.interp_shape[0] / H
|
queries[:, :, 2] *= self.interp_shape[0] / H
|
||||||
|
|
||||||
if self.single_point:
|
if self.single_point:
|
||||||
traj_e = torch.zeros((B, T, N, 2)).cuda()
|
traj_e = torch.zeros((B, T, N, 2), device=device)
|
||||||
vis_e = torch.zeros((B, T, N)).cuda()
|
vis_e = torch.zeros((B, T, N), device=device)
|
||||||
for pind in range((N)):
|
for pind in range((N)):
|
||||||
query = queries[:, pind : pind + 1]
|
query = queries[:, pind : pind + 1]
|
||||||
|
|
||||||
@ -60,8 +61,10 @@ class EvaluationPredictor(torch.nn.Module):
|
|||||||
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
|
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
|
||||||
else:
|
else:
|
||||||
if self.grid_size > 0:
|
if self.grid_size > 0:
|
||||||
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
|
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device)
|
||||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() #
|
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(
|
||||||
|
device
|
||||||
|
) #
|
||||||
queries = torch.cat([queries, xy], dim=1) #
|
queries = torch.cat([queries, xy], dim=1) #
|
||||||
|
|
||||||
traj_e, __, vis_e, __ = self.model(
|
traj_e, __, vis_e, __ = self.model(
|
||||||
@ -91,8 +94,8 @@ class EvaluationPredictor(torch.nn.Module):
|
|||||||
query = torch.cat([query, xy_target], dim=1).to(device) #
|
query = torch.cat([query, xy_target], dim=1).to(device) #
|
||||||
|
|
||||||
if self.grid_size > 0:
|
if self.grid_size > 0:
|
||||||
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
|
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device)
|
||||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() #
|
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
||||||
query = torch.cat([query, xy], dim=1).to(device) #
|
query = torch.cat([query, xy], dim=1).to(device) #
|
||||||
# crop the video to start from the queried frame
|
# crop the video to start from the queried frame
|
||||||
query[0, 0, 0] = 0
|
query[0, 0, 0] = 0
|
||||||
|
@ -116,7 +116,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
|||||||
queries[:, :, 1] *= self.interp_shape[1] / W
|
queries[:, :, 1] *= self.interp_shape[1] / W
|
||||||
queries[:, :, 2] *= self.interp_shape[0] / H
|
queries[:, :, 2] *= self.interp_shape[0] / H
|
||||||
elif grid_size > 0:
|
elif grid_size > 0:
|
||||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape)
|
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
|
||||||
if segm_mask is not None:
|
if segm_mask is not None:
|
||||||
segm_mask = F.interpolate(
|
segm_mask = F.interpolate(
|
||||||
segm_mask, tuple(self.interp_shape), mode="nearest"
|
segm_mask, tuple(self.interp_shape), mode="nearest"
|
||||||
|
@ -65,26 +65,10 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 2,
|
||||||
"id": "1745a859-71d4-4ec3-8ef3-027cabe786d4",
|
"id": "1745a859-71d4-4ec3-8ef3-027cabe786d4",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/private/home/nikitakaraev/dev/neurips_2023/co-tracker\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/private/home/nikitakaraev/.conda/envs/stereoformer/lib/python3.8/site-packages/requests/__init__.py:109: RequestsDependencyWarning: urllib3 (1.26.14) or chardet (None)/charset_normalizer (3.2.0) doesn't match a supported version!\n",
|
|
||||||
" warnings.warn(\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"%cd ..\n",
|
"%cd ..\n",
|
||||||
"import os\n",
|
"import os\n",
|
||||||
@ -105,7 +89,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 3,
|
||||||
"id": "f1f9ca4d-951e-49d2-8844-91f7bcadfecd",
|
"id": "f1f9ca4d-951e-49d2-8844-91f7bcadfecd",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -116,7 +100,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 4,
|
||||||
"id": "fb4c2e9d-0e85-4c10-81a2-827d0759bf87",
|
"id": "fb4c2e9d-0e85-4c10-81a2-827d0759bf87",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -129,7 +113,7 @@
|
|||||||
"<IPython.core.display.HTML object>"
|
"<IPython.core.display.HTML object>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 3,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -175,8 +159,8 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"if torch.cuda.is_available():\n",
|
"if torch.cuda.is_available():\n",
|
||||||
" model=model.cuda()\n",
|
" model = model.cuda()\n",
|
||||||
" video=video.cuda()"
|
" video = video.cuda()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -282,7 +266,9 @@
|
|||||||
" [10., 600., 500.], # frame number 10\n",
|
" [10., 600., 500.], # frame number 10\n",
|
||||||
" [20., 750., 600.], # ...\n",
|
" [20., 750., 600.], # ...\n",
|
||||||
" [30., 900., 200.]\n",
|
" [30., 900., 200.]\n",
|
||||||
"]).cuda()"
|
"])\n",
|
||||||
|
"if torch.cuda.is_available():\n",
|
||||||
|
" queries = queries.cuda()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
2
train.py
2
train.py
@ -138,6 +138,8 @@ def run_test_eval(evaluator, model, dataloaders, writer, step):
|
|||||||
single_point=False,
|
single_point=False,
|
||||||
n_iters=6,
|
n_iters=6,
|
||||||
)
|
)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
predictor.model = predictor.model.cuda()
|
||||||
|
|
||||||
metrics = evaluator.evaluate_sequence(
|
metrics = evaluator.evaluate_sequence(
|
||||||
model=predictor,
|
model=predictor,
|
||||||
|
Loading…
Reference in New Issue
Block a user