add torch.hub support

This commit is contained in:
nikitakaraevv
2023-07-21 07:30:44 -07:00
parent d7d1e92742
commit 24054be360
5 changed files with 126 additions and 58 deletions

View File

@@ -12,6 +12,8 @@ from cotracker.models.core.cotracker.cotracker import CoTracker
def build_cotracker(
checkpoint: str,
):
if checkpoint is None:
return build_cotracker_stride_4_wind_8()
model_name = checkpoint.split("/")[-1].split(".")[0]
if model_name == "cotracker_stride_4_wind_8":
return build_cotracker_stride_4_wind_8(checkpoint=checkpoint)

View File

@@ -25,7 +25,6 @@ class CoTrackerPredictor(torch.nn.Module):
model = build_cotracker(checkpoint)
self.model = model
self.model.to("cuda")
self.model.eval()
@torch.no_grad()
@@ -72,7 +71,7 @@ class CoTrackerPredictor(torch.nn.Module):
grid_width = W // grid_step
grid_height = H // grid_step
tracks = visibilities = None
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to("cuda")
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
grid_pts[0, :, 0] = grid_query_frame
for offset in tqdm(range(grid_step * grid_step)):
ox = offset % grid_step
@@ -107,10 +106,8 @@ class CoTrackerPredictor(torch.nn.Module):
assert B == 1
video = video.reshape(B * T, C, H, W)
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").cuda()
video = video.reshape(
B, T, 3, self.interp_shape[0], self.interp_shape[1]
).cuda()
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear")
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
if queries is not None:
queries = queries.clone()