Merge branch 'main' of github.com:JunkyByte/co-tracker
This commit is contained in:
@@ -185,7 +185,11 @@ class Evaluator:
|
||||
if not all(gotit):
|
||||
print("batch is None")
|
||||
continue
|
||||
dataclass_to_cuda_(sample)
|
||||
if torch.cuda.is_available():
|
||||
dataclass_to_cuda_(sample)
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
if (
|
||||
not train_mode
|
||||
@@ -205,7 +209,7 @@ class Evaluator:
|
||||
queries[:, :, 1],
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
).to(device)
|
||||
else:
|
||||
queries = torch.cat(
|
||||
[
|
||||
@@ -213,7 +217,7 @@ class Evaluator:
|
||||
sample.trajectory[:, 0],
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
).to(device)
|
||||
|
||||
pred_tracks = model(sample.video, queries)
|
||||
if "strided" in dataset_name:
|
||||
|
@@ -102,6 +102,8 @@ def run_eval(cfg: DefaultConfig):
|
||||
single_point=cfg.single_point,
|
||||
n_iters=cfg.n_iters,
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
predictor.model = predictor.model.cuda()
|
||||
|
||||
# Setting the random seeds
|
||||
torch.manual_seed(cfg.seed)
|
||||
|
@@ -12,6 +12,8 @@ from cotracker.models.core.cotracker.cotracker import CoTracker
|
||||
def build_cotracker(
|
||||
checkpoint: str,
|
||||
):
|
||||
if checkpoint is None:
|
||||
return build_cotracker_stride_4_wind_8()
|
||||
model_name = checkpoint.split("/")[-1].split(".")[0]
|
||||
if model_name == "cotracker_stride_4_wind_8":
|
||||
return build_cotracker_stride_4_wind_8(checkpoint=checkpoint)
|
||||
|
@@ -25,11 +25,11 @@ from cotracker.models.core.embeddings import (
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device='cuda'):
|
||||
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device="cuda"):
|
||||
if grid_size == 1:
|
||||
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[
|
||||
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2], device=device)[
|
||||
None, None
|
||||
].to(device)
|
||||
]
|
||||
|
||||
grid_y, grid_x = meshgrid2d(
|
||||
1, grid_size, grid_size, stack=False, norm=False, device=device
|
||||
|
@@ -29,11 +29,10 @@ class EvaluationPredictor(torch.nn.Module):
|
||||
self.n_iters = n_iters
|
||||
|
||||
self.model = cotracker_model
|
||||
self.model.to("cuda")
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, video, queries):
|
||||
queries = queries.clone().cuda()
|
||||
queries = queries.clone()
|
||||
B, T, C, H, W = video.shape
|
||||
B, N, D = queries.shape
|
||||
|
||||
@@ -42,14 +41,16 @@ class EvaluationPredictor(torch.nn.Module):
|
||||
|
||||
rgbs = video.reshape(B * T, C, H, W)
|
||||
rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear")
|
||||
rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]).cuda()
|
||||
rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
device = rgbs.device
|
||||
|
||||
queries[:, :, 1] *= self.interp_shape[1] / W
|
||||
queries[:, :, 2] *= self.interp_shape[0] / H
|
||||
|
||||
if self.single_point:
|
||||
traj_e = torch.zeros((B, T, N, 2)).cuda()
|
||||
vis_e = torch.zeros((B, T, N)).cuda()
|
||||
traj_e = torch.zeros((B, T, N, 2), device=device)
|
||||
vis_e = torch.zeros((B, T, N), device=device)
|
||||
for pind in range((N)):
|
||||
query = queries[:, pind : pind + 1]
|
||||
|
||||
@@ -60,8 +61,10 @@ class EvaluationPredictor(torch.nn.Module):
|
||||
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
|
||||
else:
|
||||
if self.grid_size > 0:
|
||||
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() #
|
||||
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device)
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(
|
||||
device
|
||||
) #
|
||||
queries = torch.cat([queries, xy], dim=1) #
|
||||
|
||||
traj_e, __, vis_e, __ = self.model(
|
||||
@@ -91,8 +94,8 @@ class EvaluationPredictor(torch.nn.Module):
|
||||
query = torch.cat([query, xy_target], dim=1).to(device) #
|
||||
|
||||
if self.grid_size > 0:
|
||||
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() #
|
||||
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device)
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
||||
query = torch.cat([query, xy], dim=1).to(device) #
|
||||
# crop the video to start from the queried frame
|
||||
query[0, 0, 0] = 0
|
||||
|
@@ -25,8 +25,6 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
model = build_cotracker(checkpoint)
|
||||
|
||||
self.model = model
|
||||
self.device = device or 'cuda'
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -73,7 +71,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
grid_width = W // grid_step
|
||||
grid_height = H // grid_step
|
||||
tracks = visibilities = None
|
||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(self.device)
|
||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
|
||||
grid_pts[0, :, 0] = grid_query_frame
|
||||
for offset in tqdm(range(grid_step * grid_step)):
|
||||
ox = offset % grid_step
|
||||
@@ -108,10 +106,8 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
assert B == 1
|
||||
|
||||
video = video.reshape(B * T, C, H, W)
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").to(self.device)
|
||||
video = video.reshape(
|
||||
B, T, 3, self.interp_shape[0], self.interp_shape[1]
|
||||
).to(self.device)
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear")
|
||||
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
if queries is not None:
|
||||
queries = queries.clone()
|
||||
@@ -120,7 +116,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
queries[:, :, 1] *= self.interp_shape[1] / W
|
||||
queries[:, :, 2] *= self.interp_shape[0] / H
|
||||
elif grid_size > 0:
|
||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=self.device)
|
||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
|
||||
if segm_mask is not None:
|
||||
segm_mask = F.interpolate(
|
||||
segm_mask, tuple(self.interp_shape), mode="nearest"
|
||||
|
@@ -14,7 +14,6 @@ from matplotlib import cm
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as transforms
|
||||
from moviepy.editor import ImageSequenceClip
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
@@ -67,7 +66,7 @@ class Visualizer:
|
||||
gt_tracks: torch.Tensor = None, # (B,T,N,2)
|
||||
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
||||
filename: str = "video",
|
||||
writer: SummaryWriter = None,
|
||||
writer=None, # tensorboard Summary Writer, used for visualization during training
|
||||
step: int = 0,
|
||||
query_frame: int = 0,
|
||||
save_video: bool = True,
|
||||
|
Reference in New Issue
Block a user