add some comments

This commit is contained in:
Hanzhang ma 2024-07-10 00:05:34 +02:00
parent 9ed8669a50
commit 36d1566750

View File

@ -17,6 +17,7 @@ class CoTrackerPredictor(torch.nn.Module):
self.support_grid_size = 6
model = build_cotracker(checkpoint)
self.interp_shape = model.model_resolution
print(self.interp_shape)
self.model = model
self.model.eval()
@ -103,12 +104,16 @@ class CoTrackerPredictor(torch.nn.Module):
B, T, C, H, W = video.shape
video = video.reshape(B * T, C, H, W)
# ? what is interpolate?
# 将video插值成interp_shape?
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
if queries is not None:
B, N, D = queries.shape
B, N, D = queries.shape # batch_size, number of points, (t,x,y)
assert D == 3
# query 缩放到( interp_shape - 1 ) / (W - 1)
# 插完值之后缩放
queries = queries.clone()
queries[:, :, 1:] *= queries.new_tensor(
[
@ -116,6 +121,7 @@ class CoTrackerPredictor(torch.nn.Module):
(self.interp_shape[0] - 1) / (H - 1),
]
)
# 生成grid
elif grid_size > 0:
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
if segm_mask is not None:
@ -131,6 +137,8 @@ class CoTrackerPredictor(torch.nn.Module):
dim=2,
).repeat(B, 1, 1)
# 添加支持点
if add_support_grid:
grid_pts = get_points_on_a_grid(
self.support_grid_size, self.interp_shape, device=video.device