readme.md update, demo flexible save path (#83)

This commit is contained in:
Iurii Makarov 2024-05-11 15:34:09 +01:00 committed by GitHub
parent 0f9d32869a
commit e29e938311
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 4 deletions

View File

@ -119,7 +119,7 @@ We strongly recommend installing both PyTorch and TorchVision with CUDA support,
git clone https://github.com/facebookresearch/co-tracker
cd co-tracker
pip install -e .
pip install matplotlib flow_vis tqdm tensorboard
pip install matplotlib flow_vis tqdm tensorboard imageio[ffmpeg]
```
You can manually download the CoTracker2 checkpoint from the links below and place it in the `checkpoints` folder as follows:
@ -132,6 +132,11 @@ cd ..
```
For old checkpoints, see [this section](#previous-version).
After installation, this is how you could run the model on `./assets/apple.mp4` (results will be saved to `./saved_videos/apple.mp4`):
```bash
python demo.py --checkpoint checkpoints/cotracker2.pth
```
## Evaluation
To reproduce the results presented in the paper, download the following datasets:

View File

@ -83,11 +83,12 @@ if __name__ == "__main__":
print("computed")
# save a video with predicted tracks
seq_name = args.video_path.split("/")[-1]
seq_name = os.path.splitext(args.video_path.split("/")[-1])[0]
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
vis.visualize(
video,
pred_tracks,
pred_visibility,
query_frame=0 if args.backward_tracking else args.grid_query_frame,
filename=seq_name,
)

View File

@ -97,7 +97,7 @@ if __name__ == "__main__":
print("Tracks are computed")
# save a video with predicted tracks
seq_name = args.video_path.split("/")[-1]
seq_name = os.path.splitext(args.video_path.split("/")[-1])[0]
video = torch.tensor(np.stack(window_frames), device=DEFAULT_DEVICE).permute(0, 3, 1, 2)[None]
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame)
vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame, filename=seq_name)