add some comments
This commit is contained in:
parent
9ed8669a50
commit
36d1566750
@ -17,6 +17,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
self.support_grid_size = 6
|
||||
model = build_cotracker(checkpoint)
|
||||
self.interp_shape = model.model_resolution
|
||||
print(self.interp_shape)
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
|
||||
@ -103,12 +104,16 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
B, T, C, H, W = video.shape
|
||||
|
||||
video = video.reshape(B * T, C, H, W)
|
||||
# ? what is interpolate?
|
||||
# 将video插值成interp_shape?
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
|
||||
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
if queries is not None:
|
||||
B, N, D = queries.shape
|
||||
B, N, D = queries.shape # batch_size, number of points, (t,x,y)
|
||||
assert D == 3
|
||||
# query 缩放到( interp_shape - 1 ) / (W - 1)
|
||||
# 插完值之后缩放
|
||||
queries = queries.clone()
|
||||
queries[:, :, 1:] *= queries.new_tensor(
|
||||
[
|
||||
@ -116,6 +121,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
(self.interp_shape[0] - 1) / (H - 1),
|
||||
]
|
||||
)
|
||||
# 生成grid
|
||||
elif grid_size > 0:
|
||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
|
||||
if segm_mask is not None:
|
||||
@ -131,6 +137,8 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
dim=2,
|
||||
).repeat(B, 1, 1)
|
||||
|
||||
# 添加支持点
|
||||
|
||||
if add_support_grid:
|
||||
grid_pts = get_points_on_a_grid(
|
||||
self.support_grid_size, self.interp_shape, device=video.device
|
||||
|
Loading…
Reference in New Issue
Block a user