add colab demo
This commit is contained in:
		| @@ -6,6 +6,10 @@ | |||||||
|  |  | ||||||
| [[`Paper`]()] [[`Project`](https://co-tracker.github.io/)] [[`BibTeX`](#citing-cotracker)] | [[`Paper`]()] [[`Project`](https://co-tracker.github.io/)] [[`BibTeX`](#citing-cotracker)] | ||||||
|  |  | ||||||
|  | <a target="_blank" href="https://colab.research.google.com/github/facebookresearch/co-tracker/blob/main/notebooks/demo.ipynb"> | ||||||
|  |   <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> | ||||||
|  | </a> | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| **CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow. | **CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow. | ||||||
| @@ -15,7 +19,7 @@ CoTracker can track: | |||||||
| - Points sampled on a regular grid on any video frame  | - Points sampled on a regular grid on any video frame  | ||||||
| - Manually selected points | - Manually selected points | ||||||
|  |  | ||||||
| Try these tracking modes for yourself with our [Colab demo](https://github.com/facebookresearch/co-tracker/notebooks/demo.ipynb). | Try these tracking modes for yourself with our [Colab demo](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb). | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -43,7 +47,7 @@ cd .. | |||||||
|  |  | ||||||
|  |  | ||||||
| ## Running the Demo: | ## Running the Demo: | ||||||
| Try our [Colab demo](https://github.com/facebookresearch/co-tracker/notebooks/demo.ipynb) or run a local demo with 10*10 points sampled on a grid on the first frame of a video: | Try our [Colab demo](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb) or run a local demo with 10*10 points sampled on a grid on the first frame of a video: | ||||||
| ``` | ``` | ||||||
| python demo.py --grid_size 10 | python demo.py --grid_size 10 | ||||||
| ``` | ``` | ||||||
|   | |||||||
| @@ -18,6 +18,22 @@ from torch.utils.tensorboard import SummaryWriter | |||||||
| import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def read_video_from_path(path): | ||||||
|  |     cap = cv2.VideoCapture(path) | ||||||
|  |     if not cap.isOpened(): | ||||||
|  |         print("Error opening video file") | ||||||
|  |     else: | ||||||
|  |         frames = [] | ||||||
|  |         while cap.isOpened(): | ||||||
|  |             ret, frame = cap.read() | ||||||
|  |             if ret == True: | ||||||
|  |                 frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) | ||||||
|  |             else: | ||||||
|  |                 break | ||||||
|  |         cap.release() | ||||||
|  |     return np.stack(frames) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Visualizer: | class Visualizer: | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|   | |||||||
							
								
								
									
										8
									
								
								demo.py
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								demo.py
									
									
									
									
									
								
							| @@ -5,13 +5,13 @@ | |||||||
| # LICENSE file in the root directory of this source tree. | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
| import os | import os | ||||||
|  | import cv2 | ||||||
| import torch | import torch | ||||||
| import argparse | import argparse | ||||||
| import numpy as np | import numpy as np | ||||||
|  |  | ||||||
| from torchvision.io import read_video |  | ||||||
| from PIL import Image | from PIL import Image | ||||||
| from cotracker.utils.visualizer import Visualizer | from cotracker.utils.visualizer import Visualizer, read_video_from_path | ||||||
| from cotracker.predictor import CoTrackerPredictor | from cotracker.predictor import CoTrackerPredictor | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -49,8 +49,8 @@ if __name__ == "__main__": | |||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|  |  | ||||||
|     # load the input video frame by frame |     # load the input video frame by frame | ||||||
|     video = read_video(args.video_path) |     video = read_video_from_path(args.video_path) | ||||||
|     video = video[0].permute(0, 3, 1, 2)[None].float() |     video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float() | ||||||
|     segm_mask = np.array(Image.open(os.path.join(args.mask_path))) |     segm_mask = np.array(Image.open(os.path.join(args.mask_path))) | ||||||
|     segm_mask = torch.from_numpy(segm_mask)[None, None] |     segm_mask = torch.from_numpy(segm_mask)[None, None] | ||||||
|  |  | ||||||
|   | |||||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
		Reference in New Issue
	
	Block a user