add colab demo

This commit is contained in:
nikitakaraevv 2023-07-17 18:21:54 -07:00
parent 6d62d873fa
commit 6880e31b5b
4 changed files with 113 additions and 100 deletions

View File

@ -6,6 +6,10 @@
[[`Paper`]()] [[`Project`](https://co-tracker.github.io/)] [[`BibTeX`](#citing-cotracker)]
<a target="_blank" href="https://colab.research.google.com/github/facebookresearch/co-tracker/blob/main/notebooks/demo.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
![bmx-bumps](./assets/bmx-bumps.gif)
**CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow.
@ -15,7 +19,7 @@ CoTracker can track:
- Points sampled on a regular grid on any video frame
- Manually selected points
Try these tracking modes for yourself with our [Colab demo](https://github.com/facebookresearch/co-tracker/notebooks/demo.ipynb).
Try these tracking modes for yourself with our [Colab demo](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb).
@ -43,7 +47,7 @@ cd ..
## Running the Demo:
Try our [Colab demo](https://github.com/facebookresearch/co-tracker/notebooks/demo.ipynb) or run a local demo with 10*10 points sampled on a grid on the first frame of a video:
Try our [Colab demo](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb) or run a local demo with 10*10 points sampled on a grid on the first frame of a video:
```
python demo.py --grid_size 10
```

View File

@ -18,6 +18,22 @@ from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
def read_video_from_path(path):
cap = cv2.VideoCapture(path)
if not cap.isOpened():
print("Error opening video file")
else:
frames = []
while cap.isOpened():
ret, frame = cap.read()
if ret == True:
frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
else:
break
cap.release()
return np.stack(frames)
class Visualizer:
def __init__(
self,

View File

@ -5,13 +5,13 @@
# LICENSE file in the root directory of this source tree.
import os
import cv2
import torch
import argparse
import numpy as np
from torchvision.io import read_video
from PIL import Image
from cotracker.utils.visualizer import Visualizer
from cotracker.utils.visualizer import Visualizer, read_video_from_path
from cotracker.predictor import CoTrackerPredictor
@ -49,8 +49,8 @@ if __name__ == "__main__":
args = parser.parse_args()
# load the input video frame by frame
video = read_video(args.video_path)
video = video[0].permute(0, 3, 1, 2)[None].float()
video = read_video_from_path(args.video_path)
video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
segm_mask = torch.from_numpy(segm_mask)[None, None]

File diff suppressed because one or more lines are too long