add torch.hub support
This commit is contained in:
parent
d7d1e92742
commit
24054be360
@ -12,6 +12,8 @@ from cotracker.models.core.cotracker.cotracker import CoTracker
|
||||
def build_cotracker(
|
||||
checkpoint: str,
|
||||
):
|
||||
if checkpoint is None:
|
||||
return build_cotracker_stride_4_wind_8()
|
||||
model_name = checkpoint.split("/")[-1].split(".")[0]
|
||||
if model_name == "cotracker_stride_4_wind_8":
|
||||
return build_cotracker_stride_4_wind_8(checkpoint=checkpoint)
|
||||
|
@ -25,7 +25,6 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
model = build_cotracker(checkpoint)
|
||||
|
||||
self.model = model
|
||||
self.model.to("cuda")
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
@ -72,7 +71,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
grid_width = W // grid_step
|
||||
grid_height = H // grid_step
|
||||
tracks = visibilities = None
|
||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to("cuda")
|
||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
|
||||
grid_pts[0, :, 0] = grid_query_frame
|
||||
for offset in tqdm(range(grid_step * grid_step)):
|
||||
ox = offset % grid_step
|
||||
@ -107,10 +106,8 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
assert B == 1
|
||||
|
||||
video = video.reshape(B * T, C, H, W)
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").cuda()
|
||||
video = video.reshape(
|
||||
B, T, 3, self.interp_shape[0], self.interp_shape[1]
|
||||
).cuda()
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear")
|
||||
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
if queries is not None:
|
||||
queries = queries.clone()
|
||||
|
5
demo.py
5
demo.py
@ -55,6 +55,11 @@ if __name__ == "__main__":
|
||||
segm_mask = torch.from_numpy(segm_mask)[None, None]
|
||||
|
||||
model = CoTrackerPredictor(checkpoint=args.checkpoint)
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
video = video.cuda()
|
||||
else:
|
||||
print("CUDA is not available!")
|
||||
|
||||
pred_tracks, pred_visibility = model(
|
||||
video,
|
||||
|
32
hubconf.py
Normal file
32
hubconf.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
dependencies = ["torch", "einops", "timm", "tqdm"]
|
||||
|
||||
_COTRACKER_URL = (
|
||||
"https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth"
|
||||
)
|
||||
|
||||
|
||||
def _make_cotracker_predictor(*, pretrained: bool = True, **kwargs):
|
||||
from cotracker.predictor import CoTrackerPredictor
|
||||
|
||||
predictor = CoTrackerPredictor(checkpoint=None)
|
||||
if pretrained:
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
_COTRACKER_URL, map_location="cpu"
|
||||
)
|
||||
predictor.model.load_state_dict(state_dict)
|
||||
return predictor
|
||||
|
||||
|
||||
def cotracker_w8(*, pretrained: bool = True, **kwargs):
|
||||
"""
|
||||
CoTracker model with stride 4 and window length 8. (The main model from the paper)
|
||||
"""
|
||||
return _make_cotracker_predictor(pretrained=pretrained, **kwargs)
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user