comment the code

This commit is contained in:
Hanzhang ma 2024-07-09 10:54:29 +02:00
parent 19767a9d65
commit eeda4d3c98

View File

@ -23,7 +23,7 @@ class CoTrackerPredictor(torch.nn.Module):
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
video, # (B, T, 3, H, W) video, # (B, T, 3, H, W) Batch_size, time, rgb, height, width
# input prompt types: # input prompt types:
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame. # - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
# *backward_tracking=True* will compute tracks in both directions. # *backward_tracking=True* will compute tracks in both directions.
@ -59,14 +59,23 @@ class CoTrackerPredictor(torch.nn.Module):
*_, H, W = video.shape *_, H, W = video.shape
grid_step = W // grid_size grid_step = W // grid_size
grid_width = W // grid_step grid_width = W // grid_step
grid_height = H // grid_step grid_height = H // grid_step # set the whole video to grid_size number of grids
tracks = visibilities = None tracks = visibilities = None
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device) grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
# (batch_size, grid_number, t,x,y)
grid_pts[0, :, 0] = grid_query_frame grid_pts[0, :, 0] = grid_query_frame
# iterate every grid
for offset in range(grid_step * grid_step): for offset in range(grid_step * grid_step):
print(f"step {offset} / {grid_step * grid_step}") print(f"step {offset} / {grid_step * grid_step}")
ox = offset % grid_step ox = offset % grid_step
oy = offset // grid_step oy = offset // grid_step
# initialize
# for example
# grid width = 4, grid height = 4, grid step = 10, ox = 1
# torch.arange(grid_width) = [0,1,2,3]
# torch.arange(grid_width).repeat(grid_height) = [0,1,2,3,0,1,2,3,0,1,2,3]
# torch.arange(grid_width).repeat(grid_height) * grid_step = [0,10,20,30,0,10,20,30,0,10,20,30]
# get the location in the image
grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox
grid_pts[0, :, 2] = ( grid_pts[0, :, 2] = (
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy