fix online demo

This commit is contained in:
Nikita Karaev 2023-12-29 16:12:42 +00:00
parent 721fcc237b
commit 3716e36249

View File

@ -52,25 +52,33 @@ if __name__ == "__main__":
window_frames = []
def _process_step(window_frames, is_first_step, grid_size):
def _process_step(window_frames, is_first_step, grid_size, grid_query_frame):
video_chunk = (
torch.tensor(np.stack(window_frames[-model.step * 2 :]), device=DEFAULT_DEVICE)
.float()
.permute(0, 3, 1, 2)[None]
) # (1, T, 3, H, W)
return model(video_chunk, is_first_step=is_first_step, grid_size=grid_size)
return model(
video_chunk,
is_first_step=is_first_step,
grid_size=grid_size,
grid_query_frame=grid_query_frame,
)
# Iterating over video frames, processing one window at a time:
is_first_step = True
for i, frame in enumerate(
iio.imiter(
"https://github.com/facebookresearch/co-tracker/blob/main/assets/apple.mp4",
"./assets/apple.mp4",
plugin="FFMPEG",
)
):
if i % model.step == 0 and i != 0:
pred_tracks, pred_visibility = _process_step(
window_frames, is_first_step, grid_size=args.grid_size
window_frames,
is_first_step,
grid_size=args.grid_size,
grid_query_frame=args.grid_query_frame,
)
is_first_step = False
window_frames.append(frame)
@ -79,6 +87,7 @@ if __name__ == "__main__":
window_frames[-(i % model.step) - model.step - 1 :],
is_first_step,
grid_size=args.grid_size,
grid_query_frame=args.grid_query_frame,
)
print("Tracks are computed")