add colab demo
This commit is contained in:
parent
6d62d873fa
commit
6880e31b5b
@ -6,6 +6,10 @@
|
||||
|
||||
[[`Paper`]()] [[`Project`](https://co-tracker.github.io/)] [[`BibTeX`](#citing-cotracker)]
|
||||
|
||||
<a target="_blank" href="https://colab.research.google.com/github/facebookresearch/co-tracker/blob/main/notebooks/demo.ipynb">
|
||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||
</a>
|
||||
|
||||

|
||||
|
||||
**CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow.
|
||||
@ -15,7 +19,7 @@ CoTracker can track:
|
||||
- Points sampled on a regular grid on any video frame
|
||||
- Manually selected points
|
||||
|
||||
Try these tracking modes for yourself with our [Colab demo](https://github.com/facebookresearch/co-tracker/notebooks/demo.ipynb).
|
||||
Try these tracking modes for yourself with our [Colab demo](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb).
|
||||
|
||||
|
||||
|
||||
@ -43,7 +47,7 @@ cd ..
|
||||
|
||||
|
||||
## Running the Demo:
|
||||
Try our [Colab demo](https://github.com/facebookresearch/co-tracker/notebooks/demo.ipynb) or run a local demo with 10*10 points sampled on a grid on the first frame of a video:
|
||||
Try our [Colab demo](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb) or run a local demo with 10*10 points sampled on a grid on the first frame of a video:
|
||||
```
|
||||
python demo.py --grid_size 10
|
||||
```
|
||||
|
@ -18,6 +18,22 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def read_video_from_path(path):
|
||||
cap = cv2.VideoCapture(path)
|
||||
if not cap.isOpened():
|
||||
print("Error opening video file")
|
||||
else:
|
||||
frames = []
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if ret == True:
|
||||
frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
|
||||
else:
|
||||
break
|
||||
cap.release()
|
||||
return np.stack(frames)
|
||||
|
||||
|
||||
class Visualizer:
|
||||
def __init__(
|
||||
self,
|
||||
|
8
demo.py
8
demo.py
@ -5,13 +5,13 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from torchvision.io import read_video
|
||||
from PIL import Image
|
||||
from cotracker.utils.visualizer import Visualizer
|
||||
from cotracker.utils.visualizer import Visualizer, read_video_from_path
|
||||
from cotracker.predictor import CoTrackerPredictor
|
||||
|
||||
|
||||
@ -49,8 +49,8 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
# load the input video frame by frame
|
||||
video = read_video(args.video_path)
|
||||
video = video[0].permute(0, 3, 1, 2)[None].float()
|
||||
video = read_video_from_path(args.video_path)
|
||||
video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
|
||||
segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
|
||||
segm_mask = torch.from_numpy(segm_mask)[None, None]
|
||||
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user