add torch.hub support

This commit is contained in:
nikitakaraevv 2023-07-21 07:30:44 -07:00
parent d7d1e92742
commit 24054be360
5 changed files with 126 additions and 58 deletions

View File

@ -12,6 +12,8 @@ from cotracker.models.core.cotracker.cotracker import CoTracker
def build_cotracker( def build_cotracker(
checkpoint: str, checkpoint: str,
): ):
if checkpoint is None:
return build_cotracker_stride_4_wind_8()
model_name = checkpoint.split("/")[-1].split(".")[0] model_name = checkpoint.split("/")[-1].split(".")[0]
if model_name == "cotracker_stride_4_wind_8": if model_name == "cotracker_stride_4_wind_8":
return build_cotracker_stride_4_wind_8(checkpoint=checkpoint) return build_cotracker_stride_4_wind_8(checkpoint=checkpoint)

View File

@ -25,7 +25,6 @@ class CoTrackerPredictor(torch.nn.Module):
model = build_cotracker(checkpoint) model = build_cotracker(checkpoint)
self.model = model self.model = model
self.model.to("cuda")
self.model.eval() self.model.eval()
@torch.no_grad() @torch.no_grad()
@ -72,7 +71,7 @@ class CoTrackerPredictor(torch.nn.Module):
grid_width = W // grid_step grid_width = W // grid_step
grid_height = H // grid_step grid_height = H // grid_step
tracks = visibilities = None tracks = visibilities = None
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to("cuda") grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
grid_pts[0, :, 0] = grid_query_frame grid_pts[0, :, 0] = grid_query_frame
for offset in tqdm(range(grid_step * grid_step)): for offset in tqdm(range(grid_step * grid_step)):
ox = offset % grid_step ox = offset % grid_step
@ -107,10 +106,8 @@ class CoTrackerPredictor(torch.nn.Module):
assert B == 1 assert B == 1
video = video.reshape(B * T, C, H, W) video = video.reshape(B * T, C, H, W)
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").cuda() video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear")
video = video.reshape( video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
B, T, 3, self.interp_shape[0], self.interp_shape[1]
).cuda()
if queries is not None: if queries is not None:
queries = queries.clone() queries = queries.clone()

View File

@ -55,6 +55,11 @@ if __name__ == "__main__":
segm_mask = torch.from_numpy(segm_mask)[None, None] segm_mask = torch.from_numpy(segm_mask)[None, None]
model = CoTrackerPredictor(checkpoint=args.checkpoint) model = CoTrackerPredictor(checkpoint=args.checkpoint)
if torch.cuda.is_available():
model = model.cuda()
video = video.cuda()
else:
print("CUDA is not available!")
pred_tracks, pred_visibility = model( pred_tracks, pred_visibility = model(
video, video,

32
hubconf.py Normal file
View File

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
dependencies = ["torch", "einops", "timm", "tqdm"]
_COTRACKER_URL = (
"https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth"
)
def _make_cotracker_predictor(*, pretrained: bool = True, **kwargs):
from cotracker.predictor import CoTrackerPredictor
predictor = CoTrackerPredictor(checkpoint=None)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
_COTRACKER_URL, map_location="cpu"
)
predictor.model.load_state_dict(state_dict)
return predictor
def cotracker_w8(*, pretrained: bool = True, **kwargs):
"""
CoTracker model with stride 4 and window length 8. (The main model from the paper)
"""
return _make_cotracker_predictor(pretrained=pretrained, **kwargs)

File diff suppressed because one or more lines are too long