diff --git a/README.md b/README.md index c132d81..4f483f2 100644 --- a/README.md +++ b/README.md @@ -119,7 +119,7 @@ We strongly recommend installing both PyTorch and TorchVision with CUDA support, git clone https://github.com/facebookresearch/co-tracker cd co-tracker pip install -e . -pip install matplotlib flow_vis tqdm tensorboard +pip install matplotlib flow_vis tqdm tensorboard imageio[ffmpeg] ``` You can manually download the CoTracker2 checkpoint from the links below and place it in the `checkpoints` folder as follows: @@ -132,6 +132,11 @@ cd .. ``` For old checkpoints, see [this section](#previous-version). +After installation, this is how you could run the model on `./assets/apple.mp4` (results will be saved to `./saved_videos/apple.mp4`): +```bash +python demo.py --checkpoint checkpoints/cotracker2.pth +``` + ## Evaluation To reproduce the results presented in the paper, download the following datasets: diff --git a/demo.py b/demo.py index 94a87f7..2f26587 100644 --- a/demo.py +++ b/demo.py @@ -83,11 +83,12 @@ if __name__ == "__main__": print("computed") # save a video with predicted tracks - seq_name = args.video_path.split("/")[-1] + seq_name = os.path.splitext(args.video_path.split("/")[-1])[0] vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) vis.visualize( video, pred_tracks, pred_visibility, query_frame=0 if args.backward_tracking else args.grid_query_frame, + filename=seq_name, ) diff --git a/online_demo.py b/online_demo.py index 7aad145..d1f4321 100644 --- a/online_demo.py +++ b/online_demo.py @@ -97,7 +97,7 @@ if __name__ == "__main__": print("Tracks are computed") # save a video with predicted tracks - seq_name = args.video_path.split("/")[-1] + seq_name = os.path.splitext(args.video_path.split("/")[-1])[0] video = torch.tensor(np.stack(window_frames), device=DEFAULT_DEVICE).permute(0, 3, 1, 2)[None] vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) - vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame) + vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame, filename=seq_name)