comment the code
This commit is contained in:
parent
19767a9d65
commit
eeda4d3c98
@ -23,7 +23,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
video, # (B, T, 3, H, W)
|
video, # (B, T, 3, H, W) Batch_size, time, rgb, height, width
|
||||||
# input prompt types:
|
# input prompt types:
|
||||||
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
|
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
|
||||||
# *backward_tracking=True* will compute tracks in both directions.
|
# *backward_tracking=True* will compute tracks in both directions.
|
||||||
@ -59,14 +59,23 @@ class CoTrackerPredictor(torch.nn.Module):
|
|||||||
*_, H, W = video.shape
|
*_, H, W = video.shape
|
||||||
grid_step = W // grid_size
|
grid_step = W // grid_size
|
||||||
grid_width = W // grid_step
|
grid_width = W // grid_step
|
||||||
grid_height = H // grid_step
|
grid_height = H // grid_step # set the whole video to grid_size number of grids
|
||||||
tracks = visibilities = None
|
tracks = visibilities = None
|
||||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
|
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
|
||||||
|
# (batch_size, grid_number, t,x,y)
|
||||||
grid_pts[0, :, 0] = grid_query_frame
|
grid_pts[0, :, 0] = grid_query_frame
|
||||||
|
# iterate every grid
|
||||||
for offset in range(grid_step * grid_step):
|
for offset in range(grid_step * grid_step):
|
||||||
print(f"step {offset} / {grid_step * grid_step}")
|
print(f"step {offset} / {grid_step * grid_step}")
|
||||||
ox = offset % grid_step
|
ox = offset % grid_step
|
||||||
oy = offset // grid_step
|
oy = offset // grid_step
|
||||||
|
# initialize
|
||||||
|
# for example
|
||||||
|
# grid width = 4, grid height = 4, grid step = 10, ox = 1
|
||||||
|
# torch.arange(grid_width) = [0,1,2,3]
|
||||||
|
# torch.arange(grid_width).repeat(grid_height) = [0,1,2,3,0,1,2,3,0,1,2,3]
|
||||||
|
# torch.arange(grid_width).repeat(grid_height) * grid_step = [0,10,20,30,0,10,20,30,0,10,20,30]
|
||||||
|
# get the location in the image
|
||||||
grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox
|
grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox
|
||||||
grid_pts[0, :, 2] = (
|
grid_pts[0, :, 2] = (
|
||||||
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
|
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
|
||||||
|
Loading…
Reference in New Issue
Block a user