release cotracker 2.0

This commit is contained in:
Nikita Karaev 2023-12-27 12:54:02 +00:00
parent 3df96621ed
commit f8fab323c4
38 changed files with 2238 additions and 1910 deletions

257
README.md
View File

@ -13,111 +13,218 @@
<img alt="Spaces" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue">
</a>
<img width="500" src="./assets/bmx-bumps.gif" />
<img width="1100" src="./assets/teaser.png" />
**CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow.
CoTracker can track:
- **Every pixel** in a video
- Points sampled on a regular grid on any video frame
- Manually selected points
Try these tracking modes for yourself with our [Colab demo](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb) or in the [Hugging Face Space](https://huggingface.co/spaces/facebook/cotracker).
- **Any pixel** in a video
- A **quasi-dense** set of pixels together
- Points can be manually selected or sampled on a grid in any video frame
Try these tracking modes for yourself with our [Colab demo](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb) or in the [Hugging Face Space 🤗](https://huggingface.co/spaces/facebook/cotracker).
**Updates:**
- [December 27, 2023] 📣 CoTracker2 is now available! It can now track many more (up to **265*265**!) points jointly and it has a cleaner and more memory-efficient implementation. It also supports online processing. See the [updated paper](https://arxiv.org/abs/2307.07635) for more details. The old version remains available [here](https://github.com/facebookresearch/co-tracker/tree/8d364031971f6b3efec945dd15c468a183e58212).
- [September 5, 2023] 📣 You can now run our Gradio demo [locally](./gradio_demo/app.py)!
## Quick start
The easiest way to use CoTracker is to load a pretrained model from `torch.hub`:
### Offline mode:
```pip install imageio[ffmpeg]```, then:
```python
import torch
# Download the video
url = 'https://github.com/facebookresearch/co-tracker/blob/main/assets/apple.mp4'
import imageio.v3 as iio
frames = iio.imread(url, plugin="FFMPEG") # plugin="pyav"
device = 'cuda'
grid_size = 10
video = torch.tensor(frames).permute(0, 3, 1, 2)[None].float().to(device) # B T C H W
# Run Offline CoTracker:
cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker2").to(device)
pred_tracks, pred_visibility = cotracker(video, grid_size=grid_size) # B T N 2, B T N 1
```
### Online mode:
```python
cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online").to(device)
# Run Online CoTracker, the same model with a different API:
# Initialize online processing
cotracker(video_chunk=video, is_first_step=True, grid_size=grid_size)
# Process the video
for ind in range(0, video.shape[1] - cotracker.step, cotracker.step):
pred_tracks, pred_visibility = cotracker(
video_chunk=video[:, ind : ind + cotracker.step * 2]
) # B T N 2, B T N 1
```
Online processing is more memory-efficient and allows for the processing of longer videos. However, in the example provided above, the video length is known! See [the online demo](./online_demo.py) for an example of tracking from an online stream with an unknown video length.
### Visualize predicted tracks:
```pip install matplotlib```, then:
```python
from cotracker.utils.visualizer import Visualizer
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
vis.visualize(video, pred_tracks, pred_visibility)
```
We offer a number of other ways to interact with CoTracker:
1. Interactive Gradio demo:
- A demo is available in the [`facebook/cotracker` Hugging Face Space 🤗](https://huggingface.co/spaces/facebook/cotracker).
- You can use the gradio demo locally by running [`python -m gradio_demo.app`](./gradio_demo/app.py) after installing the required packages: `pip install -r gradio_demo/requirements.txt`.
2. Jupyter notebook:
- You can run the notebook in
[Google Colab](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb).
- Or explore the notebook located at [`notebooks/demo.ipynb`](./notebooks/demo.ipynb).
2. You can [install](#installation-instructions) CoTracker _locally_ and then:
- Run an *offline* demo with 10 ⨉ 10 points sampled on a grid on the first frame of a video (results will be saved to `./saved_videos/demo.mp4`)):
```bash
python demo.py --grid_size 10
```
- Run an *online* demo:
```bash
python online_demo.py
```
A GPU is strongly recommended for using CoTracker locally.
<img width="500" src="./assets/bmx-bumps.gif" />
### Update: September 5, 2023
📣 You can now run our Gradio demo [locally](./gradio_demo/app.py)!
## Installation Instructions
Ensure you have both PyTorch and TorchVision installed on your system. Follow the instructions [here](https://pytorch.org/get-started/locally/) for the installation. We strongly recommend installing both PyTorch and TorchVision with CUDA support.
You can use a Pretrained Model via PyTorch Hub, as described above, or install CoTracker from this GitHub repo.
This is the best way if you need to run our local demo or evaluate/train CoTracker.
### Pretrained models via PyTorch Hub
The easiest way to use CoTracker is to load a pretrained model from torch.hub:
```
pip install einops timm tqdm
```
```
import torch
import timm
import einops
import tqdm
Ensure you have both _PyTorch_ and _TorchVision_ installed on your system. Follow the instructions [here](https://pytorch.org/get-started/locally/) for the installation.
We strongly recommend installing both PyTorch and TorchVision with CUDA support, although for small tasks CoTracker can be run on CPU.
cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker_w8")
```
Another option is to install it from this gihub repo. That's the best way if you need to run our demo or evaluate / train CoTracker:
### Steps to Install CoTracker and its dependencies:
```
### Install a Development Version
```bash
git clone https://github.com/facebookresearch/co-tracker
cd co-tracker
pip install -e .
pip install opencv-python einops timm matplotlib moviepy flow_vis
pip install matplotlib flow_vis tqdm tensorboard
```
You can manually download the CoTracker2 checkpoint from the links below and place it in the `checkpoints` folder as follows:
### Download Model Weights:
```
mkdir checkpoints
```bash
mkdir -p checkpoints
cd checkpoints
wget https://huggingface.co/facebook/cotracker/resolve/main/cotracker2.pth
cd ..
```
For old checkpoints, see [this section](#previous-version).
## Evaluation
To reproduce the results presented in the paper, download the following datasets:
- [TAP-Vid](https://github.com/deepmind/tapnet)
- [Dynamic Replica](https://dynamic-stereo.github.io/)
And install the necessary dependencies:
```bash
pip install hydra-core==1.1.0 mediapy
```
Then, execute the following command to evaluate on TAP-Vid DAVIS:
```bash
python ./cotracker/evaluation/evaluate.py --config-name eval_tapvid_davis_first exp_dir=./eval_outputs dataset_root=your/tapvid/path
```
By default, evaluation will be slow since it is done for one target point at a time, which ensures robustness and fairness, as described in the paper.
We have fixed some bugs and retrained the model after updating the paper. These are the numbers that you should be able to reproduce using the released checkpoint and the current version of the codebase:
| | DAVIS First, AJ | DAVIS First, $\delta_\text{avg}^\text{vis}$ | DAVIS First, OA | DAVIS Strided, AJ | DAVIS Strided, $\delta_\text{avg}^\text{vis}$ | DAVIS Strided, OA | DR, $\delta_\text{avg}$| DR, $\delta_\text{avg}^\text{vis}$| DR, $\delta_\text{avg}^\text{occ}$|
| :---: |:---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| CoTracker2, 27.12.23 | 60.9 | 75.4 | 88.4 | 65.1 | 79.0 | 89.4 | 61.4 | 68.4 | 38.2
## Training
To train the CoTracker as described in our paper, you first need to generate annotations for [Google Kubric](https://github.com/google-research/kubric) MOVI-f dataset.
Instructions for annotation generation can be found [here](https://github.com/deepmind/tapnet).
You can also find a discussion on dataset generation in [this issue](https://github.com/facebookresearch/co-tracker/issues/8).
Once you have the annotated dataset, you need to make sure you followed the steps for evaluation setup and install the training dependencies:
```bash
pip install pytorch_lightning==1.6.0 tensorboard
```
Now you can launch training on Kubric.
Our model was trained for 50000 iterations on 32 GPUs (4 nodes with 8 GPUs).
Modify _dataset_root_ and _ckpt_path_ accordingly before running this command. For training on 4 nodes, add `--num_nodes 4`.
```bash
python train.py --batch_size 1 \
--num_steps 50000 --ckpt_path ./ --dataset_root ./datasets --model_name cotracker \
--save_freq 200 --sequence_len 24 --eval_datasets dynamic_replica tapvid_davis_first \
--traj_per_sample 768 --sliding_window_len 8 \
--num_virtual_tracks 64 --model_stride 4
```
## Development
### Building the documentation
To build CoTracker documentation, first install the dependencies:
```bash
pip install sphinx
pip install sphinxcontrib-bibtex
```
Then you can use this command to generate the documentation in the `docs/_build/html` folder:
```bash
make -C docs html
```
## Previous version
The old version of the code is available [here](https://github.com/facebookresearch/co-tracker/tree/8d364031971f6b3efec945dd15c468a183e58212).
You can also download the corresponding checkpoints:
```bash
wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth
wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_12.pth
wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_8_wind_16.pth
cd ..
```
## Usage:
We offer a number of ways to interact with CoTracker:
1. A demo is available in the [`facebook/cotracker` Hugging Face Space](https://huggingface.co/spaces/facebook/cotracker).
2. You can run the extended demo in Colab:
[Colab notebook](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb)
3. You can use the gradio demo locally by running [`python -m gradio_demo.app`](./gradio_demo/app.py) after installing the required packages: ```pip install -r gradio_demo/requirements.txt```.
4. You can play with CoTracker by running the Jupyter notebook located at [`notebooks/demo.ipynb`](./notebooks/demo.ipynb) locally (if you have a GPU).
5. Finally, you can run a local demo with 10*10 points sampled on a grid on the first frame of a video:
```
python demo.py --grid_size 10
```
## Evaluation
To reproduce the results presented in the paper, download the following datasets:
- [TAP-Vid](https://github.com/deepmind/tapnet)
- [BADJA](https://github.com/benjiebob/BADJA)
- [ZJU-Mocap (FastCapture)](https://arxiv.org/abs/2303.11898)
And install the necessary dependencies:
```
pip install hydra-core==1.1.0 mediapy
```
Then, execute the following command to evaluate on BADJA:
```
python ./cotracker/evaluation/evaluate.py --config-name eval_badja exp_dir=./eval_outputs dataset_root=your/badja/path
```
By default, evaluation will be slow since it is done for one target point at a time, which ensures robustness and fairness, as described in the paper.
## Training
To train the CoTracker as described in our paper, you first need to generate annotations for [Google Kubric](https://github.com/google-research/kubric) MOVI-f dataset. Instructions for annotation generation can be found [here](https://github.com/deepmind/tapnet).
Once you have the annotated dataset, you need to make sure you followed the steps for evaluation setup and install the training dependencies:
```
pip install pytorch_lightning==1.6.0 tensorboard
```
Now you can launch training on Kubric. Our model was trained for 50000 iterations on 32 GPUs (4 nodes with 8 GPUs).
Modify *dataset_root* and *ckpt_path* accordingly before running this command:
```
python train.py --batch_size 1 --num_workers 28 \
--num_steps 50000 --ckpt_path ./ --dataset_root ./datasets --model_name cotracker \
--save_freq 200 --sequence_len 24 --eval_datasets tapvid_davis_first badja \
--traj_per_sample 256 --sliding_window_len 8 --updateformer_space_depth 6 --updateformer_time_depth 6 \
--save_every_n_epoch 10 --evaluate_every_n_epoch 10 --model_stride 4
```
## License
The majority of CoTracker is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Particle Video Revisited is licensed under the MIT license, TAP-Vid is licensed under the Apache 2.0 license.
## Acknowledgments
We would like to thank [PIPs](https://github.com/aharley/pips) and [TAP-Vid](https://github.com/deepmind/tapnet) for publicly releasing their code and data. We also want to thank [Luke Melas-Kyriazi](https://lukemelas.github.io/) for proofreading the paper, [Jianyuan Wang](https://jytime.github.io/), [Roman Shapovalov](https://shapovalov.ro/) and [Adam W. Harley](https://adamharley.com/) for the insightful discussions.
## Citing CoTracker
If you find our repository useful, please consider giving it a star ⭐ and citing our paper in your work:
```
```bibtex
@article{karaev2023cotracker,
title={CoTracker: It is Better to Track Together},
author={Nikita Karaev and Ignacio Rocco and Benjamin Graham and Natalia Neverova and Andrea Vedaldi and Christian Rupprecht},

Binary file not shown.

BIN
assets/teaser.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.0 MiB

View File

@ -1,390 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import numpy as np
import os
import json
import imageio
import cv2
from enum import Enum
from cotracker.datasets.utils import CoTrackerData, resize_sample
IGNORE_ANIMALS = [
# "bear.json",
# "camel.json",
"cat_jump.json"
# "cows.json",
# "dog.json",
# "dog-agility.json",
# "horsejump-high.json",
# "horsejump-low.json",
# "impala0.json",
# "rs_dog.json"
"tiger.json"
]
class SMALJointCatalog(Enum):
# body_0 = 0
# body_1 = 1
# body_2 = 2
# body_3 = 3
# body_4 = 4
# body_5 = 5
# body_6 = 6
# upper_right_0 = 7
upper_right_1 = 8
upper_right_2 = 9
upper_right_3 = 10
# upper_left_0 = 11
upper_left_1 = 12
upper_left_2 = 13
upper_left_3 = 14
neck_lower = 15
# neck_upper = 16
# lower_right_0 = 17
lower_right_1 = 18
lower_right_2 = 19
lower_right_3 = 20
# lower_left_0 = 21
lower_left_1 = 22
lower_left_2 = 23
lower_left_3 = 24
tail_0 = 25
# tail_1 = 26
# tail_2 = 27
tail_3 = 28
# tail_4 = 29
# tail_5 = 30
tail_6 = 31
jaw = 32
nose = 33 # ADDED JOINT FOR VERTEX 1863
# chin = 34 # ADDED JOINT FOR VERTEX 26
right_ear = 35 # ADDED JOINT FOR VERTEX 149
left_ear = 36 # ADDED JOINT FOR VERTEX 2124
class SMALJointInfo:
def __init__(self):
# These are the
self.annotated_classes = np.array(
[
8,
9,
10, # upper_right
12,
13,
14, # upper_left
15, # neck
18,
19,
20, # lower_right
22,
23,
24, # lower_left
25,
28,
31, # tail
32,
33, # head
35, # right_ear
36,
]
) # left_ear
self.annotated_markers = np.array(
[
cv2.MARKER_CROSS,
cv2.MARKER_STAR,
cv2.MARKER_TRIANGLE_DOWN,
cv2.MARKER_CROSS,
cv2.MARKER_STAR,
cv2.MARKER_TRIANGLE_DOWN,
cv2.MARKER_CROSS,
cv2.MARKER_CROSS,
cv2.MARKER_STAR,
cv2.MARKER_TRIANGLE_DOWN,
cv2.MARKER_CROSS,
cv2.MARKER_STAR,
cv2.MARKER_TRIANGLE_DOWN,
cv2.MARKER_CROSS,
cv2.MARKER_STAR,
cv2.MARKER_TRIANGLE_DOWN,
cv2.MARKER_CROSS,
cv2.MARKER_STAR,
cv2.MARKER_CROSS,
cv2.MARKER_CROSS,
]
)
self.joint_regions = np.array(
[
0,
0,
0,
0,
0,
0,
0,
1,
1,
1,
1,
2,
2,
2,
2,
3,
3,
4,
4,
4,
4,
5,
5,
5,
5,
6,
6,
6,
6,
6,
6,
6,
7,
7,
7,
8,
9,
]
)
self.annotated_joint_region = self.joint_regions[self.annotated_classes]
self.region_colors = np.array(
[
[250, 190, 190], # body, light pink
[60, 180, 75], # upper_right, green
[230, 25, 75], # upper_left, red
[128, 0, 0], # neck, maroon
[0, 130, 200], # lower_right, blue
[255, 255, 25], # lower_left, yellow
[240, 50, 230], # tail, majenta
[245, 130, 48], # jaw / nose / chin, orange
[29, 98, 115], # right_ear, turquoise
[255, 153, 204],
]
) # left_ear, pink
self.joint_colors = np.array(self.region_colors)[self.annotated_joint_region]
class BADJAData:
def __init__(self, data_root, complete=False):
annotations_path = os.path.join(data_root, "joint_annotations")
self.animal_dict = {}
self.animal_count = 0
self.smal_joint_info = SMALJointInfo()
for __, animal_json in enumerate(sorted(os.listdir(annotations_path))):
if animal_json not in IGNORE_ANIMALS:
json_path = os.path.join(annotations_path, animal_json)
with open(json_path) as json_data:
animal_joint_data = json.load(json_data)
filenames = []
segnames = []
joints = []
visible = []
first_path = animal_joint_data[0]["segmentation_path"]
last_path = animal_joint_data[-1]["segmentation_path"]
first_frame = first_path.split("/")[-1]
last_frame = last_path.split("/")[-1]
if not "extra_videos" in first_path:
animal = first_path.split("/")[-2]
first_frame_int = int(first_frame.split(".")[0])
last_frame_int = int(last_frame.split(".")[0])
for fr in range(first_frame_int, last_frame_int + 1):
ref_file_name = os.path.join(
data_root,
"DAVIS/JPEGImages/Full-Resolution/%s/%05d.jpg"
% (animal, fr),
)
ref_seg_name = os.path.join(
data_root,
"DAVIS/Annotations/Full-Resolution/%s/%05d.png"
% (animal, fr),
)
foundit = False
for ind, image_annotation in enumerate(animal_joint_data):
file_name = os.path.join(
data_root, image_annotation["image_path"]
)
seg_name = os.path.join(
data_root, image_annotation["segmentation_path"]
)
if file_name == ref_file_name:
foundit = True
label_ind = ind
if foundit:
image_annotation = animal_joint_data[label_ind]
file_name = os.path.join(
data_root, image_annotation["image_path"]
)
seg_name = os.path.join(
data_root, image_annotation["segmentation_path"]
)
joint = np.array(image_annotation["joints"])
vis = np.array(image_annotation["visibility"])
else:
file_name = ref_file_name
seg_name = ref_seg_name
joint = None
vis = None
filenames.append(file_name)
segnames.append(seg_name)
joints.append(joint)
visible.append(vis)
if len(filenames):
self.animal_dict[self.animal_count] = (
filenames,
segnames,
joints,
visible,
)
self.animal_count += 1
print("Loaded BADJA dataset")
def get_loader(self):
for __ in range(int(1e6)):
animal_id = np.random.choice(len(self.animal_dict.keys()))
filenames, segnames, joints, visible = self.animal_dict[animal_id]
image_id = np.random.randint(0, len(filenames))
seg_file = segnames[image_id]
image_file = filenames[image_id]
joints = joints[image_id].copy()
joints = joints[self.smal_joint_info.annotated_classes]
visible = visible[image_id][self.smal_joint_info.annotated_classes]
rgb_img = imageio.imread(image_file) # , mode='RGB')
sil_img = imageio.imread(seg_file) # , mode='RGB')
rgb_h, rgb_w, _ = rgb_img.shape
sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST)
yield rgb_img, sil_img, joints, visible, image_file
def get_video(self, animal_id):
filenames, segnames, joint, visible = self.animal_dict[animal_id]
rgbs = []
segs = []
joints = []
visibles = []
for s in range(len(filenames)):
image_file = filenames[s]
rgb_img = imageio.imread(image_file) # , mode='RGB')
rgb_h, rgb_w, _ = rgb_img.shape
seg_file = segnames[s]
sil_img = imageio.imread(seg_file) # , mode='RGB')
sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST)
jo = joint[s]
if jo is not None:
joi = joint[s].copy()
joi = joi[self.smal_joint_info.annotated_classes]
vis = visible[s][self.smal_joint_info.annotated_classes]
else:
joi = None
vis = None
rgbs.append(rgb_img)
segs.append(sil_img)
joints.append(joi)
visibles.append(vis)
return rgbs, segs, joints, visibles, filenames[0]
class BadjaDataset(torch.utils.data.Dataset):
def __init__(
self, data_root, max_seq_len=1000, dataset_resolution=(384, 512)
):
self.data_root = data_root
self.badja_data = BADJAData(data_root)
self.max_seq_len = max_seq_len
self.dataset_resolution = dataset_resolution
print(
"found %d unique videos in %s"
% (self.badja_data.animal_count, self.data_root)
)
def __getitem__(self, index):
rgbs, segs, joints, visibles, filename = self.badja_data.get_video(index)
S = len(rgbs)
H, W, __ = rgbs[0].shape
H, W, __ = segs[0].shape
N, __ = joints[0].shape
# let's eliminate the Nones
# note the first one is guaranteed present
for s in range(1, S):
if joints[s] is None:
joints[s] = np.zeros_like(joints[0])
visibles[s] = np.zeros_like(visibles[0])
# eliminate the mystery dim
segs = [seg[:, :, 0] for seg in segs]
rgbs = np.stack(rgbs, 0)
segs = np.stack(segs, 0)
trajs = np.stack(joints, 0)
visibles = np.stack(visibles, 0)
rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0, 3, 1, 2).float()
segs = torch.from_numpy(segs).reshape(S, 1, H, W).float()
trajs = torch.from_numpy(trajs).reshape(S, N, 2).float()
visibles = torch.from_numpy(visibles).reshape(S, N)
rgbs = rgbs[: self.max_seq_len]
segs = segs[: self.max_seq_len]
trajs = trajs[: self.max_seq_len]
visibles = visibles[: self.max_seq_len]
# apparently the coords are in yx order
trajs = torch.flip(trajs, [2])
if "extra_videos" in filename:
seq_name = filename.split("/")[-3]
else:
seq_name = filename.split("/")[-2]
rgbs, trajs, segs = resize_sample(rgbs, trajs, segs, self.dataset_resolution)
return CoTrackerData(rgbs, segs, trajs, visibles, seq_name=seq_name)
def __len__(self):
return self.badja_data.animal_count

View File

@ -0,0 +1,166 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import dataclasses
import numpy as np
from dataclasses import Field, MISSING
from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple
_X = TypeVar("_X")
def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:
"""
Loads to a @dataclass or collection hierarchy including dataclasses
from a json recursively.
Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).
raises KeyError if json has keys not mapping to the dataclass fields.
Args:
f: Either a path to a file, or a file opened for writing.
cls: The class of the loaded dataclass.
binary: Set to True if `f` is a file handle, else False.
"""
if binary:
asdict = json.loads(f.read().decode("utf8"))
else:
asdict = json.load(f)
# in the list case, run a faster "vectorized" version
cls = get_args(cls)[0]
res = list(_dataclass_list_from_dict_list(asdict, cls))
return res
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
if get_origin(type_) is Union:
args = get_args(type_)
if len(args) == 2 and args[1] == type(None): # noqa E721
return True, args[0]
if type_ is Any:
return True, Any
return False, type_
def _unwrap_type(tp):
# strips Optional wrapper, if any
if get_origin(tp) is Union:
args = get_args(tp)
if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721
# this is typing.Optional
return args[0] if args[1] is type(None) else args[1] # noqa: E721
return tp
def _get_dataclass_field_default(field: Field) -> Any:
if field.default_factory is not MISSING:
# pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,
# dataclasses._DefaultFactory[typing.Any]]` is not a function.
return field.default_factory()
elif field.default is not MISSING:
return field.default
else:
return None
def _dataclass_list_from_dict_list(dlist, typeannot):
"""
Vectorised version of `_dataclass_from_dict`.
The output should be equivalent to
`[_dataclass_from_dict(d, typeannot) for d in dlist]`.
Args:
dlist: list of objects to convert.
typeannot: type of each of those objects.
Returns:
iterator or list over converted objects of the same length as `dlist`.
Raises:
ValueError: it assumes the objects have None's in consistent places across
objects, otherwise it would ignore some values. This generally holds for
auto-generated annotations, but otherwise use `_dataclass_from_dict`.
"""
cls = get_origin(typeannot) or typeannot
if typeannot is Any:
return dlist
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
return dlist
if any(obj is None for obj in dlist):
# filter out Nones and recurse on the resulting list
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
idx, notnone = zip(*idx_notnone)
converted = _dataclass_list_from_dict_list(notnone, typeannot)
res = [None] * len(dlist)
for i, obj in zip(idx, converted):
res[i] = obj
return res
is_optional, contained_type = _resolve_optional(typeannot)
if is_optional:
return _dataclass_list_from_dict_list(dlist, contained_type)
# otherwise, we dispatch by the type of the provided annotation to convert to
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
# For namedtuple, call the function recursively on the lists of corresponding keys
types = cls.__annotations__.values()
dlist_T = zip(*dlist)
res_T = [
_dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types)
]
return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]
elif issubclass(cls, (list, tuple)):
# For list/tuple, call the function recursively on the lists of corresponding positions
types = get_args(typeannot)
if len(types) == 1: # probably List; replicate for all items
types = types * len(dlist[0])
dlist_T = zip(*dlist)
res_T = (
_dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types)
)
if issubclass(cls, tuple):
return list(zip(*res_T))
else:
return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)]
elif issubclass(cls, dict):
# For the dictionary, call the function recursively on concatenated keys and vertices
key_t, val_t = get_args(typeannot)
all_keys_res = _dataclass_list_from_dict_list(
[k for obj in dlist for k in obj.keys()], key_t
)
all_vals_res = _dataclass_list_from_dict_list(
[k for obj in dlist for k in obj.values()], val_t
)
indices = np.cumsum([len(obj) for obj in dlist])
assert indices[-1] == len(all_keys_res)
keys = np.split(list(all_keys_res), indices[:-1])
all_vals_res_iter = iter(all_vals_res)
return [cls(zip(k, all_vals_res_iter)) for k in keys]
elif not dataclasses.is_dataclass(typeannot):
return dlist
# dataclass node: 2nd recursion base; call the function recursively on the lists
# of the corresponding fields
assert dataclasses.is_dataclass(cls)
fieldtypes = {
f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f))
for f in dataclasses.fields(typeannot)
}
# NOTE the default object is shared here
key_lists = (
_dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_)
for k, (type_, default) in fieldtypes.items()
)
transposed = zip(*key_lists)
return [cls(*vals_as_tuple) for vals_as_tuple in transposed]

View File

@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import gzip
import torch
import numpy as np
import torch.utils.data as data
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Optional, Any, Dict, Tuple
from cotracker.datasets.utils import CoTrackerData
from cotracker.datasets.dataclass_utils import load_dataclass
@dataclass
class ImageAnnotation:
# path to jpg file, relative w.r.t. dataset_root
path: str
# H x W
size: Tuple[int, int]
@dataclass
class DynamicReplicaFrameAnnotation:
"""A dataclass used to load annotations from json."""
# can be used to join with `SequenceAnnotation`
sequence_name: str
# 0-based, continuous frame number within sequence
frame_number: int
# timestamp in seconds from the video start
frame_timestamp: float
image: ImageAnnotation
meta: Optional[Dict[str, Any]] = None
camera_name: Optional[str] = None
trajectories: Optional[str] = None
class DynamicReplicaDataset(data.Dataset):
def __init__(
self,
root,
split="valid",
traj_per_sample=256,
crop_size=None,
sample_len=-1,
only_first_n_samples=-1,
rgbd_input=False,
):
super(DynamicReplicaDataset, self).__init__()
self.root = root
self.sample_len = sample_len
self.split = split
self.traj_per_sample = traj_per_sample
self.rgbd_input = rgbd_input
self.crop_size = crop_size
frame_annotations_file = f"frame_annotations_{split}.jgz"
self.sample_list = []
with gzip.open(
os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8"
) as zipfile:
frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation])
seq_annot = defaultdict(list)
for frame_annot in frame_annots_list:
if frame_annot.camera_name == "left":
seq_annot[frame_annot.sequence_name].append(frame_annot)
for seq_name in seq_annot.keys():
seq_len = len(seq_annot[seq_name])
step = self.sample_len if self.sample_len > 0 else seq_len
counter = 0
for ref_idx in range(0, seq_len, step):
sample = seq_annot[seq_name][ref_idx : ref_idx + step]
self.sample_list.append(sample)
counter += 1
if only_first_n_samples > 0 and counter >= only_first_n_samples:
break
def __len__(self):
return len(self.sample_list)
def crop(self, rgbs, trajs):
T, N, _ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
H_new = H
W_new = W
# simple random crop
y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2
x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
trajs[:, :, 0] -= x0
trajs[:, :, 1] -= y0
return rgbs, trajs
def __getitem__(self, index):
sample = self.sample_list[index]
T = len(sample)
rgbs, visibilities, traj_2d = [], [], []
H, W = sample[0].image.size
image_size = (H, W)
for i in range(T):
traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"])
traj = torch.load(traj_path)
visibilities.append(traj["verts_inds_vis"].numpy())
rgbs.append(traj["img"].numpy())
traj_2d.append(traj["traj_2d"].numpy()[..., :2])
traj_2d = np.stack(traj_2d)
visibility = np.stack(visibilities)
T, N, D = traj_2d.shape
# subsample trajectories for augmentations
visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample]
traj_2d = traj_2d[:, visible_inds_sampled]
visibility = visibility[:, visible_inds_sampled]
if self.crop_size is not None:
rgbs, traj_2d = self.crop(rgbs, traj_2d)
H, W, _ = rgbs[0].shape
image_size = self.crop_size
visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False
visibility[traj_2d[:, :, 0] < 0] = False
visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False
visibility[traj_2d[:, :, 1] < 0] = False
# filter out points that're visible for less than 10 frames
visible_inds_resampled = visibility.sum(0) > 10
traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled])
visibility = torch.from_numpy(visibility[:, visible_inds_resampled])
rgbs = np.stack(rgbs, 0)
video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float()
return CoTrackerData(
video=video,
trajectory=traj_2d,
visibility=visibility,
valid=torch.ones(T, N),
seq_name=sample[0].sequence_name,
)

View File

@ -1,72 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
# from PIL import Image
import imageio
import numpy as np
from cotracker.datasets.utils import CoTrackerData, resize_sample
class FastCaptureDataset(torch.utils.data.Dataset):
def __init__(
self,
data_root,
max_seq_len=50,
max_num_points=20,
dataset_resolution=(384, 512),
):
self.data_root = data_root
self.seq_names = os.listdir(os.path.join(data_root, "renders_local_rm"))
self.pth_dir = os.path.join(data_root, "zju_tracking")
self.max_seq_len = max_seq_len
self.max_num_points = max_num_points
self.dataset_resolution = dataset_resolution
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
def __getitem__(self, index):
seq_name = self.seq_names[index]
spath = os.path.join(self.data_root, "renders_local_rm", seq_name)
pthpath = os.path.join(self.pth_dir, seq_name + ".pth")
rgbs = []
img_paths = sorted(os.listdir(spath))
for i, img_path in enumerate(img_paths):
if i < self.max_seq_len:
rgbs.append(imageio.imread(os.path.join(spath, img_path)))
annot_dict = torch.load(pthpath)
traj_2d = annot_dict["traj_2d"][:, :, : self.max_seq_len]
visibility = annot_dict["visibility"][:, : self.max_seq_len]
S = len(rgbs)
H, W, __ = rgbs[0].shape
*_, S = traj_2d.shape
visibile_pts_first_frame_inds = (visibility[:, 0] > 0).nonzero(as_tuple=False)[
:, 0
]
torch.manual_seed(0)
point_inds = torch.randperm(len(visibile_pts_first_frame_inds))[
: self.max_num_points
]
visible_inds_sampled = visibile_pts_first_frame_inds[point_inds]
rgbs = np.stack(rgbs, 0)
rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0, 3, 1, 2).float()
segs = torch.ones(S, 1, H, W).float()
trajs = traj_2d[visible_inds_sampled].permute(2, 0, 1).float()
visibles = visibility[visible_inds_sampled].permute(1, 0)
rgbs, trajs, segs = resize_sample(rgbs, trajs, segs, self.dataset_resolution)
return CoTrackerData(rgbs, segs, trajs, visibles, seq_name=seq_name)
def __len__(self):
return len(self.seq_names)

View File

@ -6,6 +6,7 @@
import os
import torch
import cv2
import imageio
import numpy as np
@ -13,7 +14,6 @@ import numpy as np
from cotracker.datasets.utils import CoTrackerData
from torchvision.transforms import ColorJitter, GaussianBlur
from PIL import Image
import cv2
class CoTrackerDataset(torch.utils.data.Dataset):
@ -37,9 +37,7 @@ class CoTrackerDataset(torch.utils.data.Dataset):
self.crop_size = crop_size
# photometric augmentation
self.photo_aug = ColorJitter(
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14
)
self.photo_aug = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14)
self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
self.blur_aug_prob = 0.25
@ -77,12 +75,7 @@ class CoTrackerDataset(torch.utils.data.Dataset):
print("warning: sampling failed")
# fake sample, so we can still collate
sample = CoTrackerData(
video=torch.zeros(
(self.seq_len, 3, self.crop_size[0], self.crop_size[1])
),
segmentation=torch.zeros(
(self.seq_len, 1, self.crop_size[0], self.crop_size[1])
),
video=torch.zeros((self.seq_len, 3, self.crop_size[0], self.crop_size[1])),
trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)),
visibility=torch.zeros((self.seq_len, self.traj_per_sample)),
valid=torch.zeros((self.seq_len, self.traj_per_sample)),
@ -105,23 +98,16 @@ class CoTrackerDataset(torch.utils.data.Dataset):
for _ in range(
np.random.randint(1, self.eraser_max + 1)
): # number of times to occlude
xc = np.random.randint(0, W)
yc = np.random.randint(0, H)
dx = np.random.randint(
self.eraser_bounds[0], self.eraser_bounds[1]
)
dy = np.random.randint(
self.eraser_bounds[0], self.eraser_bounds[1]
)
dx = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
dy = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
mean_color = np.mean(
rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0
)
mean_color = np.mean(rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0)
rgbs[i][y0:y1, x0:x1, :] = mean_color
occ_inds = np.logical_and(
@ -132,14 +118,11 @@ class CoTrackerDataset(torch.utils.data.Dataset):
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
if replace:
rgbs_alt = [
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
for rgb in rgbs
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs
]
rgbs_alt = [
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
for rgb in rgbs_alt
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs_alt
]
############ replace transform (per image after the first) ############
@ -152,12 +135,8 @@ class CoTrackerDataset(torch.utils.data.Dataset):
): # number of times to occlude
xc = np.random.randint(0, W)
yc = np.random.randint(0, H)
dx = np.random.randint(
self.replace_bounds[0], self.replace_bounds[1]
)
dy = np.random.randint(
self.replace_bounds[0], self.replace_bounds[1]
)
dx = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
dy = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
@ -181,17 +160,11 @@ class CoTrackerDataset(torch.utils.data.Dataset):
############ photometric augmentation ############
if np.random.rand() < self.color_aug_prob:
# random per-frame amount of aug
rgbs = [
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
for rgb in rgbs
]
rgbs = [np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
if np.random.rand() < self.blur_aug_prob:
# random per-frame amount of blur
rgbs = [
np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8)
for rgb in rgbs
]
rgbs = [np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
return rgbs, trajs, visibles
@ -212,9 +185,7 @@ class CoTrackerDataset(torch.utils.data.Dataset):
pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
rgbs = [
np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs
]
rgbs = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs]
trajs[:, :, 0] += pad_x0
trajs[:, :, 1] += pad_y0
H, W = rgbs[0].shape[:2]
@ -263,12 +234,9 @@ class CoTrackerDataset(torch.utils.data.Dataset):
H_new = np.clip(H_new, self.crop_size[0] + 10, None)
W_new = np.clip(W_new, self.crop_size[1] + 10, None)
# recompute scale in case we clipped
scale_x = W_new / float(W)
scale_y = H_new / float(H)
rgbs_scaled.append(
cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR)
)
scale_x = (W_new - 1) / float(W - 1)
scale_y = (H_new - 1) / float(H - 1)
rgbs_scaled.append(cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR))
trajs[s, :, 0] *= scale_x
trajs[s, :, 1] *= scale_y
rgbs = rgbs_scaled
@ -292,22 +260,16 @@ class CoTrackerDataset(torch.utils.data.Dataset):
for s in range(S):
# on each frame, shift a bit more
if s == 1:
offset_x = np.random.randint(
-self.max_crop_offset, self.max_crop_offset
)
offset_y = np.random.randint(
-self.max_crop_offset, self.max_crop_offset
)
offset_x = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
offset_y = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
elif s > 1:
offset_x = int(
offset_x * 0.8
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1)
* 0.2
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
)
offset_y = int(
offset_y * 0.8
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1)
* 0.2
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
)
x0 = x0 + offset_x
y0 = y0 + offset_y
@ -362,20 +324,9 @@ class CoTrackerDataset(torch.utils.data.Dataset):
W_new = W
# simple random crop
y0 = (
0
if self.crop_size[0] >= H_new
else np.random.randint(0, H_new - self.crop_size[0])
)
x0 = (
0
if self.crop_size[1] >= W_new
else np.random.randint(0, W_new - self.crop_size[1])
)
rgbs = [
rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
for rgb in rgbs
]
y0 = 0 if self.crop_size[0] >= H_new else np.random.randint(0, H_new - self.crop_size[0])
x0 = 0 if self.crop_size[1] >= W_new else np.random.randint(0, W_new - self.crop_size[1])
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
trajs[:, :, 0] -= x0
trajs[:, :, 1] -= y0
@ -442,9 +393,7 @@ class KubricMovifDataset(CoTrackerDataset):
traj_2d = np.transpose(traj_2d, (1, 0, 2))
visibility = np.transpose(np.logical_not(visibility), (1, 0))
if self.use_augs:
rgbs, traj_2d, visibility = self.add_photometric_augs(
rgbs, traj_2d, visibility
)
rgbs, traj_2d, visibility = self.add_photometric_augs(rgbs, traj_2d, visibility)
rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility)
else:
rgbs, traj_2d = self.crop(rgbs, traj_2d)
@ -462,9 +411,9 @@ class KubricMovifDataset(CoTrackerDataset):
if self.sample_vis_1st_frame:
visibile_pts_inds = visibile_pts_first_frame_inds
else:
visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(
as_tuple=False
)[:, 0]
visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(as_tuple=False)[
:, 0
]
visibile_pts_inds = torch.cat(
(visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
)
@ -479,10 +428,8 @@ class KubricMovifDataset(CoTrackerDataset):
valids = torch.ones((self.seq_len, self.traj_per_sample))
rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float()
segs = torch.ones((self.seq_len, 1, self.crop_size[0], self.crop_size[1]))
sample = CoTrackerData(
video=rgbs,
segmentation=segs,
trajectory=trajs,
visibility=visibles,
valid=valids,

View File

@ -179,12 +179,9 @@ class TapVidDataset(torch.utils.data.Dataset):
target_points = self.points_dataset[video_name]["points"]
if self.resize_to_256:
frames = resize_video(frames, [256, 256])
target_points *= np.array([256, 256])
target_points *= np.array([255, 255]) # 1 should be mapped to 256-1
else:
target_points *= np.array([frames.shape[2], frames.shape[1]])
T, H, W, C = frames.shape
N, T, D = target_points.shape
target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1])
target_occ = self.points_dataset[video_name]["occluded"]
if self.queried_first:
@ -193,21 +190,15 @@ class TapVidDataset(torch.utils.data.Dataset):
converted = sample_queries_strided(target_occ, target_points, frames)
assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
trajs = (
torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float()
) # T, N, D
trajs = torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() # T, N, D
rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
segs = torch.ones(T, 1, H, W).float()
visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[
0
].permute(
visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute(
1, 0
) # T, N
query_points = torch.from_numpy(converted["query_points"])[0] # T, N
return CoTrackerData(
rgbs,
segs,
trajs,
visibles,
seq_name=str(video_name),

View File

@ -19,11 +19,11 @@ class CoTrackerData:
"""
video: torch.Tensor # B, S, C, H, W
segmentation: torch.Tensor # B, S, 1, H, W
trajectory: torch.Tensor # B, S, N, 2
visibility: torch.Tensor # B, S, N
# optional data
valid: Optional[torch.Tensor] = None # B, S, N
segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
seq_name: Optional[str] = None
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
@ -33,19 +33,20 @@ def collate_fn(batch):
Collate function for video tracks data.
"""
video = torch.stack([b.video for b in batch], dim=0)
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
trajectory = torch.stack([b.trajectory for b in batch], dim=0)
visibility = torch.stack([b.visibility for b in batch], dim=0)
query_points = None
query_points = segmentation = None
if batch[0].query_points is not None:
query_points = torch.stack([b.query_points for b in batch], dim=0)
if batch[0].segmentation is not None:
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
seq_name = [b.seq_name for b in batch]
return CoTrackerData(
video,
segmentation,
trajectory,
visibility,
video=video,
trajectory=trajectory,
visibility=visibility,
segmentation=segmentation,
seq_name=seq_name,
query_points=query_points,
)
@ -57,13 +58,18 @@ def collate_fn_train(batch):
"""
gotit = [gotit for _, gotit in batch]
video = torch.stack([b.video for b, _ in batch], dim=0)
segmentation = torch.stack([b.segmentation for b, _ in batch], dim=0)
trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
valid = torch.stack([b.valid for b, _ in batch], dim=0)
seq_name = [b.seq_name for b, _ in batch]
return (
CoTrackerData(video, segmentation, trajectory, visibility, valid, seq_name),
CoTrackerData(
video=video,
trajectory=trajectory,
visibility=visibility,
valid=valid,
seq_name=seq_name,
),
gotit,
)
@ -98,17 +104,3 @@ def dataclass_to_cuda_(obj):
for f in dataclasses.fields(obj):
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
return obj
def resize_sample(rgbs, trajs_g, segs, interp_shape):
S, C, H, W = rgbs.shape
S, N, D = trajs_g.shape
assert D == 2
rgbs = F.interpolate(rgbs, interp_shape, mode="bilinear")
segs = F.interpolate(segs, interp_shape, mode="nearest")
trajs_g[:, :, 0] *= interp_shape[1] / W
trajs_g[:, :, 1] *= interp_shape[0] / H
return rgbs, trajs_g, segs

View File

@ -1,6 +1,6 @@
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: badja
dataset_name: dynamic_replica

View File

@ -1,6 +0,0 @@
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: fastcapture

View File

@ -37,57 +37,7 @@ class Evaluator:
pred_trajectory, pred_visibility = pred_trajectory
else:
pred_visibility = None
if dataset_name == "badja":
sample.segmentation = (sample.segmentation > 0).float()
*_, N, _ = sample.trajectory.shape
accs = []
accs_3px = []
for s1 in range(1, sample.video.shape[1]): # target frame
for n in range(N):
vis = sample.visibility[0, s1, n]
if vis > 0:
coord_e = pred_trajectory[0, s1, n] # 2
coord_g = sample.trajectory[0, s1, n] # 2
dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0))
area = torch.sum(sample.segmentation[0, s1])
# print_('0.2*sqrt(area)', 0.2*torch.sqrt(area))
thr = 0.2 * torch.sqrt(area)
# correct =
accs.append((dist < thr).float())
# print('thr',thr)
accs_3px.append((dist < 3.0).float())
res = torch.mean(torch.stack(accs)) * 100.0
res_3px = torch.mean(torch.stack(accs_3px)) * 100.0
metrics[sample.seq_name[0]] = res.item()
metrics[sample.seq_name[0] + "_accuracy"] = res_3px.item()
print(metrics)
print(
"avg", np.mean([v for k, v in metrics.items() if "accuracy" not in k])
)
print(
"avg acc 3px",
np.mean([v for k, v in metrics.items() if "accuracy" in k]),
)
elif dataset_name == "fastcapture" or ("kubric" in dataset_name):
*_, N, _ = sample.trajectory.shape
accs = []
for s1 in range(1, sample.video.shape[1]): # target frame
for n in range(N):
vis = sample.visibility[0, s1, n]
if vis > 0:
coord_e = pred_trajectory[0, s1, n] # 2
coord_g = sample.trajectory[0, s1, n] # 2
dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0))
thr = 3
correct = (dist < thr).float()
accs.append(correct)
res = torch.mean(torch.stack(accs)) * 100.0
metrics[sample.seq_name[0] + "_accuracy"] = res.item()
print(metrics)
print("avg", np.mean([v for v in metrics.values()]))
elif "tapvid" in dataset_name:
if "tapvid" in dataset_name:
B, T, N, D = sample.trajectory.shape
traj = sample.trajectory.clone()
thr = 0.9
@ -99,7 +49,6 @@ class Evaluator:
if not pred_visibility.dtype == torch.bool:
pred_visibility = pred_visibility > thr
# pred_trajectory
query_points = sample.query_points.clone().cpu().numpy()
pred_visibility = pred_visibility[:, :, :N]
@ -107,15 +56,11 @@ class Evaluator:
gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy()
gt_occluded = (
torch.logical_not(sample.visibility.clone().permute(0, 2, 1))
.cpu()
.numpy()
torch.logical_not(sample.visibility.clone().permute(0, 2, 1)).cpu().numpy()
)
pred_occluded = (
torch.logical_not(pred_visibility.clone().permute(0, 2, 1))
.cpu()
.numpy()
torch.logical_not(pred_visibility.clone().permute(0, 2, 1)).cpu().numpy()
)
pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy()
@ -140,27 +85,79 @@ class Evaluator:
logging.info(f"avg: {metrics['avg']}")
print("metrics", out_metrics)
print("avg", metrics["avg"])
else:
rgbs = sample.video
trajs_g = sample.trajectory
valids = sample.valid
vis_g = sample.visibility
elif dataset_name == "dynamic_replica" or dataset_name == "pointodyssey":
*_, N, _ = sample.trajectory.shape
B, T, N = sample.visibility.shape
H, W = sample.video.shape[-2:]
device = sample.video.device
B, S, C, H, W = rgbs.shape
assert C == 3
B, S, N, D = trajs_g.shape
out_metrics = {}
assert torch.sum(valids) == B * S * N
d_vis_sum = d_occ_sum = d_sum_all = 0.0
thrs = [1, 2, 4, 8, 16]
sx_ = (W - 1) / 255.0
sy_ = (H - 1) / 255.0
sc_py = np.array([sx_, sy_]).reshape([1, 1, 2])
sc_pt = torch.from_numpy(sc_py).float().to(device)
__, first_visible_inds = torch.max(sample.visibility, dim=1)
vis_g = (torch.sum(vis_g, dim=1, keepdim=True) >= 4).float().repeat(1, S, 1)
frame_ids_tensor = torch.arange(T, device=device)[None, :, None].repeat(B, 1, N)
start_tracking_mask = frame_ids_tensor > (first_visible_inds.unsqueeze(1))
ate = torch.norm(pred_trajectory - trajs_g, dim=-1) # B, S, N
for thr in thrs:
d_ = (
torch.norm(
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
dim=-1,
)
< thr
).float() # B,S-1,N
d_occ = (
reduce_masked_mean(d_, (1 - sample.visibility) * start_tracking_mask).item()
* 100.0
)
d_occ_sum += d_occ
out_metrics[f"accuracy_occ_{thr}"] = d_occ
metrics["things_all"] = reduce_masked_mean(ate, valids).item()
metrics["things_vis"] = reduce_masked_mean(ate, valids * vis_g).item()
metrics["things_occ"] = reduce_masked_mean(
ate, valids * (1.0 - vis_g)
).item()
d_vis = (
reduce_masked_mean(d_, sample.visibility * start_tracking_mask).item() * 100.0
)
d_vis_sum += d_vis
out_metrics[f"accuracy_vis_{thr}"] = d_vis
d_all = reduce_masked_mean(d_, start_tracking_mask).item() * 100.0
d_sum_all += d_all
out_metrics[f"accuracy_{thr}"] = d_all
d_occ_avg = d_occ_sum / len(thrs)
d_vis_avg = d_vis_sum / len(thrs)
d_all_avg = d_sum_all / len(thrs)
sur_thr = 50
dists = torch.norm(
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
dim=-1,
) # B,S,N
dist_ok = 1 - (dists > sur_thr).float() * sample.visibility # B,S,N
survival = torch.cumprod(dist_ok, dim=1) # B,S,N
out_metrics["survival"] = torch.mean(survival).item() * 100.0
out_metrics["accuracy_occ"] = d_occ_avg
out_metrics["accuracy_vis"] = d_vis_avg
out_metrics["accuracy"] = d_all_avg
metrics[sample.seq_name[0]] = out_metrics
for metric_name in out_metrics.keys():
if "avg" not in metrics:
metrics["avg"] = {}
metrics["avg"][metric_name] = float(
np.mean([v[metric_name] for k, v in metrics.items() if k != "avg"])
)
logging.info(f"Metrics: {out_metrics}")
logging.info(f"avg: {metrics['avg']}")
print("metrics", out_metrics)
print("avg", metrics["avg"])
@torch.no_grad()
def evaluate_sequence(
@ -169,6 +166,7 @@ class Evaluator:
test_dataloader: torch.utils.data.DataLoader,
dataset_name: str,
train_mode=False,
visualize_every: int = 1,
writer: Optional[SummaryWriter] = None,
step: Optional[int] = 0,
):
@ -221,7 +219,6 @@ class Evaluator:
pred_tracks = model(sample.video, queries)
if "strided" in dataset_name:
inv_video = sample.video.flip(1).clone()
inv_queries = queries.clone()
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
@ -243,14 +240,14 @@ class Evaluator:
seq_name = sample.seq_name[0]
else:
seq_name = str(ind)
vis.visualize(
sample.video,
pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
filename=dataset_name + "_" + seq_name,
writer=writer,
step=step,
)
if ind % visualize_every == 0:
vis.visualize(
sample.video,
pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
filename=dataset_name + "_" + seq_name,
writer=writer,
step=step,
)
self.compute_metrics(metrics, sample, pred_tracks, dataset_name)
return metrics

View File

@ -14,9 +14,8 @@ import numpy as np
import torch
from omegaconf import OmegaConf
from cotracker.datasets.badja_dataset import BadjaDataset
from cotracker.datasets.fast_capture_dataset import FastCaptureDataset
from cotracker.datasets.tap_vid_datasets import TapVidDataset
from cotracker.datasets.dr_dataset import DynamicReplicaDataset
from cotracker.datasets.utils import collate_fn
from cotracker.models.evaluation_predictor import EvaluationPredictor
@ -33,23 +32,20 @@ class DefaultConfig:
exp_dir: str = "./outputs"
# Name of the dataset to be used for the evaluation.
dataset_name: str = "badja"
dataset_name: str = "tapvid_davis_first"
# The root directory of the dataset.
dataset_root: str = "./"
# Path to the pre-trained model checkpoint to be used for the evaluation.
# The default value is the path to a specific CoTracker model checkpoint.
# Other available options are commented.
checkpoint: str = "./checkpoints/cotracker_stride_4_wind_8.pth"
# cotracker_stride_4_wind_12
# cotracker_stride_8_wind_16
checkpoint: str = "./checkpoints/cotracker2.pth"
# EvaluationPredictor parameters
# The size (N) of the support grid used in the predictor.
# The total number of points is (N*N).
grid_size: int = 6
grid_size: int = 5
# The size (N) of the local support grid.
local_grid_size: int = 6
local_grid_size: int = 8
# A flag indicating whether to evaluate one ground truth point at a time.
single_point: bool = True
# The number of iterative updates for each sliding window.
@ -111,18 +107,10 @@ def run_eval(cfg: DefaultConfig):
# Constructing the specified dataset
curr_collate_fn = collate_fn
if cfg.dataset_name == "badja":
test_dataset = BadjaDataset(data_root=os.path.join(cfg.dataset_root, "BADJA"))
elif cfg.dataset_name == "fastcapture":
test_dataset = FastCaptureDataset(
data_root=os.path.join(cfg.dataset_root, "fastcapture"),
max_seq_len=100,
max_num_points=20,
)
elif "tapvid" in cfg.dataset_name:
if "tapvid" in cfg.dataset_name:
dataset_type = cfg.dataset_name.split("_")[1]
if dataset_type == "davis":
data_root = os.path.join(cfg.dataset_root, "/tapvid_davis/tapvid_davis.pkl")
data_root = os.path.join(cfg.dataset_root, "tapvid_davis", "tapvid_davis.pkl")
elif dataset_type == "kinetics":
data_root = os.path.join(
cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics"
@ -132,6 +120,8 @@ def run_eval(cfg: DefaultConfig):
data_root=data_root,
queried_first=not "strided" in cfg.dataset_name,
)
elif cfg.dataset_name == "dynamic_replica":
test_dataset = DynamicReplicaDataset(sample_len=300, only_first_n_samples=1)
# Creating the DataLoader object
test_dataloader = torch.utils.data.DataLoader(
@ -155,10 +145,8 @@ def run_eval(cfg: DefaultConfig):
print(end - start)
# Saving the evaluation results to a .json file
if not "tapvid" in cfg.dataset_name:
print("evaluate_result", evaluate_result)
else:
evaluate_result = evaluate_result["avg"]
evaluate_result = evaluate_result["avg"]
print("evaluate_result", evaluate_result)
result_file = os.path.join(cfg.exp_dir, f"result_eval_.json")
evaluate_result["time"] = end - start
print(f"Dumping eval results to {result_file}.")

View File

@ -6,63 +6,24 @@
import torch
from cotracker.models.core.cotracker.cotracker import CoTracker
from cotracker.models.core.cotracker.cotracker import CoTracker2
def build_cotracker(
checkpoint: str,
):
if checkpoint is None:
return build_cotracker_stride_4_wind_8()
return build_cotracker()
model_name = checkpoint.split("/")[-1].split(".")[0]
if model_name == "cotracker_stride_4_wind_8":
return build_cotracker_stride_4_wind_8(checkpoint=checkpoint)
elif model_name == "cotracker_stride_4_wind_12":
return build_cotracker_stride_4_wind_12(checkpoint=checkpoint)
elif model_name == "cotracker_stride_8_wind_16":
return build_cotracker_stride_8_wind_16(checkpoint=checkpoint)
if model_name == "cotracker":
return build_cotracker(checkpoint=checkpoint)
else:
raise ValueError(f"Unknown model name {model_name}")
# model used to produce the results in the paper
def build_cotracker_stride_4_wind_8(checkpoint=None):
return _build_cotracker(
stride=4,
sequence_len=8,
checkpoint=checkpoint,
)
def build_cotracker(checkpoint=None):
cotracker = CoTracker2(stride=4, window_len=8, add_space_attn=True)
def build_cotracker_stride_4_wind_12(checkpoint=None):
return _build_cotracker(
stride=4,
sequence_len=12,
checkpoint=checkpoint,
)
# the fastest model
def build_cotracker_stride_8_wind_16(checkpoint=None):
return _build_cotracker(
stride=8,
sequence_len=16,
checkpoint=checkpoint,
)
def _build_cotracker(
stride,
sequence_len,
checkpoint=None,
):
cotracker = CoTracker(
stride=stride,
S=sequence_len,
add_space_attn=True,
space_depth=6,
time_depth=6,
)
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f, map_location="cpu")

View File

@ -7,9 +7,71 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from typing import Callable
import collections
from torch import Tensor
from itertools import repeat
from einops import rearrange
from timm.models.vision_transformer import Attention, Mlp
from cotracker.models.core.model_utils import bilinear_sampler
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
to_2tuple = _ntuple(2)
class Mlp(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.0,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class ResidualBlock(nn.Module):
@ -24,9 +86,7 @@ class ResidualBlock(nn.Module):
stride=stride,
padding_mode="zeros",
)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros")
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
@ -75,28 +135,14 @@ class ResidualBlock(nn.Module):
class BasicEncoder(nn.Module):
def __init__(
self, input_dim=3, output_dim=128, stride=8, norm_fn="batch", dropout=0.0
):
def __init__(self, input_dim=3, output_dim=128, stride=4):
super(BasicEncoder, self).__init__()
self.stride = stride
self.norm_fn = norm_fn
self.in_planes = 64
self.norm_fn = "instance"
self.in_planes = output_dim // 2
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes)
self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2)
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(self.in_planes)
self.norm2 = nn.BatchNorm2d(output_dim * 2)
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(self.in_planes)
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm1 = nn.InstanceNorm2d(self.in_planes)
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
self.conv1 = nn.Conv2d(
input_dim,
@ -107,37 +153,24 @@ class BasicEncoder(nn.Module):
padding_mode="zeros",
)
self.relu1 = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(output_dim // 2, stride=1)
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
self.layer3 = self._make_layer(output_dim, stride=2)
self.layer4 = self._make_layer(output_dim, stride=2)
self.shallow = False
if self.shallow:
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=2)
self.conv2 = nn.Conv2d(128 + 96 + 64, output_dim, kernel_size=1)
else:
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=2)
self.layer4 = self._make_layer(128, stride=2)
self.conv2 = nn.Conv2d(
128 + 128 + 96 + 64,
output_dim * 2,
kernel_size=3,
padding=1,
padding_mode="zeros",
)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
self.conv2 = nn.Conv2d(
output_dim * 3 + output_dim // 4,
output_dim * 2,
kernel_size=3,
padding=1,
padding_mode="zeros",
)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
elif isinstance(m, (nn.InstanceNorm2d)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
@ -158,122 +191,47 @@ class BasicEncoder(nn.Module):
x = self.norm1(x)
x = self.relu1(x)
if self.shallow:
a = self.layer1(x)
b = self.layer2(a)
c = self.layer3(b)
a = F.interpolate(
a,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
b = F.interpolate(
b,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
c = F.interpolate(
c,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
x = self.conv2(torch.cat([a, b, c], dim=1))
else:
a = self.layer1(x)
b = self.layer2(a)
c = self.layer3(b)
d = self.layer4(c)
a = F.interpolate(
a,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
b = F.interpolate(
b,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
c = F.interpolate(
c,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
d = F.interpolate(
d,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
x = self.conv2(torch.cat([a, b, c, d], dim=1))
x = self.norm2(x)
x = self.relu2(x)
x = self.conv3(x)
a = self.layer1(x)
b = self.layer2(a)
c = self.layer3(b)
d = self.layer4(c)
if self.training and self.dropout is not None:
x = self.dropout(x)
def _bilinear_intepolate(x):
return F.interpolate(
x,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
a = _bilinear_intepolate(a)
b = _bilinear_intepolate(b)
c = _bilinear_intepolate(c)
d = _bilinear_intepolate(d)
x = self.conv2(torch.cat([a, b, c, d], dim=1))
x = self.norm2(x)
x = self.relu2(x)
x = self.conv3(x)
return x
class AttnBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = Attention(
hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=approx_gelu,
drop=0,
)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
"""Wrapper for grid_sample, uses pixel coordinates"""
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1, 1], dim=-1)
# go to 0,1 then 0,2 then -1,1
xgrid = 2 * xgrid / (W - 1) - 1
ygrid = 2 * ygrid / (H - 1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
class CorrBlock:
def __init__(self, fmaps, num_levels=4, radius=4):
def __init__(
self,
fmaps,
num_levels=4,
radius=4,
multiple_track_feats=False,
padding_mode="zeros",
):
B, S, C, H, W = fmaps.shape
self.S, self.C, self.H, self.W = S, C, H, W
self.padding_mode = padding_mode
self.num_levels = num_levels
self.radius = radius
self.fmaps_pyramid = []
self.multiple_track_feats = multiple_track_feats
self.fmaps_pyramid.append(fmaps)
for i in range(self.num_levels - 1):
@ -292,109 +250,118 @@ class CorrBlock:
out_pyramid = []
for i in range(self.num_levels):
corrs = self.corrs_pyramid[i] # B, S, N, H, W
_, _, _, H, W = corrs.shape
*_, H, W = corrs.shape
dx = torch.linspace(-r, r, 2 * r + 1)
dy = torch.linspace(-r, r, 2 * r + 1)
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
coords.device
)
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords_lvl = centroid_lvl + delta_lvl
corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl)
corrs = bilinear_sampler(
corrs.reshape(B * S * N, 1, H, W),
coords_lvl,
padding_mode=self.padding_mode,
)
corrs = corrs.view(B, S, N, -1)
out_pyramid.append(corrs)
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
return out.contiguous().float()
out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
return out
def corr(self, targets):
B, S, N, C = targets.shape
if self.multiple_track_feats:
targets_split = targets.split(C // self.num_levels, dim=-1)
B, S, N, C = targets_split[0].shape
assert C == self.C
assert S == self.S
fmap1 = targets
self.corrs_pyramid = []
for fmaps in self.fmaps_pyramid:
_, _, _, H, W = fmaps.shape
fmap2s = fmaps.view(B, S, C, H * W)
for i, fmaps in enumerate(self.fmaps_pyramid):
*_, H, W = fmaps.shape
fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
if self.multiple_track_feats:
fmap1 = targets_split[i]
corrs = torch.matmul(fmap1, fmap2s)
corrs = corrs.view(B, S, N, H, W)
corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
corrs = corrs / torch.sqrt(torch.tensor(C).float())
self.corrs_pyramid.append(corrs)
class UpdateFormer(nn.Module):
"""
Transformer model that updates track estimates.
"""
class Attention(nn.Module):
def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False):
super().__init__()
inner_dim = dim_head * num_heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = num_heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
self.to_out = nn.Linear(inner_dim, query_dim)
def forward(self, x, context=None, attn_bias=None):
B, N1, C = x.shape
h = self.heads
q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
context = default(context, x)
k, v = self.to_kv(context).chunk(2, dim=-1)
N2 = context.shape[1]
k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
sim = (q @ k.transpose(-2, -1)) * self.scale
if attn_bias is not None:
sim = sim + attn_bias
attn = sim.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
return self.to_out(x)
class AttnBlock(nn.Module):
def __init__(
self,
space_depth=12,
time_depth=12,
input_dim=320,
hidden_size=384,
num_heads=8,
output_dim=130,
hidden_size,
num_heads,
attn_class: Callable[..., nn.Module] = Attention,
mlp_ratio=4.0,
add_space_attn=True,
**block_kwargs
):
super().__init__()
self.out_channels = 2
self.num_heads = num_heads
self.hidden_size = hidden_size
self.add_space_attn = add_space_attn
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.time_blocks = nn.ModuleList(
[
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
for _ in range(time_depth)
]
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=approx_gelu,
drop=0,
)
if add_space_attn:
self.space_blocks = nn.ModuleList(
[
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
for _ in range(space_depth)
]
def forward(self, x, mask=None):
attn_bias = mask
if mask is not None:
mask = (
(mask[:, None] * mask[:, :, None])
.unsqueeze(1)
.expand(-1, self.attn.num_heads, -1, -1)
)
assert len(self.time_blocks) >= len(self.space_blocks)
self.initialize_weights()
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def forward(self, input_tensor):
x = self.input_transform(input_tensor)
j = 0
for i in range(len(self.time_blocks)):
B, N, T, _ = x.shape
x_time = rearrange(x, "b n t c -> (b n) t c", b=B, t=T, n=N)
x_time = self.time_blocks[i](x_time)
x = rearrange(x_time, "(b n) t c -> b n t c ", b=B, t=T, n=N)
if self.add_space_attn and (
i % (len(self.time_blocks) // len(self.space_blocks)) == 0
):
x_space = rearrange(x, "b n t c -> (b t) n c ", b=B, t=T, n=N)
x_space = self.space_blocks[j](x_space)
x = rearrange(x_space, "(b t) n c -> b n t c ", b=B, t=T, n=N)
j += 1
flow = self.flow_head(x)
return flow
max_neg_value = -torch.finfo(x.dtype).max
attn_bias = (~mask) * max_neg_value
x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
x = x + self.mlp(self.norm2(x))
return x

View File

@ -6,102 +6,74 @@
import torch
import torch.nn as nn
from einops import rearrange
import torch.nn.functional as F
from cotracker.models.core.cotracker.blocks import (
BasicEncoder,
CorrBlock,
UpdateFormer,
)
from cotracker.models.core.model_utils import meshgrid2d, bilinear_sample2d, smart_cat
from cotracker.models.core.model_utils import sample_features4d, sample_features5d
from cotracker.models.core.embeddings import (
get_2d_embedding,
get_1d_sincos_pos_embed_from_grid,
get_2d_sincos_pos_embed,
)
from cotracker.models.core.cotracker.blocks import (
Mlp,
BasicEncoder,
AttnBlock,
CorrBlock,
Attention,
)
torch.manual_seed(0)
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device="cpu"):
if grid_size == 1:
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2], device=device)[
None, None
]
grid_y, grid_x = meshgrid2d(
1, grid_size, grid_size, stack=False, norm=False, device=device
)
step = interp_shape[1] // 64
if grid_center[0] != 0 or grid_center[1] != 0:
grid_y = grid_y - grid_size / 2.0
grid_x = grid_x - grid_size / 2.0
grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * (
interp_shape[0] - step * 2
)
grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * (
interp_shape[1] - step * 2
)
grid_y = grid_y + grid_center[0]
grid_x = grid_x + grid_center[1]
xy = torch.stack([grid_x, grid_y], dim=-1).to(device)
return xy
def sample_pos_embed(grid_size, embed_dim, coords):
pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim, grid_size=grid_size)
pos_embed = (
torch.from_numpy(pos_embed)
.reshape(grid_size[0], grid_size[1], embed_dim)
.float()
.unsqueeze(0)
.to(coords.device)
)
sampled_pos_embed = bilinear_sample2d(
pos_embed.permute(0, 3, 1, 2), coords[:, 0, :, 0], coords[:, 0, :, 1]
)
return sampled_pos_embed
class CoTracker(nn.Module):
class CoTracker2(nn.Module):
def __init__(
self,
S=8,
stride=8,
window_len=8,
stride=4,
add_space_attn=True,
num_heads=8,
hidden_size=384,
space_depth=12,
time_depth=12,
num_virtual_tracks=64,
model_resolution=(384, 512),
):
super(CoTracker, self).__init__()
self.S = S
super(CoTracker2, self).__init__()
self.window_len = window_len
self.stride = stride
self.hidden_dim = 256
self.latent_dim = latent_dim = 128
self.corr_levels = 4
self.corr_radius = 3
self.latent_dim = 128
self.add_space_attn = add_space_attn
self.fnet = BasicEncoder(
output_dim=self.latent_dim, norm_fn="instance", dropout=0, stride=stride
)
self.updateformer = UpdateFormer(
space_depth=space_depth,
time_depth=time_depth,
input_dim=456,
hidden_size=hidden_size,
num_heads=num_heads,
output_dim=latent_dim + 2,
self.fnet = BasicEncoder(output_dim=self.latent_dim)
self.num_virtual_tracks = num_virtual_tracks
self.model_resolution = model_resolution
self.input_dim = 456
self.updateformer = EfficientUpdateFormer(
space_depth=6,
time_depth=6,
input_dim=self.input_dim,
hidden_size=384,
output_dim=self.latent_dim + 2,
mlp_ratio=4.0,
add_space_attn=add_space_attn,
num_virtual_tracks=num_virtual_tracks,
)
time_grid = torch.linspace(0, window_len - 1, window_len).reshape(1, window_len, 1)
self.register_buffer(
"time_emb", get_1d_sincos_pos_embed_from_grid(self.input_dim, time_grid[0])
)
self.register_buffer(
"pos_emb",
get_2d_sincos_pos_embed(
embed_dim=self.input_dim,
grid_size=(
model_resolution[0] // stride,
model_resolution[1] // stride,
),
),
)
self.norm = nn.GroupNorm(1, self.latent_dim)
self.ffeat_updater = nn.Sequential(
self.track_feat_updater = nn.Sequential(
nn.Linear(self.latent_dim, self.latent_dim),
nn.GELU(),
)
@ -109,243 +81,423 @@ class CoTracker(nn.Module):
nn.Linear(self.latent_dim, 1),
)
def forward_iteration(
def forward_window(
self,
fmaps,
coords_init,
feat_init=None,
vis_init=None,
coords,
track_feat=None,
vis=None,
track_mask=None,
attention_mask=None,
iters=4,
):
B, S_init, N, D = coords_init.shape
assert D == 2
assert B == 1
# B = batch size
# S = number of frames in the window)
# N = number of tracks
# C = channels of a point feature vector
# E = positional embedding size
# LRR = local receptive field radius
# D = dimension of the transformer input tokens
B, S, __, H8, W8 = fmaps.shape
# track_feat = B S N C
# vis = B S N 1
# track_mask = B S N 1
# attention_mask = B S N
device = fmaps.device
B, S_init, N, __ = track_mask.shape
B, S, *_ = fmaps.shape
if S_init < S:
coords = torch.cat(
[coords_init, coords_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
)
vis_init = torch.cat(
[vis_init, vis_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
)
else:
coords = coords_init.clone()
fcorr_fn = CorrBlock(
fmaps, num_levels=self.corr_levels, radius=self.corr_radius
track_mask = F.pad(track_mask, (0, 0, 0, 0, 0, S - S_init), "constant")
track_mask_vis = (
torch.cat([track_mask, vis], dim=-1).permute(0, 2, 1, 3).reshape(B * N, S, 2)
)
ffeats = feat_init.clone()
times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1)
pos_embed = sample_pos_embed(
grid_size=(H8, W8),
embed_dim=456,
coords=coords,
corr_block = CorrBlock(
fmaps,
num_levels=4,
radius=3,
padding_mode="border",
)
pos_embed = rearrange(pos_embed, "b e n -> (b n) e").unsqueeze(1)
times_embed = (
torch.from_numpy(get_1d_sincos_pos_embed_from_grid(456, times_[0]))[None]
.repeat(B, 1, 1)
.float()
.to(device)
)
coord_predictions = []
sampled_pos_emb = (
sample_features4d(self.pos_emb.repeat(B, 1, 1, 1), coords[:, 0])
.reshape(B * N, self.input_dim)
.unsqueeze(1)
) # B E N -> (B N) 1 E
coord_preds = []
for __ in range(iters):
coords = coords.detach()
fcorr_fn.corr(ffeats)
coords = coords.detach() # B S N 2
corr_block.corr(track_feat)
fcorrs = fcorr_fn.sample(coords) # B, S, N, LRR
LRR = fcorrs.shape[3]
# Sample correlation features around each point
fcorrs = corr_block.sample(coords) # (B N) S LRR
fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR)
flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
# Get the flow embeddings
flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
flow_emb = get_2d_embedding(flows, 64, cat_coords=True) # N S E
flows_cat = get_2d_embedding(flows_, 64, cat_coords=True)
ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
track_feat_ = track_feat.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
if track_mask.shape[1] < vis_init.shape[1]:
track_mask = torch.cat(
[
track_mask,
torch.zeros_like(track_mask[:, 0]).repeat(
1, vis_init.shape[1] - track_mask.shape[1], 1, 1
),
],
dim=1,
)
concat = (
torch.cat([track_mask, vis_init], dim=2)
.permute(0, 2, 1, 3)
.reshape(B * N, S, 2)
transformer_input = torch.cat([flow_emb, fcorrs, track_feat_, track_mask_vis], dim=2)
x = transformer_input + sampled_pos_emb + self.time_emb
x = x.view(B, N, S, -1) # (B N) S D -> B N S D
delta = self.updateformer(
x,
attention_mask.reshape(B * S, N), # B S N -> (B S) N
)
transformer_input = torch.cat([flows_cat, fcorrs_, ffeats_, concat], dim=2)
x = transformer_input + pos_embed + times_embed
delta_coords = delta[..., :2].permute(0, 2, 1, 3)
coords = coords + delta_coords
coord_preds.append(coords * self.stride)
x = rearrange(x, "(b n) t d -> b n t d", b=B)
delta = self.updateformer(x)
delta = rearrange(delta, " b n t d -> (b n) t d")
delta_coords_ = delta[:, :, :2]
delta_feats_ = delta[:, :, 2:]
delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N * S, self.latent_dim)
ffeats_ = self.ffeat_updater(self.norm(delta_feats_)) + ffeats_
ffeats = ffeats_.reshape(B, N, S, self.latent_dim).permute(
delta_feats_ = delta[..., 2:].reshape(B * N * S, self.latent_dim)
track_feat_ = track_feat.permute(0, 2, 1, 3).reshape(B * N * S, self.latent_dim)
track_feat_ = self.track_feat_updater(self.norm(delta_feats_)) + track_feat_
track_feat = track_feat_.reshape(B, N, S, self.latent_dim).permute(
0, 2, 1, 3
) # B,S,N,C
) # (B N S) C -> B S N C
coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
coord_predictions.append(coords * self.stride)
vis_pred = self.vis_predictor(track_feat).reshape(B, S, N)
return coord_preds, vis_pred
vis_e = self.vis_predictor(ffeats.reshape(B * S * N, self.latent_dim)).reshape(
B, S, N
def get_track_feat(self, fmaps, queried_frames, queried_coords):
sample_frames = queried_frames[:, None, :, None]
sample_coords = torch.cat(
[
sample_frames,
queried_coords[:, None],
],
dim=-1,
)
return coord_predictions, vis_e, feat_init
sample_track_feats = sample_features5d(fmaps, sample_coords)
return sample_track_feats
def forward(self, rgbs, queries, iters=4, feat_init=None, is_train=False):
B, T, C, H, W = rgbs.shape
def init_video_online_processing(self):
self.online_ind = 0
self.online_track_feat = None
self.online_coords_predicted = None
self.online_vis_predicted = None
def forward(self, video, queries, iters=4, is_train=False, is_online=False):
"""Predict tracks
Args:
video (FloatTensor[B, T, 3]): input videos.
queries (FloatTensor[B, N, 3]): point queries.
iters (int, optional): number of updates. Defaults to 4.
is_train (bool, optional): enables training mode. Defaults to False.
is_online (bool, optional): enables online mode. Defaults to False. Before enabling, call model.init_video_online_processing().
Returns:
- coords_predicted (FloatTensor[B, T, N, 2]):
- vis_predicted (FloatTensor[B, T, N]):
- train_data: `None` if `is_train` is false, otherwise:
- all_vis_predictions (List[FloatTensor[B, S, N, 1]]):
- all_coords_predictions (List[FloatTensor[B, S, N, 2]]):
- mask (BoolTensor[B, T, N]):
"""
B, T, C, H, W = video.shape
B, N, __ = queries.shape
S = self.window_len
device = queries.device
device = rgbs.device
assert B == 1
# INIT for the first sequence
# We want to sort points by the first frame they are visible to add them to the tensor of tracked points consequtively
first_positive_inds = queries[:, :, 0].long()
# B = batch size
# S = number of frames in the window of the padded video
# S_trimmed = actual number of frames in the window
# N = number of tracks
# C = color channels (3 for RGB)
# E = positional embedding size
# LRR = local receptive field radius
# D = dimension of the transformer input tokens
__, sort_inds = torch.sort(first_positive_inds[0], dim=0, descending=False)
inv_sort_inds = torch.argsort(sort_inds, dim=0)
first_positive_sorted_inds = first_positive_inds[0][sort_inds]
# video = B T C H W
# queries = B N 3
# coords_init = B S N 2
# vis_init = B S N 1
assert torch.allclose(
first_positive_inds[0], first_positive_inds[0][sort_inds][inv_sort_inds]
assert S >= 2 # A tracker needs at least two frames to track something
if is_online:
assert T <= S, "Online mode: video chunk must be <= window size."
assert self.online_ind is not None, "Call model.init_video_online_processing() first."
assert not is_train, "Training not supported in online mode."
step = S // 2 # How much the sliding window moves at every step
video = 2 * (video / 255.0) - 1.0
# The first channel is the frame number
# The rest are the coordinates of points we want to track
queried_frames = queries[:, :, 0].long()
queried_coords = queries[..., 1:]
queried_coords = queried_coords / self.stride
# We store our predictions here
coords_predicted = torch.zeros((B, T, N, 2), device=device)
vis_predicted = torch.zeros((B, T, N), device=device)
if is_online:
if self.online_coords_predicted is None:
# Init online predictions with zeros
self.online_coords_predicted = coords_predicted
self.online_vis_predicted = vis_predicted
else:
# Pad online predictions with zeros for the current window
pad = min(step, T - step)
coords_predicted = F.pad(
self.online_coords_predicted, (0, 0, 0, 0, 0, pad), "constant"
)
vis_predicted = F.pad(self.online_vis_predicted, (0, 0, 0, pad), "constant")
all_coords_predictions, all_vis_predictions = [], []
# Pad the video so that an integer number of sliding windows fit into it
# TODO: we may drop this requirement because the transformer should not care
# TODO: pad the features instead of the video
pad = S - T if is_online else (S - T % S) % S # We don't want to pad if T % S == 0
video = F.pad(video.reshape(B, 1, T, C * H * W), (0, 0, 0, pad), "replicate").reshape(
B, -1, C, H, W
)
coords_init = queries[:, :, 1:].reshape(B, 1, N, 2).repeat(
1, self.S, 1, 1
) / float(self.stride)
# Compute convolutional features for the video or for the current chunk in case of online mode
fmaps = self.fnet(video.reshape(-1, C, H, W)).reshape(
B, -1, self.latent_dim, H // self.stride, W // self.stride
)
rgbs = 2 * (rgbs / 255.0) - 1.0
# We compute track features
track_feat = self.get_track_feat(
fmaps,
queried_frames - self.online_ind if is_online else queried_frames,
queried_coords,
).repeat(1, S, 1, 1)
if is_online:
# We update track features for the current window
sample_frames = queried_frames[:, None, :, None] # B 1 N 1
left = 0 if self.online_ind == 0 else self.online_ind + step
right = self.online_ind + S
sample_mask = (sample_frames >= left) & (sample_frames < right)
if self.online_track_feat is None:
self.online_track_feat = torch.zeros_like(track_feat, device=device)
self.online_track_feat += track_feat * sample_mask
track_feat = self.online_track_feat.clone()
# We process ((num_windows - 1) * step + S) frames in total, so there are
# (ceil((T - S) / step) + 1) windows
num_windows = (T - S + step - 1) // step + 1
# We process only the current video chunk in the online mode
indices = [self.online_ind] if is_online else range(0, step * num_windows, step)
traj_e = torch.zeros((B, T, N, 2), device=device)
vis_e = torch.zeros((B, T, N), device=device)
ind_array = torch.arange(T, device=device)
ind_array = ind_array[None, :, None].repeat(B, 1, N)
track_mask = (ind_array >= first_positive_inds[:, None, :]).unsqueeze(-1)
# these are logits, so we initialize visibility with something that would give a value close to 1 after softmax
vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10
ind = 0
track_mask_ = track_mask[:, :, sort_inds].clone()
coords_init_ = coords_init[:, :, sort_inds].clone()
vis_init_ = vis_init[:, :, sort_inds].clone()
prev_wind_idx = 0
fmaps_ = None
vis_predictions = []
coord_predictions = []
wind_inds = []
while ind < T - self.S // 2:
rgbs_seq = rgbs[:, ind : ind + self.S]
S = S_local = rgbs_seq.shape[1]
if S < self.S:
rgbs_seq = torch.cat(
[rgbs_seq, rgbs_seq[:, -1, None].repeat(1, self.S - S, 1, 1, 1)],
dim=1,
coords_init = queried_coords.reshape(B, 1, N, 2).expand(B, S, N, 2).float()
vis_init = torch.ones((B, S, N, 1), device=device).float() * 10
for ind in indices:
# We copy over coords and vis for tracks that are queried
# by the end of the previous window, which is ind + overlap
if ind > 0:
overlap = S - step
copy_over = (queried_frames < ind + overlap)[:, None, :, None] # B 1 N 1
coords_prev = torch.nn.functional.pad(
coords_predicted[:, ind : ind + overlap] / self.stride,
(0, 0, 0, 0, 0, step),
"replicate",
) # B S N 2
vis_prev = torch.nn.functional.pad(
vis_predicted[:, ind : ind + overlap, :, None].clone(),
(0, 0, 0, 0, 0, step),
"replicate",
) # B S N 1
coords_init = torch.where(
copy_over.expand_as(coords_init), coords_prev, coords_init
)
S = rgbs_seq.shape[1]
rgbs_ = rgbs_seq.reshape(B * S, C, H, W)
vis_init = torch.where(copy_over.expand_as(vis_init), vis_prev, vis_init)
if fmaps_ is None:
fmaps_ = self.fnet(rgbs_)
else:
fmaps_ = torch.cat(
[fmaps_[self.S // 2 :], self.fnet(rgbs_[self.S // 2 :])], dim=0
)
fmaps = fmaps_.reshape(
B, S, self.latent_dim, H // self.stride, W // self.stride
)
# The attention mask is 1 for the spatio-temporal points within
# a track which is updated in the current window
attention_mask = (queried_frames < ind + S).reshape(B, 1, N).repeat(1, S, 1) # B S N
curr_wind_points = torch.nonzero(first_positive_sorted_inds < ind + self.S)
if curr_wind_points.shape[0] == 0:
ind = ind + self.S // 2
continue
wind_idx = curr_wind_points[-1] + 1
# The track mask is 1 for the spatio-temporal points that actually
# need updating: only after begin queried, and not if contained
# in a previous window
track_mask = (
queried_frames[:, None, :, None]
<= torch.arange(ind, ind + S, device=device)[None, :, None, None]
).contiguous() # B S N 1
if wind_idx - prev_wind_idx > 0:
fmaps_sample = fmaps[
:, first_positive_sorted_inds[prev_wind_idx:wind_idx] - ind
]
if ind > 0:
track_mask[:, :overlap, :, :] = False
feat_init_ = bilinear_sample2d(
fmaps_sample,
coords_init_[:, 0, prev_wind_idx:wind_idx, 0],
coords_init_[:, 0, prev_wind_idx:wind_idx, 1],
).permute(0, 2, 1)
feat_init_ = feat_init_.unsqueeze(1).repeat(1, self.S, 1, 1)
feat_init = smart_cat(feat_init, feat_init_, dim=2)
if prev_wind_idx > 0:
new_coords = coords[-1][:, self.S // 2 :] / float(self.stride)
coords_init_[:, : self.S // 2, :prev_wind_idx] = new_coords
coords_init_[:, self.S // 2 :, :prev_wind_idx] = new_coords[
:, -1
].repeat(1, self.S // 2, 1, 1)
new_vis = vis[:, self.S // 2 :].unsqueeze(-1)
vis_init_[:, : self.S // 2, :prev_wind_idx] = new_vis
vis_init_[:, self.S // 2 :, :prev_wind_idx] = new_vis[:, -1].repeat(
1, self.S // 2, 1, 1
)
coords, vis, __ = self.forward_iteration(
fmaps=fmaps,
coords_init=coords_init_[:, :, :wind_idx],
feat_init=feat_init[:, :, :wind_idx],
vis_init=vis_init_[:, :, :wind_idx],
track_mask=track_mask_[:, ind : ind + self.S, :wind_idx],
# Predict the coordinates and visibility for the current window
coords, vis = self.forward_window(
fmaps=fmaps if is_online else fmaps[:, ind : ind + S],
coords=coords_init,
track_feat=attention_mask.unsqueeze(-1) * track_feat,
vis=vis_init,
track_mask=track_mask,
attention_mask=attention_mask,
iters=iters,
)
S_trimmed = T if is_online else min(T - ind, S) # accounts for last window duration
coords_predicted[:, ind : ind + S] = coords[-1][:, :S_trimmed]
vis_predicted[:, ind : ind + S] = vis[:, :S_trimmed]
if is_train:
vis_predictions.append(torch.sigmoid(vis[:, :S_local]))
coord_predictions.append([coord[:, :S_local] for coord in coords])
wind_inds.append(wind_idx)
all_coords_predictions.append([coord[:, :S_trimmed] for coord in coords])
all_vis_predictions.append(torch.sigmoid(vis[:, :S_trimmed]))
traj_e[:, ind : ind + self.S, :wind_idx] = coords[-1][:, :S_local]
vis_e[:, ind : ind + self.S, :wind_idx] = vis[:, :S_local]
if is_online:
self.online_ind += step
self.online_coords_predicted = coords_predicted
self.online_vis_predicted = vis_predicted
vis_predicted = torch.sigmoid(vis_predicted)
track_mask_[:, : ind + self.S, :wind_idx] = 0.0
ind = ind + self.S // 2
if is_train:
mask = queried_frames[:, None] <= torch.arange(0, T, device=device)[None, :, None]
train_data = (all_coords_predictions, all_vis_predictions, mask)
else:
train_data = None
prev_wind_idx = wind_idx
return coords_predicted, vis_predicted, train_data
traj_e = traj_e[:, :, inv_sort_inds]
vis_e = vis_e[:, :, inv_sort_inds]
vis_e = torch.sigmoid(vis_e)
class EfficientUpdateFormer(nn.Module):
"""
Transformer model that updates track estimates.
"""
train_data = (
(vis_predictions, coord_predictions, wind_inds, sort_inds)
if is_train
else None
def __init__(
self,
space_depth=6,
time_depth=6,
input_dim=320,
hidden_size=384,
num_heads=8,
output_dim=130,
mlp_ratio=4.0,
add_space_attn=True,
num_virtual_tracks=64,
):
super().__init__()
self.out_channels = 2
self.num_heads = num_heads
self.hidden_size = hidden_size
self.add_space_attn = add_space_attn
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
self.num_virtual_tracks = num_virtual_tracks
self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
self.time_blocks = nn.ModuleList(
[
AttnBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
attn_class=Attention,
)
for _ in range(time_depth)
]
)
return traj_e, feat_init, vis_e, train_data
if add_space_attn:
self.space_virtual_blocks = nn.ModuleList(
[
AttnBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
attn_class=Attention,
)
for _ in range(space_depth)
]
)
self.space_point2virtual_blocks = nn.ModuleList(
[
CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio)
for _ in range(space_depth)
]
)
self.space_virtual2point_blocks = nn.ModuleList(
[
CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio)
for _ in range(space_depth)
]
)
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
self.initialize_weights()
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def forward(self, input_tensor, mask=None):
tokens = self.input_transform(input_tensor)
B, _, T, _ = tokens.shape
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
tokens = torch.cat([tokens, virtual_tokens], dim=1)
_, N, _, _ = tokens.shape
j = 0
for i in range(len(self.time_blocks)):
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
time_tokens = self.time_blocks[i](time_tokens)
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
if self.add_space_attn and (
i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0
):
space_tokens = (
tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
) # B N T C -> (B T) N C
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
virtual_tokens = self.space_virtual2point_blocks[j](
virtual_tokens, point_tokens, mask=mask
)
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
point_tokens = self.space_point2virtual_blocks[j](
point_tokens, virtual_tokens, mask=mask
)
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
j += 1
tokens = tokens[:, : N - self.num_virtual_tracks]
flow = self.flow_head(tokens)
return flow
class CrossAttnBlock(nn.Module):
def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm_context = nn.LayerNorm(hidden_size)
self.cross_attn = Attention(
hidden_size, context_dim=context_dim, num_heads=num_heads, qkv_bias=True, **block_kwargs
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=approx_gelu,
drop=0,
)
def forward(self, x, context, mask=None):
if mask is not None:
if mask.shape[1] == x.shape[1]:
mask = mask[:, None, :, None].expand(
-1, self.cross_attn.heads, -1, context.shape[1]
)
else:
mask = mask[:, None, None].expand(-1, self.cross_attn.heads, x.shape[1], -1)
max_neg_value = -torch.finfo(x.dtype).max
attn_bias = (~mask) * max_neg_value
x = x + self.cross_attn(
self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias
)
x = x + self.mlp(self.norm2(x))
return x

View File

@ -4,67 +4,98 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple, Union
import torch
import numpy as np
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
def get_2d_sincos_pos_embed(
embed_dim: int, grid_size: Union[int, Tuple[int, int]]
) -> torch.Tensor:
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
Args:
- embed_dim: The embedding dimension.
- grid_size: The grid size.
Returns:
- pos_embed: The generated 2D positional embedding.
"""
if isinstance(grid_size, tuple):
grid_size_h, grid_size_w = grid_size
else:
grid_size_h = grid_size_w = grid_size
grid_h = np.arange(grid_size_h, dtype=np.float32)
grid_w = np.arange(grid_size_w, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid_h = torch.arange(grid_size_h, dtype=torch.float)
grid_w = torch.arange(grid_size_w, dtype=torch.float)
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
grid = torch.stack(grid, dim=0)
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate(
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
)
return pos_embed
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
def get_2d_sincos_pos_embed_from_grid(
embed_dim: int, grid: torch.Tensor
) -> torch.Tensor:
"""
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
Args:
- embed_dim: The embedding dimension.
- grid: The grid to generate the embedding from.
Returns:
- emb: The generated 2D positional embedding.
"""
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
def get_1d_sincos_pos_embed_from_grid(
embed_dim: int, pos: torch.Tensor
) -> torch.Tensor:
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
Args:
- embed_dim: The embedding dimension.
- pos: The position to generate the embedding from.
Returns:
- emb: The generated 1D positional embedding.
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega = torch.arange(embed_dim // 2, dtype=torch.double)
omega /= embed_dim / 2.0
omega = 1.0 / 10000 ** omega # (D/2,)
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb_sin = torch.sin(out) # (M, D/2)
emb_cos = torch.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
return emb[None].float()
def get_2d_embedding(xy, C, cat_coords=True):
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
"""
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
Args:
- xy: The coordinates to generate the embedding from.
- C: The size of the embedding.
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
Returns:
- pe: The generated 2D positional embedding.
"""
B, N, D = xy.shape
assert D == 2
@ -83,72 +114,7 @@ def get_2d_embedding(xy, C, cat_coords=True):
pe_y[:, :, 0::2] = torch.sin(y * div_term)
pe_y[:, :, 1::2] = torch.cos(y * div_term)
pe = torch.cat([pe_x, pe_y], dim=2) # B, N, C*3
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
if cat_coords:
pe = torch.cat([xy, pe], dim=2) # B, N, C*3+3
return pe
def get_3d_embedding(xyz, C, cat_coords=True):
B, N, D = xyz.shape
assert D == 3
x = xyz[:, :, 0:1]
y = xyz[:, :, 1:2]
z = xyz[:, :, 2:3]
div_term = (
torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C)
).reshape(1, 1, int(C / 2))
pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
pe_x[:, :, 0::2] = torch.sin(x * div_term)
pe_x[:, :, 1::2] = torch.cos(x * div_term)
pe_y[:, :, 0::2] = torch.sin(y * div_term)
pe_y[:, :, 1::2] = torch.cos(y * div_term)
pe_z[:, :, 0::2] = torch.sin(z * div_term)
pe_z[:, :, 1::2] = torch.cos(z * div_term)
pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3
if cat_coords:
pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3
return pe
def get_4d_embedding(xyzw, C, cat_coords=True):
B, N, D = xyzw.shape
assert D == 4
x = xyzw[:, :, 0:1]
y = xyzw[:, :, 1:2]
z = xyzw[:, :, 2:3]
w = xyzw[:, :, 3:4]
div_term = (
torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C)
).reshape(1, 1, int(C / 2))
pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
pe_x[:, :, 0::2] = torch.sin(x * div_term)
pe_x[:, :, 1::2] = torch.cos(x * div_term)
pe_y[:, :, 0::2] = torch.sin(y * div_term)
pe_y[:, :, 1::2] = torch.cos(y * div_term)
pe_z[:, :, 0::2] = torch.sin(z * div_term)
pe_z[:, :, 1::2] = torch.cos(z * div_term)
pe_w[:, :, 0::2] = torch.sin(w * div_term)
pe_w[:, :, 1::2] = torch.cos(w * div_term)
pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2) # B, N, C*3
if cat_coords:
pe = torch.cat([pe, xyzw], dim=2) # B, N, C*3+3
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
return pe

View File

@ -5,6 +5,8 @@
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from typing import Optional, Tuple
EPS = 1e-6
@ -15,155 +17,240 @@ def smart_cat(tensor1, tensor2, dim):
return torch.cat([tensor1, tensor2], dim=dim)
def normalize_single(d):
# d is a whatever shape torch tensor
dmin = torch.min(d)
dmax = torch.max(d)
d = (d - dmin) / (EPS + (dmax - dmin))
return d
def get_points_on_a_grid(
size: int,
extent: Tuple[float, ...],
center: Optional[Tuple[float, ...]] = None,
device: Optional[torch.device] = torch.device("cpu"),
):
r"""Get a grid of points covering a rectangular region
`get_points_on_a_grid(size, extent)` generates a :attr:`size` by
:attr:`size` grid fo points distributed to cover a rectangular area
specified by `extent`.
The `extent` is a pair of integer :math:`(H,W)` specifying the height
and width of the rectangle.
Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
specifying the vertical and horizontal center coordinates. The center
defaults to the middle of the extent.
Points are distributed uniformly within the rectangle leaving a margin
:math:`m=W/64` from the border.
It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
points :math:`P_{ij}=(x_i, y_i)` where
.. math::
P_{ij} = \left(
c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
\right)
Points are returned in row-major order.
Args:
size (int): grid size.
extent (tuple): height and with of the grid extent.
center (tuple, optional): grid center.
device (str, optional): Defaults to `"cpu"`.
Returns:
Tensor: grid.
"""
if size == 1:
return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
if center is None:
center = [extent[0] / 2, extent[1] / 2]
margin = extent[1] / 64
range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
grid_y, grid_x = torch.meshgrid(
torch.linspace(*range_y, size, device=device),
torch.linspace(*range_x, size, device=device),
indexing="ij",
)
return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
def normalize(d):
# d is B x whatever. normalize within each element of the batch
out = torch.zeros(d.size())
if d.is_cuda:
out = out.cuda()
B = list(d.size())[0]
for b in list(range(B)):
out[b] = normalize_single(d[b])
return out
def reduce_masked_mean(input, mask, dim=None, keepdim=False):
r"""Masked mean
`reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
over a mask :attr:`mask`, returning
def meshgrid2d(B, Y, X, stack=False, norm=False, device="cpu"):
# returns a meshgrid sized B x Y x X
.. math::
\text{output} =
\frac
{\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
{\epsilon + \sum_{i=1}^N \text{mask}_i}
grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))
grid_y = torch.reshape(grid_y, [1, Y, 1])
grid_y = grid_y.repeat(B, 1, X)
where :math:`N` is the number of elements in :attr:`input` and
:attr:`mask`, and :math:`\epsilon` is a small constant to avoid
division by zero.
grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device))
grid_x = torch.reshape(grid_x, [1, 1, X])
grid_x = grid_x.repeat(B, Y, 1)
`reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
:attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
Optionally, the dimension can be kept in the output by setting
:attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
the same dimension as :attr:`input`.
if stack:
# note we stack in xy order
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
grid = torch.stack([grid_x, grid_y], dim=-1)
return grid
else:
return grid_y, grid_x
The interface is similar to `torch.mean()`.
Args:
inout (Tensor): input tensor.
mask (Tensor): mask.
dim (int, optional): Dimension to sum over. Defaults to None.
keepdim (bool, optional): Keep the summed dimension. Defaults to False.
Returns:
Tensor: mean tensor.
"""
mask = mask.expand_as(input)
prod = input * mask
def reduce_masked_mean(x, mask, dim=None, keepdim=False):
# x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
# returns shape-1
# axis can be a list of axes
for (a, b) in zip(x.size(), mask.size()):
assert a == b # some shape mismatch!
prod = x * mask
if dim is None:
numer = torch.sum(prod)
denom = EPS + torch.sum(mask)
denom = torch.sum(mask)
else:
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim)
denom = torch.sum(mask, dim=dim, keepdim=keepdim)
mean = numer / denom
mean = numer / (EPS + denom)
return mean
def bilinear_sample2d(im, x, y, return_inbounds=False):
# x and y are each B, N
# output is B, C, N
if len(im.shape) == 5:
B, N, C, H, W = list(im.shape)
else:
B, C, H, W = list(im.shape)
N = list(x.shape)[1]
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
r"""Sample a tensor using bilinear interpolation
x = x.float()
y = y.float()
H_f = torch.tensor(H, dtype=torch.float32)
W_f = torch.tensor(W, dtype=torch.float32)
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
coordinates :attr:`coords` using bilinear interpolation. It is the same
as `torch.nn.functional.grid_sample()` but with a different coordinate
convention.
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
:math:`B` is the batch size, :math:`C` is the number of channels,
:math:`H` is the height of the image, and :math:`W` is the width of the
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
max_y = (H_f - 1).int()
max_x = (W_f - 1).int()
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
that in this case the order of the components is slightly different
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
x0 = torch.floor(x).int()
x1 = x0 + 1
y0 = torch.floor(y).int()
y1 = y0 + 1
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
left-most image pixel :math:`W-1` to the center of the right-most
pixel.
x0_clip = torch.clamp(x0, 0, max_x)
x1_clip = torch.clamp(x1, 0, max_x)
y0_clip = torch.clamp(y0, 0, max_y)
y1_clip = torch.clamp(y1, 0, max_y)
dim2 = W
dim1 = W * H
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
the left-most pixel :math:`W` to the right edge of the right-most
pixel.
base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1
base = torch.reshape(base, [B, 1]).repeat([1, N])
Similar conventions apply to the :math:`y` for the range
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
:math:`[0,T-1]` and :math:`[0,T]`.
base_y0 = base + y0_clip * dim2
base_y1 = base + y1_clip * dim2
Args:
input (Tensor): batch of input images.
coords (Tensor): batch of coordinates.
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
idx_y0_x0 = base_y0 + x0_clip
idx_y0_x1 = base_y0 + x1_clip
idx_y1_x0 = base_y1 + x0_clip
idx_y1_x1 = base_y1 + x1_clip
Returns:
Tensor: sampled points.
"""
# use the indices to lookup pixels in the flat image
# im is B x C x H x W
# move C out to last dim
if len(im.shape) == 5:
im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C)
i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute(
0, 2, 1
)
i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute(
0, 2, 1
)
i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute(
0, 2, 1
)
i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute(
0, 2, 1
sizes = input.shape[2:]
assert len(sizes) in [2, 3]
if len(sizes) == 3:
# t x y -> x y t to match dimensions T H W in grid_sample
coords = coords[..., [1, 2, 0]]
if align_corners:
coords = coords * torch.tensor(
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
)
else:
im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C)
i_y0_x0 = im_flat[idx_y0_x0.long()]
i_y0_x1 = im_flat[idx_y0_x1.long()]
i_y1_x0 = im_flat[idx_y1_x0.long()]
i_y1_x1 = im_flat[idx_y1_x1.long()]
coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
# Finally calculate interpolated values.
x0_f = x0.float()
x1_f = x1.float()
y0_f = y0.float()
y1_f = y1.float()
coords -= 1
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
output = (
w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
)
# output is B*N x C
output = output.view(B, -1, C)
output = output.permute(0, 2, 1)
# output is B x C x N
if return_inbounds:
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
inbounds = (x_valid & y_valid).float()
inbounds = inbounds.reshape(
B, N
) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
return output, inbounds
def sample_features4d(input, coords):
r"""Sample spatial features
return output # B, C, N
`sample_features4d(input, coords)` samples the spatial features
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
The field is sampled at coordinates :attr:`coords` using bilinear
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
same convention as :func:`bilinear_sampler` with `align_corners=True`.
The output tensor has one feature per point, and has shape :math:`(B,
R, C)`.
Args:
input (Tensor): spatial features.
coords (Tensor): points.
Returns:
Tensor: sampled features.
"""
B, _, _, _ = input.shape
# B R 2 -> B R 1 2
coords = coords.unsqueeze(2)
# B C R 1
feats = bilinear_sampler(input, coords)
return feats.permute(0, 2, 1, 3).view(
B, -1, feats.shape[1] * feats.shape[3]
) # B C R 1 -> B R C
def sample_features5d(input, coords):
r"""Sample spatio-temporal features
`sample_features5d(input, coords)` works in the same way as
:func:`sample_features4d` but for spatio-temporal features and points:
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
Args:
input (Tensor): spatio-temporal features.
coords (Tensor): spatio-temporal points.
Returns:
Tensor: sampled features.
"""
B, T, _, _, _ = input.shape
# B T C H W -> B C T H W
input = input.permute(0, 2, 1, 3, 4)
# B R1 R2 3 -> B R1 R2 1 3
coords = coords.unsqueeze(3)
# B C R1 R2 1
feats = bilinear_sampler(input, coords)
return feats.permute(0, 2, 3, 1, 4).view(
B, feats.shape[2], feats.shape[3], feats.shape[1]
) # B C R1 R2 1 -> B R1 R2 C

View File

@ -8,16 +8,17 @@ import torch
import torch.nn.functional as F
from typing import Tuple
from cotracker.models.core.cotracker.cotracker import CoTracker, get_points_on_a_grid
from cotracker.models.core.cotracker.cotracker import CoTracker2
from cotracker.models.core.model_utils import get_points_on_a_grid
class EvaluationPredictor(torch.nn.Module):
def __init__(
self,
cotracker_model: CoTracker,
cotracker_model: CoTracker2,
interp_shape: Tuple[int, int] = (384, 512),
grid_size: int = 6,
local_grid_size: int = 6,
grid_size: int = 5,
local_grid_size: int = 8,
single_point: bool = True,
n_iters: int = 6,
) -> None:
@ -39,14 +40,14 @@ class EvaluationPredictor(torch.nn.Module):
assert D == 3
assert B == 1
rgbs = video.reshape(B * T, C, H, W)
rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear")
rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
video = video.reshape(B * T, C, H, W)
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
device = rgbs.device
device = video.device
queries[:, :, 1] *= self.interp_shape[1] / W
queries[:, :, 2] *= self.interp_shape[0] / H
queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1)
queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1)
if self.single_point:
traj_e = torch.zeros((B, T, N, 2), device=device)
@ -56,51 +57,49 @@ class EvaluationPredictor(torch.nn.Module):
t = query[0, 0, 0].long()
traj_e_pind, vis_e_pind = self._process_one_point(rgbs, query)
traj_e_pind, vis_e_pind = self._process_one_point(video, query)
traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
else:
if self.grid_size > 0:
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device)
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(
device
) #
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
queries = torch.cat([queries, xy], dim=1) #
traj_e, __, vis_e, __ = self.model(
rgbs=rgbs,
traj_e, vis_e, __ = self.model(
video=video,
queries=queries,
iters=self.n_iters,
)
traj_e[:, :, :, 0] *= W / float(self.interp_shape[1])
traj_e[:, :, :, 1] *= H / float(self.interp_shape[0])
traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1)
traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1)
return traj_e, vis_e
def _process_one_point(self, rgbs, query):
def _process_one_point(self, video, query):
t = query[0, 0, 0].long()
device = rgbs.device
device = query.device
if self.local_grid_size > 0:
xy_target = get_points_on_a_grid(
self.local_grid_size,
(50, 50),
[query[0, 0, 2], query[0, 0, 1]],
[query[0, 0, 2].item(), query[0, 0, 1].item()],
)
xy_target = torch.cat(
[torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2
xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to(
device
) #
query = torch.cat([query, xy_target], dim=1).to(device) #
query = torch.cat([query, xy_target], dim=1) #
if self.grid_size > 0:
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device)
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
query = torch.cat([query, xy], dim=1).to(device) #
query = torch.cat([query, xy], dim=1) #
# crop the video to start from the queried frame
query[0, 0, 0] = 0
traj_e_pind, __, vis_e_pind, __ = self.model(
rgbs=rgbs[:, t:], queries=query, iters=self.n_iters
traj_e_pind, vis_e_pind, __ = self.model(
video=video[:, t:], queries=query, iters=self.n_iters
)
return traj_e_pind, vis_e_pind

View File

@ -7,23 +7,16 @@
import torch
import torch.nn.functional as F
from tqdm import tqdm
from cotracker.models.core.cotracker.cotracker import get_points_on_a_grid
from cotracker.models.core.model_utils import smart_cat
from cotracker.models.build_cotracker import (
build_cotracker,
)
from cotracker.models.core.model_utils import smart_cat, get_points_on_a_grid
from cotracker.models.build_cotracker import build_cotracker
class CoTrackerPredictor(torch.nn.Module):
def __init__(
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth"
):
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
super().__init__()
self.interp_shape = (384, 512)
self.support_grid_size = 6
model = build_cotracker(checkpoint)
self.interp_shape = model.model_resolution
self.model = model
self.model.eval()
@ -43,7 +36,6 @@ class CoTrackerPredictor(torch.nn.Module):
grid_query_frame: int = 0, # only for dense and regular grid tracks
backward_tracking: bool = False,
):
if queries is None and grid_size == 0:
tracks, visibilities = self._compute_dense_tracks(
video,
@ -63,9 +55,7 @@ class CoTrackerPredictor(torch.nn.Module):
return tracks, visibilities
def _compute_dense_tracks(
self, video, grid_query_frame, grid_size=30, backward_tracking=False
):
def _compute_dense_tracks(self, video, grid_query_frame, grid_size=30, backward_tracking=False):
*_, H, W = video.shape
grid_step = W // grid_size
grid_width = W // grid_step
@ -73,12 +63,11 @@ class CoTrackerPredictor(torch.nn.Module):
tracks = visibilities = None
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
grid_pts[0, :, 0] = grid_query_frame
for offset in tqdm(range(grid_step * grid_step)):
for offset in range(grid_step * grid_step):
print(f"step {offset} / {grid_step * grid_step}")
ox = offset % grid_step
oy = offset // grid_step
grid_pts[0, :, 1] = (
torch.arange(grid_width).repeat(grid_height) * grid_step + ox
)
grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox
grid_pts[0, :, 2] = (
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
)
@ -106,21 +95,23 @@ class CoTrackerPredictor(torch.nn.Module):
assert B == 1
video = video.reshape(B * T, C, H, W)
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear")
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
if queries is not None:
queries = queries.clone()
B, N, D = queries.shape
assert D == 3
queries[:, :, 1] *= self.interp_shape[1] / W
queries[:, :, 2] *= self.interp_shape[0] / H
queries = queries.clone()
queries[:, :, 1:] *= queries.new_tensor(
[
(self.interp_shape[1] - 1) / (W - 1),
(self.interp_shape[0] - 1) / (H - 1),
]
)
elif grid_size > 0:
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
if segm_mask is not None:
segm_mask = F.interpolate(
segm_mask, tuple(self.interp_shape), mode="nearest"
)
segm_mask = F.interpolate(segm_mask, tuple(self.interp_shape), mode="nearest")
point_mask = segm_mask[0, 0][
(grid_pts[0, :, 1]).round().long().cpu(),
(grid_pts[0, :, 0]).round().long().cpu(),
@ -133,23 +124,23 @@ class CoTrackerPredictor(torch.nn.Module):
)
if add_support_grid:
grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=video.device)
grid_pts = torch.cat(
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
grid_pts = get_points_on_a_grid(
self.support_grid_size, self.interp_shape, device=video.device
)
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
queries = torch.cat([queries, grid_pts], dim=1)
tracks, __, visibilities, __ = self.model(rgbs=video, queries=queries, iters=6)
tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6)
if backward_tracking:
tracks, visibilities = self._compute_backward_tracks(
video, queries, tracks, visibilities
)
if add_support_grid:
queries[:, -self.support_grid_size ** 2 :, 0] = T - 1
queries[:, -self.support_grid_size**2 :, 0] = T - 1
if add_support_grid:
tracks = tracks[:, :, : -self.support_grid_size ** 2]
visibilities = visibilities[:, :, : -self.support_grid_size ** 2]
tracks = tracks[:, :, : -self.support_grid_size**2]
visibilities = visibilities[:, :, : -self.support_grid_size**2]
thr = 0.9
visibilities = visibilities > thr
@ -158,17 +149,18 @@ class CoTrackerPredictor(torch.nn.Module):
# TODO: batchify
for i in range(len(queries)):
queries_t = queries[i, :tracks.size(2), 0].to(torch.int64)
queries_t = queries[i, : tracks.size(2), 0].to(torch.int64)
arange = torch.arange(0, len(queries_t))
# overwrite the predictions with the query points
tracks[i, queries_t, arange] = queries[i, :tracks.size(2), 1:]
tracks[i, queries_t, arange] = queries[i, : tracks.size(2), 1:]
# correct visibilities, the query points should be visible
visibilities[i, queries_t, arange] = True
tracks[:, :, :, 0] *= W / float(self.interp_shape[1])
tracks[:, :, :, 1] *= H / float(self.interp_shape[0])
tracks *= tracks.new_tensor(
[(W - 1) / (self.interp_shape[1] - 1), (H - 1) / (self.interp_shape[0] - 1)]
)
return tracks, visibilities
def _compute_backward_tracks(self, video, queries, tracks, visibilities):
@ -176,9 +168,7 @@ class CoTrackerPredictor(torch.nn.Module):
inv_queries = queries.clone()
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
inv_tracks, __, inv_visibilities, __ = self.model(
rgbs=inv_video, queries=inv_queries, iters=6
)
inv_tracks, inv_visibilities, __ = self.model(video=inv_video, queries=inv_queries, iters=6)
inv_tracks = inv_tracks.flip(1)
inv_visibilities = inv_visibilities.flip(1)
@ -188,3 +178,79 @@ class CoTrackerPredictor(torch.nn.Module):
tracks[mask] = inv_tracks[mask]
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
return tracks, visibilities
class CoTrackerOnlinePredictor(torch.nn.Module):
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
super().__init__()
self.support_grid_size = 6
model = build_cotracker(checkpoint)
self.interp_shape = model.model_resolution
self.step = model.window_len // 2
self.model = model
self.model.eval()
@torch.no_grad()
def forward(
self,
video_chunk,
is_first_step: bool = False,
queries: torch.Tensor = None,
grid_size: int = 10,
grid_query_frame: int = 0,
add_support_grid=False,
):
# Initialize online video processing and save queried points
# This needs to be done before processing *each new video*
if is_first_step:
self.model.init_video_online_processing()
if queries is not None:
B, N, D = queries.shape
assert D == 3
queries = queries.clone()
queries[:, :, 1:] *= queries.new_tensor(
[
(self.interp_shape[1] - 1) / (W - 1),
(self.interp_shape[0] - 1) / (H - 1),
]
)
elif grid_size > 0:
grid_pts = get_points_on_a_grid(
grid_size, self.interp_shape, device=video_chunk.device
)
queries = torch.cat(
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
dim=2,
)
if add_support_grid:
grid_pts = get_points_on_a_grid(
self.support_grid_size, self.interp_shape, device=video_chunk.device
)
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
queries = torch.cat([queries, grid_pts], dim=1)
self.queries = queries
return (None, None)
B, T, C, H, W = video_chunk.shape
video_chunk = video_chunk.reshape(B * T, C, H, W)
video_chunk = F.interpolate(
video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True
)
video_chunk = video_chunk.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
tracks, visibilities, __ = self.model(
video=video_chunk,
queries=self.queries,
iters=6,
is_online=True,
)
thr = 0.9
return (
tracks
* tracks.new_tensor(
[
(W - 1) / (self.interp_shape[1] - 1),
(H - 1) / (self.interp_shape[0] - 1),
]
),
visibilities > thr,
)

View File

@ -3,36 +3,59 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
import cv2
import imageio
import torch
import flow_vis
from matplotlib import cm
import torch.nn.functional as F
import torchvision.transforms as transforms
from moviepy.editor import ImageSequenceClip
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
def read_video_from_path(path):
cap = cv2.VideoCapture(path)
if not cap.isOpened():
print("Error opening video file")
else:
frames = []
while cap.isOpened():
ret, frame = cap.read()
if ret == True:
frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
else:
break
cap.release()
try:
reader = imageio.get_reader(path)
except Exception as e:
print("Error opening video file: ", e)
return None
frames = []
for i, im in enumerate(reader):
frames.append(np.array(im))
return np.stack(frames)
def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True):
# Create a draw object
draw = ImageDraw.Draw(rgb)
# Calculate the bounding box of the circle
left_up_point = (coord[0] - radius, coord[1] - radius)
right_down_point = (coord[0] + radius, coord[1] + radius)
# Draw the circle
draw.ellipse(
[left_up_point, right_down_point],
fill=tuple(color) if visible else None,
outline=tuple(color),
)
return rgb
def draw_line(rgb, coord_y, coord_x, color, linewidth):
draw = ImageDraw.Draw(rgb)
draw.line(
(coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
fill=tuple(color),
width=linewidth,
)
return rgb
def add_weighted(rgb, alpha, original, beta, gamma):
return (rgb * alpha + original * beta + gamma).astype("uint8")
class Visualizer:
def __init__(
self,
@ -107,7 +130,7 @@ class Visualizer:
def save_video(self, video, filename, writer=None, step=0):
if writer is not None:
writer.add_video(
f"{filename}_pred_track",
filename,
video.to(torch.uint8),
global_step=step,
fps=self.fps,
@ -116,11 +139,18 @@ class Visualizer:
os.makedirs(self.save_dir, exist_ok=True)
wide_list = list(video.unbind(1))
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps)
# Write the video file
save_path = os.path.join(self.save_dir, f"{filename}_pred_track.mp4")
clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None)
# Prepare the video file path
save_path = os.path.join(self.save_dir, f"{filename}.mp4")
# Create a writer object
video_writer = imageio.get_writer(save_path, fps=self.fps)
# Write frames to the video file
for frame in wide_list[2:-1]:
video_writer.append_data(frame)
video_writer.close()
print(f"Video saved to {save_path}")
@ -149,9 +179,11 @@ class Visualizer:
# process input video
for rgb in video:
res_video.append(rgb.copy())
vector_colors = np.zeros((T, N, 3))
if self.mode == "optical_flow":
import flow_vis
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
elif segm_mask is None:
if self.mode == "rainbow":
@ -196,9 +228,7 @@ class Visualizer:
if self.tracks_leave_trace != 0:
for t in range(1, T):
first_ind = (
max(0, t - self.tracks_leave_trace)
if self.tracks_leave_trace >= 0
else 0
max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0
)
curr_tracks = tracks[first_ind : t + 1]
curr_colors = vector_colors[first_ind : t + 1]
@ -218,12 +248,11 @@ class Visualizer:
curr_colors,
)
if gt_tracks is not None:
res_video[t] = self._draw_gt_tracks(
res_video[t], gt_tracks[first_ind : t + 1]
)
res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1])
# draw points
for t in range(T):
img = Image.fromarray(np.uint8(res_video[t]))
for i in range(N):
coord = (tracks[t, i, 0], tracks[t, i, 1])
visibile = True
@ -233,15 +262,14 @@ class Visualizer:
if not compensate_for_camera_motion or (
compensate_for_camera_motion and segm_mask[i] > 0
):
cv2.circle(
res_video[t],
coord,
int(self.linewidth * 2),
vector_colors[t, i].tolist(),
thickness=-1 if visibile else 2
-1,
img = draw_circle(
img,
coord=coord,
radius=int(self.linewidth * 2),
color=vector_colors[t, i].astype(int),
visible=visibile,
)
res_video[t] = np.array(img)
# construct the final rgb sequence
if self.show_first_frame > 0:
@ -256,7 +284,7 @@ class Visualizer:
alpha: float = 0.5,
):
T, N, _ = tracks.shape
rgb = Image.fromarray(np.uint8(rgb))
for s in range(T - 1):
vector_color = vector_colors[s]
original = rgb.copy()
@ -265,16 +293,18 @@ class Visualizer:
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
if coord_y[0] != 0 and coord_y[1] != 0:
cv2.line(
rgb = draw_line(
rgb,
coord_y,
coord_x,
vector_color[i].tolist(),
vector_color[i].astype(int),
self.linewidth,
cv2.LINE_AA,
)
if self.tracks_leave_trace > 0:
rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0)
rgb = Image.fromarray(
np.uint8(add_weighted(np.array(rgb), alpha, np.array(original), 1 - alpha, 0))
)
rgb = np.array(rgb)
return rgb
def _draw_gt_tracks(
@ -283,8 +313,8 @@ class Visualizer:
gt_tracks: np.ndarray, # T x 2
):
T, N, _ = gt_tracks.shape
color = np.array((211.0, 0.0, 0.0))
color = np.array((211, 0, 0))
rgb = Image.fromarray(np.uint8(rgb))
for t in range(T):
for i in range(N):
gt_tracks = gt_tracks[t][i]
@ -293,22 +323,21 @@ class Visualizer:
length = self.linewidth * 3
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
cv2.line(
rgb = draw_line(
rgb,
coord_y,
coord_x,
color,
self.linewidth,
cv2.LINE_AA,
)
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
cv2.line(
rgb = draw_line(
rgb,
coord_y,
coord_x,
color,
self.linewidth,
cv2.LINE_AA,
)
rgb = np.array(rgb)
return rgb

8
cotracker/version.py Normal file
View File

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
__version__ = "2.0.0"

30
demo.py
View File

@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.
import os
import cv2
import torch
import argparse
import numpy as np
@ -14,9 +13,18 @@ from PIL import Image
from cotracker.utils.visualizer import Visualizer, read_video_from_path
from cotracker.predictor import CoTrackerPredictor
DEFAULT_DEVICE = ('cuda' if torch.cuda.is_available() else
'mps' if torch.backends.mps.is_available() else
'cpu')
# Unfortunately MPS acceleration does not support all the features we require,
# but we may be able to enable it in the future
DEFAULT_DEVICE = (
# "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
"cuda"
if torch.cuda.is_available()
else "cpu"
)
# if DEFAULT_DEVICE == "mps":
# os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@ -32,15 +40,16 @@ if __name__ == "__main__":
)
parser.add_argument(
"--checkpoint",
default="./checkpoints/cotracker_stride_4_wind_8.pth",
help="cotracker model",
# default="./checkpoints/cotracker.pth",
default=None,
help="CoTracker model parameters",
)
parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size")
parser.add_argument("--grid_size", type=int, default=10, help="Regular grid size")
parser.add_argument(
"--grid_query_frame",
type=int,
default=0,
help="Compute dense and grid tracks starting from this frame ",
help="Compute dense and grid tracks starting from this frame",
)
parser.add_argument(
@ -57,7 +66,10 @@ if __name__ == "__main__":
segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
segm_mask = torch.from_numpy(segm_mask)[None, None]
model = CoTrackerPredictor(checkpoint=args.checkpoint)
if args.checkpoint is not None:
model = CoTrackerPredictor(checkpoint=args.checkpoint)
else:
model = torch.hub.load("facebookresearch/co-tracker", "cotracker2")
model = model.to(DEFAULT_DEVICE)
video = video.to(DEFAULT_DEVICE)

13
docs/Makefile Normal file
View File

@ -0,0 +1,13 @@
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = _build
O = -a
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

View File

@ -0,0 +1,14 @@
Models
======
CoTracker models:
.. currentmodule:: cotracker.models
Model Utils
-----------
.. automodule:: cotracker.models.core.model_utils
:members:
:undoc-members:
:show-inheritance:

View File

@ -0,0 +1,11 @@
Utils
=====
CoTracker utilizes the following utilities:
.. currentmodule:: cotracker
.. automodule:: cotracker.utils.visualizer
:members:
:undoc-members:
:show-inheritance:

39
docs/source/conf.py Normal file
View File

@ -0,0 +1,39 @@
__version__ = None
exec(open("../../cotracker/version.py", "r").read())
project = "CoTracker"
copyright = "2023-24, Meta Platforms, Inc. and affiliates"
author = "Meta Platforms"
release = __version__
extensions = [
"sphinx.ext.napoleon",
"sphinx.ext.duration",
"sphinx.ext.doctest",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinxcontrib.bibtex",
]
intersphinx_mapping = {
"python": ("https://docs.python.org/3/", None),
"sphinx": ("https://www.sphinx-doc.org/en/master/", None),
}
intersphinx_disabled_domains = ["std"]
# templates_path = ["_templates"]
html_theme = "alabaster"
# Ignore >>> when copying code
copybutton_prompt_text = r">>> |\.\.\. "
copybutton_prompt_is_regexp = True
# -- Options for EPUB output
epub_show_urls = "footnote"
# typehints
autodoc_typehints = "description"
# citations
bibtex_bibfiles = ["references.bib"]

29
docs/source/index.rst Normal file
View File

@ -0,0 +1,29 @@
gsplat
===================================
.. image:: ../../assets/bmx-bumps.gif
:width: 800
:alt: Example of cotracker in action
Overview
--------
*CoTracker* is an open-source tracker :cite:p:`karaev2023cotracker`.
Links
-----
.. toctree::
:glob:
:maxdepth: 1
:caption: Python API
apis/*
Citations
---------
.. bibliography::
:style: unsrt
:filter: docname in docnames

View File

@ -0,0 +1,6 @@
@article{karaev2023cotracker,
title = {CoTracker: It is Better to Track Together},
author = {Nikita Karaev and Ignacio Rocco and Benjamin Graham and Natalia Neverova and Andrea Vedaldi and Christian Rupprecht},
journal = {arXiv:2307.07635},
year = {2023}
}

View File

@ -1,43 +1,39 @@
import os
import torch
import timm
import einops
import tqdm
import cv2
import gradio as gr
from cotracker.utils.visualizer import Visualizer, read_video_from_path
def cotracker_demo(
input_video,
grid_size: int = 10,
grid_query_frame: int = 0,
backward_tracking: bool = False,
tracks_leave_trace: bool = False
):
input_video,
grid_size: int = 10,
grid_query_frame: int = 0,
tracks_leave_trace: bool = False,
):
load_video = read_video_from_path(input_video)
grid_query_frame = min(len(load_video)-1, grid_query_frame)
grid_query_frame = min(len(load_video) - 1, grid_query_frame)
load_video = torch.from_numpy(load_video).permute(0, 3, 1, 2)[None].float()
model = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online")
model = torch.hub.load("facebookresearch/co-tracker", "cotracker_w8")
if torch.cuda.is_available():
model = model.cuda()
load_video = load_video.cuda()
pred_tracks, pred_visibility = model(
load_video,
grid_size=grid_size,
grid_query_frame=grid_query_frame,
backward_tracking=backward_tracking
)
model(video_chunk=load_video, is_first_step=True, grid_size=grid_size)
for ind in range(0, load_video.shape[1] - model.step, model.step):
pred_tracks, pred_visibility = model(
video_chunk=load_video[:, ind : ind + model.step * 2]
) # B T N 2, B T N 1
linewidth = 2
if grid_size < 10:
linewidth = 4
elif grid_size < 20:
linewidth = 3
vis = Visualizer(
save_dir=os.path.join(os.path.dirname(__file__), "results"),
grayscale=False,
@ -45,7 +41,7 @@ def cotracker_demo(
fps=10,
linewidth=linewidth,
show_first_frame=5,
tracks_leave_trace= -1 if tracks_leave_trace else 0,
tracks_leave_trace=-1 if tracks_leave_trace else 0,
)
import time
@ -55,44 +51,39 @@ def cotracker_demo(
filename = str(current_milli_time())
vis.visualize(
load_video,
tracks=pred_tracks,
tracks=pred_tracks,
visibility=pred_visibility,
filename=filename,
query_frame=grid_query_frame,
)
return os.path.join(
os.path.dirname(__file__), "results", f"{filename}_pred_track.mp4"
)
return os.path.join(os.path.dirname(__file__), "results", f"{filename}_pred_track.mp4")
app = gr.Interface(
title = "🎨 CoTracker: It is Better to Track Together",
description = "<div style='text-align: left;'> \
title="🎨 CoTracker: It is Better to Track Together",
description="<div style='text-align: left;'> \
<p>Welcome to <a href='http://co-tracker.github.io' target='_blank'>CoTracker</a>! This space demonstrates point (pixel) tracking in videos. \
Points are sampled on a regular grid and are tracked jointly. </p> \
<p> To get started, simply upload your <b>.mp4</b> video in landscape orientation or click on one of the example videos to load them. The shorter the video, the faster the processing. We recommend submitting short videos of length <b>2-7 seconds</b>.</p> \
<ul style='display: inline-block; text-align: left;'> \
<li>The total number of grid points is the square of <b>Grid Size</b>.</li> \
<li>To specify the starting frame for tracking, adjust <b>Grid Query Frame</b>. Tracks will be visualized only after the selected frame.</li> \
<li>Use <b>Backward Tracking</b> to track points from the selected frame in both directions.</li> \
<li>Check <b>Visualize Track Traces</b> to visualize traces of all the tracked points. </li> \
</ul> \
<p style='text-align: left'>For more details, check out our <a href='https://github.com/facebookresearch/co-tracker' target='_blank'>GitHub Repo</a> </p> \
</div>",
fn=cotracker_demo,
inputs=[
gr.Video(label="Input video", interactive=True),
gr.Slider(minimum=1, maximum=30, step=1, value=10, label="Grid Size"),
gr.Slider(minimum=0, maximum=30, step=1, value=0, label="Grid Query Frame"),
gr.Checkbox(label="Backward Tracking"),
gr.Checkbox(label="Visualize Track Traces"),
],
outputs=gr.Video(label="Video with predicted tracks"),
examples=[
[ "./assets/apple.mp4", 20, 0, False, False ],
[ "./assets/apple.mp4", 10, 30, True, False ],
["./assets/apple.mp4", 20, 0, False, False],
["./assets/apple.mp4", 10, 30, True, False],
],
cache_examples=False
cache_examples=False,
)
app.launch(share=False)

View File

@ -1,7 +1,3 @@
einops
timm
tqdm
opencv-python
matplotlib
moviepy
flow_vis

View File

@ -6,27 +6,33 @@
import torch
dependencies = ["torch", "einops", "timm", "tqdm"]
_COTRACKER_URL = (
"https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth"
)
_COTRACKER_URL = "https://huggingface.co/facebook/cotracker/resolve/main/cotracker2.pth"
def _make_cotracker_predictor(*, pretrained: bool = True, **kwargs):
from cotracker.predictor import CoTrackerPredictor
def _make_cotracker_predictor(*, pretrained: bool = True, online=False, **kwargs):
if online:
from cotracker.predictor import CoTrackerOnlinePredictor
predictor = CoTrackerPredictor(checkpoint=None)
predictor = CoTrackerOnlinePredictor(checkpoint=None)
else:
from cotracker.predictor import CoTrackerPredictor
predictor = CoTrackerPredictor(checkpoint=None)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
_COTRACKER_URL, map_location="cpu"
)
state_dict = torch.hub.load_state_dict_from_url(_COTRACKER_URL, map_location="cpu")
predictor.model.load_state_dict(state_dict)
return predictor
def cotracker_w8(*, pretrained: bool = True, **kwargs):
def cotracker2(*, pretrained: bool = True, **kwargs):
"""
CoTracker model with stride 4 and window length 8. (The main model from the paper)
CoTracker2 with stride 4 and window length 8. Can track up to 265*265 points jointly.
"""
return _make_cotracker_predictor(pretrained=pretrained, **kwargs)
return _make_cotracker_predictor(pretrained=pretrained, online=False, **kwargs)
def cotracker2_online(*, pretrained: bool = True, **kwargs):
"""
Online CoTracker2 with stride 4 and window length 8. Can track up to 265*265 points jointly.
"""
return _make_cotracker_predictor(pretrained=pretrained, online=True, **kwargs)

24
launch_training.sh Normal file
View File

@ -0,0 +1,24 @@
#!/bin/bash
EXP_DIR=$1
EXP_NAME=$2
DATE=$3
DATASET_ROOT=$4
NUM_STEPS=$5
echo `which python`
mkdir -p ${EXP_DIR}/${DATE}_${EXP_NAME}/logs/;
export PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH
sbatch --comment=${EXP_NAME} --partition=learn --time=39:00:00 --gpus-per-node=8 --nodes=4 --ntasks-per-node=8 \
--job-name=${EXP_NAME} --cpus-per-task=10 --signal=USR1@60 --open-mode=append \
--output=${EXP_DIR}/${DATE}_${EXP_NAME}/logs/%j_%x_%A_%a_%N.out \
--error=${EXP_DIR}/${DATE}_${EXP_NAME}/logs/%j_%x_%A_%a_%N.err \
--wrap="srun --label python ./train.py --batch_size 1 \
--num_steps ${NUM_STEPS} --ckpt_path ${EXP_DIR}/${DATE}_${EXP_NAME} --model_name cotracker \
--save_freq 200 --sequence_len 24 --eval_datasets dynamic_replica tapvid_davis_first \
--traj_per_sample 768 --sliding_window_len 8 \
--save_every_n_epoch 10 --evaluate_every_n_epoch 10 --model_stride 4 --dataset_root ${DATASET_ROOT} --num_nodes 4 \
--num_virtual_tracks 64"

90
online_demo.py Normal file
View File

@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import argparse
import imageio.v3 as iio
import numpy as np
from cotracker.utils.visualizer import Visualizer
from cotracker.predictor import CoTrackerOnlinePredictor
# Unfortunately MPS acceleration does not support all the features we require,
# but we may be able to enable it in the future
DEFAULT_DEVICE = (
# "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
"cuda"
if torch.cuda.is_available()
else "cpu"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--video_path",
default="./assets/apple.mp4",
help="path to a video",
)
parser.add_argument(
"--checkpoint",
default=None,
help="CoTracker model parameters",
)
parser.add_argument("--grid_size", type=int, default=10, help="Regular grid size")
parser.add_argument(
"--grid_query_frame",
type=int,
default=0,
help="Compute dense and grid tracks starting from this frame",
)
args = parser.parse_args()
if args.checkpoint is not None:
model = CoTrackerOnlinePredictor(checkpoint=args.checkpoint)
else:
model = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online")
model = model.to(DEFAULT_DEVICE)
window_frames = []
def _process_step(window_frames, is_first_step, grid_size):
video_chunk = (
torch.tensor(np.stack(window_frames[-model.step * 2 :]), device=DEFAULT_DEVICE)
.float()
.permute(0, 3, 1, 2)[None]
) # (1, T, 3, H, W)
return model(video_chunk, is_first_step=is_first_step, grid_size=grid_size)
# Iterating over video frames, processing one window at a time:
is_first_step = True
for i, frame in enumerate(
iio.imiter(
"https://github.com/facebookresearch/co-tracker/blob/main/assets/apple.mp4",
plugin="FFMPEG",
)
):
if i % model.step == 0 and i != 0:
pred_tracks, pred_visibility = _process_step(
window_frames, is_first_step, grid_size=args.grid_size
)
is_first_step = False
window_frames.append(frame)
# Processing the final video frames in case video length is not a multiple of model.step
pred_tracks, pred_visibility = _process_step(
window_frames[-(i % model.step) - model.step - 1 :],
is_first_step,
grid_size=args.grid_size,
)
print("Tracks are computed")
# save a video with predicted tracks
seq_name = args.video_path.split("/")[-1]
video = torch.tensor(np.stack(window_frames), device=DEFAULT_DEVICE).permute(0, 3, 1, 2)[None]
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame)

View File

@ -8,11 +8,11 @@ from setuptools import find_packages, setup
setup(
name="cotracker",
version="1.0",
version="2.0",
install_requires=[],
packages=find_packages(exclude="notebooks"),
extras_require={
"all": ["matplotlib", "opencv-python"],
"all": ["matplotlib"],
"dev": ["flake8", "black"],
},
)

View File

@ -0,0 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import unittest
from cotracker.models.core.model_utils import bilinear_sampler
class TestBilinearSampler(unittest.TestCase):
# Sample from an image (4d)
def _test4d(self, align_corners):
H, W = 4, 5
# Construct a grid to obtain indentity sampling
input = torch.randn(H * W).view(1, 1, H, W).float()
coords = torch.meshgrid(torch.arange(H), torch.arange(W))
coords = torch.stack(coords[::-1], dim=-1).float()[None]
if not align_corners:
coords = coords + 0.5
sampled_input = bilinear_sampler(input, coords, align_corners=align_corners)
torch.testing.assert_close(input, sampled_input)
# Sample from a video (5d)
def _test5d(self, align_corners):
T, H, W = 3, 4, 5
# Construct a grid to obtain indentity sampling
input = torch.randn(H * W).view(1, 1, H, W).float()
input = torch.stack([input, input + 1, input + 2], dim=2)
coords = torch.meshgrid(torch.arange(T), torch.arange(W), torch.arange(H))
coords = torch.stack(coords, dim=-1).float().permute(0, 2, 1, 3)[None]
if not align_corners:
coords = coords + 0.5
sampled_input = bilinear_sampler(input, coords, align_corners=align_corners)
torch.testing.assert_close(input, sampled_input)
def test4d(self):
self._test4d(align_corners=True)
self._test4d(align_corners=False)
def test5d(self):
self._test5d(align_corners=True)
self._test5d(align_corners=False)
# run the test
unittest.main()

358
train.py
View File

@ -25,22 +25,35 @@ from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning.lite import LightningLite
from cotracker.models.evaluation_predictor import EvaluationPredictor
from cotracker.models.core.cotracker.cotracker import CoTracker
from cotracker.models.core.cotracker.cotracker import CoTracker2
from cotracker.utils.visualizer import Visualizer
from cotracker.datasets.tap_vid_datasets import TapVidDataset
from cotracker.datasets.badja_dataset import BadjaDataset
from cotracker.datasets.fast_capture_dataset import FastCaptureDataset
from cotracker.datasets.dr_dataset import DynamicReplicaDataset
from cotracker.evaluation.core.evaluator import Evaluator
from cotracker.datasets import kubric_movif_dataset
from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_cuda_
from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss
# define the handler function
# for training on a slurm cluster
def sig_handler(signum, frame):
print("caught signal", signum)
print(socket.gethostname(), "USR1 signal caught.")
# do other stuff to cleanup here
print("requeuing job " + os.environ["SLURM_JOB_ID"])
os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
sys.exit(-1)
def term_handler(signum, frame):
print("bypassing sigterm", flush=True)
def fetch_optimizer(args, model):
"""Create the optimizer and learning rate scheduler"""
optimizer = optim.AdamW(
model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8
)
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8)
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
args.lr,
@ -53,69 +66,61 @@ def fetch_optimizer(args, model):
return optimizer, scheduler
def forward_batch(batch, model, args, loss_fn=None, writer=None, step=0):
rgbs = batch.video
def forward_batch(batch, model, args):
video = batch.video
trajs_g = batch.trajectory
vis_g = batch.visibility
valids = batch.valid
B, T, C, H, W = rgbs.shape
B, T, C, H, W = video.shape
assert C == 3
B, T, N, D = trajs_g.shape
device = rgbs.device
device = video.device
__, first_positive_inds = torch.max(vis_g, dim=1)
# We want to make sure that during training the model sees visible points
# that it does not need to track just yet: they are visible but queried from a later frame
N_rand = N // 4
# inds of visible points in the 1st frame
nonzero_inds = [torch.nonzero(vis_g[0, :, i]) for i in range(N)]
rand_vis_inds = torch.cat(
[
nonzero_row[torch.randint(len(nonzero_row), size=(1,))]
for nonzero_row in nonzero_inds
],
dim=1,
)
first_positive_inds = torch.cat(
[rand_vis_inds[:, :N_rand], first_positive_inds[:, N_rand:]], dim=1
)
nonzero_inds = [[torch.nonzero(vis_g[b, :, i]) for i in range(N)] for b in range(B)]
for b in range(B):
rand_vis_inds = torch.cat(
[
nonzero_row[torch.randint(len(nonzero_row), size=(1,))]
for nonzero_row in nonzero_inds[b]
],
dim=1,
)
first_positive_inds[b] = torch.cat(
[rand_vis_inds[:, :N_rand], first_positive_inds[b : b + 1, N_rand:]], dim=1
)
ind_array_ = torch.arange(T, device=device)
ind_array_ = ind_array_[None, :, None].repeat(B, 1, N)
assert torch.allclose(
vis_g[ind_array_ == first_positive_inds[:, None, :]],
torch.ones_like(vis_g),
)
assert torch.allclose(
vis_g[ind_array_ == rand_vis_inds[:, None, :]], torch.ones_like(vis_g)
)
gather = torch.gather(
trajs_g, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 2)
torch.ones(1, device=device),
)
gather = torch.gather(trajs_g, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, D))
xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1)
queries = torch.cat([first_positive_inds[:, :, None], xys], dim=2)
queries = torch.cat([first_positive_inds[:, :, None], xys[:, :, :2]], dim=2)
predictions, __, visibility, train_data = model(
rgbs=rgbs, queries=queries, iters=args.train_iters, is_train=True
predictions, visibility, train_data = model(
video=video, queries=queries, iters=args.train_iters, is_train=True
)
vis_predictions, coord_predictions, wind_inds, sort_inds = train_data
trajs_g = trajs_g[:, :, sort_inds]
vis_g = vis_g[:, :, sort_inds]
valids = valids[:, :, sort_inds]
coord_predictions, vis_predictions, valid_mask = train_data
vis_gts = []
traj_gts = []
valids_gts = []
for i, wind_idx in enumerate(wind_inds):
ind = i * (args.sliding_window_len // 2)
vis_gts.append(vis_g[:, ind : ind + args.sliding_window_len, :wind_idx])
traj_gts.append(trajs_g[:, ind : ind + args.sliding_window_len, :wind_idx])
valids_gts.append(valids[:, ind : ind + args.sliding_window_len, :wind_idx])
S = args.sliding_window_len
for ind in range(0, args.sequence_len - S // 2, S // 2):
vis_gts.append(vis_g[:, ind : ind + S])
traj_gts.append(trajs_g[:, ind : ind + S])
valids_gts.append(valids[:, ind : ind + S] * valid_mask[:, ind : ind + S])
seq_loss = sequence_loss(coord_predictions, traj_gts, vis_gts, valids_gts, 0.8)
vis_loss = balanced_ce_loss(vis_predictions, vis_gts, valids_gts)
@ -131,9 +136,17 @@ def forward_batch(batch, model, args, loss_fn=None, writer=None, step=0):
def run_test_eval(evaluator, model, dataloaders, writer, step):
model.eval()
for ds_name, dataloader in dataloaders:
visualize_every = 1
grid_size = 5
if ds_name == "dynamic_replica":
visualize_every = 8
grid_size = 0
elif "tapvid" in ds_name:
visualize_every = 5
predictor = EvaluationPredictor(
model.module.module,
grid_size=6,
grid_size=grid_size,
local_grid_size=0,
single_point=False,
n_iters=6,
@ -148,37 +161,23 @@ def run_test_eval(evaluator, model, dataloaders, writer, step):
train_mode=True,
writer=writer,
step=step,
visualize_every=visualize_every,
)
if ds_name == "badja" or ds_name == "fastcapture" or ("kubric" in ds_name):
metrics = {
**{
f"{ds_name}_avg": np.mean(
[v for k, v in metrics.items() if "accuracy" not in k]
)
},
**{
f"{ds_name}_avg_accuracy": np.mean(
[v for k, v in metrics.items() if "accuracy" in k]
)
},
}
print("avg", np.mean([v for v in metrics.values()]))
if ds_name == "dynamic_replica" or ds_name == "kubric":
metrics = {f"{ds_name}_avg_{k}": v for k, v in metrics["avg"].items()}
if "tapvid" in ds_name:
metrics = {
f"{ds_name}_avg_OA": metrics["avg"]["occlusion_accuracy"] * 100,
f"{ds_name}_avg_delta": metrics["avg"]["average_pts_within_thresh"]
* 100,
f"{ds_name}_avg_Jaccard": metrics["avg"]["average_jaccard"] * 100,
f"{ds_name}_avg_OA": metrics["avg"]["occlusion_accuracy"],
f"{ds_name}_avg_delta": metrics["avg"]["average_pts_within_thresh"],
f"{ds_name}_avg_Jaccard": metrics["avg"]["average_jaccard"],
}
writer.add_scalars(f"Eval", metrics, step)
writer.add_scalars(f"Eval_{ds_name}", metrics, step)
class Logger:
SUM_FREQ = 100
def __init__(self, model, scheduler):
@ -190,24 +189,19 @@ class Logger:
def _print_training_status(self):
metrics_data = [
self.running_loss[k] / Logger.SUM_FREQ
for k in sorted(self.running_loss.keys())
self.running_loss[k] / Logger.SUM_FREQ for k in sorted(self.running_loss.keys())
]
training_str = "[{:6d}] ".format(self.total_steps + 1)
metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data)
# print the training status
logging.info(
f"Training Metrics ({self.total_steps}): {training_str + metrics_str}"
)
logging.info(f"Training Metrics ({self.total_steps}): {training_str + metrics_str}")
if self.writer is None:
self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs"))
for k in self.running_loss:
self.writer.add_scalar(
k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps
)
self.writer.add_scalar(k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps)
self.running_loss[k] = 0.0
def push(self, metrics, task):
@ -249,79 +243,56 @@ class Lite(LightningLite):
seed_everything(0)
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2 ** 32
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
g = torch.Generator()
g.manual_seed(0)
if self.global_rank == 0:
eval_dataloaders = []
if "dynamic_replica" in args.eval_datasets:
eval_dataset = DynamicReplicaDataset(
sample_len=60, only_first_n_samples=1, rgbd_input=False
)
eval_dataloader_dr = torch.utils.data.DataLoader(
eval_dataset,
batch_size=1,
shuffle=False,
num_workers=1,
collate_fn=collate_fn,
)
eval_dataloaders.append(("dynamic_replica", eval_dataloader_dr))
eval_dataloaders = []
if "badja" in args.eval_datasets:
eval_dataset = BadjaDataset(
data_root=os.path.join(args.dataset_root, "BADJA"),
max_seq_len=args.eval_max_seq_len,
dataset_resolution=args.crop_size,
if "tapvid_davis_first" in args.eval_datasets:
data_root = os.path.join(args.dataset_root, "tapvid/tapvid_davis/tapvid_davis.pkl")
eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root)
eval_dataloader_tapvid_davis = torch.utils.data.DataLoader(
eval_dataset,
batch_size=1,
shuffle=False,
num_workers=1,
collate_fn=collate_fn,
)
eval_dataloaders.append(("tapvid_davis", eval_dataloader_tapvid_davis))
evaluator = Evaluator(args.ckpt_path)
visualizer = Visualizer(
save_dir=args.ckpt_path,
pad_value=80,
fps=1,
show_first_frame=0,
tracks_leave_trace=0,
)
eval_dataloader_badja = torch.utils.data.DataLoader(
eval_dataset,
batch_size=1,
shuffle=False,
num_workers=8,
collate_fn=collate_fn,
)
eval_dataloaders.append(("badja", eval_dataloader_badja))
if "fastcapture" in args.eval_datasets:
eval_dataset = FastCaptureDataset(
data_root=os.path.join(args.dataset_root, "fastcapture"),
max_seq_len=min(100, args.eval_max_seq_len),
max_num_points=40,
dataset_resolution=args.crop_size,
)
eval_dataloader_fastcapture = torch.utils.data.DataLoader(
eval_dataset,
batch_size=1,
shuffle=False,
num_workers=1,
collate_fn=collate_fn,
)
eval_dataloaders.append(("fastcapture", eval_dataloader_fastcapture))
if "tapvid_davis_first" in args.eval_datasets:
data_root = os.path.join(args.dataset_root, "tapvid_davis/tapvid_davis.pkl")
eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root)
eval_dataloader_tapvid_davis = torch.utils.data.DataLoader(
eval_dataset,
batch_size=1,
shuffle=False,
num_workers=1,
collate_fn=collate_fn,
)
eval_dataloaders.append(("tapvid_davis", eval_dataloader_tapvid_davis))
evaluator = Evaluator(args.ckpt_path)
visualizer = Visualizer(
save_dir=args.ckpt_path,
pad_value=80,
fps=1,
show_first_frame=0,
tracks_leave_trace=0,
)
loss_fn = None
if args.model_name == "cotracker":
model = CoTracker(
model = CoTracker2(
stride=args.model_stride,
S=args.sliding_window_len,
window_len=args.sliding_window_len,
add_space_attn=not args.remove_space_attn,
num_heads=args.updateformer_num_heads,
hidden_size=args.updateformer_hidden_size,
space_depth=args.updateformer_space_depth,
time_depth=args.updateformer_time_depth,
num_virtual_tracks=args.num_virtual_tracks,
model_resolution=args.crop_size,
)
else:
raise ValueError(f"Model {args.model_name} doesn't exist")
@ -332,7 +303,7 @@ class Lite(LightningLite):
model.cuda()
train_dataset = kubric_movif_dataset.KubricMovifDataset(
data_root=os.path.join(args.dataset_root, "kubric_movi_f"),
data_root=os.path.join(args.dataset_root, "kubric", "kubric_movi_f_tracks"),
crop_size=args.crop_size,
seq_len=args.sequence_len,
traj_per_sample=args.traj_per_sample,
@ -357,7 +328,8 @@ class Lite(LightningLite):
optimizer, scheduler = fetch_optimizer(args, model)
total_steps = 0
logger = Logger(model, scheduler)
if self.global_rank == 0:
logger = Logger(model, scheduler)
folder_ckpts = [
f
@ -383,9 +355,7 @@ class Lite(LightningLite):
logging.info(f"Load total_steps {total_steps}")
elif args.restore_ckpt is not None:
assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith(
".pt"
)
assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith(".pt")
logging.info("Loading checkpoint...")
strict = True
@ -394,9 +364,7 @@ class Lite(LightningLite):
state_dict = state_dict["model"]
if list(state_dict.keys())[0].startswith("module."):
state_dict = {
k.replace("module.", ""): v for k, v in state_dict.items()
}
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=strict)
logging.info(f"Done loading checkpoint")
@ -424,33 +392,22 @@ class Lite(LightningLite):
assert model.training
output = forward_batch(
batch,
model,
args,
loss_fn=loss_fn,
writer=logger.writer,
step=total_steps,
)
output = forward_batch(batch, model, args)
loss = 0
for k, v in output.items():
if "loss" in v:
loss += v["loss"]
logger.writer.add_scalar(
f"live_{k}_loss", v["loss"].item(), total_steps
)
if "metrics" in v:
logger.push(v["metrics"], k)
if self.global_rank == 0:
if total_steps % save_freq == save_freq - 1:
if args.model_name == "motion_diffuser":
pred_coords = model.module.module.forward_batch_test(
batch, interp_shape=args.crop_size
for k, v in output.items():
if "loss" in v:
logger.writer.add_scalar(
f"live_{k}_loss", v["loss"].item(), total_steps
)
output["flow"] = {"predictions": pred_coords[0].detach()}
if "metrics" in v:
logger.push(v["metrics"], k)
if total_steps % save_freq == save_freq - 1:
visualizer.visualize(
video=batch.video.clone(),
tracks=batch.trajectory.clone(),
@ -468,9 +425,7 @@ class Lite(LightningLite):
)
if len(output) > 1:
logger.writer.add_scalar(
f"live_total_loss", loss.item(), total_steps
)
logger.writer.add_scalar(f"live_total_loss", loss.item(), total_steps)
logger.writer.add_scalar(
f"learning_rate", optimizer.param_groups[0]["lr"], total_steps
)
@ -492,9 +447,7 @@ class Lite(LightningLite):
total_steps == 1 and args.validate_at_start
):
if (epoch + 1) % args.save_every_n_epoch == 0:
ckpt_iter = "0" * (6 - len(str(total_steps))) + str(
total_steps
)
ckpt_iter = "0" * (6 - len(str(total_steps))) + str(total_steps)
save_path = Path(
f"{args.ckpt_path}/model_{args.model_name}_{ckpt_iter}.pth"
)
@ -526,16 +479,18 @@ class Lite(LightningLite):
if total_steps > args.num_steps:
should_keep_training = False
break
if self.global_rank == 0:
print("FINISHED TRAINING")
print("FINISHED TRAINING")
PATH = f"{args.ckpt_path}/{args.model_name}_final.pth"
torch.save(model.module.module.state_dict(), PATH)
run_test_eval(evaluator, model, eval_dataloaders, logger.writer, total_steps)
logger.close()
PATH = f"{args.ckpt_path}/{args.model_name}_final.pth"
torch.save(model.module.module.state_dict(), PATH)
run_test_eval(evaluator, model, eval_dataloaders, logger.writer, total_steps)
logger.close()
if __name__ == "__main__":
signal.signal(signal.SIGUSR1, sig_handler)
signal.signal(signal.SIGTERM, term_handler)
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="cotracker", help="model name")
parser.add_argument("--restore_ckpt", help="path to restore a checkpoint")
@ -543,17 +498,12 @@ if __name__ == "__main__":
parser.add_argument(
"--batch_size", type=int, default=4, help="batch size used during training."
)
parser.add_argument(
"--num_workers", type=int, default=6, help="number of dataloader workers"
)
parser.add_argument("--num_nodes", type=int, default=1)
parser.add_argument("--num_workers", type=int, default=10, help="number of dataloader workers")
parser.add_argument(
"--mixed_precision", action="store_true", help="use mixed precision"
)
parser.add_argument("--mixed_precision", action="store_true", help="use mixed precision")
parser.add_argument("--lr", type=float, default=0.0005, help="max learning rate.")
parser.add_argument(
"--wdecay", type=float, default=0.00001, help="Weight decay in optimizer."
)
parser.add_argument("--wdecay", type=float, default=0.00001, help="Weight decay in optimizer.")
parser.add_argument(
"--num_steps", type=int, default=200000, help="length of training schedule."
)
@ -596,13 +546,11 @@ if __name__ == "__main__":
default=4,
help="number of updates to the disparity field in each forward pass.",
)
parser.add_argument(
"--sequence_len", type=int, default=8, help="train sequence length"
)
parser.add_argument("--sequence_len", type=int, default=8, help="train sequence length")
parser.add_argument(
"--eval_datasets",
nargs="+",
default=["things", "badja"],
default=["tapvid_davis_first"],
help="what datasets to use for evaluation",
)
@ -611,6 +559,12 @@ if __name__ == "__main__":
action="store_true",
help="remove space attention from CoTracker",
)
parser.add_argument(
"--num_virtual_tracks",
type=int,
default=None,
help="stride of the CoTracker feature network",
)
parser.add_argument(
"--dont_use_augs",
action="store_true",
@ -627,30 +581,6 @@ if __name__ == "__main__":
default=8,
help="length of the CoTracker sliding window",
)
parser.add_argument(
"--updateformer_hidden_size",
type=int,
default=384,
help="hidden dimension of the CoTracker transformer model",
)
parser.add_argument(
"--updateformer_num_heads",
type=int,
default=8,
help="number of heads of the CoTracker transformer model",
)
parser.add_argument(
"--updateformer_space_depth",
type=int,
default=12,
help="number of group attention layers in the CoTracker transformer model",
)
parser.add_argument(
"--updateformer_time_depth",
type=int,
default=12,
help="number of time attention layers in the CoTracker transformer model",
)
parser.add_argument(
"--model_stride",
type=int,
@ -680,9 +610,9 @@ if __name__ == "__main__":
from pytorch_lightning.strategies import DDPStrategy
Lite(
strategy=DDPStrategy(find_unused_parameters=True),
strategy=DDPStrategy(find_unused_parameters=False),
devices="auto",
accelerator="gpu",
precision=32,
# num_nodes=4,
num_nodes=args.num_nodes,
).run(args)