add colab demo

This commit is contained in:
nikitakaraevv 2023-07-17 18:21:54 -07:00
parent 6d62d873fa
commit 6880e31b5b
4 changed files with 113 additions and 100 deletions

View File

@ -6,6 +6,10 @@
[[`Paper`]()] [[`Project`](https://co-tracker.github.io/)] [[`BibTeX`](#citing-cotracker)] [[`Paper`]()] [[`Project`](https://co-tracker.github.io/)] [[`BibTeX`](#citing-cotracker)]
<a target="_blank" href="https://colab.research.google.com/github/facebookresearch/co-tracker/blob/main/notebooks/demo.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
![bmx-bumps](./assets/bmx-bumps.gif) ![bmx-bumps](./assets/bmx-bumps.gif)
**CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow. **CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow.
@ -15,7 +19,7 @@ CoTracker can track:
- Points sampled on a regular grid on any video frame - Points sampled on a regular grid on any video frame
- Manually selected points - Manually selected points
Try these tracking modes for yourself with our [Colab demo](https://github.com/facebookresearch/co-tracker/notebooks/demo.ipynb). Try these tracking modes for yourself with our [Colab demo](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb).
@ -43,7 +47,7 @@ cd ..
## Running the Demo: ## Running the Demo:
Try our [Colab demo](https://github.com/facebookresearch/co-tracker/notebooks/demo.ipynb) or run a local demo with 10*10 points sampled on a grid on the first frame of a video: Try our [Colab demo](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb) or run a local demo with 10*10 points sampled on a grid on the first frame of a video:
``` ```
python demo.py --grid_size 10 python demo.py --grid_size 10
``` ```

View File

@ -18,6 +18,22 @@ from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def read_video_from_path(path):
cap = cv2.VideoCapture(path)
if not cap.isOpened():
print("Error opening video file")
else:
frames = []
while cap.isOpened():
ret, frame = cap.read()
if ret == True:
frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
else:
break
cap.release()
return np.stack(frames)
class Visualizer: class Visualizer:
def __init__( def __init__(
self, self,

View File

@ -5,13 +5,13 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import os import os
import cv2
import torch import torch
import argparse import argparse
import numpy as np import numpy as np
from torchvision.io import read_video
from PIL import Image from PIL import Image
from cotracker.utils.visualizer import Visualizer from cotracker.utils.visualizer import Visualizer, read_video_from_path
from cotracker.predictor import CoTrackerPredictor from cotracker.predictor import CoTrackerPredictor
@ -49,8 +49,8 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
# load the input video frame by frame # load the input video frame by frame
video = read_video(args.video_path) video = read_video_from_path(args.video_path)
video = video[0].permute(0, 3, 1, 2)[None].float() video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
segm_mask = np.array(Image.open(os.path.join(args.mask_path))) segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
segm_mask = torch.from_numpy(segm_mask)[None, None] segm_mask = torch.from_numpy(segm_mask)[None, None]

File diff suppressed because one or more lines are too long