add colab demo
This commit is contained in:
parent
6d62d873fa
commit
6880e31b5b
@ -6,6 +6,10 @@
|
|||||||
|
|
||||||
[[`Paper`]()] [[`Project`](https://co-tracker.github.io/)] [[`BibTeX`](#citing-cotracker)]
|
[[`Paper`]()] [[`Project`](https://co-tracker.github.io/)] [[`BibTeX`](#citing-cotracker)]
|
||||||
|
|
||||||
|
<a target="_blank" href="https://colab.research.google.com/github/facebookresearch/co-tracker/blob/main/notebooks/demo.ipynb">
|
||||||
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||||
|
</a>
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
**CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow.
|
**CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow.
|
||||||
@ -15,7 +19,7 @@ CoTracker can track:
|
|||||||
- Points sampled on a regular grid on any video frame
|
- Points sampled on a regular grid on any video frame
|
||||||
- Manually selected points
|
- Manually selected points
|
||||||
|
|
||||||
Try these tracking modes for yourself with our [Colab demo](https://github.com/facebookresearch/co-tracker/notebooks/demo.ipynb).
|
Try these tracking modes for yourself with our [Colab demo](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb).
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -43,7 +47,7 @@ cd ..
|
|||||||
|
|
||||||
|
|
||||||
## Running the Demo:
|
## Running the Demo:
|
||||||
Try our [Colab demo](https://github.com/facebookresearch/co-tracker/notebooks/demo.ipynb) or run a local demo with 10*10 points sampled on a grid on the first frame of a video:
|
Try our [Colab demo](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb) or run a local demo with 10*10 points sampled on a grid on the first frame of a video:
|
||||||
```
|
```
|
||||||
python demo.py --grid_size 10
|
python demo.py --grid_size 10
|
||||||
```
|
```
|
||||||
|
@ -18,6 +18,22 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def read_video_from_path(path):
|
||||||
|
cap = cv2.VideoCapture(path)
|
||||||
|
if not cap.isOpened():
|
||||||
|
print("Error opening video file")
|
||||||
|
else:
|
||||||
|
frames = []
|
||||||
|
while cap.isOpened():
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if ret == True:
|
||||||
|
frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
cap.release()
|
||||||
|
return np.stack(frames)
|
||||||
|
|
||||||
|
|
||||||
class Visualizer:
|
class Visualizer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
8
demo.py
8
demo.py
@ -5,13 +5,13 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from torchvision.io import read_video
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from cotracker.utils.visualizer import Visualizer
|
from cotracker.utils.visualizer import Visualizer, read_video_from_path
|
||||||
from cotracker.predictor import CoTrackerPredictor
|
from cotracker.predictor import CoTrackerPredictor
|
||||||
|
|
||||||
|
|
||||||
@ -49,8 +49,8 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# load the input video frame by frame
|
# load the input video frame by frame
|
||||||
video = read_video(args.video_path)
|
video = read_video_from_path(args.video_path)
|
||||||
video = video[0].permute(0, 3, 1, 2)[None].float()
|
video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
|
||||||
segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
|
segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
|
||||||
segm_mask = torch.from_numpy(segm_mask)[None, None]
|
segm_mask = torch.from_numpy(segm_mask)[None, None]
|
||||||
|
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user