release cotracker 2.0
This commit is contained in:
parent
3df96621ed
commit
f8fab323c4
257
README.md
257
README.md
@ -13,111 +13,218 @@
|
||||
<img alt="Spaces" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue">
|
||||
</a>
|
||||
|
||||
<img width="500" src="./assets/bmx-bumps.gif" />
|
||||
<img width="1100" src="./assets/teaser.png" />
|
||||
|
||||
**CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow.
|
||||
|
||||
|
||||
CoTracker can track:
|
||||
- **Every pixel** in a video
|
||||
- Points sampled on a regular grid on any video frame
|
||||
- Manually selected points
|
||||
|
||||
Try these tracking modes for yourself with our [Colab demo](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb) or in the [Hugging Face Space](https://huggingface.co/spaces/facebook/cotracker).
|
||||
- **Any pixel** in a video
|
||||
- A **quasi-dense** set of pixels together
|
||||
- Points can be manually selected or sampled on a grid in any video frame
|
||||
|
||||
Try these tracking modes for yourself with our [Colab demo](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb) or in the [Hugging Face Space 🤗](https://huggingface.co/spaces/facebook/cotracker).
|
||||
|
||||
**Updates:**
|
||||
|
||||
- [December 27, 2023] 📣 CoTracker2 is now available! It can now track many more (up to **265*265**!) points jointly and it has a cleaner and more memory-efficient implementation. It also supports online processing. See the [updated paper](https://arxiv.org/abs/2307.07635) for more details. The old version remains available [here](https://github.com/facebookresearch/co-tracker/tree/8d364031971f6b3efec945dd15c468a183e58212).
|
||||
|
||||
- [September 5, 2023] 📣 You can now run our Gradio demo [locally](./gradio_demo/app.py)!
|
||||
|
||||
## Quick start
|
||||
The easiest way to use CoTracker is to load a pretrained model from `torch.hub`:
|
||||
|
||||
### Offline mode:
|
||||
```pip install imageio[ffmpeg]```, then:
|
||||
```python
|
||||
import torch
|
||||
# Download the video
|
||||
url = 'https://github.com/facebookresearch/co-tracker/blob/main/assets/apple.mp4'
|
||||
|
||||
import imageio.v3 as iio
|
||||
frames = iio.imread(url, plugin="FFMPEG") # plugin="pyav"
|
||||
|
||||
device = 'cuda'
|
||||
grid_size = 10
|
||||
video = torch.tensor(frames).permute(0, 3, 1, 2)[None].float().to(device) # B T C H W
|
||||
|
||||
# Run Offline CoTracker:
|
||||
cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker2").to(device)
|
||||
pred_tracks, pred_visibility = cotracker(video, grid_size=grid_size) # B T N 2, B T N 1
|
||||
```
|
||||
### Online mode:
|
||||
```python
|
||||
cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online").to(device)
|
||||
|
||||
# Run Online CoTracker, the same model with a different API:
|
||||
# Initialize online processing
|
||||
cotracker(video_chunk=video, is_first_step=True, grid_size=grid_size)
|
||||
|
||||
# Process the video
|
||||
for ind in range(0, video.shape[1] - cotracker.step, cotracker.step):
|
||||
pred_tracks, pred_visibility = cotracker(
|
||||
video_chunk=video[:, ind : ind + cotracker.step * 2]
|
||||
) # B T N 2, B T N 1
|
||||
```
|
||||
Online processing is more memory-efficient and allows for the processing of longer videos. However, in the example provided above, the video length is known! See [the online demo](./online_demo.py) for an example of tracking from an online stream with an unknown video length.
|
||||
|
||||
### Visualize predicted tracks:
|
||||
```pip install matplotlib```, then:
|
||||
```python
|
||||
from cotracker.utils.visualizer import Visualizer
|
||||
|
||||
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
|
||||
vis.visualize(video, pred_tracks, pred_visibility)
|
||||
```
|
||||
|
||||
We offer a number of other ways to interact with CoTracker:
|
||||
|
||||
1. Interactive Gradio demo:
|
||||
- A demo is available in the [`facebook/cotracker` Hugging Face Space 🤗](https://huggingface.co/spaces/facebook/cotracker).
|
||||
- You can use the gradio demo locally by running [`python -m gradio_demo.app`](./gradio_demo/app.py) after installing the required packages: `pip install -r gradio_demo/requirements.txt`.
|
||||
2. Jupyter notebook:
|
||||
- You can run the notebook in
|
||||
[Google Colab](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb).
|
||||
- Or explore the notebook located at [`notebooks/demo.ipynb`](./notebooks/demo.ipynb).
|
||||
2. You can [install](#installation-instructions) CoTracker _locally_ and then:
|
||||
- Run an *offline* demo with 10 ⨉ 10 points sampled on a grid on the first frame of a video (results will be saved to `./saved_videos/demo.mp4`)):
|
||||
|
||||
```bash
|
||||
python demo.py --grid_size 10
|
||||
```
|
||||
- Run an *online* demo:
|
||||
|
||||
```bash
|
||||
python online_demo.py
|
||||
```
|
||||
|
||||
A GPU is strongly recommended for using CoTracker locally.
|
||||
|
||||
<img width="500" src="./assets/bmx-bumps.gif" />
|
||||
|
||||
### Update: September 5, 2023
|
||||
📣 You can now run our Gradio demo [locally](./gradio_demo/app.py)!
|
||||
|
||||
## Installation Instructions
|
||||
Ensure you have both PyTorch and TorchVision installed on your system. Follow the instructions [here](https://pytorch.org/get-started/locally/) for the installation. We strongly recommend installing both PyTorch and TorchVision with CUDA support.
|
||||
You can use a Pretrained Model via PyTorch Hub, as described above, or install CoTracker from this GitHub repo.
|
||||
This is the best way if you need to run our local demo or evaluate/train CoTracker.
|
||||
|
||||
### Pretrained models via PyTorch Hub
|
||||
The easiest way to use CoTracker is to load a pretrained model from torch.hub:
|
||||
```
|
||||
pip install einops timm tqdm
|
||||
```
|
||||
```
|
||||
import torch
|
||||
import timm
|
||||
import einops
|
||||
import tqdm
|
||||
Ensure you have both _PyTorch_ and _TorchVision_ installed on your system. Follow the instructions [here](https://pytorch.org/get-started/locally/) for the installation.
|
||||
We strongly recommend installing both PyTorch and TorchVision with CUDA support, although for small tasks CoTracker can be run on CPU.
|
||||
|
||||
cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker_w8")
|
||||
```
|
||||
Another option is to install it from this gihub repo. That's the best way if you need to run our demo or evaluate / train CoTracker:
|
||||
### Steps to Install CoTracker and its dependencies:
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Install a Development Version
|
||||
|
||||
```bash
|
||||
git clone https://github.com/facebookresearch/co-tracker
|
||||
cd co-tracker
|
||||
pip install -e .
|
||||
pip install opencv-python einops timm matplotlib moviepy flow_vis
|
||||
pip install matplotlib flow_vis tqdm tensorboard
|
||||
```
|
||||
|
||||
You can manually download the CoTracker2 checkpoint from the links below and place it in the `checkpoints` folder as follows:
|
||||
|
||||
### Download Model Weights:
|
||||
```
|
||||
mkdir checkpoints
|
||||
```bash
|
||||
mkdir -p checkpoints
|
||||
cd checkpoints
|
||||
wget https://huggingface.co/facebook/cotracker/resolve/main/cotracker2.pth
|
||||
cd ..
|
||||
```
|
||||
For old checkpoints, see [this section](#previous-version).
|
||||
|
||||
## Evaluation
|
||||
|
||||
To reproduce the results presented in the paper, download the following datasets:
|
||||
|
||||
- [TAP-Vid](https://github.com/deepmind/tapnet)
|
||||
- [Dynamic Replica](https://dynamic-stereo.github.io/)
|
||||
|
||||
And install the necessary dependencies:
|
||||
|
||||
```bash
|
||||
pip install hydra-core==1.1.0 mediapy
|
||||
```
|
||||
|
||||
Then, execute the following command to evaluate on TAP-Vid DAVIS:
|
||||
|
||||
```bash
|
||||
python ./cotracker/evaluation/evaluate.py --config-name eval_tapvid_davis_first exp_dir=./eval_outputs dataset_root=your/tapvid/path
|
||||
```
|
||||
|
||||
By default, evaluation will be slow since it is done for one target point at a time, which ensures robustness and fairness, as described in the paper.
|
||||
|
||||
We have fixed some bugs and retrained the model after updating the paper. These are the numbers that you should be able to reproduce using the released checkpoint and the current version of the codebase:
|
||||
| | DAVIS First, AJ | DAVIS First, $\delta_\text{avg}^\text{vis}$ | DAVIS First, OA | DAVIS Strided, AJ | DAVIS Strided, $\delta_\text{avg}^\text{vis}$ | DAVIS Strided, OA | DR, $\delta_\text{avg}$| DR, $\delta_\text{avg}^\text{vis}$| DR, $\delta_\text{avg}^\text{occ}$|
|
||||
| :---: |:---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|
||||
| CoTracker2, 27.12.23 | 60.9 | 75.4 | 88.4 | 65.1 | 79.0 | 89.4 | 61.4 | 68.4 | 38.2
|
||||
|
||||
|
||||
## Training
|
||||
|
||||
To train the CoTracker as described in our paper, you first need to generate annotations for [Google Kubric](https://github.com/google-research/kubric) MOVI-f dataset.
|
||||
Instructions for annotation generation can be found [here](https://github.com/deepmind/tapnet).
|
||||
You can also find a discussion on dataset generation in [this issue](https://github.com/facebookresearch/co-tracker/issues/8).
|
||||
|
||||
Once you have the annotated dataset, you need to make sure you followed the steps for evaluation setup and install the training dependencies:
|
||||
|
||||
```bash
|
||||
pip install pytorch_lightning==1.6.0 tensorboard
|
||||
```
|
||||
|
||||
Now you can launch training on Kubric.
|
||||
Our model was trained for 50000 iterations on 32 GPUs (4 nodes with 8 GPUs).
|
||||
Modify _dataset_root_ and _ckpt_path_ accordingly before running this command. For training on 4 nodes, add `--num_nodes 4`.
|
||||
|
||||
```bash
|
||||
python train.py --batch_size 1 \
|
||||
--num_steps 50000 --ckpt_path ./ --dataset_root ./datasets --model_name cotracker \
|
||||
--save_freq 200 --sequence_len 24 --eval_datasets dynamic_replica tapvid_davis_first \
|
||||
--traj_per_sample 768 --sliding_window_len 8 \
|
||||
--num_virtual_tracks 64 --model_stride 4
|
||||
```
|
||||
|
||||
|
||||
## Development
|
||||
|
||||
### Building the documentation
|
||||
|
||||
To build CoTracker documentation, first install the dependencies:
|
||||
|
||||
```bash
|
||||
pip install sphinx
|
||||
pip install sphinxcontrib-bibtex
|
||||
```
|
||||
|
||||
Then you can use this command to generate the documentation in the `docs/_build/html` folder:
|
||||
|
||||
```bash
|
||||
make -C docs html
|
||||
```
|
||||
|
||||
|
||||
## Previous version
|
||||
The old version of the code is available [here](https://github.com/facebookresearch/co-tracker/tree/8d364031971f6b3efec945dd15c468a183e58212).
|
||||
You can also download the corresponding checkpoints:
|
||||
```bash
|
||||
wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth
|
||||
wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_12.pth
|
||||
wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_8_wind_16.pth
|
||||
cd ..
|
||||
```
|
||||
|
||||
|
||||
## Usage:
|
||||
We offer a number of ways to interact with CoTracker:
|
||||
1. A demo is available in the [`facebook/cotracker` Hugging Face Space](https://huggingface.co/spaces/facebook/cotracker).
|
||||
2. You can run the extended demo in Colab:
|
||||
[Colab notebook](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb)
|
||||
3. You can use the gradio demo locally by running [`python -m gradio_demo.app`](./gradio_demo/app.py) after installing the required packages: ```pip install -r gradio_demo/requirements.txt```.
|
||||
4. You can play with CoTracker by running the Jupyter notebook located at [`notebooks/demo.ipynb`](./notebooks/demo.ipynb) locally (if you have a GPU).
|
||||
5. Finally, you can run a local demo with 10*10 points sampled on a grid on the first frame of a video:
|
||||
```
|
||||
python demo.py --grid_size 10
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
To reproduce the results presented in the paper, download the following datasets:
|
||||
- [TAP-Vid](https://github.com/deepmind/tapnet)
|
||||
- [BADJA](https://github.com/benjiebob/BADJA)
|
||||
- [ZJU-Mocap (FastCapture)](https://arxiv.org/abs/2303.11898)
|
||||
|
||||
And install the necessary dependencies:
|
||||
```
|
||||
pip install hydra-core==1.1.0 mediapy
|
||||
```
|
||||
Then, execute the following command to evaluate on BADJA:
|
||||
```
|
||||
python ./cotracker/evaluation/evaluate.py --config-name eval_badja exp_dir=./eval_outputs dataset_root=your/badja/path
|
||||
```
|
||||
By default, evaluation will be slow since it is done for one target point at a time, which ensures robustness and fairness, as described in the paper.
|
||||
|
||||
## Training
|
||||
To train the CoTracker as described in our paper, you first need to generate annotations for [Google Kubric](https://github.com/google-research/kubric) MOVI-f dataset. Instructions for annotation generation can be found [here](https://github.com/deepmind/tapnet).
|
||||
|
||||
Once you have the annotated dataset, you need to make sure you followed the steps for evaluation setup and install the training dependencies:
|
||||
```
|
||||
pip install pytorch_lightning==1.6.0 tensorboard
|
||||
```
|
||||
Now you can launch training on Kubric. Our model was trained for 50000 iterations on 32 GPUs (4 nodes with 8 GPUs).
|
||||
Modify *dataset_root* and *ckpt_path* accordingly before running this command:
|
||||
```
|
||||
python train.py --batch_size 1 --num_workers 28 \
|
||||
--num_steps 50000 --ckpt_path ./ --dataset_root ./datasets --model_name cotracker \
|
||||
--save_freq 200 --sequence_len 24 --eval_datasets tapvid_davis_first badja \
|
||||
--traj_per_sample 256 --sliding_window_len 8 --updateformer_space_depth 6 --updateformer_time_depth 6 \
|
||||
--save_every_n_epoch 10 --evaluate_every_n_epoch 10 --model_stride 4
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
The majority of CoTracker is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Particle Video Revisited is licensed under the MIT license, TAP-Vid is licensed under the Apache 2.0 license.
|
||||
|
||||
## Acknowledgments
|
||||
|
||||
We would like to thank [PIPs](https://github.com/aharley/pips) and [TAP-Vid](https://github.com/deepmind/tapnet) for publicly releasing their code and data. We also want to thank [Luke Melas-Kyriazi](https://lukemelas.github.io/) for proofreading the paper, [Jianyuan Wang](https://jytime.github.io/), [Roman Shapovalov](https://shapovalov.ro/) and [Adam W. Harley](https://adamharley.com/) for the insightful discussions.
|
||||
|
||||
## Citing CoTracker
|
||||
|
||||
If you find our repository useful, please consider giving it a star ⭐ and citing our paper in your work:
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{karaev2023cotracker,
|
||||
title={CoTracker: It is Better to Track Together},
|
||||
author={Nikita Karaev and Ignacio Rocco and Benjamin Graham and Natalia Neverova and Andrea Vedaldi and Christian Rupprecht},
|
||||
|
BIN
assets/apple.mp4
BIN
assets/apple.mp4
Binary file not shown.
BIN
assets/teaser.png
Normal file
BIN
assets/teaser.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 7.0 MiB |
@ -1,390 +0,0 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
import json
|
||||
import imageio
|
||||
import cv2
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from cotracker.datasets.utils import CoTrackerData, resize_sample
|
||||
|
||||
IGNORE_ANIMALS = [
|
||||
# "bear.json",
|
||||
# "camel.json",
|
||||
"cat_jump.json"
|
||||
# "cows.json",
|
||||
# "dog.json",
|
||||
# "dog-agility.json",
|
||||
# "horsejump-high.json",
|
||||
# "horsejump-low.json",
|
||||
# "impala0.json",
|
||||
# "rs_dog.json"
|
||||
"tiger.json"
|
||||
]
|
||||
|
||||
|
||||
class SMALJointCatalog(Enum):
|
||||
# body_0 = 0
|
||||
# body_1 = 1
|
||||
# body_2 = 2
|
||||
# body_3 = 3
|
||||
# body_4 = 4
|
||||
# body_5 = 5
|
||||
# body_6 = 6
|
||||
# upper_right_0 = 7
|
||||
upper_right_1 = 8
|
||||
upper_right_2 = 9
|
||||
upper_right_3 = 10
|
||||
# upper_left_0 = 11
|
||||
upper_left_1 = 12
|
||||
upper_left_2 = 13
|
||||
upper_left_3 = 14
|
||||
neck_lower = 15
|
||||
# neck_upper = 16
|
||||
# lower_right_0 = 17
|
||||
lower_right_1 = 18
|
||||
lower_right_2 = 19
|
||||
lower_right_3 = 20
|
||||
# lower_left_0 = 21
|
||||
lower_left_1 = 22
|
||||
lower_left_2 = 23
|
||||
lower_left_3 = 24
|
||||
tail_0 = 25
|
||||
# tail_1 = 26
|
||||
# tail_2 = 27
|
||||
tail_3 = 28
|
||||
# tail_4 = 29
|
||||
# tail_5 = 30
|
||||
tail_6 = 31
|
||||
jaw = 32
|
||||
nose = 33 # ADDED JOINT FOR VERTEX 1863
|
||||
# chin = 34 # ADDED JOINT FOR VERTEX 26
|
||||
right_ear = 35 # ADDED JOINT FOR VERTEX 149
|
||||
left_ear = 36 # ADDED JOINT FOR VERTEX 2124
|
||||
|
||||
|
||||
class SMALJointInfo:
|
||||
def __init__(self):
|
||||
# These are the
|
||||
self.annotated_classes = np.array(
|
||||
[
|
||||
8,
|
||||
9,
|
||||
10, # upper_right
|
||||
12,
|
||||
13,
|
||||
14, # upper_left
|
||||
15, # neck
|
||||
18,
|
||||
19,
|
||||
20, # lower_right
|
||||
22,
|
||||
23,
|
||||
24, # lower_left
|
||||
25,
|
||||
28,
|
||||
31, # tail
|
||||
32,
|
||||
33, # head
|
||||
35, # right_ear
|
||||
36,
|
||||
]
|
||||
) # left_ear
|
||||
|
||||
self.annotated_markers = np.array(
|
||||
[
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_STAR,
|
||||
cv2.MARKER_TRIANGLE_DOWN,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_STAR,
|
||||
cv2.MARKER_TRIANGLE_DOWN,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_STAR,
|
||||
cv2.MARKER_TRIANGLE_DOWN,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_STAR,
|
||||
cv2.MARKER_TRIANGLE_DOWN,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_STAR,
|
||||
cv2.MARKER_TRIANGLE_DOWN,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_STAR,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_CROSS,
|
||||
]
|
||||
)
|
||||
|
||||
self.joint_regions = np.array(
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
3,
|
||||
3,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
5,
|
||||
5,
|
||||
5,
|
||||
5,
|
||||
6,
|
||||
6,
|
||||
6,
|
||||
6,
|
||||
6,
|
||||
6,
|
||||
6,
|
||||
7,
|
||||
7,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
]
|
||||
)
|
||||
|
||||
self.annotated_joint_region = self.joint_regions[self.annotated_classes]
|
||||
self.region_colors = np.array(
|
||||
[
|
||||
[250, 190, 190], # body, light pink
|
||||
[60, 180, 75], # upper_right, green
|
||||
[230, 25, 75], # upper_left, red
|
||||
[128, 0, 0], # neck, maroon
|
||||
[0, 130, 200], # lower_right, blue
|
||||
[255, 255, 25], # lower_left, yellow
|
||||
[240, 50, 230], # tail, majenta
|
||||
[245, 130, 48], # jaw / nose / chin, orange
|
||||
[29, 98, 115], # right_ear, turquoise
|
||||
[255, 153, 204],
|
||||
]
|
||||
) # left_ear, pink
|
||||
|
||||
self.joint_colors = np.array(self.region_colors)[self.annotated_joint_region]
|
||||
|
||||
|
||||
class BADJAData:
|
||||
def __init__(self, data_root, complete=False):
|
||||
annotations_path = os.path.join(data_root, "joint_annotations")
|
||||
|
||||
self.animal_dict = {}
|
||||
self.animal_count = 0
|
||||
self.smal_joint_info = SMALJointInfo()
|
||||
for __, animal_json in enumerate(sorted(os.listdir(annotations_path))):
|
||||
if animal_json not in IGNORE_ANIMALS:
|
||||
json_path = os.path.join(annotations_path, animal_json)
|
||||
with open(json_path) as json_data:
|
||||
animal_joint_data = json.load(json_data)
|
||||
|
||||
filenames = []
|
||||
segnames = []
|
||||
joints = []
|
||||
visible = []
|
||||
|
||||
first_path = animal_joint_data[0]["segmentation_path"]
|
||||
last_path = animal_joint_data[-1]["segmentation_path"]
|
||||
first_frame = first_path.split("/")[-1]
|
||||
last_frame = last_path.split("/")[-1]
|
||||
|
||||
if not "extra_videos" in first_path:
|
||||
animal = first_path.split("/")[-2]
|
||||
|
||||
first_frame_int = int(first_frame.split(".")[0])
|
||||
last_frame_int = int(last_frame.split(".")[0])
|
||||
|
||||
for fr in range(first_frame_int, last_frame_int + 1):
|
||||
ref_file_name = os.path.join(
|
||||
data_root,
|
||||
"DAVIS/JPEGImages/Full-Resolution/%s/%05d.jpg"
|
||||
% (animal, fr),
|
||||
)
|
||||
ref_seg_name = os.path.join(
|
||||
data_root,
|
||||
"DAVIS/Annotations/Full-Resolution/%s/%05d.png"
|
||||
% (animal, fr),
|
||||
)
|
||||
|
||||
foundit = False
|
||||
for ind, image_annotation in enumerate(animal_joint_data):
|
||||
file_name = os.path.join(
|
||||
data_root, image_annotation["image_path"]
|
||||
)
|
||||
seg_name = os.path.join(
|
||||
data_root, image_annotation["segmentation_path"]
|
||||
)
|
||||
|
||||
if file_name == ref_file_name:
|
||||
foundit = True
|
||||
label_ind = ind
|
||||
|
||||
if foundit:
|
||||
image_annotation = animal_joint_data[label_ind]
|
||||
file_name = os.path.join(
|
||||
data_root, image_annotation["image_path"]
|
||||
)
|
||||
seg_name = os.path.join(
|
||||
data_root, image_annotation["segmentation_path"]
|
||||
)
|
||||
joint = np.array(image_annotation["joints"])
|
||||
vis = np.array(image_annotation["visibility"])
|
||||
else:
|
||||
file_name = ref_file_name
|
||||
seg_name = ref_seg_name
|
||||
joint = None
|
||||
vis = None
|
||||
|
||||
filenames.append(file_name)
|
||||
segnames.append(seg_name)
|
||||
joints.append(joint)
|
||||
visible.append(vis)
|
||||
|
||||
if len(filenames):
|
||||
self.animal_dict[self.animal_count] = (
|
||||
filenames,
|
||||
segnames,
|
||||
joints,
|
||||
visible,
|
||||
)
|
||||
self.animal_count += 1
|
||||
print("Loaded BADJA dataset")
|
||||
|
||||
def get_loader(self):
|
||||
for __ in range(int(1e6)):
|
||||
animal_id = np.random.choice(len(self.animal_dict.keys()))
|
||||
filenames, segnames, joints, visible = self.animal_dict[animal_id]
|
||||
|
||||
image_id = np.random.randint(0, len(filenames))
|
||||
|
||||
seg_file = segnames[image_id]
|
||||
image_file = filenames[image_id]
|
||||
|
||||
joints = joints[image_id].copy()
|
||||
joints = joints[self.smal_joint_info.annotated_classes]
|
||||
visible = visible[image_id][self.smal_joint_info.annotated_classes]
|
||||
|
||||
rgb_img = imageio.imread(image_file) # , mode='RGB')
|
||||
sil_img = imageio.imread(seg_file) # , mode='RGB')
|
||||
|
||||
rgb_h, rgb_w, _ = rgb_img.shape
|
||||
sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST)
|
||||
|
||||
yield rgb_img, sil_img, joints, visible, image_file
|
||||
|
||||
def get_video(self, animal_id):
|
||||
filenames, segnames, joint, visible = self.animal_dict[animal_id]
|
||||
|
||||
rgbs = []
|
||||
segs = []
|
||||
joints = []
|
||||
visibles = []
|
||||
|
||||
for s in range(len(filenames)):
|
||||
image_file = filenames[s]
|
||||
rgb_img = imageio.imread(image_file) # , mode='RGB')
|
||||
rgb_h, rgb_w, _ = rgb_img.shape
|
||||
|
||||
seg_file = segnames[s]
|
||||
sil_img = imageio.imread(seg_file) # , mode='RGB')
|
||||
sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST)
|
||||
|
||||
jo = joint[s]
|
||||
|
||||
if jo is not None:
|
||||
joi = joint[s].copy()
|
||||
joi = joi[self.smal_joint_info.annotated_classes]
|
||||
vis = visible[s][self.smal_joint_info.annotated_classes]
|
||||
else:
|
||||
joi = None
|
||||
vis = None
|
||||
|
||||
rgbs.append(rgb_img)
|
||||
segs.append(sil_img)
|
||||
joints.append(joi)
|
||||
visibles.append(vis)
|
||||
|
||||
return rgbs, segs, joints, visibles, filenames[0]
|
||||
|
||||
|
||||
class BadjaDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self, data_root, max_seq_len=1000, dataset_resolution=(384, 512)
|
||||
):
|
||||
|
||||
self.data_root = data_root
|
||||
self.badja_data = BADJAData(data_root)
|
||||
self.max_seq_len = max_seq_len
|
||||
self.dataset_resolution = dataset_resolution
|
||||
print(
|
||||
"found %d unique videos in %s"
|
||||
% (self.badja_data.animal_count, self.data_root)
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
rgbs, segs, joints, visibles, filename = self.badja_data.get_video(index)
|
||||
S = len(rgbs)
|
||||
H, W, __ = rgbs[0].shape
|
||||
H, W, __ = segs[0].shape
|
||||
|
||||
N, __ = joints[0].shape
|
||||
|
||||
# let's eliminate the Nones
|
||||
# note the first one is guaranteed present
|
||||
for s in range(1, S):
|
||||
if joints[s] is None:
|
||||
joints[s] = np.zeros_like(joints[0])
|
||||
visibles[s] = np.zeros_like(visibles[0])
|
||||
|
||||
# eliminate the mystery dim
|
||||
segs = [seg[:, :, 0] for seg in segs]
|
||||
|
||||
rgbs = np.stack(rgbs, 0)
|
||||
segs = np.stack(segs, 0)
|
||||
trajs = np.stack(joints, 0)
|
||||
visibles = np.stack(visibles, 0)
|
||||
|
||||
rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0, 3, 1, 2).float()
|
||||
segs = torch.from_numpy(segs).reshape(S, 1, H, W).float()
|
||||
trajs = torch.from_numpy(trajs).reshape(S, N, 2).float()
|
||||
visibles = torch.from_numpy(visibles).reshape(S, N)
|
||||
|
||||
rgbs = rgbs[: self.max_seq_len]
|
||||
segs = segs[: self.max_seq_len]
|
||||
trajs = trajs[: self.max_seq_len]
|
||||
visibles = visibles[: self.max_seq_len]
|
||||
# apparently the coords are in yx order
|
||||
trajs = torch.flip(trajs, [2])
|
||||
|
||||
if "extra_videos" in filename:
|
||||
seq_name = filename.split("/")[-3]
|
||||
else:
|
||||
seq_name = filename.split("/")[-2]
|
||||
|
||||
rgbs, trajs, segs = resize_sample(rgbs, trajs, segs, self.dataset_resolution)
|
||||
|
||||
return CoTrackerData(rgbs, segs, trajs, visibles, seq_name=seq_name)
|
||||
|
||||
def __len__(self):
|
||||
return self.badja_data.animal_count
|
166
cotracker/datasets/dataclass_utils.py
Normal file
166
cotracker/datasets/dataclass_utils.py
Normal file
@ -0,0 +1,166 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import json
|
||||
import dataclasses
|
||||
import numpy as np
|
||||
from dataclasses import Field, MISSING
|
||||
from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple
|
||||
|
||||
_X = TypeVar("_X")
|
||||
|
||||
|
||||
def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:
|
||||
"""
|
||||
Loads to a @dataclass or collection hierarchy including dataclasses
|
||||
from a json recursively.
|
||||
Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).
|
||||
raises KeyError if json has keys not mapping to the dataclass fields.
|
||||
|
||||
Args:
|
||||
f: Either a path to a file, or a file opened for writing.
|
||||
cls: The class of the loaded dataclass.
|
||||
binary: Set to True if `f` is a file handle, else False.
|
||||
"""
|
||||
if binary:
|
||||
asdict = json.loads(f.read().decode("utf8"))
|
||||
else:
|
||||
asdict = json.load(f)
|
||||
|
||||
# in the list case, run a faster "vectorized" version
|
||||
cls = get_args(cls)[0]
|
||||
res = list(_dataclass_list_from_dict_list(asdict, cls))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
|
||||
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
|
||||
if get_origin(type_) is Union:
|
||||
args = get_args(type_)
|
||||
if len(args) == 2 and args[1] == type(None): # noqa E721
|
||||
return True, args[0]
|
||||
if type_ is Any:
|
||||
return True, Any
|
||||
|
||||
return False, type_
|
||||
|
||||
|
||||
def _unwrap_type(tp):
|
||||
# strips Optional wrapper, if any
|
||||
if get_origin(tp) is Union:
|
||||
args = get_args(tp)
|
||||
if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721
|
||||
# this is typing.Optional
|
||||
return args[0] if args[1] is type(None) else args[1] # noqa: E721
|
||||
return tp
|
||||
|
||||
|
||||
def _get_dataclass_field_default(field: Field) -> Any:
|
||||
if field.default_factory is not MISSING:
|
||||
# pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,
|
||||
# dataclasses._DefaultFactory[typing.Any]]` is not a function.
|
||||
return field.default_factory()
|
||||
elif field.default is not MISSING:
|
||||
return field.default
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _dataclass_list_from_dict_list(dlist, typeannot):
|
||||
"""
|
||||
Vectorised version of `_dataclass_from_dict`.
|
||||
The output should be equivalent to
|
||||
`[_dataclass_from_dict(d, typeannot) for d in dlist]`.
|
||||
|
||||
Args:
|
||||
dlist: list of objects to convert.
|
||||
typeannot: type of each of those objects.
|
||||
Returns:
|
||||
iterator or list over converted objects of the same length as `dlist`.
|
||||
|
||||
Raises:
|
||||
ValueError: it assumes the objects have None's in consistent places across
|
||||
objects, otherwise it would ignore some values. This generally holds for
|
||||
auto-generated annotations, but otherwise use `_dataclass_from_dict`.
|
||||
"""
|
||||
|
||||
cls = get_origin(typeannot) or typeannot
|
||||
|
||||
if typeannot is Any:
|
||||
return dlist
|
||||
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
|
||||
return dlist
|
||||
if any(obj is None for obj in dlist):
|
||||
# filter out Nones and recurse on the resulting list
|
||||
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
|
||||
idx, notnone = zip(*idx_notnone)
|
||||
converted = _dataclass_list_from_dict_list(notnone, typeannot)
|
||||
res = [None] * len(dlist)
|
||||
for i, obj in zip(idx, converted):
|
||||
res[i] = obj
|
||||
return res
|
||||
|
||||
is_optional, contained_type = _resolve_optional(typeannot)
|
||||
if is_optional:
|
||||
return _dataclass_list_from_dict_list(dlist, contained_type)
|
||||
|
||||
# otherwise, we dispatch by the type of the provided annotation to convert to
|
||||
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
||||
# For namedtuple, call the function recursively on the lists of corresponding keys
|
||||
types = cls.__annotations__.values()
|
||||
dlist_T = zip(*dlist)
|
||||
res_T = [
|
||||
_dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types)
|
||||
]
|
||||
return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]
|
||||
elif issubclass(cls, (list, tuple)):
|
||||
# For list/tuple, call the function recursively on the lists of corresponding positions
|
||||
types = get_args(typeannot)
|
||||
if len(types) == 1: # probably List; replicate for all items
|
||||
types = types * len(dlist[0])
|
||||
dlist_T = zip(*dlist)
|
||||
res_T = (
|
||||
_dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types)
|
||||
)
|
||||
if issubclass(cls, tuple):
|
||||
return list(zip(*res_T))
|
||||
else:
|
||||
return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)]
|
||||
elif issubclass(cls, dict):
|
||||
# For the dictionary, call the function recursively on concatenated keys and vertices
|
||||
key_t, val_t = get_args(typeannot)
|
||||
all_keys_res = _dataclass_list_from_dict_list(
|
||||
[k for obj in dlist for k in obj.keys()], key_t
|
||||
)
|
||||
all_vals_res = _dataclass_list_from_dict_list(
|
||||
[k for obj in dlist for k in obj.values()], val_t
|
||||
)
|
||||
indices = np.cumsum([len(obj) for obj in dlist])
|
||||
assert indices[-1] == len(all_keys_res)
|
||||
|
||||
keys = np.split(list(all_keys_res), indices[:-1])
|
||||
all_vals_res_iter = iter(all_vals_res)
|
||||
return [cls(zip(k, all_vals_res_iter)) for k in keys]
|
||||
elif not dataclasses.is_dataclass(typeannot):
|
||||
return dlist
|
||||
|
||||
# dataclass node: 2nd recursion base; call the function recursively on the lists
|
||||
# of the corresponding fields
|
||||
assert dataclasses.is_dataclass(cls)
|
||||
fieldtypes = {
|
||||
f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f))
|
||||
for f in dataclasses.fields(typeannot)
|
||||
}
|
||||
|
||||
# NOTE the default object is shared here
|
||||
key_lists = (
|
||||
_dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_)
|
||||
for k, (type_, default) in fieldtypes.items()
|
||||
)
|
||||
transposed = zip(*key_lists)
|
||||
return [cls(*vals_as_tuple) for vals_as_tuple in transposed]
|
161
cotracker/datasets/dr_dataset.py
Normal file
161
cotracker/datasets/dr_dataset.py
Normal file
@ -0,0 +1,161 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import os
|
||||
import gzip
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.utils.data as data
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Any, Dict, Tuple
|
||||
|
||||
from cotracker.datasets.utils import CoTrackerData
|
||||
from cotracker.datasets.dataclass_utils import load_dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageAnnotation:
|
||||
# path to jpg file, relative w.r.t. dataset_root
|
||||
path: str
|
||||
# H x W
|
||||
size: Tuple[int, int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DynamicReplicaFrameAnnotation:
|
||||
"""A dataclass used to load annotations from json."""
|
||||
|
||||
# can be used to join with `SequenceAnnotation`
|
||||
sequence_name: str
|
||||
# 0-based, continuous frame number within sequence
|
||||
frame_number: int
|
||||
# timestamp in seconds from the video start
|
||||
frame_timestamp: float
|
||||
|
||||
image: ImageAnnotation
|
||||
meta: Optional[Dict[str, Any]] = None
|
||||
|
||||
camera_name: Optional[str] = None
|
||||
trajectories: Optional[str] = None
|
||||
|
||||
|
||||
class DynamicReplicaDataset(data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
split="valid",
|
||||
traj_per_sample=256,
|
||||
crop_size=None,
|
||||
sample_len=-1,
|
||||
only_first_n_samples=-1,
|
||||
rgbd_input=False,
|
||||
):
|
||||
super(DynamicReplicaDataset, self).__init__()
|
||||
self.root = root
|
||||
self.sample_len = sample_len
|
||||
self.split = split
|
||||
self.traj_per_sample = traj_per_sample
|
||||
self.rgbd_input = rgbd_input
|
||||
self.crop_size = crop_size
|
||||
frame_annotations_file = f"frame_annotations_{split}.jgz"
|
||||
self.sample_list = []
|
||||
with gzip.open(
|
||||
os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8"
|
||||
) as zipfile:
|
||||
frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation])
|
||||
seq_annot = defaultdict(list)
|
||||
for frame_annot in frame_annots_list:
|
||||
if frame_annot.camera_name == "left":
|
||||
seq_annot[frame_annot.sequence_name].append(frame_annot)
|
||||
|
||||
for seq_name in seq_annot.keys():
|
||||
seq_len = len(seq_annot[seq_name])
|
||||
|
||||
step = self.sample_len if self.sample_len > 0 else seq_len
|
||||
counter = 0
|
||||
|
||||
for ref_idx in range(0, seq_len, step):
|
||||
sample = seq_annot[seq_name][ref_idx : ref_idx + step]
|
||||
self.sample_list.append(sample)
|
||||
counter += 1
|
||||
if only_first_n_samples > 0 and counter >= only_first_n_samples:
|
||||
break
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sample_list)
|
||||
|
||||
def crop(self, rgbs, trajs):
|
||||
T, N, _ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
H_new = H
|
||||
W_new = W
|
||||
|
||||
# simple random crop
|
||||
y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2
|
||||
x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2
|
||||
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
|
||||
|
||||
trajs[:, :, 0] -= x0
|
||||
trajs[:, :, 1] -= y0
|
||||
|
||||
return rgbs, trajs
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.sample_list[index]
|
||||
T = len(sample)
|
||||
rgbs, visibilities, traj_2d = [], [], []
|
||||
|
||||
H, W = sample[0].image.size
|
||||
image_size = (H, W)
|
||||
|
||||
for i in range(T):
|
||||
traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"])
|
||||
traj = torch.load(traj_path)
|
||||
|
||||
visibilities.append(traj["verts_inds_vis"].numpy())
|
||||
|
||||
rgbs.append(traj["img"].numpy())
|
||||
traj_2d.append(traj["traj_2d"].numpy()[..., :2])
|
||||
|
||||
traj_2d = np.stack(traj_2d)
|
||||
visibility = np.stack(visibilities)
|
||||
T, N, D = traj_2d.shape
|
||||
# subsample trajectories for augmentations
|
||||
visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample]
|
||||
|
||||
traj_2d = traj_2d[:, visible_inds_sampled]
|
||||
visibility = visibility[:, visible_inds_sampled]
|
||||
|
||||
if self.crop_size is not None:
|
||||
rgbs, traj_2d = self.crop(rgbs, traj_2d)
|
||||
H, W, _ = rgbs[0].shape
|
||||
image_size = self.crop_size
|
||||
|
||||
visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False
|
||||
visibility[traj_2d[:, :, 0] < 0] = False
|
||||
visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False
|
||||
visibility[traj_2d[:, :, 1] < 0] = False
|
||||
|
||||
# filter out points that're visible for less than 10 frames
|
||||
visible_inds_resampled = visibility.sum(0) > 10
|
||||
traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled])
|
||||
visibility = torch.from_numpy(visibility[:, visible_inds_resampled])
|
||||
|
||||
rgbs = np.stack(rgbs, 0)
|
||||
video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float()
|
||||
return CoTrackerData(
|
||||
video=video,
|
||||
trajectory=traj_2d,
|
||||
visibility=visibility,
|
||||
valid=torch.ones(T, N),
|
||||
seq_name=sample[0].sequence_name,
|
||||
)
|
@ -1,72 +0,0 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import torch
|
||||
|
||||
# from PIL import Image
|
||||
import imageio
|
||||
import numpy as np
|
||||
from cotracker.datasets.utils import CoTrackerData, resize_sample
|
||||
|
||||
|
||||
class FastCaptureDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
max_seq_len=50,
|
||||
max_num_points=20,
|
||||
dataset_resolution=(384, 512),
|
||||
):
|
||||
|
||||
self.data_root = data_root
|
||||
self.seq_names = os.listdir(os.path.join(data_root, "renders_local_rm"))
|
||||
self.pth_dir = os.path.join(data_root, "zju_tracking")
|
||||
self.max_seq_len = max_seq_len
|
||||
self.max_num_points = max_num_points
|
||||
self.dataset_resolution = dataset_resolution
|
||||
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
|
||||
|
||||
def __getitem__(self, index):
|
||||
seq_name = self.seq_names[index]
|
||||
spath = os.path.join(self.data_root, "renders_local_rm", seq_name)
|
||||
pthpath = os.path.join(self.pth_dir, seq_name + ".pth")
|
||||
|
||||
rgbs = []
|
||||
img_paths = sorted(os.listdir(spath))
|
||||
for i, img_path in enumerate(img_paths):
|
||||
if i < self.max_seq_len:
|
||||
rgbs.append(imageio.imread(os.path.join(spath, img_path)))
|
||||
|
||||
annot_dict = torch.load(pthpath)
|
||||
traj_2d = annot_dict["traj_2d"][:, :, : self.max_seq_len]
|
||||
visibility = annot_dict["visibility"][:, : self.max_seq_len]
|
||||
|
||||
S = len(rgbs)
|
||||
H, W, __ = rgbs[0].shape
|
||||
*_, S = traj_2d.shape
|
||||
visibile_pts_first_frame_inds = (visibility[:, 0] > 0).nonzero(as_tuple=False)[
|
||||
:, 0
|
||||
]
|
||||
torch.manual_seed(0)
|
||||
point_inds = torch.randperm(len(visibile_pts_first_frame_inds))[
|
||||
: self.max_num_points
|
||||
]
|
||||
visible_inds_sampled = visibile_pts_first_frame_inds[point_inds]
|
||||
|
||||
rgbs = np.stack(rgbs, 0)
|
||||
rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0, 3, 1, 2).float()
|
||||
|
||||
segs = torch.ones(S, 1, H, W).float()
|
||||
trajs = traj_2d[visible_inds_sampled].permute(2, 0, 1).float()
|
||||
visibles = visibility[visible_inds_sampled].permute(1, 0)
|
||||
|
||||
rgbs, trajs, segs = resize_sample(rgbs, trajs, segs, self.dataset_resolution)
|
||||
|
||||
return CoTrackerData(rgbs, segs, trajs, visibles, seq_name=seq_name)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.seq_names)
|
@ -6,6 +6,7 @@
|
||||
|
||||
import os
|
||||
import torch
|
||||
import cv2
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
@ -13,7 +14,6 @@ import numpy as np
|
||||
from cotracker.datasets.utils import CoTrackerData
|
||||
from torchvision.transforms import ColorJitter, GaussianBlur
|
||||
from PIL import Image
|
||||
import cv2
|
||||
|
||||
|
||||
class CoTrackerDataset(torch.utils.data.Dataset):
|
||||
@ -37,9 +37,7 @@ class CoTrackerDataset(torch.utils.data.Dataset):
|
||||
self.crop_size = crop_size
|
||||
|
||||
# photometric augmentation
|
||||
self.photo_aug = ColorJitter(
|
||||
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14
|
||||
)
|
||||
self.photo_aug = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14)
|
||||
self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
|
||||
|
||||
self.blur_aug_prob = 0.25
|
||||
@ -77,12 +75,7 @@ class CoTrackerDataset(torch.utils.data.Dataset):
|
||||
print("warning: sampling failed")
|
||||
# fake sample, so we can still collate
|
||||
sample = CoTrackerData(
|
||||
video=torch.zeros(
|
||||
(self.seq_len, 3, self.crop_size[0], self.crop_size[1])
|
||||
),
|
||||
segmentation=torch.zeros(
|
||||
(self.seq_len, 1, self.crop_size[0], self.crop_size[1])
|
||||
),
|
||||
video=torch.zeros((self.seq_len, 3, self.crop_size[0], self.crop_size[1])),
|
||||
trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)),
|
||||
visibility=torch.zeros((self.seq_len, self.traj_per_sample)),
|
||||
valid=torch.zeros((self.seq_len, self.traj_per_sample)),
|
||||
@ -105,23 +98,16 @@ class CoTrackerDataset(torch.utils.data.Dataset):
|
||||
for _ in range(
|
||||
np.random.randint(1, self.eraser_max + 1)
|
||||
): # number of times to occlude
|
||||
|
||||
xc = np.random.randint(0, W)
|
||||
yc = np.random.randint(0, H)
|
||||
dx = np.random.randint(
|
||||
self.eraser_bounds[0], self.eraser_bounds[1]
|
||||
)
|
||||
dy = np.random.randint(
|
||||
self.eraser_bounds[0], self.eraser_bounds[1]
|
||||
)
|
||||
dx = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
|
||||
dy = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
|
||||
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
|
||||
mean_color = np.mean(
|
||||
rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0
|
||||
)
|
||||
mean_color = np.mean(rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0)
|
||||
rgbs[i][y0:y1, x0:x1, :] = mean_color
|
||||
|
||||
occ_inds = np.logical_and(
|
||||
@ -132,14 +118,11 @@ class CoTrackerDataset(torch.utils.data.Dataset):
|
||||
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
||||
|
||||
if replace:
|
||||
|
||||
rgbs_alt = [
|
||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
|
||||
for rgb in rgbs
|
||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs
|
||||
]
|
||||
rgbs_alt = [
|
||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
|
||||
for rgb in rgbs_alt
|
||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs_alt
|
||||
]
|
||||
|
||||
############ replace transform (per image after the first) ############
|
||||
@ -152,12 +135,8 @@ class CoTrackerDataset(torch.utils.data.Dataset):
|
||||
): # number of times to occlude
|
||||
xc = np.random.randint(0, W)
|
||||
yc = np.random.randint(0, H)
|
||||
dx = np.random.randint(
|
||||
self.replace_bounds[0], self.replace_bounds[1]
|
||||
)
|
||||
dy = np.random.randint(
|
||||
self.replace_bounds[0], self.replace_bounds[1]
|
||||
)
|
||||
dx = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
|
||||
dy = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
|
||||
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
@ -181,17 +160,11 @@ class CoTrackerDataset(torch.utils.data.Dataset):
|
||||
############ photometric augmentation ############
|
||||
if np.random.rand() < self.color_aug_prob:
|
||||
# random per-frame amount of aug
|
||||
rgbs = [
|
||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
|
||||
for rgb in rgbs
|
||||
]
|
||||
rgbs = [np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
|
||||
|
||||
if np.random.rand() < self.blur_aug_prob:
|
||||
# random per-frame amount of blur
|
||||
rgbs = [
|
||||
np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8)
|
||||
for rgb in rgbs
|
||||
]
|
||||
rgbs = [np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
|
||||
|
||||
return rgbs, trajs, visibles
|
||||
|
||||
@ -212,9 +185,7 @@ class CoTrackerDataset(torch.utils.data.Dataset):
|
||||
pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
|
||||
rgbs = [
|
||||
np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs
|
||||
]
|
||||
rgbs = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs]
|
||||
trajs[:, :, 0] += pad_x0
|
||||
trajs[:, :, 1] += pad_y0
|
||||
H, W = rgbs[0].shape[:2]
|
||||
@ -263,12 +234,9 @@ class CoTrackerDataset(torch.utils.data.Dataset):
|
||||
H_new = np.clip(H_new, self.crop_size[0] + 10, None)
|
||||
W_new = np.clip(W_new, self.crop_size[1] + 10, None)
|
||||
# recompute scale in case we clipped
|
||||
scale_x = W_new / float(W)
|
||||
scale_y = H_new / float(H)
|
||||
|
||||
rgbs_scaled.append(
|
||||
cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR)
|
||||
)
|
||||
scale_x = (W_new - 1) / float(W - 1)
|
||||
scale_y = (H_new - 1) / float(H - 1)
|
||||
rgbs_scaled.append(cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR))
|
||||
trajs[s, :, 0] *= scale_x
|
||||
trajs[s, :, 1] *= scale_y
|
||||
rgbs = rgbs_scaled
|
||||
@ -292,22 +260,16 @@ class CoTrackerDataset(torch.utils.data.Dataset):
|
||||
for s in range(S):
|
||||
# on each frame, shift a bit more
|
||||
if s == 1:
|
||||
offset_x = np.random.randint(
|
||||
-self.max_crop_offset, self.max_crop_offset
|
||||
)
|
||||
offset_y = np.random.randint(
|
||||
-self.max_crop_offset, self.max_crop_offset
|
||||
)
|
||||
offset_x = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
|
||||
offset_y = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
|
||||
elif s > 1:
|
||||
offset_x = int(
|
||||
offset_x * 0.8
|
||||
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1)
|
||||
* 0.2
|
||||
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
|
||||
)
|
||||
offset_y = int(
|
||||
offset_y * 0.8
|
||||
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1)
|
||||
* 0.2
|
||||
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
|
||||
)
|
||||
x0 = x0 + offset_x
|
||||
y0 = y0 + offset_y
|
||||
@ -362,20 +324,9 @@ class CoTrackerDataset(torch.utils.data.Dataset):
|
||||
W_new = W
|
||||
|
||||
# simple random crop
|
||||
y0 = (
|
||||
0
|
||||
if self.crop_size[0] >= H_new
|
||||
else np.random.randint(0, H_new - self.crop_size[0])
|
||||
)
|
||||
x0 = (
|
||||
0
|
||||
if self.crop_size[1] >= W_new
|
||||
else np.random.randint(0, W_new - self.crop_size[1])
|
||||
)
|
||||
rgbs = [
|
||||
rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
|
||||
for rgb in rgbs
|
||||
]
|
||||
y0 = 0 if self.crop_size[0] >= H_new else np.random.randint(0, H_new - self.crop_size[0])
|
||||
x0 = 0 if self.crop_size[1] >= W_new else np.random.randint(0, W_new - self.crop_size[1])
|
||||
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
|
||||
|
||||
trajs[:, :, 0] -= x0
|
||||
trajs[:, :, 1] -= y0
|
||||
@ -442,9 +393,7 @@ class KubricMovifDataset(CoTrackerDataset):
|
||||
traj_2d = np.transpose(traj_2d, (1, 0, 2))
|
||||
visibility = np.transpose(np.logical_not(visibility), (1, 0))
|
||||
if self.use_augs:
|
||||
rgbs, traj_2d, visibility = self.add_photometric_augs(
|
||||
rgbs, traj_2d, visibility
|
||||
)
|
||||
rgbs, traj_2d, visibility = self.add_photometric_augs(rgbs, traj_2d, visibility)
|
||||
rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility)
|
||||
else:
|
||||
rgbs, traj_2d = self.crop(rgbs, traj_2d)
|
||||
@ -462,9 +411,9 @@ class KubricMovifDataset(CoTrackerDataset):
|
||||
if self.sample_vis_1st_frame:
|
||||
visibile_pts_inds = visibile_pts_first_frame_inds
|
||||
else:
|
||||
visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(
|
||||
as_tuple=False
|
||||
)[:, 0]
|
||||
visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(as_tuple=False)[
|
||||
:, 0
|
||||
]
|
||||
visibile_pts_inds = torch.cat(
|
||||
(visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
|
||||
)
|
||||
@ -479,10 +428,8 @@ class KubricMovifDataset(CoTrackerDataset):
|
||||
valids = torch.ones((self.seq_len, self.traj_per_sample))
|
||||
|
||||
rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float()
|
||||
segs = torch.ones((self.seq_len, 1, self.crop_size[0], self.crop_size[1]))
|
||||
sample = CoTrackerData(
|
||||
video=rgbs,
|
||||
segmentation=segs,
|
||||
trajectory=trajs,
|
||||
visibility=visibles,
|
||||
valid=valids,
|
||||
|
@ -179,12 +179,9 @@ class TapVidDataset(torch.utils.data.Dataset):
|
||||
target_points = self.points_dataset[video_name]["points"]
|
||||
if self.resize_to_256:
|
||||
frames = resize_video(frames, [256, 256])
|
||||
target_points *= np.array([256, 256])
|
||||
target_points *= np.array([255, 255]) # 1 should be mapped to 256-1
|
||||
else:
|
||||
target_points *= np.array([frames.shape[2], frames.shape[1]])
|
||||
|
||||
T, H, W, C = frames.shape
|
||||
N, T, D = target_points.shape
|
||||
target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1])
|
||||
|
||||
target_occ = self.points_dataset[video_name]["occluded"]
|
||||
if self.queried_first:
|
||||
@ -193,21 +190,15 @@ class TapVidDataset(torch.utils.data.Dataset):
|
||||
converted = sample_queries_strided(target_occ, target_points, frames)
|
||||
assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
|
||||
|
||||
trajs = (
|
||||
torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float()
|
||||
) # T, N, D
|
||||
trajs = torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() # T, N, D
|
||||
|
||||
rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
|
||||
segs = torch.ones(T, 1, H, W).float()
|
||||
visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[
|
||||
0
|
||||
].permute(
|
||||
visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute(
|
||||
1, 0
|
||||
) # T, N
|
||||
query_points = torch.from_numpy(converted["query_points"])[0] # T, N
|
||||
return CoTrackerData(
|
||||
rgbs,
|
||||
segs,
|
||||
trajs,
|
||||
visibles,
|
||||
seq_name=str(video_name),
|
||||
|
@ -19,11 +19,11 @@ class CoTrackerData:
|
||||
"""
|
||||
|
||||
video: torch.Tensor # B, S, C, H, W
|
||||
segmentation: torch.Tensor # B, S, 1, H, W
|
||||
trajectory: torch.Tensor # B, S, N, 2
|
||||
visibility: torch.Tensor # B, S, N
|
||||
# optional data
|
||||
valid: Optional[torch.Tensor] = None # B, S, N
|
||||
segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
|
||||
seq_name: Optional[str] = None
|
||||
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
|
||||
|
||||
@ -33,19 +33,20 @@ def collate_fn(batch):
|
||||
Collate function for video tracks data.
|
||||
"""
|
||||
video = torch.stack([b.video for b in batch], dim=0)
|
||||
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
|
||||
trajectory = torch.stack([b.trajectory for b in batch], dim=0)
|
||||
visibility = torch.stack([b.visibility for b in batch], dim=0)
|
||||
query_points = None
|
||||
query_points = segmentation = None
|
||||
if batch[0].query_points is not None:
|
||||
query_points = torch.stack([b.query_points for b in batch], dim=0)
|
||||
if batch[0].segmentation is not None:
|
||||
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
|
||||
seq_name = [b.seq_name for b in batch]
|
||||
|
||||
return CoTrackerData(
|
||||
video,
|
||||
segmentation,
|
||||
trajectory,
|
||||
visibility,
|
||||
video=video,
|
||||
trajectory=trajectory,
|
||||
visibility=visibility,
|
||||
segmentation=segmentation,
|
||||
seq_name=seq_name,
|
||||
query_points=query_points,
|
||||
)
|
||||
@ -57,13 +58,18 @@ def collate_fn_train(batch):
|
||||
"""
|
||||
gotit = [gotit for _, gotit in batch]
|
||||
video = torch.stack([b.video for b, _ in batch], dim=0)
|
||||
segmentation = torch.stack([b.segmentation for b, _ in batch], dim=0)
|
||||
trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
|
||||
visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
|
||||
valid = torch.stack([b.valid for b, _ in batch], dim=0)
|
||||
seq_name = [b.seq_name for b, _ in batch]
|
||||
return (
|
||||
CoTrackerData(video, segmentation, trajectory, visibility, valid, seq_name),
|
||||
CoTrackerData(
|
||||
video=video,
|
||||
trajectory=trajectory,
|
||||
visibility=visibility,
|
||||
valid=valid,
|
||||
seq_name=seq_name,
|
||||
),
|
||||
gotit,
|
||||
)
|
||||
|
||||
@ -98,17 +104,3 @@ def dataclass_to_cuda_(obj):
|
||||
for f in dataclasses.fields(obj):
|
||||
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
|
||||
return obj
|
||||
|
||||
|
||||
def resize_sample(rgbs, trajs_g, segs, interp_shape):
|
||||
S, C, H, W = rgbs.shape
|
||||
S, N, D = trajs_g.shape
|
||||
|
||||
assert D == 2
|
||||
|
||||
rgbs = F.interpolate(rgbs, interp_shape, mode="bilinear")
|
||||
segs = F.interpolate(segs, interp_shape, mode="nearest")
|
||||
|
||||
trajs_g[:, :, 0] *= interp_shape[1] / W
|
||||
trajs_g[:, :, 1] *= interp_shape[0] / H
|
||||
return rgbs, trajs_g, segs
|
||||
|
@ -1,6 +1,6 @@
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: badja
|
||||
dataset_name: dynamic_replica
|
||||
|
||||
|
@ -1,6 +0,0 @@
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: fastcapture
|
||||
|
||||
|
@ -37,57 +37,7 @@ class Evaluator:
|
||||
pred_trajectory, pred_visibility = pred_trajectory
|
||||
else:
|
||||
pred_visibility = None
|
||||
if dataset_name == "badja":
|
||||
sample.segmentation = (sample.segmentation > 0).float()
|
||||
*_, N, _ = sample.trajectory.shape
|
||||
accs = []
|
||||
accs_3px = []
|
||||
for s1 in range(1, sample.video.shape[1]): # target frame
|
||||
for n in range(N):
|
||||
vis = sample.visibility[0, s1, n]
|
||||
if vis > 0:
|
||||
coord_e = pred_trajectory[0, s1, n] # 2
|
||||
coord_g = sample.trajectory[0, s1, n] # 2
|
||||
dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0))
|
||||
area = torch.sum(sample.segmentation[0, s1])
|
||||
# print_('0.2*sqrt(area)', 0.2*torch.sqrt(area))
|
||||
thr = 0.2 * torch.sqrt(area)
|
||||
# correct =
|
||||
accs.append((dist < thr).float())
|
||||
# print('thr',thr)
|
||||
accs_3px.append((dist < 3.0).float())
|
||||
|
||||
res = torch.mean(torch.stack(accs)) * 100.0
|
||||
res_3px = torch.mean(torch.stack(accs_3px)) * 100.0
|
||||
metrics[sample.seq_name[0]] = res.item()
|
||||
metrics[sample.seq_name[0] + "_accuracy"] = res_3px.item()
|
||||
print(metrics)
|
||||
print(
|
||||
"avg", np.mean([v for k, v in metrics.items() if "accuracy" not in k])
|
||||
)
|
||||
print(
|
||||
"avg acc 3px",
|
||||
np.mean([v for k, v in metrics.items() if "accuracy" in k]),
|
||||
)
|
||||
elif dataset_name == "fastcapture" or ("kubric" in dataset_name):
|
||||
*_, N, _ = sample.trajectory.shape
|
||||
accs = []
|
||||
for s1 in range(1, sample.video.shape[1]): # target frame
|
||||
for n in range(N):
|
||||
vis = sample.visibility[0, s1, n]
|
||||
if vis > 0:
|
||||
coord_e = pred_trajectory[0, s1, n] # 2
|
||||
coord_g = sample.trajectory[0, s1, n] # 2
|
||||
dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0))
|
||||
thr = 3
|
||||
correct = (dist < thr).float()
|
||||
accs.append(correct)
|
||||
|
||||
res = torch.mean(torch.stack(accs)) * 100.0
|
||||
metrics[sample.seq_name[0] + "_accuracy"] = res.item()
|
||||
print(metrics)
|
||||
print("avg", np.mean([v for v in metrics.values()]))
|
||||
elif "tapvid" in dataset_name:
|
||||
if "tapvid" in dataset_name:
|
||||
B, T, N, D = sample.trajectory.shape
|
||||
traj = sample.trajectory.clone()
|
||||
thr = 0.9
|
||||
@ -99,7 +49,6 @@ class Evaluator:
|
||||
if not pred_visibility.dtype == torch.bool:
|
||||
pred_visibility = pred_visibility > thr
|
||||
|
||||
# pred_trajectory
|
||||
query_points = sample.query_points.clone().cpu().numpy()
|
||||
|
||||
pred_visibility = pred_visibility[:, :, :N]
|
||||
@ -107,15 +56,11 @@ class Evaluator:
|
||||
|
||||
gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy()
|
||||
gt_occluded = (
|
||||
torch.logical_not(sample.visibility.clone().permute(0, 2, 1))
|
||||
.cpu()
|
||||
.numpy()
|
||||
torch.logical_not(sample.visibility.clone().permute(0, 2, 1)).cpu().numpy()
|
||||
)
|
||||
|
||||
pred_occluded = (
|
||||
torch.logical_not(pred_visibility.clone().permute(0, 2, 1))
|
||||
.cpu()
|
||||
.numpy()
|
||||
torch.logical_not(pred_visibility.clone().permute(0, 2, 1)).cpu().numpy()
|
||||
)
|
||||
pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy()
|
||||
|
||||
@ -140,27 +85,79 @@ class Evaluator:
|
||||
logging.info(f"avg: {metrics['avg']}")
|
||||
print("metrics", out_metrics)
|
||||
print("avg", metrics["avg"])
|
||||
else:
|
||||
rgbs = sample.video
|
||||
trajs_g = sample.trajectory
|
||||
valids = sample.valid
|
||||
vis_g = sample.visibility
|
||||
elif dataset_name == "dynamic_replica" or dataset_name == "pointodyssey":
|
||||
*_, N, _ = sample.trajectory.shape
|
||||
B, T, N = sample.visibility.shape
|
||||
H, W = sample.video.shape[-2:]
|
||||
device = sample.video.device
|
||||
|
||||
B, S, C, H, W = rgbs.shape
|
||||
assert C == 3
|
||||
B, S, N, D = trajs_g.shape
|
||||
out_metrics = {}
|
||||
|
||||
assert torch.sum(valids) == B * S * N
|
||||
d_vis_sum = d_occ_sum = d_sum_all = 0.0
|
||||
thrs = [1, 2, 4, 8, 16]
|
||||
sx_ = (W - 1) / 255.0
|
||||
sy_ = (H - 1) / 255.0
|
||||
sc_py = np.array([sx_, sy_]).reshape([1, 1, 2])
|
||||
sc_pt = torch.from_numpy(sc_py).float().to(device)
|
||||
__, first_visible_inds = torch.max(sample.visibility, dim=1)
|
||||
|
||||
vis_g = (torch.sum(vis_g, dim=1, keepdim=True) >= 4).float().repeat(1, S, 1)
|
||||
frame_ids_tensor = torch.arange(T, device=device)[None, :, None].repeat(B, 1, N)
|
||||
start_tracking_mask = frame_ids_tensor > (first_visible_inds.unsqueeze(1))
|
||||
|
||||
ate = torch.norm(pred_trajectory - trajs_g, dim=-1) # B, S, N
|
||||
for thr in thrs:
|
||||
d_ = (
|
||||
torch.norm(
|
||||
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
|
||||
dim=-1,
|
||||
)
|
||||
< thr
|
||||
).float() # B,S-1,N
|
||||
d_occ = (
|
||||
reduce_masked_mean(d_, (1 - sample.visibility) * start_tracking_mask).item()
|
||||
* 100.0
|
||||
)
|
||||
d_occ_sum += d_occ
|
||||
out_metrics[f"accuracy_occ_{thr}"] = d_occ
|
||||
|
||||
metrics["things_all"] = reduce_masked_mean(ate, valids).item()
|
||||
metrics["things_vis"] = reduce_masked_mean(ate, valids * vis_g).item()
|
||||
metrics["things_occ"] = reduce_masked_mean(
|
||||
ate, valids * (1.0 - vis_g)
|
||||
).item()
|
||||
d_vis = (
|
||||
reduce_masked_mean(d_, sample.visibility * start_tracking_mask).item() * 100.0
|
||||
)
|
||||
d_vis_sum += d_vis
|
||||
out_metrics[f"accuracy_vis_{thr}"] = d_vis
|
||||
|
||||
d_all = reduce_masked_mean(d_, start_tracking_mask).item() * 100.0
|
||||
d_sum_all += d_all
|
||||
out_metrics[f"accuracy_{thr}"] = d_all
|
||||
|
||||
d_occ_avg = d_occ_sum / len(thrs)
|
||||
d_vis_avg = d_vis_sum / len(thrs)
|
||||
d_all_avg = d_sum_all / len(thrs)
|
||||
|
||||
sur_thr = 50
|
||||
dists = torch.norm(
|
||||
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
|
||||
dim=-1,
|
||||
) # B,S,N
|
||||
dist_ok = 1 - (dists > sur_thr).float() * sample.visibility # B,S,N
|
||||
survival = torch.cumprod(dist_ok, dim=1) # B,S,N
|
||||
out_metrics["survival"] = torch.mean(survival).item() * 100.0
|
||||
|
||||
out_metrics["accuracy_occ"] = d_occ_avg
|
||||
out_metrics["accuracy_vis"] = d_vis_avg
|
||||
out_metrics["accuracy"] = d_all_avg
|
||||
|
||||
metrics[sample.seq_name[0]] = out_metrics
|
||||
for metric_name in out_metrics.keys():
|
||||
if "avg" not in metrics:
|
||||
metrics["avg"] = {}
|
||||
metrics["avg"][metric_name] = float(
|
||||
np.mean([v[metric_name] for k, v in metrics.items() if k != "avg"])
|
||||
)
|
||||
|
||||
logging.info(f"Metrics: {out_metrics}")
|
||||
logging.info(f"avg: {metrics['avg']}")
|
||||
print("metrics", out_metrics)
|
||||
print("avg", metrics["avg"])
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_sequence(
|
||||
@ -169,6 +166,7 @@ class Evaluator:
|
||||
test_dataloader: torch.utils.data.DataLoader,
|
||||
dataset_name: str,
|
||||
train_mode=False,
|
||||
visualize_every: int = 1,
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
step: Optional[int] = 0,
|
||||
):
|
||||
@ -221,7 +219,6 @@ class Evaluator:
|
||||
|
||||
pred_tracks = model(sample.video, queries)
|
||||
if "strided" in dataset_name:
|
||||
|
||||
inv_video = sample.video.flip(1).clone()
|
||||
inv_queries = queries.clone()
|
||||
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
||||
@ -243,14 +240,14 @@ class Evaluator:
|
||||
seq_name = sample.seq_name[0]
|
||||
else:
|
||||
seq_name = str(ind)
|
||||
|
||||
vis.visualize(
|
||||
sample.video,
|
||||
pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
|
||||
filename=dataset_name + "_" + seq_name,
|
||||
writer=writer,
|
||||
step=step,
|
||||
)
|
||||
if ind % visualize_every == 0:
|
||||
vis.visualize(
|
||||
sample.video,
|
||||
pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
|
||||
filename=dataset_name + "_" + seq_name,
|
||||
writer=writer,
|
||||
step=step,
|
||||
)
|
||||
|
||||
self.compute_metrics(metrics, sample, pred_tracks, dataset_name)
|
||||
return metrics
|
||||
|
@ -14,9 +14,8 @@ import numpy as np
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from cotracker.datasets.badja_dataset import BadjaDataset
|
||||
from cotracker.datasets.fast_capture_dataset import FastCaptureDataset
|
||||
from cotracker.datasets.tap_vid_datasets import TapVidDataset
|
||||
from cotracker.datasets.dr_dataset import DynamicReplicaDataset
|
||||
from cotracker.datasets.utils import collate_fn
|
||||
|
||||
from cotracker.models.evaluation_predictor import EvaluationPredictor
|
||||
@ -33,23 +32,20 @@ class DefaultConfig:
|
||||
exp_dir: str = "./outputs"
|
||||
|
||||
# Name of the dataset to be used for the evaluation.
|
||||
dataset_name: str = "badja"
|
||||
dataset_name: str = "tapvid_davis_first"
|
||||
# The root directory of the dataset.
|
||||
dataset_root: str = "./"
|
||||
|
||||
# Path to the pre-trained model checkpoint to be used for the evaluation.
|
||||
# The default value is the path to a specific CoTracker model checkpoint.
|
||||
# Other available options are commented.
|
||||
checkpoint: str = "./checkpoints/cotracker_stride_4_wind_8.pth"
|
||||
# cotracker_stride_4_wind_12
|
||||
# cotracker_stride_8_wind_16
|
||||
checkpoint: str = "./checkpoints/cotracker2.pth"
|
||||
|
||||
# EvaluationPredictor parameters
|
||||
# The size (N) of the support grid used in the predictor.
|
||||
# The total number of points is (N*N).
|
||||
grid_size: int = 6
|
||||
grid_size: int = 5
|
||||
# The size (N) of the local support grid.
|
||||
local_grid_size: int = 6
|
||||
local_grid_size: int = 8
|
||||
# A flag indicating whether to evaluate one ground truth point at a time.
|
||||
single_point: bool = True
|
||||
# The number of iterative updates for each sliding window.
|
||||
@ -111,18 +107,10 @@ def run_eval(cfg: DefaultConfig):
|
||||
|
||||
# Constructing the specified dataset
|
||||
curr_collate_fn = collate_fn
|
||||
if cfg.dataset_name == "badja":
|
||||
test_dataset = BadjaDataset(data_root=os.path.join(cfg.dataset_root, "BADJA"))
|
||||
elif cfg.dataset_name == "fastcapture":
|
||||
test_dataset = FastCaptureDataset(
|
||||
data_root=os.path.join(cfg.dataset_root, "fastcapture"),
|
||||
max_seq_len=100,
|
||||
max_num_points=20,
|
||||
)
|
||||
elif "tapvid" in cfg.dataset_name:
|
||||
if "tapvid" in cfg.dataset_name:
|
||||
dataset_type = cfg.dataset_name.split("_")[1]
|
||||
if dataset_type == "davis":
|
||||
data_root = os.path.join(cfg.dataset_root, "/tapvid_davis/tapvid_davis.pkl")
|
||||
data_root = os.path.join(cfg.dataset_root, "tapvid_davis", "tapvid_davis.pkl")
|
||||
elif dataset_type == "kinetics":
|
||||
data_root = os.path.join(
|
||||
cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics"
|
||||
@ -132,6 +120,8 @@ def run_eval(cfg: DefaultConfig):
|
||||
data_root=data_root,
|
||||
queried_first=not "strided" in cfg.dataset_name,
|
||||
)
|
||||
elif cfg.dataset_name == "dynamic_replica":
|
||||
test_dataset = DynamicReplicaDataset(sample_len=300, only_first_n_samples=1)
|
||||
|
||||
# Creating the DataLoader object
|
||||
test_dataloader = torch.utils.data.DataLoader(
|
||||
@ -155,10 +145,8 @@ def run_eval(cfg: DefaultConfig):
|
||||
print(end - start)
|
||||
|
||||
# Saving the evaluation results to a .json file
|
||||
if not "tapvid" in cfg.dataset_name:
|
||||
print("evaluate_result", evaluate_result)
|
||||
else:
|
||||
evaluate_result = evaluate_result["avg"]
|
||||
evaluate_result = evaluate_result["avg"]
|
||||
print("evaluate_result", evaluate_result)
|
||||
result_file = os.path.join(cfg.exp_dir, f"result_eval_.json")
|
||||
evaluate_result["time"] = end - start
|
||||
print(f"Dumping eval results to {result_file}.")
|
||||
|
@ -6,63 +6,24 @@
|
||||
|
||||
import torch
|
||||
|
||||
from cotracker.models.core.cotracker.cotracker import CoTracker
|
||||
from cotracker.models.core.cotracker.cotracker import CoTracker2
|
||||
|
||||
|
||||
def build_cotracker(
|
||||
checkpoint: str,
|
||||
):
|
||||
if checkpoint is None:
|
||||
return build_cotracker_stride_4_wind_8()
|
||||
return build_cotracker()
|
||||
model_name = checkpoint.split("/")[-1].split(".")[0]
|
||||
if model_name == "cotracker_stride_4_wind_8":
|
||||
return build_cotracker_stride_4_wind_8(checkpoint=checkpoint)
|
||||
elif model_name == "cotracker_stride_4_wind_12":
|
||||
return build_cotracker_stride_4_wind_12(checkpoint=checkpoint)
|
||||
elif model_name == "cotracker_stride_8_wind_16":
|
||||
return build_cotracker_stride_8_wind_16(checkpoint=checkpoint)
|
||||
if model_name == "cotracker":
|
||||
return build_cotracker(checkpoint=checkpoint)
|
||||
else:
|
||||
raise ValueError(f"Unknown model name {model_name}")
|
||||
|
||||
|
||||
# model used to produce the results in the paper
|
||||
def build_cotracker_stride_4_wind_8(checkpoint=None):
|
||||
return _build_cotracker(
|
||||
stride=4,
|
||||
sequence_len=8,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
def build_cotracker(checkpoint=None):
|
||||
cotracker = CoTracker2(stride=4, window_len=8, add_space_attn=True)
|
||||
|
||||
|
||||
def build_cotracker_stride_4_wind_12(checkpoint=None):
|
||||
return _build_cotracker(
|
||||
stride=4,
|
||||
sequence_len=12,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
# the fastest model
|
||||
def build_cotracker_stride_8_wind_16(checkpoint=None):
|
||||
return _build_cotracker(
|
||||
stride=8,
|
||||
sequence_len=16,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def _build_cotracker(
|
||||
stride,
|
||||
sequence_len,
|
||||
checkpoint=None,
|
||||
):
|
||||
cotracker = CoTracker(
|
||||
stride=stride,
|
||||
S=sequence_len,
|
||||
add_space_attn=True,
|
||||
space_depth=6,
|
||||
time_depth=6,
|
||||
)
|
||||
if checkpoint is not None:
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f, map_location="cpu")
|
||||
|
@ -7,9 +7,71 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
import collections
|
||||
from torch import Tensor
|
||||
from itertools import repeat
|
||||
|
||||
from einops import rearrange
|
||||
from timm.models.vision_transformer import Attention, Mlp
|
||||
from cotracker.models.core.model_utils import bilinear_sampler
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=None,
|
||||
bias=True,
|
||||
drop=0.0,
|
||||
use_conv=False,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
bias = to_2tuple(bias)
|
||||
drop_probs = to_2tuple(drop)
|
||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
||||
|
||||
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
||||
self.act = act_layer()
|
||||
self.drop1 = nn.Dropout(drop_probs[0])
|
||||
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
||||
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
||||
self.drop2 = nn.Dropout(drop_probs[1])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop1(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
@ -24,9 +86,7 @@ class ResidualBlock(nn.Module):
|
||||
stride=stride,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
|
||||
)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros")
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
@ -75,28 +135,14 @@ class ResidualBlock(nn.Module):
|
||||
|
||||
|
||||
class BasicEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, input_dim=3, output_dim=128, stride=8, norm_fn="batch", dropout=0.0
|
||||
):
|
||||
def __init__(self, input_dim=3, output_dim=128, stride=4):
|
||||
super(BasicEncoder, self).__init__()
|
||||
self.stride = stride
|
||||
self.norm_fn = norm_fn
|
||||
self.in_planes = 64
|
||||
self.norm_fn = "instance"
|
||||
self.in_planes = output_dim // 2
|
||||
|
||||
if self.norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2)
|
||||
|
||||
elif self.norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(self.in_planes)
|
||||
self.norm2 = nn.BatchNorm2d(output_dim * 2)
|
||||
|
||||
elif self.norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
||||
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
||||
|
||||
elif self.norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
||||
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
input_dim,
|
||||
@ -107,37 +153,24 @@ class BasicEncoder(nn.Module):
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.layer1 = self._make_layer(output_dim // 2, stride=1)
|
||||
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
|
||||
self.layer3 = self._make_layer(output_dim, stride=2)
|
||||
self.layer4 = self._make_layer(output_dim, stride=2)
|
||||
|
||||
self.shallow = False
|
||||
if self.shallow:
|
||||
self.layer1 = self._make_layer(64, stride=1)
|
||||
self.layer2 = self._make_layer(96, stride=2)
|
||||
self.layer3 = self._make_layer(128, stride=2)
|
||||
self.conv2 = nn.Conv2d(128 + 96 + 64, output_dim, kernel_size=1)
|
||||
else:
|
||||
self.layer1 = self._make_layer(64, stride=1)
|
||||
self.layer2 = self._make_layer(96, stride=2)
|
||||
self.layer3 = self._make_layer(128, stride=2)
|
||||
self.layer4 = self._make_layer(128, stride=2)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
128 + 128 + 96 + 64,
|
||||
output_dim * 2,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
output_dim * 3 + output_dim // 4,
|
||||
output_dim * 2,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
elif isinstance(m, (nn.InstanceNorm2d)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
@ -158,122 +191,47 @@ class BasicEncoder(nn.Module):
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
if self.shallow:
|
||||
a = self.layer1(x)
|
||||
b = self.layer2(a)
|
||||
c = self.layer3(b)
|
||||
a = F.interpolate(
|
||||
a,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
b = F.interpolate(
|
||||
b,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
c = F.interpolate(
|
||||
c,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
x = self.conv2(torch.cat([a, b, c], dim=1))
|
||||
else:
|
||||
a = self.layer1(x)
|
||||
b = self.layer2(a)
|
||||
c = self.layer3(b)
|
||||
d = self.layer4(c)
|
||||
a = F.interpolate(
|
||||
a,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
b = F.interpolate(
|
||||
b,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
c = F.interpolate(
|
||||
c,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
d = F.interpolate(
|
||||
d,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
||||
x = self.norm2(x)
|
||||
x = self.relu2(x)
|
||||
x = self.conv3(x)
|
||||
a = self.layer1(x)
|
||||
b = self.layer2(a)
|
||||
c = self.layer3(b)
|
||||
d = self.layer4(c)
|
||||
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
def _bilinear_intepolate(x):
|
||||
return F.interpolate(
|
||||
x,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
|
||||
a = _bilinear_intepolate(a)
|
||||
b = _bilinear_intepolate(b)
|
||||
c = _bilinear_intepolate(c)
|
||||
d = _bilinear_intepolate(d)
|
||||
|
||||
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
||||
x = self.norm2(x)
|
||||
x = self.relu2(x)
|
||||
x = self.conv3(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = Attention(
|
||||
hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=approx_gelu,
|
||||
drop=0,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.norm1(x))
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
|
||||
"""Wrapper for grid_sample, uses pixel coordinates"""
|
||||
H, W = img.shape[-2:]
|
||||
xgrid, ygrid = coords.split([1, 1], dim=-1)
|
||||
# go to 0,1 then 0,2 then -1,1
|
||||
xgrid = 2 * xgrid / (W - 1) - 1
|
||||
ygrid = 2 * ygrid / (H - 1) - 1
|
||||
|
||||
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||
img = F.grid_sample(img, grid, align_corners=True)
|
||||
|
||||
if mask:
|
||||
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||
return img, mask.float()
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class CorrBlock:
|
||||
def __init__(self, fmaps, num_levels=4, radius=4):
|
||||
def __init__(
|
||||
self,
|
||||
fmaps,
|
||||
num_levels=4,
|
||||
radius=4,
|
||||
multiple_track_feats=False,
|
||||
padding_mode="zeros",
|
||||
):
|
||||
B, S, C, H, W = fmaps.shape
|
||||
self.S, self.C, self.H, self.W = S, C, H, W
|
||||
|
||||
self.padding_mode = padding_mode
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
self.fmaps_pyramid = []
|
||||
self.multiple_track_feats = multiple_track_feats
|
||||
|
||||
self.fmaps_pyramid.append(fmaps)
|
||||
for i in range(self.num_levels - 1):
|
||||
@ -292,109 +250,118 @@ class CorrBlock:
|
||||
out_pyramid = []
|
||||
for i in range(self.num_levels):
|
||||
corrs = self.corrs_pyramid[i] # B, S, N, H, W
|
||||
_, _, _, H, W = corrs.shape
|
||||
*_, H, W = corrs.shape
|
||||
|
||||
dx = torch.linspace(-r, r, 2 * r + 1)
|
||||
dy = torch.linspace(-r, r, 2 * r + 1)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
|
||||
coords.device
|
||||
)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
|
||||
|
||||
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
|
||||
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
|
||||
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
||||
coords_lvl = centroid_lvl + delta_lvl
|
||||
|
||||
corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl)
|
||||
corrs = bilinear_sampler(
|
||||
corrs.reshape(B * S * N, 1, H, W),
|
||||
coords_lvl,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
corrs = corrs.view(B, S, N, -1)
|
||||
out_pyramid.append(corrs)
|
||||
|
||||
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
||||
return out.contiguous().float()
|
||||
out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
|
||||
return out
|
||||
|
||||
def corr(self, targets):
|
||||
B, S, N, C = targets.shape
|
||||
if self.multiple_track_feats:
|
||||
targets_split = targets.split(C // self.num_levels, dim=-1)
|
||||
B, S, N, C = targets_split[0].shape
|
||||
|
||||
assert C == self.C
|
||||
assert S == self.S
|
||||
|
||||
fmap1 = targets
|
||||
|
||||
self.corrs_pyramid = []
|
||||
for fmaps in self.fmaps_pyramid:
|
||||
_, _, _, H, W = fmaps.shape
|
||||
fmap2s = fmaps.view(B, S, C, H * W)
|
||||
for i, fmaps in enumerate(self.fmaps_pyramid):
|
||||
*_, H, W = fmaps.shape
|
||||
fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
|
||||
if self.multiple_track_feats:
|
||||
fmap1 = targets_split[i]
|
||||
corrs = torch.matmul(fmap1, fmap2s)
|
||||
corrs = corrs.view(B, S, N, H, W)
|
||||
corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
|
||||
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
||||
self.corrs_pyramid.append(corrs)
|
||||
|
||||
|
||||
class UpdateFormer(nn.Module):
|
||||
"""
|
||||
Transformer model that updates track estimates.
|
||||
"""
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * num_heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = num_heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
|
||||
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
|
||||
self.to_out = nn.Linear(inner_dim, query_dim)
|
||||
|
||||
def forward(self, x, context=None, attn_bias=None):
|
||||
B, N1, C = x.shape
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
|
||||
context = default(context, x)
|
||||
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||
|
||||
N2 = context.shape[1]
|
||||
k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
||||
v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
||||
|
||||
sim = (q @ k.transpose(-2, -1)) * self.scale
|
||||
|
||||
if attn_bias is not None:
|
||||
sim = sim + attn_bias
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
|
||||
return self.to_out(x)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
space_depth=12,
|
||||
time_depth=12,
|
||||
input_dim=320,
|
||||
hidden_size=384,
|
||||
num_heads=8,
|
||||
output_dim=130,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
attn_class: Callable[..., nn.Module] = Attention,
|
||||
mlp_ratio=4.0,
|
||||
add_space_attn=True,
|
||||
**block_kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = 2
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.add_space_attn = add_space_attn
|
||||
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
||||
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
||||
|
||||
self.time_blocks = nn.ModuleList(
|
||||
[
|
||||
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
|
||||
for _ in range(time_depth)
|
||||
]
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=approx_gelu,
|
||||
drop=0,
|
||||
)
|
||||
|
||||
if add_space_attn:
|
||||
self.space_blocks = nn.ModuleList(
|
||||
[
|
||||
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
|
||||
for _ in range(space_depth)
|
||||
]
|
||||
def forward(self, x, mask=None):
|
||||
attn_bias = mask
|
||||
if mask is not None:
|
||||
mask = (
|
||||
(mask[:, None] * mask[:, :, None])
|
||||
.unsqueeze(1)
|
||||
.expand(-1, self.attn.num_heads, -1, -1)
|
||||
)
|
||||
assert len(self.time_blocks) >= len(self.space_blocks)
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
def forward(self, input_tensor):
|
||||
x = self.input_transform(input_tensor)
|
||||
|
||||
j = 0
|
||||
for i in range(len(self.time_blocks)):
|
||||
B, N, T, _ = x.shape
|
||||
x_time = rearrange(x, "b n t c -> (b n) t c", b=B, t=T, n=N)
|
||||
x_time = self.time_blocks[i](x_time)
|
||||
|
||||
x = rearrange(x_time, "(b n) t c -> b n t c ", b=B, t=T, n=N)
|
||||
if self.add_space_attn and (
|
||||
i % (len(self.time_blocks) // len(self.space_blocks)) == 0
|
||||
):
|
||||
x_space = rearrange(x, "b n t c -> (b t) n c ", b=B, t=T, n=N)
|
||||
x_space = self.space_blocks[j](x_space)
|
||||
x = rearrange(x_space, "(b t) n c -> b n t c ", b=B, t=T, n=N)
|
||||
j += 1
|
||||
|
||||
flow = self.flow_head(x)
|
||||
return flow
|
||||
max_neg_value = -torch.finfo(x.dtype).max
|
||||
attn_bias = (~mask) * max_neg_value
|
||||
x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
@ -6,102 +6,74 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
import torch.nn.functional as F
|
||||
|
||||
from cotracker.models.core.cotracker.blocks import (
|
||||
BasicEncoder,
|
||||
CorrBlock,
|
||||
UpdateFormer,
|
||||
)
|
||||
|
||||
from cotracker.models.core.model_utils import meshgrid2d, bilinear_sample2d, smart_cat
|
||||
from cotracker.models.core.model_utils import sample_features4d, sample_features5d
|
||||
from cotracker.models.core.embeddings import (
|
||||
get_2d_embedding,
|
||||
get_1d_sincos_pos_embed_from_grid,
|
||||
get_2d_sincos_pos_embed,
|
||||
)
|
||||
|
||||
from cotracker.models.core.cotracker.blocks import (
|
||||
Mlp,
|
||||
BasicEncoder,
|
||||
AttnBlock,
|
||||
CorrBlock,
|
||||
Attention,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device="cpu"):
|
||||
if grid_size == 1:
|
||||
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2], device=device)[
|
||||
None, None
|
||||
]
|
||||
|
||||
grid_y, grid_x = meshgrid2d(
|
||||
1, grid_size, grid_size, stack=False, norm=False, device=device
|
||||
)
|
||||
step = interp_shape[1] // 64
|
||||
if grid_center[0] != 0 or grid_center[1] != 0:
|
||||
grid_y = grid_y - grid_size / 2.0
|
||||
grid_x = grid_x - grid_size / 2.0
|
||||
grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * (
|
||||
interp_shape[0] - step * 2
|
||||
)
|
||||
grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * (
|
||||
interp_shape[1] - step * 2
|
||||
)
|
||||
|
||||
grid_y = grid_y + grid_center[0]
|
||||
grid_x = grid_x + grid_center[1]
|
||||
xy = torch.stack([grid_x, grid_y], dim=-1).to(device)
|
||||
return xy
|
||||
|
||||
|
||||
def sample_pos_embed(grid_size, embed_dim, coords):
|
||||
pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim, grid_size=grid_size)
|
||||
pos_embed = (
|
||||
torch.from_numpy(pos_embed)
|
||||
.reshape(grid_size[0], grid_size[1], embed_dim)
|
||||
.float()
|
||||
.unsqueeze(0)
|
||||
.to(coords.device)
|
||||
)
|
||||
sampled_pos_embed = bilinear_sample2d(
|
||||
pos_embed.permute(0, 3, 1, 2), coords[:, 0, :, 0], coords[:, 0, :, 1]
|
||||
)
|
||||
return sampled_pos_embed
|
||||
|
||||
|
||||
class CoTracker(nn.Module):
|
||||
class CoTracker2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
S=8,
|
||||
stride=8,
|
||||
window_len=8,
|
||||
stride=4,
|
||||
add_space_attn=True,
|
||||
num_heads=8,
|
||||
hidden_size=384,
|
||||
space_depth=12,
|
||||
time_depth=12,
|
||||
num_virtual_tracks=64,
|
||||
model_resolution=(384, 512),
|
||||
):
|
||||
super(CoTracker, self).__init__()
|
||||
self.S = S
|
||||
super(CoTracker2, self).__init__()
|
||||
self.window_len = window_len
|
||||
self.stride = stride
|
||||
self.hidden_dim = 256
|
||||
self.latent_dim = latent_dim = 128
|
||||
self.corr_levels = 4
|
||||
self.corr_radius = 3
|
||||
self.latent_dim = 128
|
||||
self.add_space_attn = add_space_attn
|
||||
self.fnet = BasicEncoder(
|
||||
output_dim=self.latent_dim, norm_fn="instance", dropout=0, stride=stride
|
||||
)
|
||||
|
||||
self.updateformer = UpdateFormer(
|
||||
space_depth=space_depth,
|
||||
time_depth=time_depth,
|
||||
input_dim=456,
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
output_dim=latent_dim + 2,
|
||||
self.fnet = BasicEncoder(output_dim=self.latent_dim)
|
||||
self.num_virtual_tracks = num_virtual_tracks
|
||||
self.model_resolution = model_resolution
|
||||
self.input_dim = 456
|
||||
self.updateformer = EfficientUpdateFormer(
|
||||
space_depth=6,
|
||||
time_depth=6,
|
||||
input_dim=self.input_dim,
|
||||
hidden_size=384,
|
||||
output_dim=self.latent_dim + 2,
|
||||
mlp_ratio=4.0,
|
||||
add_space_attn=add_space_attn,
|
||||
num_virtual_tracks=num_virtual_tracks,
|
||||
)
|
||||
|
||||
time_grid = torch.linspace(0, window_len - 1, window_len).reshape(1, window_len, 1)
|
||||
|
||||
self.register_buffer(
|
||||
"time_emb", get_1d_sincos_pos_embed_from_grid(self.input_dim, time_grid[0])
|
||||
)
|
||||
|
||||
self.register_buffer(
|
||||
"pos_emb",
|
||||
get_2d_sincos_pos_embed(
|
||||
embed_dim=self.input_dim,
|
||||
grid_size=(
|
||||
model_resolution[0] // stride,
|
||||
model_resolution[1] // stride,
|
||||
),
|
||||
),
|
||||
)
|
||||
self.norm = nn.GroupNorm(1, self.latent_dim)
|
||||
self.ffeat_updater = nn.Sequential(
|
||||
self.track_feat_updater = nn.Sequential(
|
||||
nn.Linear(self.latent_dim, self.latent_dim),
|
||||
nn.GELU(),
|
||||
)
|
||||
@ -109,243 +81,423 @@ class CoTracker(nn.Module):
|
||||
nn.Linear(self.latent_dim, 1),
|
||||
)
|
||||
|
||||
def forward_iteration(
|
||||
def forward_window(
|
||||
self,
|
||||
fmaps,
|
||||
coords_init,
|
||||
feat_init=None,
|
||||
vis_init=None,
|
||||
coords,
|
||||
track_feat=None,
|
||||
vis=None,
|
||||
track_mask=None,
|
||||
attention_mask=None,
|
||||
iters=4,
|
||||
):
|
||||
B, S_init, N, D = coords_init.shape
|
||||
assert D == 2
|
||||
assert B == 1
|
||||
# B = batch size
|
||||
# S = number of frames in the window)
|
||||
# N = number of tracks
|
||||
# C = channels of a point feature vector
|
||||
# E = positional embedding size
|
||||
# LRR = local receptive field radius
|
||||
# D = dimension of the transformer input tokens
|
||||
|
||||
B, S, __, H8, W8 = fmaps.shape
|
||||
# track_feat = B S N C
|
||||
# vis = B S N 1
|
||||
# track_mask = B S N 1
|
||||
# attention_mask = B S N
|
||||
|
||||
device = fmaps.device
|
||||
B, S_init, N, __ = track_mask.shape
|
||||
B, S, *_ = fmaps.shape
|
||||
|
||||
if S_init < S:
|
||||
coords = torch.cat(
|
||||
[coords_init, coords_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
|
||||
)
|
||||
vis_init = torch.cat(
|
||||
[vis_init, vis_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
|
||||
)
|
||||
else:
|
||||
coords = coords_init.clone()
|
||||
|
||||
fcorr_fn = CorrBlock(
|
||||
fmaps, num_levels=self.corr_levels, radius=self.corr_radius
|
||||
track_mask = F.pad(track_mask, (0, 0, 0, 0, 0, S - S_init), "constant")
|
||||
track_mask_vis = (
|
||||
torch.cat([track_mask, vis], dim=-1).permute(0, 2, 1, 3).reshape(B * N, S, 2)
|
||||
)
|
||||
|
||||
ffeats = feat_init.clone()
|
||||
|
||||
times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1)
|
||||
|
||||
pos_embed = sample_pos_embed(
|
||||
grid_size=(H8, W8),
|
||||
embed_dim=456,
|
||||
coords=coords,
|
||||
corr_block = CorrBlock(
|
||||
fmaps,
|
||||
num_levels=4,
|
||||
radius=3,
|
||||
padding_mode="border",
|
||||
)
|
||||
pos_embed = rearrange(pos_embed, "b e n -> (b n) e").unsqueeze(1)
|
||||
times_embed = (
|
||||
torch.from_numpy(get_1d_sincos_pos_embed_from_grid(456, times_[0]))[None]
|
||||
.repeat(B, 1, 1)
|
||||
.float()
|
||||
.to(device)
|
||||
)
|
||||
coord_predictions = []
|
||||
|
||||
sampled_pos_emb = (
|
||||
sample_features4d(self.pos_emb.repeat(B, 1, 1, 1), coords[:, 0])
|
||||
.reshape(B * N, self.input_dim)
|
||||
.unsqueeze(1)
|
||||
) # B E N -> (B N) 1 E
|
||||
|
||||
coord_preds = []
|
||||
for __ in range(iters):
|
||||
coords = coords.detach()
|
||||
fcorr_fn.corr(ffeats)
|
||||
coords = coords.detach() # B S N 2
|
||||
corr_block.corr(track_feat)
|
||||
|
||||
fcorrs = fcorr_fn.sample(coords) # B, S, N, LRR
|
||||
LRR = fcorrs.shape[3]
|
||||
# Sample correlation features around each point
|
||||
fcorrs = corr_block.sample(coords) # (B N) S LRR
|
||||
|
||||
fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR)
|
||||
flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
|
||||
# Get the flow embeddings
|
||||
flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
|
||||
flow_emb = get_2d_embedding(flows, 64, cat_coords=True) # N S E
|
||||
|
||||
flows_cat = get_2d_embedding(flows_, 64, cat_coords=True)
|
||||
ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
|
||||
track_feat_ = track_feat.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
|
||||
|
||||
if track_mask.shape[1] < vis_init.shape[1]:
|
||||
track_mask = torch.cat(
|
||||
[
|
||||
track_mask,
|
||||
torch.zeros_like(track_mask[:, 0]).repeat(
|
||||
1, vis_init.shape[1] - track_mask.shape[1], 1, 1
|
||||
),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
concat = (
|
||||
torch.cat([track_mask, vis_init], dim=2)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B * N, S, 2)
|
||||
transformer_input = torch.cat([flow_emb, fcorrs, track_feat_, track_mask_vis], dim=2)
|
||||
x = transformer_input + sampled_pos_emb + self.time_emb
|
||||
x = x.view(B, N, S, -1) # (B N) S D -> B N S D
|
||||
|
||||
delta = self.updateformer(
|
||||
x,
|
||||
attention_mask.reshape(B * S, N), # B S N -> (B S) N
|
||||
)
|
||||
|
||||
transformer_input = torch.cat([flows_cat, fcorrs_, ffeats_, concat], dim=2)
|
||||
x = transformer_input + pos_embed + times_embed
|
||||
delta_coords = delta[..., :2].permute(0, 2, 1, 3)
|
||||
coords = coords + delta_coords
|
||||
coord_preds.append(coords * self.stride)
|
||||
|
||||
x = rearrange(x, "(b n) t d -> b n t d", b=B)
|
||||
|
||||
delta = self.updateformer(x)
|
||||
|
||||
delta = rearrange(delta, " b n t d -> (b n) t d")
|
||||
|
||||
delta_coords_ = delta[:, :, :2]
|
||||
delta_feats_ = delta[:, :, 2:]
|
||||
|
||||
delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
|
||||
ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N * S, self.latent_dim)
|
||||
|
||||
ffeats_ = self.ffeat_updater(self.norm(delta_feats_)) + ffeats_
|
||||
|
||||
ffeats = ffeats_.reshape(B, N, S, self.latent_dim).permute(
|
||||
delta_feats_ = delta[..., 2:].reshape(B * N * S, self.latent_dim)
|
||||
track_feat_ = track_feat.permute(0, 2, 1, 3).reshape(B * N * S, self.latent_dim)
|
||||
track_feat_ = self.track_feat_updater(self.norm(delta_feats_)) + track_feat_
|
||||
track_feat = track_feat_.reshape(B, N, S, self.latent_dim).permute(
|
||||
0, 2, 1, 3
|
||||
) # B,S,N,C
|
||||
) # (B N S) C -> B S N C
|
||||
|
||||
coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
|
||||
coord_predictions.append(coords * self.stride)
|
||||
vis_pred = self.vis_predictor(track_feat).reshape(B, S, N)
|
||||
return coord_preds, vis_pred
|
||||
|
||||
vis_e = self.vis_predictor(ffeats.reshape(B * S * N, self.latent_dim)).reshape(
|
||||
B, S, N
|
||||
def get_track_feat(self, fmaps, queried_frames, queried_coords):
|
||||
sample_frames = queried_frames[:, None, :, None]
|
||||
sample_coords = torch.cat(
|
||||
[
|
||||
sample_frames,
|
||||
queried_coords[:, None],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return coord_predictions, vis_e, feat_init
|
||||
sample_track_feats = sample_features5d(fmaps, sample_coords)
|
||||
return sample_track_feats
|
||||
|
||||
def forward(self, rgbs, queries, iters=4, feat_init=None, is_train=False):
|
||||
B, T, C, H, W = rgbs.shape
|
||||
def init_video_online_processing(self):
|
||||
self.online_ind = 0
|
||||
self.online_track_feat = None
|
||||
self.online_coords_predicted = None
|
||||
self.online_vis_predicted = None
|
||||
|
||||
def forward(self, video, queries, iters=4, is_train=False, is_online=False):
|
||||
"""Predict tracks
|
||||
|
||||
Args:
|
||||
video (FloatTensor[B, T, 3]): input videos.
|
||||
queries (FloatTensor[B, N, 3]): point queries.
|
||||
iters (int, optional): number of updates. Defaults to 4.
|
||||
is_train (bool, optional): enables training mode. Defaults to False.
|
||||
is_online (bool, optional): enables online mode. Defaults to False. Before enabling, call model.init_video_online_processing().
|
||||
|
||||
Returns:
|
||||
- coords_predicted (FloatTensor[B, T, N, 2]):
|
||||
- vis_predicted (FloatTensor[B, T, N]):
|
||||
- train_data: `None` if `is_train` is false, otherwise:
|
||||
- all_vis_predictions (List[FloatTensor[B, S, N, 1]]):
|
||||
- all_coords_predictions (List[FloatTensor[B, S, N, 2]]):
|
||||
- mask (BoolTensor[B, T, N]):
|
||||
"""
|
||||
B, T, C, H, W = video.shape
|
||||
B, N, __ = queries.shape
|
||||
S = self.window_len
|
||||
device = queries.device
|
||||
|
||||
device = rgbs.device
|
||||
assert B == 1
|
||||
# INIT for the first sequence
|
||||
# We want to sort points by the first frame they are visible to add them to the tensor of tracked points consequtively
|
||||
first_positive_inds = queries[:, :, 0].long()
|
||||
# B = batch size
|
||||
# S = number of frames in the window of the padded video
|
||||
# S_trimmed = actual number of frames in the window
|
||||
# N = number of tracks
|
||||
# C = color channels (3 for RGB)
|
||||
# E = positional embedding size
|
||||
# LRR = local receptive field radius
|
||||
# D = dimension of the transformer input tokens
|
||||
|
||||
__, sort_inds = torch.sort(first_positive_inds[0], dim=0, descending=False)
|
||||
inv_sort_inds = torch.argsort(sort_inds, dim=0)
|
||||
first_positive_sorted_inds = first_positive_inds[0][sort_inds]
|
||||
# video = B T C H W
|
||||
# queries = B N 3
|
||||
# coords_init = B S N 2
|
||||
# vis_init = B S N 1
|
||||
|
||||
assert torch.allclose(
|
||||
first_positive_inds[0], first_positive_inds[0][sort_inds][inv_sort_inds]
|
||||
assert S >= 2 # A tracker needs at least two frames to track something
|
||||
if is_online:
|
||||
assert T <= S, "Online mode: video chunk must be <= window size."
|
||||
assert self.online_ind is not None, "Call model.init_video_online_processing() first."
|
||||
assert not is_train, "Training not supported in online mode."
|
||||
step = S // 2 # How much the sliding window moves at every step
|
||||
video = 2 * (video / 255.0) - 1.0
|
||||
|
||||
# The first channel is the frame number
|
||||
# The rest are the coordinates of points we want to track
|
||||
queried_frames = queries[:, :, 0].long()
|
||||
|
||||
queried_coords = queries[..., 1:]
|
||||
queried_coords = queried_coords / self.stride
|
||||
|
||||
# We store our predictions here
|
||||
coords_predicted = torch.zeros((B, T, N, 2), device=device)
|
||||
vis_predicted = torch.zeros((B, T, N), device=device)
|
||||
if is_online:
|
||||
if self.online_coords_predicted is None:
|
||||
# Init online predictions with zeros
|
||||
self.online_coords_predicted = coords_predicted
|
||||
self.online_vis_predicted = vis_predicted
|
||||
else:
|
||||
# Pad online predictions with zeros for the current window
|
||||
pad = min(step, T - step)
|
||||
coords_predicted = F.pad(
|
||||
self.online_coords_predicted, (0, 0, 0, 0, 0, pad), "constant"
|
||||
)
|
||||
vis_predicted = F.pad(self.online_vis_predicted, (0, 0, 0, pad), "constant")
|
||||
all_coords_predictions, all_vis_predictions = [], []
|
||||
|
||||
# Pad the video so that an integer number of sliding windows fit into it
|
||||
# TODO: we may drop this requirement because the transformer should not care
|
||||
# TODO: pad the features instead of the video
|
||||
pad = S - T if is_online else (S - T % S) % S # We don't want to pad if T % S == 0
|
||||
video = F.pad(video.reshape(B, 1, T, C * H * W), (0, 0, 0, pad), "replicate").reshape(
|
||||
B, -1, C, H, W
|
||||
)
|
||||
|
||||
coords_init = queries[:, :, 1:].reshape(B, 1, N, 2).repeat(
|
||||
1, self.S, 1, 1
|
||||
) / float(self.stride)
|
||||
# Compute convolutional features for the video or for the current chunk in case of online mode
|
||||
fmaps = self.fnet(video.reshape(-1, C, H, W)).reshape(
|
||||
B, -1, self.latent_dim, H // self.stride, W // self.stride
|
||||
)
|
||||
|
||||
rgbs = 2 * (rgbs / 255.0) - 1.0
|
||||
# We compute track features
|
||||
track_feat = self.get_track_feat(
|
||||
fmaps,
|
||||
queried_frames - self.online_ind if is_online else queried_frames,
|
||||
queried_coords,
|
||||
).repeat(1, S, 1, 1)
|
||||
if is_online:
|
||||
# We update track features for the current window
|
||||
sample_frames = queried_frames[:, None, :, None] # B 1 N 1
|
||||
left = 0 if self.online_ind == 0 else self.online_ind + step
|
||||
right = self.online_ind + S
|
||||
sample_mask = (sample_frames >= left) & (sample_frames < right)
|
||||
if self.online_track_feat is None:
|
||||
self.online_track_feat = torch.zeros_like(track_feat, device=device)
|
||||
self.online_track_feat += track_feat * sample_mask
|
||||
track_feat = self.online_track_feat.clone()
|
||||
# We process ((num_windows - 1) * step + S) frames in total, so there are
|
||||
# (ceil((T - S) / step) + 1) windows
|
||||
num_windows = (T - S + step - 1) // step + 1
|
||||
# We process only the current video chunk in the online mode
|
||||
indices = [self.online_ind] if is_online else range(0, step * num_windows, step)
|
||||
|
||||
traj_e = torch.zeros((B, T, N, 2), device=device)
|
||||
vis_e = torch.zeros((B, T, N), device=device)
|
||||
|
||||
ind_array = torch.arange(T, device=device)
|
||||
ind_array = ind_array[None, :, None].repeat(B, 1, N)
|
||||
|
||||
track_mask = (ind_array >= first_positive_inds[:, None, :]).unsqueeze(-1)
|
||||
# these are logits, so we initialize visibility with something that would give a value close to 1 after softmax
|
||||
vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10
|
||||
|
||||
ind = 0
|
||||
|
||||
track_mask_ = track_mask[:, :, sort_inds].clone()
|
||||
coords_init_ = coords_init[:, :, sort_inds].clone()
|
||||
vis_init_ = vis_init[:, :, sort_inds].clone()
|
||||
|
||||
prev_wind_idx = 0
|
||||
fmaps_ = None
|
||||
vis_predictions = []
|
||||
coord_predictions = []
|
||||
wind_inds = []
|
||||
while ind < T - self.S // 2:
|
||||
rgbs_seq = rgbs[:, ind : ind + self.S]
|
||||
|
||||
S = S_local = rgbs_seq.shape[1]
|
||||
if S < self.S:
|
||||
rgbs_seq = torch.cat(
|
||||
[rgbs_seq, rgbs_seq[:, -1, None].repeat(1, self.S - S, 1, 1, 1)],
|
||||
dim=1,
|
||||
coords_init = queried_coords.reshape(B, 1, N, 2).expand(B, S, N, 2).float()
|
||||
vis_init = torch.ones((B, S, N, 1), device=device).float() * 10
|
||||
for ind in indices:
|
||||
# We copy over coords and vis for tracks that are queried
|
||||
# by the end of the previous window, which is ind + overlap
|
||||
if ind > 0:
|
||||
overlap = S - step
|
||||
copy_over = (queried_frames < ind + overlap)[:, None, :, None] # B 1 N 1
|
||||
coords_prev = torch.nn.functional.pad(
|
||||
coords_predicted[:, ind : ind + overlap] / self.stride,
|
||||
(0, 0, 0, 0, 0, step),
|
||||
"replicate",
|
||||
) # B S N 2
|
||||
vis_prev = torch.nn.functional.pad(
|
||||
vis_predicted[:, ind : ind + overlap, :, None].clone(),
|
||||
(0, 0, 0, 0, 0, step),
|
||||
"replicate",
|
||||
) # B S N 1
|
||||
coords_init = torch.where(
|
||||
copy_over.expand_as(coords_init), coords_prev, coords_init
|
||||
)
|
||||
S = rgbs_seq.shape[1]
|
||||
rgbs_ = rgbs_seq.reshape(B * S, C, H, W)
|
||||
vis_init = torch.where(copy_over.expand_as(vis_init), vis_prev, vis_init)
|
||||
|
||||
if fmaps_ is None:
|
||||
fmaps_ = self.fnet(rgbs_)
|
||||
else:
|
||||
fmaps_ = torch.cat(
|
||||
[fmaps_[self.S // 2 :], self.fnet(rgbs_[self.S // 2 :])], dim=0
|
||||
)
|
||||
fmaps = fmaps_.reshape(
|
||||
B, S, self.latent_dim, H // self.stride, W // self.stride
|
||||
)
|
||||
# The attention mask is 1 for the spatio-temporal points within
|
||||
# a track which is updated in the current window
|
||||
attention_mask = (queried_frames < ind + S).reshape(B, 1, N).repeat(1, S, 1) # B S N
|
||||
|
||||
curr_wind_points = torch.nonzero(first_positive_sorted_inds < ind + self.S)
|
||||
if curr_wind_points.shape[0] == 0:
|
||||
ind = ind + self.S // 2
|
||||
continue
|
||||
wind_idx = curr_wind_points[-1] + 1
|
||||
# The track mask is 1 for the spatio-temporal points that actually
|
||||
# need updating: only after begin queried, and not if contained
|
||||
# in a previous window
|
||||
track_mask = (
|
||||
queried_frames[:, None, :, None]
|
||||
<= torch.arange(ind, ind + S, device=device)[None, :, None, None]
|
||||
).contiguous() # B S N 1
|
||||
|
||||
if wind_idx - prev_wind_idx > 0:
|
||||
fmaps_sample = fmaps[
|
||||
:, first_positive_sorted_inds[prev_wind_idx:wind_idx] - ind
|
||||
]
|
||||
if ind > 0:
|
||||
track_mask[:, :overlap, :, :] = False
|
||||
|
||||
feat_init_ = bilinear_sample2d(
|
||||
fmaps_sample,
|
||||
coords_init_[:, 0, prev_wind_idx:wind_idx, 0],
|
||||
coords_init_[:, 0, prev_wind_idx:wind_idx, 1],
|
||||
).permute(0, 2, 1)
|
||||
|
||||
feat_init_ = feat_init_.unsqueeze(1).repeat(1, self.S, 1, 1)
|
||||
feat_init = smart_cat(feat_init, feat_init_, dim=2)
|
||||
|
||||
if prev_wind_idx > 0:
|
||||
new_coords = coords[-1][:, self.S // 2 :] / float(self.stride)
|
||||
|
||||
coords_init_[:, : self.S // 2, :prev_wind_idx] = new_coords
|
||||
coords_init_[:, self.S // 2 :, :prev_wind_idx] = new_coords[
|
||||
:, -1
|
||||
].repeat(1, self.S // 2, 1, 1)
|
||||
|
||||
new_vis = vis[:, self.S // 2 :].unsqueeze(-1)
|
||||
vis_init_[:, : self.S // 2, :prev_wind_idx] = new_vis
|
||||
vis_init_[:, self.S // 2 :, :prev_wind_idx] = new_vis[:, -1].repeat(
|
||||
1, self.S // 2, 1, 1
|
||||
)
|
||||
|
||||
coords, vis, __ = self.forward_iteration(
|
||||
fmaps=fmaps,
|
||||
coords_init=coords_init_[:, :, :wind_idx],
|
||||
feat_init=feat_init[:, :, :wind_idx],
|
||||
vis_init=vis_init_[:, :, :wind_idx],
|
||||
track_mask=track_mask_[:, ind : ind + self.S, :wind_idx],
|
||||
# Predict the coordinates and visibility for the current window
|
||||
coords, vis = self.forward_window(
|
||||
fmaps=fmaps if is_online else fmaps[:, ind : ind + S],
|
||||
coords=coords_init,
|
||||
track_feat=attention_mask.unsqueeze(-1) * track_feat,
|
||||
vis=vis_init,
|
||||
track_mask=track_mask,
|
||||
attention_mask=attention_mask,
|
||||
iters=iters,
|
||||
)
|
||||
|
||||
S_trimmed = T if is_online else min(T - ind, S) # accounts for last window duration
|
||||
coords_predicted[:, ind : ind + S] = coords[-1][:, :S_trimmed]
|
||||
vis_predicted[:, ind : ind + S] = vis[:, :S_trimmed]
|
||||
if is_train:
|
||||
vis_predictions.append(torch.sigmoid(vis[:, :S_local]))
|
||||
coord_predictions.append([coord[:, :S_local] for coord in coords])
|
||||
wind_inds.append(wind_idx)
|
||||
all_coords_predictions.append([coord[:, :S_trimmed] for coord in coords])
|
||||
all_vis_predictions.append(torch.sigmoid(vis[:, :S_trimmed]))
|
||||
|
||||
traj_e[:, ind : ind + self.S, :wind_idx] = coords[-1][:, :S_local]
|
||||
vis_e[:, ind : ind + self.S, :wind_idx] = vis[:, :S_local]
|
||||
if is_online:
|
||||
self.online_ind += step
|
||||
self.online_coords_predicted = coords_predicted
|
||||
self.online_vis_predicted = vis_predicted
|
||||
vis_predicted = torch.sigmoid(vis_predicted)
|
||||
|
||||
track_mask_[:, : ind + self.S, :wind_idx] = 0.0
|
||||
ind = ind + self.S // 2
|
||||
if is_train:
|
||||
mask = queried_frames[:, None] <= torch.arange(0, T, device=device)[None, :, None]
|
||||
train_data = (all_coords_predictions, all_vis_predictions, mask)
|
||||
else:
|
||||
train_data = None
|
||||
|
||||
prev_wind_idx = wind_idx
|
||||
return coords_predicted, vis_predicted, train_data
|
||||
|
||||
traj_e = traj_e[:, :, inv_sort_inds]
|
||||
vis_e = vis_e[:, :, inv_sort_inds]
|
||||
|
||||
vis_e = torch.sigmoid(vis_e)
|
||||
class EfficientUpdateFormer(nn.Module):
|
||||
"""
|
||||
Transformer model that updates track estimates.
|
||||
"""
|
||||
|
||||
train_data = (
|
||||
(vis_predictions, coord_predictions, wind_inds, sort_inds)
|
||||
if is_train
|
||||
else None
|
||||
def __init__(
|
||||
self,
|
||||
space_depth=6,
|
||||
time_depth=6,
|
||||
input_dim=320,
|
||||
hidden_size=384,
|
||||
num_heads=8,
|
||||
output_dim=130,
|
||||
mlp_ratio=4.0,
|
||||
add_space_attn=True,
|
||||
num_virtual_tracks=64,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = 2
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.add_space_attn = add_space_attn
|
||||
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
||||
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
||||
self.num_virtual_tracks = num_virtual_tracks
|
||||
self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
|
||||
self.time_blocks = nn.ModuleList(
|
||||
[
|
||||
AttnBlock(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attn_class=Attention,
|
||||
)
|
||||
for _ in range(time_depth)
|
||||
]
|
||||
)
|
||||
return traj_e, feat_init, vis_e, train_data
|
||||
|
||||
if add_space_attn:
|
||||
self.space_virtual_blocks = nn.ModuleList(
|
||||
[
|
||||
AttnBlock(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attn_class=Attention,
|
||||
)
|
||||
for _ in range(space_depth)
|
||||
]
|
||||
)
|
||||
self.space_point2virtual_blocks = nn.ModuleList(
|
||||
[
|
||||
CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio)
|
||||
for _ in range(space_depth)
|
||||
]
|
||||
)
|
||||
self.space_virtual2point_blocks = nn.ModuleList(
|
||||
[
|
||||
CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio)
|
||||
for _ in range(space_depth)
|
||||
]
|
||||
)
|
||||
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
def forward(self, input_tensor, mask=None):
|
||||
tokens = self.input_transform(input_tensor)
|
||||
B, _, T, _ = tokens.shape
|
||||
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
|
||||
tokens = torch.cat([tokens, virtual_tokens], dim=1)
|
||||
_, N, _, _ = tokens.shape
|
||||
|
||||
j = 0
|
||||
for i in range(len(self.time_blocks)):
|
||||
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
|
||||
time_tokens = self.time_blocks[i](time_tokens)
|
||||
|
||||
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
|
||||
if self.add_space_attn and (
|
||||
i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0
|
||||
):
|
||||
space_tokens = (
|
||||
tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
|
||||
) # B N T C -> (B T) N C
|
||||
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
|
||||
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
|
||||
|
||||
virtual_tokens = self.space_virtual2point_blocks[j](
|
||||
virtual_tokens, point_tokens, mask=mask
|
||||
)
|
||||
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
|
||||
point_tokens = self.space_point2virtual_blocks[j](
|
||||
point_tokens, virtual_tokens, mask=mask
|
||||
)
|
||||
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
|
||||
tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
|
||||
j += 1
|
||||
tokens = tokens[:, : N - self.num_virtual_tracks]
|
||||
flow = self.flow_head(tokens)
|
||||
return flow
|
||||
|
||||
|
||||
class CrossAttnBlock(nn.Module):
|
||||
def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.norm_context = nn.LayerNorm(hidden_size)
|
||||
self.cross_attn = Attention(
|
||||
hidden_size, context_dim=context_dim, num_heads=num_heads, qkv_bias=True, **block_kwargs
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=approx_gelu,
|
||||
drop=0,
|
||||
)
|
||||
|
||||
def forward(self, x, context, mask=None):
|
||||
if mask is not None:
|
||||
if mask.shape[1] == x.shape[1]:
|
||||
mask = mask[:, None, :, None].expand(
|
||||
-1, self.cross_attn.heads, -1, context.shape[1]
|
||||
)
|
||||
else:
|
||||
mask = mask[:, None, None].expand(-1, self.cross_attn.heads, x.shape[1], -1)
|
||||
|
||||
max_neg_value = -torch.finfo(x.dtype).max
|
||||
attn_bias = (~mask) * max_neg_value
|
||||
x = x + self.cross_attn(
|
||||
self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias
|
||||
)
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
@ -4,67 +4,98 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Tuple, Union
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
||||
def get_2d_sincos_pos_embed(
|
||||
embed_dim: int, grid_size: Union[int, Tuple[int, int]]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
|
||||
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
|
||||
Args:
|
||||
- embed_dim: The embedding dimension.
|
||||
- grid_size: The grid size.
|
||||
Returns:
|
||||
- pos_embed: The generated 2D positional embedding.
|
||||
"""
|
||||
if isinstance(grid_size, tuple):
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
else:
|
||||
grid_size_h = grid_size_w = grid_size
|
||||
grid_h = np.arange(grid_size_h, dtype=np.float32)
|
||||
grid_w = np.arange(grid_size_w, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid_h = torch.arange(grid_size_h, dtype=torch.float)
|
||||
grid_w = torch.arange(grid_size_w, dtype=torch.float)
|
||||
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate(
|
||||
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
|
||||
)
|
||||
return pos_embed
|
||||
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
def get_2d_sincos_pos_embed_from_grid(
|
||||
embed_dim: int, grid: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
|
||||
|
||||
Args:
|
||||
- embed_dim: The embedding dimension.
|
||||
- grid: The grid to generate the embedding from.
|
||||
|
||||
Returns:
|
||||
- emb: The generated 2D positional embedding.
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
def get_1d_sincos_pos_embed_from_grid(
|
||||
embed_dim: int, pos: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
||||
|
||||
Args:
|
||||
- embed_dim: The embedding dimension.
|
||||
- pos: The position to generate the embedding from.
|
||||
|
||||
Returns:
|
||||
- emb: The generated 1D positional embedding.
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000 ** omega # (D/2,)
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
emb_sin = torch.sin(out) # (M, D/2)
|
||||
emb_cos = torch.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
||||
return emb[None].float()
|
||||
|
||||
|
||||
def get_2d_embedding(xy, C, cat_coords=True):
|
||||
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
|
||||
|
||||
Args:
|
||||
- xy: The coordinates to generate the embedding from.
|
||||
- C: The size of the embedding.
|
||||
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
|
||||
|
||||
Returns:
|
||||
- pe: The generated 2D positional embedding.
|
||||
"""
|
||||
B, N, D = xy.shape
|
||||
assert D == 2
|
||||
|
||||
@ -83,72 +114,7 @@ def get_2d_embedding(xy, C, cat_coords=True):
|
||||
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
||||
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
||||
|
||||
pe = torch.cat([pe_x, pe_y], dim=2) # B, N, C*3
|
||||
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
|
||||
if cat_coords:
|
||||
pe = torch.cat([xy, pe], dim=2) # B, N, C*3+3
|
||||
return pe
|
||||
|
||||
|
||||
def get_3d_embedding(xyz, C, cat_coords=True):
|
||||
B, N, D = xyz.shape
|
||||
assert D == 3
|
||||
|
||||
x = xyz[:, :, 0:1]
|
||||
y = xyz[:, :, 1:2]
|
||||
z = xyz[:, :, 2:3]
|
||||
div_term = (
|
||||
torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C)
|
||||
).reshape(1, 1, int(C / 2))
|
||||
|
||||
pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
||||
pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
||||
pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
||||
|
||||
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
||||
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
||||
|
||||
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
||||
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
||||
|
||||
pe_z[:, :, 0::2] = torch.sin(z * div_term)
|
||||
pe_z[:, :, 1::2] = torch.cos(z * div_term)
|
||||
|
||||
pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3
|
||||
if cat_coords:
|
||||
pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3
|
||||
return pe
|
||||
|
||||
|
||||
def get_4d_embedding(xyzw, C, cat_coords=True):
|
||||
B, N, D = xyzw.shape
|
||||
assert D == 4
|
||||
|
||||
x = xyzw[:, :, 0:1]
|
||||
y = xyzw[:, :, 1:2]
|
||||
z = xyzw[:, :, 2:3]
|
||||
w = xyzw[:, :, 3:4]
|
||||
div_term = (
|
||||
torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C)
|
||||
).reshape(1, 1, int(C / 2))
|
||||
|
||||
pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
||||
pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
||||
pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
||||
pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
||||
|
||||
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
||||
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
||||
|
||||
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
||||
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
||||
|
||||
pe_z[:, :, 0::2] = torch.sin(z * div_term)
|
||||
pe_z[:, :, 1::2] = torch.cos(z * div_term)
|
||||
|
||||
pe_w[:, :, 0::2] = torch.sin(w * div_term)
|
||||
pe_w[:, :, 1::2] = torch.cos(w * div_term)
|
||||
|
||||
pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2) # B, N, C*3
|
||||
if cat_coords:
|
||||
pe = torch.cat([pe, xyzw], dim=2) # B, N, C*3+3
|
||||
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
|
||||
return pe
|
||||
|
@ -5,6 +5,8 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
@ -15,155 +17,240 @@ def smart_cat(tensor1, tensor2, dim):
|
||||
return torch.cat([tensor1, tensor2], dim=dim)
|
||||
|
||||
|
||||
def normalize_single(d):
|
||||
# d is a whatever shape torch tensor
|
||||
dmin = torch.min(d)
|
||||
dmax = torch.max(d)
|
||||
d = (d - dmin) / (EPS + (dmax - dmin))
|
||||
return d
|
||||
def get_points_on_a_grid(
|
||||
size: int,
|
||||
extent: Tuple[float, ...],
|
||||
center: Optional[Tuple[float, ...]] = None,
|
||||
device: Optional[torch.device] = torch.device("cpu"),
|
||||
):
|
||||
r"""Get a grid of points covering a rectangular region
|
||||
|
||||
`get_points_on_a_grid(size, extent)` generates a :attr:`size` by
|
||||
:attr:`size` grid fo points distributed to cover a rectangular area
|
||||
specified by `extent`.
|
||||
|
||||
The `extent` is a pair of integer :math:`(H,W)` specifying the height
|
||||
and width of the rectangle.
|
||||
|
||||
Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
|
||||
specifying the vertical and horizontal center coordinates. The center
|
||||
defaults to the middle of the extent.
|
||||
|
||||
Points are distributed uniformly within the rectangle leaving a margin
|
||||
:math:`m=W/64` from the border.
|
||||
|
||||
It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
|
||||
points :math:`P_{ij}=(x_i, y_i)` where
|
||||
|
||||
.. math::
|
||||
P_{ij} = \left(
|
||||
c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
|
||||
c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
|
||||
\right)
|
||||
|
||||
Points are returned in row-major order.
|
||||
|
||||
Args:
|
||||
size (int): grid size.
|
||||
extent (tuple): height and with of the grid extent.
|
||||
center (tuple, optional): grid center.
|
||||
device (str, optional): Defaults to `"cpu"`.
|
||||
|
||||
Returns:
|
||||
Tensor: grid.
|
||||
"""
|
||||
if size == 1:
|
||||
return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
|
||||
|
||||
if center is None:
|
||||
center = [extent[0] / 2, extent[1] / 2]
|
||||
|
||||
margin = extent[1] / 64
|
||||
range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
|
||||
range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
|
||||
grid_y, grid_x = torch.meshgrid(
|
||||
torch.linspace(*range_y, size, device=device),
|
||||
torch.linspace(*range_x, size, device=device),
|
||||
indexing="ij",
|
||||
)
|
||||
return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
|
||||
|
||||
|
||||
def normalize(d):
|
||||
# d is B x whatever. normalize within each element of the batch
|
||||
out = torch.zeros(d.size())
|
||||
if d.is_cuda:
|
||||
out = out.cuda()
|
||||
B = list(d.size())[0]
|
||||
for b in list(range(B)):
|
||||
out[b] = normalize_single(d[b])
|
||||
return out
|
||||
def reduce_masked_mean(input, mask, dim=None, keepdim=False):
|
||||
r"""Masked mean
|
||||
|
||||
`reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
|
||||
over a mask :attr:`mask`, returning
|
||||
|
||||
def meshgrid2d(B, Y, X, stack=False, norm=False, device="cpu"):
|
||||
# returns a meshgrid sized B x Y x X
|
||||
.. math::
|
||||
\text{output} =
|
||||
\frac
|
||||
{\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
|
||||
{\epsilon + \sum_{i=1}^N \text{mask}_i}
|
||||
|
||||
grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))
|
||||
grid_y = torch.reshape(grid_y, [1, Y, 1])
|
||||
grid_y = grid_y.repeat(B, 1, X)
|
||||
where :math:`N` is the number of elements in :attr:`input` and
|
||||
:attr:`mask`, and :math:`\epsilon` is a small constant to avoid
|
||||
division by zero.
|
||||
|
||||
grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device))
|
||||
grid_x = torch.reshape(grid_x, [1, 1, X])
|
||||
grid_x = grid_x.repeat(B, Y, 1)
|
||||
`reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
|
||||
:attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
|
||||
Optionally, the dimension can be kept in the output by setting
|
||||
:attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
|
||||
the same dimension as :attr:`input`.
|
||||
|
||||
if stack:
|
||||
# note we stack in xy order
|
||||
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
|
||||
grid = torch.stack([grid_x, grid_y], dim=-1)
|
||||
return grid
|
||||
else:
|
||||
return grid_y, grid_x
|
||||
The interface is similar to `torch.mean()`.
|
||||
|
||||
Args:
|
||||
inout (Tensor): input tensor.
|
||||
mask (Tensor): mask.
|
||||
dim (int, optional): Dimension to sum over. Defaults to None.
|
||||
keepdim (bool, optional): Keep the summed dimension. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Tensor: mean tensor.
|
||||
"""
|
||||
|
||||
mask = mask.expand_as(input)
|
||||
|
||||
prod = input * mask
|
||||
|
||||
def reduce_masked_mean(x, mask, dim=None, keepdim=False):
|
||||
# x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
|
||||
# returns shape-1
|
||||
# axis can be a list of axes
|
||||
for (a, b) in zip(x.size(), mask.size()):
|
||||
assert a == b # some shape mismatch!
|
||||
prod = x * mask
|
||||
if dim is None:
|
||||
numer = torch.sum(prod)
|
||||
denom = EPS + torch.sum(mask)
|
||||
denom = torch.sum(mask)
|
||||
else:
|
||||
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
||||
denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim)
|
||||
denom = torch.sum(mask, dim=dim, keepdim=keepdim)
|
||||
|
||||
mean = numer / denom
|
||||
mean = numer / (EPS + denom)
|
||||
return mean
|
||||
|
||||
|
||||
def bilinear_sample2d(im, x, y, return_inbounds=False):
|
||||
# x and y are each B, N
|
||||
# output is B, C, N
|
||||
if len(im.shape) == 5:
|
||||
B, N, C, H, W = list(im.shape)
|
||||
else:
|
||||
B, C, H, W = list(im.shape)
|
||||
N = list(x.shape)[1]
|
||||
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
|
||||
r"""Sample a tensor using bilinear interpolation
|
||||
|
||||
x = x.float()
|
||||
y = y.float()
|
||||
H_f = torch.tensor(H, dtype=torch.float32)
|
||||
W_f = torch.tensor(W, dtype=torch.float32)
|
||||
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
||||
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
||||
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
||||
convention.
|
||||
|
||||
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
|
||||
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
||||
:math:`B` is the batch size, :math:`C` is the number of channels,
|
||||
:math:`H` is the height of the image, and :math:`W` is the width of the
|
||||
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
||||
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
||||
|
||||
max_y = (H_f - 1).int()
|
||||
max_x = (W_f - 1).int()
|
||||
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
||||
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
||||
that in this case the order of the components is slightly different
|
||||
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
||||
|
||||
x0 = torch.floor(x).int()
|
||||
x1 = x0 + 1
|
||||
y0 = torch.floor(y).int()
|
||||
y1 = y0 + 1
|
||||
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
||||
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
||||
left-most image pixel :math:`W-1` to the center of the right-most
|
||||
pixel.
|
||||
|
||||
x0_clip = torch.clamp(x0, 0, max_x)
|
||||
x1_clip = torch.clamp(x1, 0, max_x)
|
||||
y0_clip = torch.clamp(y0, 0, max_y)
|
||||
y1_clip = torch.clamp(y1, 0, max_y)
|
||||
dim2 = W
|
||||
dim1 = W * H
|
||||
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
||||
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
||||
the left-most pixel :math:`W` to the right edge of the right-most
|
||||
pixel.
|
||||
|
||||
base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1
|
||||
base = torch.reshape(base, [B, 1]).repeat([1, N])
|
||||
Similar conventions apply to the :math:`y` for the range
|
||||
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
||||
:math:`[0,T-1]` and :math:`[0,T]`.
|
||||
|
||||
base_y0 = base + y0_clip * dim2
|
||||
base_y1 = base + y1_clip * dim2
|
||||
Args:
|
||||
input (Tensor): batch of input images.
|
||||
coords (Tensor): batch of coordinates.
|
||||
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
||||
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
||||
|
||||
idx_y0_x0 = base_y0 + x0_clip
|
||||
idx_y0_x1 = base_y0 + x1_clip
|
||||
idx_y1_x0 = base_y1 + x0_clip
|
||||
idx_y1_x1 = base_y1 + x1_clip
|
||||
Returns:
|
||||
Tensor: sampled points.
|
||||
"""
|
||||
|
||||
# use the indices to lookup pixels in the flat image
|
||||
# im is B x C x H x W
|
||||
# move C out to last dim
|
||||
if len(im.shape) == 5:
|
||||
im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C)
|
||||
i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute(
|
||||
0, 2, 1
|
||||
sizes = input.shape[2:]
|
||||
|
||||
assert len(sizes) in [2, 3]
|
||||
|
||||
if len(sizes) == 3:
|
||||
# t x y -> x y t to match dimensions T H W in grid_sample
|
||||
coords = coords[..., [1, 2, 0]]
|
||||
|
||||
if align_corners:
|
||||
coords = coords * torch.tensor(
|
||||
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
|
||||
)
|
||||
else:
|
||||
im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C)
|
||||
i_y0_x0 = im_flat[idx_y0_x0.long()]
|
||||
i_y0_x1 = im_flat[idx_y0_x1.long()]
|
||||
i_y1_x0 = im_flat[idx_y1_x0.long()]
|
||||
i_y1_x1 = im_flat[idx_y1_x1.long()]
|
||||
coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
|
||||
|
||||
# Finally calculate interpolated values.
|
||||
x0_f = x0.float()
|
||||
x1_f = x1.float()
|
||||
y0_f = y0.float()
|
||||
y1_f = y1.float()
|
||||
coords -= 1
|
||||
|
||||
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
|
||||
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
|
||||
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
|
||||
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
|
||||
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
|
||||
|
||||
output = (
|
||||
w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
|
||||
)
|
||||
# output is B*N x C
|
||||
output = output.view(B, -1, C)
|
||||
output = output.permute(0, 2, 1)
|
||||
# output is B x C x N
|
||||
|
||||
if return_inbounds:
|
||||
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
|
||||
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
|
||||
inbounds = (x_valid & y_valid).float()
|
||||
inbounds = inbounds.reshape(
|
||||
B, N
|
||||
) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
|
||||
return output, inbounds
|
||||
def sample_features4d(input, coords):
|
||||
r"""Sample spatial features
|
||||
|
||||
return output # B, C, N
|
||||
`sample_features4d(input, coords)` samples the spatial features
|
||||
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
||||
|
||||
The field is sampled at coordinates :attr:`coords` using bilinear
|
||||
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
||||
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
||||
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
||||
|
||||
The output tensor has one feature per point, and has shape :math:`(B,
|
||||
R, C)`.
|
||||
|
||||
Args:
|
||||
input (Tensor): spatial features.
|
||||
coords (Tensor): points.
|
||||
|
||||
Returns:
|
||||
Tensor: sampled features.
|
||||
"""
|
||||
|
||||
B, _, _, _ = input.shape
|
||||
|
||||
# B R 2 -> B R 1 2
|
||||
coords = coords.unsqueeze(2)
|
||||
|
||||
# B C R 1
|
||||
feats = bilinear_sampler(input, coords)
|
||||
|
||||
return feats.permute(0, 2, 1, 3).view(
|
||||
B, -1, feats.shape[1] * feats.shape[3]
|
||||
) # B C R 1 -> B R C
|
||||
|
||||
|
||||
def sample_features5d(input, coords):
|
||||
r"""Sample spatio-temporal features
|
||||
|
||||
`sample_features5d(input, coords)` works in the same way as
|
||||
:func:`sample_features4d` but for spatio-temporal features and points:
|
||||
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
|
||||
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
|
||||
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
|
||||
|
||||
Args:
|
||||
input (Tensor): spatio-temporal features.
|
||||
coords (Tensor): spatio-temporal points.
|
||||
|
||||
Returns:
|
||||
Tensor: sampled features.
|
||||
"""
|
||||
|
||||
B, T, _, _, _ = input.shape
|
||||
|
||||
# B T C H W -> B C T H W
|
||||
input = input.permute(0, 2, 1, 3, 4)
|
||||
|
||||
# B R1 R2 3 -> B R1 R2 1 3
|
||||
coords = coords.unsqueeze(3)
|
||||
|
||||
# B C R1 R2 1
|
||||
feats = bilinear_sampler(input, coords)
|
||||
|
||||
return feats.permute(0, 2, 3, 1, 4).view(
|
||||
B, feats.shape[2], feats.shape[3], feats.shape[1]
|
||||
) # B C R1 R2 1 -> B R1 R2 C
|
||||
|
@ -8,16 +8,17 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
|
||||
from cotracker.models.core.cotracker.cotracker import CoTracker, get_points_on_a_grid
|
||||
from cotracker.models.core.cotracker.cotracker import CoTracker2
|
||||
from cotracker.models.core.model_utils import get_points_on_a_grid
|
||||
|
||||
|
||||
class EvaluationPredictor(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cotracker_model: CoTracker,
|
||||
cotracker_model: CoTracker2,
|
||||
interp_shape: Tuple[int, int] = (384, 512),
|
||||
grid_size: int = 6,
|
||||
local_grid_size: int = 6,
|
||||
grid_size: int = 5,
|
||||
local_grid_size: int = 8,
|
||||
single_point: bool = True,
|
||||
n_iters: int = 6,
|
||||
) -> None:
|
||||
@ -39,14 +40,14 @@ class EvaluationPredictor(torch.nn.Module):
|
||||
assert D == 3
|
||||
assert B == 1
|
||||
|
||||
rgbs = video.reshape(B * T, C, H, W)
|
||||
rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear")
|
||||
rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
video = video.reshape(B * T, C, H, W)
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
|
||||
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
device = rgbs.device
|
||||
device = video.device
|
||||
|
||||
queries[:, :, 1] *= self.interp_shape[1] / W
|
||||
queries[:, :, 2] *= self.interp_shape[0] / H
|
||||
queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1)
|
||||
queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1)
|
||||
|
||||
if self.single_point:
|
||||
traj_e = torch.zeros((B, T, N, 2), device=device)
|
||||
@ -56,51 +57,49 @@ class EvaluationPredictor(torch.nn.Module):
|
||||
|
||||
t = query[0, 0, 0].long()
|
||||
|
||||
traj_e_pind, vis_e_pind = self._process_one_point(rgbs, query)
|
||||
traj_e_pind, vis_e_pind = self._process_one_point(video, query)
|
||||
traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
|
||||
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
|
||||
else:
|
||||
if self.grid_size > 0:
|
||||
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device)
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(
|
||||
device
|
||||
) #
|
||||
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
||||
queries = torch.cat([queries, xy], dim=1) #
|
||||
|
||||
traj_e, __, vis_e, __ = self.model(
|
||||
rgbs=rgbs,
|
||||
traj_e, vis_e, __ = self.model(
|
||||
video=video,
|
||||
queries=queries,
|
||||
iters=self.n_iters,
|
||||
)
|
||||
|
||||
traj_e[:, :, :, 0] *= W / float(self.interp_shape[1])
|
||||
traj_e[:, :, :, 1] *= H / float(self.interp_shape[0])
|
||||
traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1)
|
||||
traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1)
|
||||
return traj_e, vis_e
|
||||
|
||||
def _process_one_point(self, rgbs, query):
|
||||
def _process_one_point(self, video, query):
|
||||
t = query[0, 0, 0].long()
|
||||
|
||||
device = rgbs.device
|
||||
device = query.device
|
||||
if self.local_grid_size > 0:
|
||||
xy_target = get_points_on_a_grid(
|
||||
self.local_grid_size,
|
||||
(50, 50),
|
||||
[query[0, 0, 2], query[0, 0, 1]],
|
||||
[query[0, 0, 2].item(), query[0, 0, 1].item()],
|
||||
)
|
||||
|
||||
xy_target = torch.cat(
|
||||
[torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2
|
||||
xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to(
|
||||
device
|
||||
) #
|
||||
query = torch.cat([query, xy_target], dim=1).to(device) #
|
||||
query = torch.cat([query, xy_target], dim=1) #
|
||||
|
||||
if self.grid_size > 0:
|
||||
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device)
|
||||
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
||||
query = torch.cat([query, xy], dim=1).to(device) #
|
||||
query = torch.cat([query, xy], dim=1) #
|
||||
# crop the video to start from the queried frame
|
||||
query[0, 0, 0] = 0
|
||||
traj_e_pind, __, vis_e_pind, __ = self.model(
|
||||
rgbs=rgbs[:, t:], queries=query, iters=self.n_iters
|
||||
traj_e_pind, vis_e_pind, __ = self.model(
|
||||
video=video[:, t:], queries=query, iters=self.n_iters
|
||||
)
|
||||
|
||||
return traj_e_pind, vis_e_pind
|
||||
|
@ -7,23 +7,16 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tqdm import tqdm
|
||||
from cotracker.models.core.cotracker.cotracker import get_points_on_a_grid
|
||||
from cotracker.models.core.model_utils import smart_cat
|
||||
from cotracker.models.build_cotracker import (
|
||||
build_cotracker,
|
||||
)
|
||||
from cotracker.models.core.model_utils import smart_cat, get_points_on_a_grid
|
||||
from cotracker.models.build_cotracker import build_cotracker
|
||||
|
||||
|
||||
class CoTrackerPredictor(torch.nn.Module):
|
||||
def __init__(
|
||||
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth"
|
||||
):
|
||||
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
|
||||
super().__init__()
|
||||
self.interp_shape = (384, 512)
|
||||
self.support_grid_size = 6
|
||||
model = build_cotracker(checkpoint)
|
||||
|
||||
self.interp_shape = model.model_resolution
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
|
||||
@ -43,7 +36,6 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
grid_query_frame: int = 0, # only for dense and regular grid tracks
|
||||
backward_tracking: bool = False,
|
||||
):
|
||||
|
||||
if queries is None and grid_size == 0:
|
||||
tracks, visibilities = self._compute_dense_tracks(
|
||||
video,
|
||||
@ -63,9 +55,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
|
||||
return tracks, visibilities
|
||||
|
||||
def _compute_dense_tracks(
|
||||
self, video, grid_query_frame, grid_size=30, backward_tracking=False
|
||||
):
|
||||
def _compute_dense_tracks(self, video, grid_query_frame, grid_size=30, backward_tracking=False):
|
||||
*_, H, W = video.shape
|
||||
grid_step = W // grid_size
|
||||
grid_width = W // grid_step
|
||||
@ -73,12 +63,11 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
tracks = visibilities = None
|
||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
|
||||
grid_pts[0, :, 0] = grid_query_frame
|
||||
for offset in tqdm(range(grid_step * grid_step)):
|
||||
for offset in range(grid_step * grid_step):
|
||||
print(f"step {offset} / {grid_step * grid_step}")
|
||||
ox = offset % grid_step
|
||||
oy = offset // grid_step
|
||||
grid_pts[0, :, 1] = (
|
||||
torch.arange(grid_width).repeat(grid_height) * grid_step + ox
|
||||
)
|
||||
grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox
|
||||
grid_pts[0, :, 2] = (
|
||||
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
|
||||
)
|
||||
@ -106,21 +95,23 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
assert B == 1
|
||||
|
||||
video = video.reshape(B * T, C, H, W)
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear")
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
|
||||
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
if queries is not None:
|
||||
queries = queries.clone()
|
||||
B, N, D = queries.shape
|
||||
assert D == 3
|
||||
queries[:, :, 1] *= self.interp_shape[1] / W
|
||||
queries[:, :, 2] *= self.interp_shape[0] / H
|
||||
queries = queries.clone()
|
||||
queries[:, :, 1:] *= queries.new_tensor(
|
||||
[
|
||||
(self.interp_shape[1] - 1) / (W - 1),
|
||||
(self.interp_shape[0] - 1) / (H - 1),
|
||||
]
|
||||
)
|
||||
elif grid_size > 0:
|
||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
|
||||
if segm_mask is not None:
|
||||
segm_mask = F.interpolate(
|
||||
segm_mask, tuple(self.interp_shape), mode="nearest"
|
||||
)
|
||||
segm_mask = F.interpolate(segm_mask, tuple(self.interp_shape), mode="nearest")
|
||||
point_mask = segm_mask[0, 0][
|
||||
(grid_pts[0, :, 1]).round().long().cpu(),
|
||||
(grid_pts[0, :, 0]).round().long().cpu(),
|
||||
@ -133,23 +124,23 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
)
|
||||
|
||||
if add_support_grid:
|
||||
grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=video.device)
|
||||
grid_pts = torch.cat(
|
||||
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
|
||||
grid_pts = get_points_on_a_grid(
|
||||
self.support_grid_size, self.interp_shape, device=video.device
|
||||
)
|
||||
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
|
||||
queries = torch.cat([queries, grid_pts], dim=1)
|
||||
|
||||
tracks, __, visibilities, __ = self.model(rgbs=video, queries=queries, iters=6)
|
||||
tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6)
|
||||
|
||||
if backward_tracking:
|
||||
tracks, visibilities = self._compute_backward_tracks(
|
||||
video, queries, tracks, visibilities
|
||||
)
|
||||
if add_support_grid:
|
||||
queries[:, -self.support_grid_size ** 2 :, 0] = T - 1
|
||||
queries[:, -self.support_grid_size**2 :, 0] = T - 1
|
||||
if add_support_grid:
|
||||
tracks = tracks[:, :, : -self.support_grid_size ** 2]
|
||||
visibilities = visibilities[:, :, : -self.support_grid_size ** 2]
|
||||
tracks = tracks[:, :, : -self.support_grid_size**2]
|
||||
visibilities = visibilities[:, :, : -self.support_grid_size**2]
|
||||
thr = 0.9
|
||||
visibilities = visibilities > thr
|
||||
|
||||
@ -158,17 +149,18 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
|
||||
# TODO: batchify
|
||||
for i in range(len(queries)):
|
||||
queries_t = queries[i, :tracks.size(2), 0].to(torch.int64)
|
||||
queries_t = queries[i, : tracks.size(2), 0].to(torch.int64)
|
||||
arange = torch.arange(0, len(queries_t))
|
||||
|
||||
# overwrite the predictions with the query points
|
||||
tracks[i, queries_t, arange] = queries[i, :tracks.size(2), 1:]
|
||||
tracks[i, queries_t, arange] = queries[i, : tracks.size(2), 1:]
|
||||
|
||||
# correct visibilities, the query points should be visible
|
||||
visibilities[i, queries_t, arange] = True
|
||||
|
||||
tracks[:, :, :, 0] *= W / float(self.interp_shape[1])
|
||||
tracks[:, :, :, 1] *= H / float(self.interp_shape[0])
|
||||
tracks *= tracks.new_tensor(
|
||||
[(W - 1) / (self.interp_shape[1] - 1), (H - 1) / (self.interp_shape[0] - 1)]
|
||||
)
|
||||
return tracks, visibilities
|
||||
|
||||
def _compute_backward_tracks(self, video, queries, tracks, visibilities):
|
||||
@ -176,9 +168,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
inv_queries = queries.clone()
|
||||
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
||||
|
||||
inv_tracks, __, inv_visibilities, __ = self.model(
|
||||
rgbs=inv_video, queries=inv_queries, iters=6
|
||||
)
|
||||
inv_tracks, inv_visibilities, __ = self.model(video=inv_video, queries=inv_queries, iters=6)
|
||||
|
||||
inv_tracks = inv_tracks.flip(1)
|
||||
inv_visibilities = inv_visibilities.flip(1)
|
||||
@ -188,3 +178,79 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
tracks[mask] = inv_tracks[mask]
|
||||
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
|
||||
return tracks, visibilities
|
||||
|
||||
|
||||
class CoTrackerOnlinePredictor(torch.nn.Module):
|
||||
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
|
||||
super().__init__()
|
||||
self.support_grid_size = 6
|
||||
model = build_cotracker(checkpoint)
|
||||
self.interp_shape = model.model_resolution
|
||||
self.step = model.window_len // 2
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
video_chunk,
|
||||
is_first_step: bool = False,
|
||||
queries: torch.Tensor = None,
|
||||
grid_size: int = 10,
|
||||
grid_query_frame: int = 0,
|
||||
add_support_grid=False,
|
||||
):
|
||||
# Initialize online video processing and save queried points
|
||||
# This needs to be done before processing *each new video*
|
||||
if is_first_step:
|
||||
self.model.init_video_online_processing()
|
||||
if queries is not None:
|
||||
B, N, D = queries.shape
|
||||
assert D == 3
|
||||
queries = queries.clone()
|
||||
queries[:, :, 1:] *= queries.new_tensor(
|
||||
[
|
||||
(self.interp_shape[1] - 1) / (W - 1),
|
||||
(self.interp_shape[0] - 1) / (H - 1),
|
||||
]
|
||||
)
|
||||
elif grid_size > 0:
|
||||
grid_pts = get_points_on_a_grid(
|
||||
grid_size, self.interp_shape, device=video_chunk.device
|
||||
)
|
||||
queries = torch.cat(
|
||||
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
|
||||
dim=2,
|
||||
)
|
||||
if add_support_grid:
|
||||
grid_pts = get_points_on_a_grid(
|
||||
self.support_grid_size, self.interp_shape, device=video_chunk.device
|
||||
)
|
||||
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
|
||||
queries = torch.cat([queries, grid_pts], dim=1)
|
||||
self.queries = queries
|
||||
return (None, None)
|
||||
B, T, C, H, W = video_chunk.shape
|
||||
video_chunk = video_chunk.reshape(B * T, C, H, W)
|
||||
video_chunk = F.interpolate(
|
||||
video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True
|
||||
)
|
||||
video_chunk = video_chunk.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
tracks, visibilities, __ = self.model(
|
||||
video=video_chunk,
|
||||
queries=self.queries,
|
||||
iters=6,
|
||||
is_online=True,
|
||||
)
|
||||
thr = 0.9
|
||||
return (
|
||||
tracks
|
||||
* tracks.new_tensor(
|
||||
[
|
||||
(W - 1) / (self.interp_shape[1] - 1),
|
||||
(H - 1) / (self.interp_shape[0] - 1),
|
||||
]
|
||||
),
|
||||
visibilities > thr,
|
||||
)
|
||||
|
@ -3,36 +3,59 @@
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import imageio
|
||||
import torch
|
||||
import flow_vis
|
||||
|
||||
from matplotlib import cm
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as transforms
|
||||
from moviepy.editor import ImageSequenceClip
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
|
||||
def read_video_from_path(path):
|
||||
cap = cv2.VideoCapture(path)
|
||||
if not cap.isOpened():
|
||||
print("Error opening video file")
|
||||
else:
|
||||
frames = []
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if ret == True:
|
||||
frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
|
||||
else:
|
||||
break
|
||||
cap.release()
|
||||
try:
|
||||
reader = imageio.get_reader(path)
|
||||
except Exception as e:
|
||||
print("Error opening video file: ", e)
|
||||
return None
|
||||
frames = []
|
||||
for i, im in enumerate(reader):
|
||||
frames.append(np.array(im))
|
||||
return np.stack(frames)
|
||||
|
||||
|
||||
def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True):
|
||||
# Create a draw object
|
||||
draw = ImageDraw.Draw(rgb)
|
||||
# Calculate the bounding box of the circle
|
||||
left_up_point = (coord[0] - radius, coord[1] - radius)
|
||||
right_down_point = (coord[0] + radius, coord[1] + radius)
|
||||
# Draw the circle
|
||||
draw.ellipse(
|
||||
[left_up_point, right_down_point],
|
||||
fill=tuple(color) if visible else None,
|
||||
outline=tuple(color),
|
||||
)
|
||||
return rgb
|
||||
|
||||
|
||||
def draw_line(rgb, coord_y, coord_x, color, linewidth):
|
||||
draw = ImageDraw.Draw(rgb)
|
||||
draw.line(
|
||||
(coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
|
||||
fill=tuple(color),
|
||||
width=linewidth,
|
||||
)
|
||||
return rgb
|
||||
|
||||
|
||||
def add_weighted(rgb, alpha, original, beta, gamma):
|
||||
return (rgb * alpha + original * beta + gamma).astype("uint8")
|
||||
|
||||
|
||||
class Visualizer:
|
||||
def __init__(
|
||||
self,
|
||||
@ -107,7 +130,7 @@ class Visualizer:
|
||||
def save_video(self, video, filename, writer=None, step=0):
|
||||
if writer is not None:
|
||||
writer.add_video(
|
||||
f"{filename}_pred_track",
|
||||
filename,
|
||||
video.to(torch.uint8),
|
||||
global_step=step,
|
||||
fps=self.fps,
|
||||
@ -116,11 +139,18 @@ class Visualizer:
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
wide_list = list(video.unbind(1))
|
||||
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
|
||||
clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps)
|
||||
|
||||
# Write the video file
|
||||
save_path = os.path.join(self.save_dir, f"{filename}_pred_track.mp4")
|
||||
clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None)
|
||||
# Prepare the video file path
|
||||
save_path = os.path.join(self.save_dir, f"{filename}.mp4")
|
||||
|
||||
# Create a writer object
|
||||
video_writer = imageio.get_writer(save_path, fps=self.fps)
|
||||
|
||||
# Write frames to the video file
|
||||
for frame in wide_list[2:-1]:
|
||||
video_writer.append_data(frame)
|
||||
|
||||
video_writer.close()
|
||||
|
||||
print(f"Video saved to {save_path}")
|
||||
|
||||
@ -149,9 +179,11 @@ class Visualizer:
|
||||
# process input video
|
||||
for rgb in video:
|
||||
res_video.append(rgb.copy())
|
||||
|
||||
vector_colors = np.zeros((T, N, 3))
|
||||
|
||||
if self.mode == "optical_flow":
|
||||
import flow_vis
|
||||
|
||||
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
|
||||
elif segm_mask is None:
|
||||
if self.mode == "rainbow":
|
||||
@ -196,9 +228,7 @@ class Visualizer:
|
||||
if self.tracks_leave_trace != 0:
|
||||
for t in range(1, T):
|
||||
first_ind = (
|
||||
max(0, t - self.tracks_leave_trace)
|
||||
if self.tracks_leave_trace >= 0
|
||||
else 0
|
||||
max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0
|
||||
)
|
||||
curr_tracks = tracks[first_ind : t + 1]
|
||||
curr_colors = vector_colors[first_ind : t + 1]
|
||||
@ -218,12 +248,11 @@ class Visualizer:
|
||||
curr_colors,
|
||||
)
|
||||
if gt_tracks is not None:
|
||||
res_video[t] = self._draw_gt_tracks(
|
||||
res_video[t], gt_tracks[first_ind : t + 1]
|
||||
)
|
||||
res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1])
|
||||
|
||||
# draw points
|
||||
for t in range(T):
|
||||
img = Image.fromarray(np.uint8(res_video[t]))
|
||||
for i in range(N):
|
||||
coord = (tracks[t, i, 0], tracks[t, i, 1])
|
||||
visibile = True
|
||||
@ -233,15 +262,14 @@ class Visualizer:
|
||||
if not compensate_for_camera_motion or (
|
||||
compensate_for_camera_motion and segm_mask[i] > 0
|
||||
):
|
||||
|
||||
cv2.circle(
|
||||
res_video[t],
|
||||
coord,
|
||||
int(self.linewidth * 2),
|
||||
vector_colors[t, i].tolist(),
|
||||
thickness=-1 if visibile else 2
|
||||
-1,
|
||||
img = draw_circle(
|
||||
img,
|
||||
coord=coord,
|
||||
radius=int(self.linewidth * 2),
|
||||
color=vector_colors[t, i].astype(int),
|
||||
visible=visibile,
|
||||
)
|
||||
res_video[t] = np.array(img)
|
||||
|
||||
# construct the final rgb sequence
|
||||
if self.show_first_frame > 0:
|
||||
@ -256,7 +284,7 @@ class Visualizer:
|
||||
alpha: float = 0.5,
|
||||
):
|
||||
T, N, _ = tracks.shape
|
||||
|
||||
rgb = Image.fromarray(np.uint8(rgb))
|
||||
for s in range(T - 1):
|
||||
vector_color = vector_colors[s]
|
||||
original = rgb.copy()
|
||||
@ -265,16 +293,18 @@ class Visualizer:
|
||||
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
|
||||
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
|
||||
if coord_y[0] != 0 and coord_y[1] != 0:
|
||||
cv2.line(
|
||||
rgb = draw_line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
vector_color[i].tolist(),
|
||||
vector_color[i].astype(int),
|
||||
self.linewidth,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
if self.tracks_leave_trace > 0:
|
||||
rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0)
|
||||
rgb = Image.fromarray(
|
||||
np.uint8(add_weighted(np.array(rgb), alpha, np.array(original), 1 - alpha, 0))
|
||||
)
|
||||
rgb = np.array(rgb)
|
||||
return rgb
|
||||
|
||||
def _draw_gt_tracks(
|
||||
@ -283,8 +313,8 @@ class Visualizer:
|
||||
gt_tracks: np.ndarray, # T x 2
|
||||
):
|
||||
T, N, _ = gt_tracks.shape
|
||||
color = np.array((211.0, 0.0, 0.0))
|
||||
|
||||
color = np.array((211, 0, 0))
|
||||
rgb = Image.fromarray(np.uint8(rgb))
|
||||
for t in range(T):
|
||||
for i in range(N):
|
||||
gt_tracks = gt_tracks[t][i]
|
||||
@ -293,22 +323,21 @@ class Visualizer:
|
||||
length = self.linewidth * 3
|
||||
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
|
||||
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
|
||||
cv2.line(
|
||||
rgb = draw_line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
color,
|
||||
self.linewidth,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
|
||||
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
|
||||
cv2.line(
|
||||
rgb = draw_line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
color,
|
||||
self.linewidth,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
rgb = np.array(rgb)
|
||||
return rgb
|
||||
|
8
cotracker/version.py
Normal file
8
cotracker/version.py
Normal file
@ -0,0 +1,8 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
__version__ = "2.0.0"
|
30
demo.py
30
demo.py
@ -5,7 +5,6 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
@ -14,9 +13,18 @@ from PIL import Image
|
||||
from cotracker.utils.visualizer import Visualizer, read_video_from_path
|
||||
from cotracker.predictor import CoTrackerPredictor
|
||||
|
||||
DEFAULT_DEVICE = ('cuda' if torch.cuda.is_available() else
|
||||
'mps' if torch.backends.mps.is_available() else
|
||||
'cpu')
|
||||
# Unfortunately MPS acceleration does not support all the features we require,
|
||||
# but we may be able to enable it in the future
|
||||
|
||||
DEFAULT_DEVICE = (
|
||||
# "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
# if DEFAULT_DEVICE == "mps":
|
||||
# os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -32,15 +40,16 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
default="./checkpoints/cotracker_stride_4_wind_8.pth",
|
||||
help="cotracker model",
|
||||
# default="./checkpoints/cotracker.pth",
|
||||
default=None,
|
||||
help="CoTracker model parameters",
|
||||
)
|
||||
parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size")
|
||||
parser.add_argument("--grid_size", type=int, default=10, help="Regular grid size")
|
||||
parser.add_argument(
|
||||
"--grid_query_frame",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Compute dense and grid tracks starting from this frame ",
|
||||
help="Compute dense and grid tracks starting from this frame",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -57,7 +66,10 @@ if __name__ == "__main__":
|
||||
segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
|
||||
segm_mask = torch.from_numpy(segm_mask)[None, None]
|
||||
|
||||
model = CoTrackerPredictor(checkpoint=args.checkpoint)
|
||||
if args.checkpoint is not None:
|
||||
model = CoTrackerPredictor(checkpoint=args.checkpoint)
|
||||
else:
|
||||
model = torch.hub.load("facebookresearch/co-tracker", "cotracker2")
|
||||
model = model.to(DEFAULT_DEVICE)
|
||||
video = video.to(DEFAULT_DEVICE)
|
||||
|
||||
|
13
docs/Makefile
Normal file
13
docs/Makefile
Normal file
@ -0,0 +1,13 @@
|
||||
SPHINXOPTS ?=
|
||||
SPHINXBUILD ?= sphinx-build
|
||||
SOURCEDIR = source
|
||||
BUILDDIR = _build
|
||||
O = -a
|
||||
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
14
docs/source/apis/models.rst
Normal file
14
docs/source/apis/models.rst
Normal file
@ -0,0 +1,14 @@
|
||||
Models
|
||||
======
|
||||
|
||||
CoTracker models:
|
||||
|
||||
.. currentmodule:: cotracker.models
|
||||
|
||||
Model Utils
|
||||
-----------
|
||||
|
||||
.. automodule:: cotracker.models.core.model_utils
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
11
docs/source/apis/utils.rst
Normal file
11
docs/source/apis/utils.rst
Normal file
@ -0,0 +1,11 @@
|
||||
Utils
|
||||
=====
|
||||
|
||||
CoTracker utilizes the following utilities:
|
||||
|
||||
.. currentmodule:: cotracker
|
||||
|
||||
.. automodule:: cotracker.utils.visualizer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
39
docs/source/conf.py
Normal file
39
docs/source/conf.py
Normal file
@ -0,0 +1,39 @@
|
||||
__version__ = None
|
||||
exec(open("../../cotracker/version.py", "r").read())
|
||||
|
||||
project = "CoTracker"
|
||||
copyright = "2023-24, Meta Platforms, Inc. and affiliates"
|
||||
author = "Meta Platforms"
|
||||
release = __version__
|
||||
|
||||
extensions = [
|
||||
"sphinx.ext.napoleon",
|
||||
"sphinx.ext.duration",
|
||||
"sphinx.ext.doctest",
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.intersphinx",
|
||||
"sphinxcontrib.bibtex",
|
||||
]
|
||||
|
||||
intersphinx_mapping = {
|
||||
"python": ("https://docs.python.org/3/", None),
|
||||
"sphinx": ("https://www.sphinx-doc.org/en/master/", None),
|
||||
}
|
||||
intersphinx_disabled_domains = ["std"]
|
||||
|
||||
# templates_path = ["_templates"]
|
||||
html_theme = "alabaster"
|
||||
|
||||
# Ignore >>> when copying code
|
||||
copybutton_prompt_text = r">>> |\.\.\. "
|
||||
copybutton_prompt_is_regexp = True
|
||||
|
||||
# -- Options for EPUB output
|
||||
epub_show_urls = "footnote"
|
||||
|
||||
# typehints
|
||||
autodoc_typehints = "description"
|
||||
|
||||
# citations
|
||||
bibtex_bibfiles = ["references.bib"]
|
29
docs/source/index.rst
Normal file
29
docs/source/index.rst
Normal file
@ -0,0 +1,29 @@
|
||||
gsplat
|
||||
===================================
|
||||
|
||||
.. image:: ../../assets/bmx-bumps.gif
|
||||
:width: 800
|
||||
:alt: Example of cotracker in action
|
||||
|
||||
Overview
|
||||
--------
|
||||
|
||||
*CoTracker* is an open-source tracker :cite:p:`karaev2023cotracker`.
|
||||
|
||||
Links
|
||||
-----
|
||||
|
||||
.. toctree::
|
||||
:glob:
|
||||
:maxdepth: 1
|
||||
:caption: Python API
|
||||
|
||||
apis/*
|
||||
|
||||
|
||||
Citations
|
||||
---------
|
||||
|
||||
.. bibliography::
|
||||
:style: unsrt
|
||||
:filter: docname in docnames
|
6
docs/source/references.bib
Normal file
6
docs/source/references.bib
Normal file
@ -0,0 +1,6 @@
|
||||
@article{karaev2023cotracker,
|
||||
title = {CoTracker: It is Better to Track Together},
|
||||
author = {Nikita Karaev and Ignacio Rocco and Benjamin Graham and Natalia Neverova and Andrea Vedaldi and Christian Rupprecht},
|
||||
journal = {arXiv:2307.07635},
|
||||
year = {2023}
|
||||
}
|
@ -1,43 +1,39 @@
|
||||
import os
|
||||
import torch
|
||||
import timm
|
||||
import einops
|
||||
import tqdm
|
||||
import cv2
|
||||
import gradio as gr
|
||||
|
||||
from cotracker.utils.visualizer import Visualizer, read_video_from_path
|
||||
|
||||
|
||||
def cotracker_demo(
|
||||
input_video,
|
||||
grid_size: int = 10,
|
||||
grid_query_frame: int = 0,
|
||||
backward_tracking: bool = False,
|
||||
tracks_leave_trace: bool = False
|
||||
):
|
||||
input_video,
|
||||
grid_size: int = 10,
|
||||
grid_query_frame: int = 0,
|
||||
tracks_leave_trace: bool = False,
|
||||
):
|
||||
load_video = read_video_from_path(input_video)
|
||||
|
||||
grid_query_frame = min(len(load_video)-1, grid_query_frame)
|
||||
grid_query_frame = min(len(load_video) - 1, grid_query_frame)
|
||||
load_video = torch.from_numpy(load_video).permute(0, 3, 1, 2)[None].float()
|
||||
|
||||
model = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online")
|
||||
|
||||
model = torch.hub.load("facebookresearch/co-tracker", "cotracker_w8")
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
load_video = load_video.cuda()
|
||||
pred_tracks, pred_visibility = model(
|
||||
load_video,
|
||||
grid_size=grid_size,
|
||||
grid_query_frame=grid_query_frame,
|
||||
backward_tracking=backward_tracking
|
||||
)
|
||||
|
||||
model(video_chunk=load_video, is_first_step=True, grid_size=grid_size)
|
||||
for ind in range(0, load_video.shape[1] - model.step, model.step):
|
||||
pred_tracks, pred_visibility = model(
|
||||
video_chunk=load_video[:, ind : ind + model.step * 2]
|
||||
) # B T N 2, B T N 1
|
||||
|
||||
linewidth = 2
|
||||
if grid_size < 10:
|
||||
linewidth = 4
|
||||
elif grid_size < 20:
|
||||
linewidth = 3
|
||||
|
||||
|
||||
vis = Visualizer(
|
||||
save_dir=os.path.join(os.path.dirname(__file__), "results"),
|
||||
grayscale=False,
|
||||
@ -45,7 +41,7 @@ def cotracker_demo(
|
||||
fps=10,
|
||||
linewidth=linewidth,
|
||||
show_first_frame=5,
|
||||
tracks_leave_trace= -1 if tracks_leave_trace else 0,
|
||||
tracks_leave_trace=-1 if tracks_leave_trace else 0,
|
||||
)
|
||||
import time
|
||||
|
||||
@ -55,44 +51,39 @@ def cotracker_demo(
|
||||
filename = str(current_milli_time())
|
||||
vis.visualize(
|
||||
load_video,
|
||||
tracks=pred_tracks,
|
||||
tracks=pred_tracks,
|
||||
visibility=pred_visibility,
|
||||
filename=filename,
|
||||
query_frame=grid_query_frame,
|
||||
)
|
||||
return os.path.join(
|
||||
os.path.dirname(__file__), "results", f"{filename}_pred_track.mp4"
|
||||
)
|
||||
return os.path.join(os.path.dirname(__file__), "results", f"{filename}_pred_track.mp4")
|
||||
|
||||
|
||||
app = gr.Interface(
|
||||
title = "🎨 CoTracker: It is Better to Track Together",
|
||||
description = "<div style='text-align: left;'> \
|
||||
title="🎨 CoTracker: It is Better to Track Together",
|
||||
description="<div style='text-align: left;'> \
|
||||
<p>Welcome to <a href='http://co-tracker.github.io' target='_blank'>CoTracker</a>! This space demonstrates point (pixel) tracking in videos. \
|
||||
Points are sampled on a regular grid and are tracked jointly. </p> \
|
||||
<p> To get started, simply upload your <b>.mp4</b> video in landscape orientation or click on one of the example videos to load them. The shorter the video, the faster the processing. We recommend submitting short videos of length <b>2-7 seconds</b>.</p> \
|
||||
<ul style='display: inline-block; text-align: left;'> \
|
||||
<li>The total number of grid points is the square of <b>Grid Size</b>.</li> \
|
||||
<li>To specify the starting frame for tracking, adjust <b>Grid Query Frame</b>. Tracks will be visualized only after the selected frame.</li> \
|
||||
<li>Use <b>Backward Tracking</b> to track points from the selected frame in both directions.</li> \
|
||||
<li>Check <b>Visualize Track Traces</b> to visualize traces of all the tracked points. </li> \
|
||||
</ul> \
|
||||
<p style='text-align: left'>For more details, check out our <a href='https://github.com/facebookresearch/co-tracker' target='_blank'>GitHub Repo</a> ⭐</p> \
|
||||
</div>",
|
||||
|
||||
fn=cotracker_demo,
|
||||
inputs=[
|
||||
gr.Video(label="Input video", interactive=True),
|
||||
gr.Slider(minimum=1, maximum=30, step=1, value=10, label="Grid Size"),
|
||||
gr.Slider(minimum=0, maximum=30, step=1, value=0, label="Grid Query Frame"),
|
||||
gr.Checkbox(label="Backward Tracking"),
|
||||
gr.Checkbox(label="Visualize Track Traces"),
|
||||
],
|
||||
outputs=gr.Video(label="Video with predicted tracks"),
|
||||
examples=[
|
||||
[ "./assets/apple.mp4", 20, 0, False, False ],
|
||||
[ "./assets/apple.mp4", 10, 30, True, False ],
|
||||
["./assets/apple.mp4", 20, 0, False, False],
|
||||
["./assets/apple.mp4", 10, 30, True, False],
|
||||
],
|
||||
cache_examples=False
|
||||
cache_examples=False,
|
||||
)
|
||||
app.launch(share=False)
|
||||
|
@ -1,7 +1,3 @@
|
||||
einops
|
||||
timm
|
||||
tqdm
|
||||
opencv-python
|
||||
matplotlib
|
||||
moviepy
|
||||
flow_vis
|
||||
|
34
hubconf.py
34
hubconf.py
@ -6,27 +6,33 @@
|
||||
|
||||
import torch
|
||||
|
||||
dependencies = ["torch", "einops", "timm", "tqdm"]
|
||||
|
||||
_COTRACKER_URL = (
|
||||
"https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth"
|
||||
)
|
||||
_COTRACKER_URL = "https://huggingface.co/facebook/cotracker/resolve/main/cotracker2.pth"
|
||||
|
||||
|
||||
def _make_cotracker_predictor(*, pretrained: bool = True, **kwargs):
|
||||
from cotracker.predictor import CoTrackerPredictor
|
||||
def _make_cotracker_predictor(*, pretrained: bool = True, online=False, **kwargs):
|
||||
if online:
|
||||
from cotracker.predictor import CoTrackerOnlinePredictor
|
||||
|
||||
predictor = CoTrackerPredictor(checkpoint=None)
|
||||
predictor = CoTrackerOnlinePredictor(checkpoint=None)
|
||||
else:
|
||||
from cotracker.predictor import CoTrackerPredictor
|
||||
|
||||
predictor = CoTrackerPredictor(checkpoint=None)
|
||||
if pretrained:
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
_COTRACKER_URL, map_location="cpu"
|
||||
)
|
||||
state_dict = torch.hub.load_state_dict_from_url(_COTRACKER_URL, map_location="cpu")
|
||||
predictor.model.load_state_dict(state_dict)
|
||||
return predictor
|
||||
|
||||
|
||||
def cotracker_w8(*, pretrained: bool = True, **kwargs):
|
||||
def cotracker2(*, pretrained: bool = True, **kwargs):
|
||||
"""
|
||||
CoTracker model with stride 4 and window length 8. (The main model from the paper)
|
||||
CoTracker2 with stride 4 and window length 8. Can track up to 265*265 points jointly.
|
||||
"""
|
||||
return _make_cotracker_predictor(pretrained=pretrained, **kwargs)
|
||||
return _make_cotracker_predictor(pretrained=pretrained, online=False, **kwargs)
|
||||
|
||||
|
||||
def cotracker2_online(*, pretrained: bool = True, **kwargs):
|
||||
"""
|
||||
Online CoTracker2 with stride 4 and window length 8. Can track up to 265*265 points jointly.
|
||||
"""
|
||||
return _make_cotracker_predictor(pretrained=pretrained, online=True, **kwargs)
|
||||
|
24
launch_training.sh
Normal file
24
launch_training.sh
Normal file
@ -0,0 +1,24 @@
|
||||
#!/bin/bash
|
||||
|
||||
EXP_DIR=$1
|
||||
EXP_NAME=$2
|
||||
DATE=$3
|
||||
DATASET_ROOT=$4
|
||||
NUM_STEPS=$5
|
||||
|
||||
|
||||
echo `which python`
|
||||
|
||||
mkdir -p ${EXP_DIR}/${DATE}_${EXP_NAME}/logs/;
|
||||
|
||||
export PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH
|
||||
sbatch --comment=${EXP_NAME} --partition=learn --time=39:00:00 --gpus-per-node=8 --nodes=4 --ntasks-per-node=8 \
|
||||
--job-name=${EXP_NAME} --cpus-per-task=10 --signal=USR1@60 --open-mode=append \
|
||||
--output=${EXP_DIR}/${DATE}_${EXP_NAME}/logs/%j_%x_%A_%a_%N.out \
|
||||
--error=${EXP_DIR}/${DATE}_${EXP_NAME}/logs/%j_%x_%A_%a_%N.err \
|
||||
--wrap="srun --label python ./train.py --batch_size 1 \
|
||||
--num_steps ${NUM_STEPS} --ckpt_path ${EXP_DIR}/${DATE}_${EXP_NAME} --model_name cotracker \
|
||||
--save_freq 200 --sequence_len 24 --eval_datasets dynamic_replica tapvid_davis_first \
|
||||
--traj_per_sample 768 --sliding_window_len 8 \
|
||||
--save_every_n_epoch 10 --evaluate_every_n_epoch 10 --model_stride 4 --dataset_root ${DATASET_ROOT} --num_nodes 4 \
|
||||
--num_virtual_tracks 64"
|
90
online_demo.py
Normal file
90
online_demo.py
Normal file
@ -0,0 +1,90 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import argparse
|
||||
import imageio.v3 as iio
|
||||
import numpy as np
|
||||
|
||||
from cotracker.utils.visualizer import Visualizer
|
||||
from cotracker.predictor import CoTrackerOnlinePredictor
|
||||
|
||||
# Unfortunately MPS acceleration does not support all the features we require,
|
||||
# but we may be able to enable it in the future
|
||||
|
||||
DEFAULT_DEVICE = (
|
||||
# "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--video_path",
|
||||
default="./assets/apple.mp4",
|
||||
help="path to a video",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
default=None,
|
||||
help="CoTracker model parameters",
|
||||
)
|
||||
parser.add_argument("--grid_size", type=int, default=10, help="Regular grid size")
|
||||
parser.add_argument(
|
||||
"--grid_query_frame",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Compute dense and grid tracks starting from this frame",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.checkpoint is not None:
|
||||
model = CoTrackerOnlinePredictor(checkpoint=args.checkpoint)
|
||||
else:
|
||||
model = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online")
|
||||
model = model.to(DEFAULT_DEVICE)
|
||||
|
||||
window_frames = []
|
||||
|
||||
def _process_step(window_frames, is_first_step, grid_size):
|
||||
video_chunk = (
|
||||
torch.tensor(np.stack(window_frames[-model.step * 2 :]), device=DEFAULT_DEVICE)
|
||||
.float()
|
||||
.permute(0, 3, 1, 2)[None]
|
||||
) # (1, T, 3, H, W)
|
||||
return model(video_chunk, is_first_step=is_first_step, grid_size=grid_size)
|
||||
|
||||
# Iterating over video frames, processing one window at a time:
|
||||
is_first_step = True
|
||||
for i, frame in enumerate(
|
||||
iio.imiter(
|
||||
"https://github.com/facebookresearch/co-tracker/blob/main/assets/apple.mp4",
|
||||
plugin="FFMPEG",
|
||||
)
|
||||
):
|
||||
if i % model.step == 0 and i != 0:
|
||||
pred_tracks, pred_visibility = _process_step(
|
||||
window_frames, is_first_step, grid_size=args.grid_size
|
||||
)
|
||||
is_first_step = False
|
||||
window_frames.append(frame)
|
||||
# Processing the final video frames in case video length is not a multiple of model.step
|
||||
pred_tracks, pred_visibility = _process_step(
|
||||
window_frames[-(i % model.step) - model.step - 1 :],
|
||||
is_first_step,
|
||||
grid_size=args.grid_size,
|
||||
)
|
||||
|
||||
print("Tracks are computed")
|
||||
|
||||
# save a video with predicted tracks
|
||||
seq_name = args.video_path.split("/")[-1]
|
||||
video = torch.tensor(np.stack(window_frames), device=DEFAULT_DEVICE).permute(0, 3, 1, 2)[None]
|
||||
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
|
||||
vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame)
|
4
setup.py
4
setup.py
@ -8,11 +8,11 @@ from setuptools import find_packages, setup
|
||||
|
||||
setup(
|
||||
name="cotracker",
|
||||
version="1.0",
|
||||
version="2.0",
|
||||
install_requires=[],
|
||||
packages=find_packages(exclude="notebooks"),
|
||||
extras_require={
|
||||
"all": ["matplotlib", "opencv-python"],
|
||||
"all": ["matplotlib"],
|
||||
"dev": ["flake8", "black"],
|
||||
},
|
||||
)
|
||||
|
51
tests/test_bilinear_sample.py
Normal file
51
tests/test_bilinear_sample.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
import unittest
|
||||
|
||||
from cotracker.models.core.model_utils import bilinear_sampler
|
||||
|
||||
|
||||
class TestBilinearSampler(unittest.TestCase):
|
||||
# Sample from an image (4d)
|
||||
def _test4d(self, align_corners):
|
||||
H, W = 4, 5
|
||||
# Construct a grid to obtain indentity sampling
|
||||
input = torch.randn(H * W).view(1, 1, H, W).float()
|
||||
coords = torch.meshgrid(torch.arange(H), torch.arange(W))
|
||||
coords = torch.stack(coords[::-1], dim=-1).float()[None]
|
||||
if not align_corners:
|
||||
coords = coords + 0.5
|
||||
sampled_input = bilinear_sampler(input, coords, align_corners=align_corners)
|
||||
torch.testing.assert_close(input, sampled_input)
|
||||
|
||||
# Sample from a video (5d)
|
||||
def _test5d(self, align_corners):
|
||||
T, H, W = 3, 4, 5
|
||||
# Construct a grid to obtain indentity sampling
|
||||
input = torch.randn(H * W).view(1, 1, H, W).float()
|
||||
input = torch.stack([input, input + 1, input + 2], dim=2)
|
||||
coords = torch.meshgrid(torch.arange(T), torch.arange(W), torch.arange(H))
|
||||
coords = torch.stack(coords, dim=-1).float().permute(0, 2, 1, 3)[None]
|
||||
|
||||
if not align_corners:
|
||||
coords = coords + 0.5
|
||||
sampled_input = bilinear_sampler(input, coords, align_corners=align_corners)
|
||||
torch.testing.assert_close(input, sampled_input)
|
||||
|
||||
def test4d(self):
|
||||
self._test4d(align_corners=True)
|
||||
self._test4d(align_corners=False)
|
||||
|
||||
def test5d(self):
|
||||
self._test5d(align_corners=True)
|
||||
self._test5d(align_corners=False)
|
||||
|
||||
|
||||
# run the test
|
||||
unittest.main()
|
358
train.py
358
train.py
@ -25,22 +25,35 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from pytorch_lightning.lite import LightningLite
|
||||
|
||||
from cotracker.models.evaluation_predictor import EvaluationPredictor
|
||||
from cotracker.models.core.cotracker.cotracker import CoTracker
|
||||
from cotracker.models.core.cotracker.cotracker import CoTracker2
|
||||
from cotracker.utils.visualizer import Visualizer
|
||||
from cotracker.datasets.tap_vid_datasets import TapVidDataset
|
||||
from cotracker.datasets.badja_dataset import BadjaDataset
|
||||
from cotracker.datasets.fast_capture_dataset import FastCaptureDataset
|
||||
|
||||
from cotracker.datasets.dr_dataset import DynamicReplicaDataset
|
||||
from cotracker.evaluation.core.evaluator import Evaluator
|
||||
from cotracker.datasets import kubric_movif_dataset
|
||||
from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_cuda_
|
||||
from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss
|
||||
|
||||
|
||||
# define the handler function
|
||||
# for training on a slurm cluster
|
||||
def sig_handler(signum, frame):
|
||||
print("caught signal", signum)
|
||||
print(socket.gethostname(), "USR1 signal caught.")
|
||||
# do other stuff to cleanup here
|
||||
print("requeuing job " + os.environ["SLURM_JOB_ID"])
|
||||
os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def term_handler(signum, frame):
|
||||
print("bypassing sigterm", flush=True)
|
||||
|
||||
|
||||
def fetch_optimizer(args, model):
|
||||
"""Create the optimizer and learning rate scheduler"""
|
||||
optimizer = optim.AdamW(
|
||||
model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8
|
||||
)
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8)
|
||||
scheduler = optim.lr_scheduler.OneCycleLR(
|
||||
optimizer,
|
||||
args.lr,
|
||||
@ -53,69 +66,61 @@ def fetch_optimizer(args, model):
|
||||
return optimizer, scheduler
|
||||
|
||||
|
||||
def forward_batch(batch, model, args, loss_fn=None, writer=None, step=0):
|
||||
rgbs = batch.video
|
||||
def forward_batch(batch, model, args):
|
||||
video = batch.video
|
||||
trajs_g = batch.trajectory
|
||||
vis_g = batch.visibility
|
||||
valids = batch.valid
|
||||
B, T, C, H, W = rgbs.shape
|
||||
B, T, C, H, W = video.shape
|
||||
assert C == 3
|
||||
B, T, N, D = trajs_g.shape
|
||||
device = rgbs.device
|
||||
device = video.device
|
||||
|
||||
__, first_positive_inds = torch.max(vis_g, dim=1)
|
||||
# We want to make sure that during training the model sees visible points
|
||||
# that it does not need to track just yet: they are visible but queried from a later frame
|
||||
N_rand = N // 4
|
||||
# inds of visible points in the 1st frame
|
||||
nonzero_inds = [torch.nonzero(vis_g[0, :, i]) for i in range(N)]
|
||||
rand_vis_inds = torch.cat(
|
||||
[
|
||||
nonzero_row[torch.randint(len(nonzero_row), size=(1,))]
|
||||
for nonzero_row in nonzero_inds
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
first_positive_inds = torch.cat(
|
||||
[rand_vis_inds[:, :N_rand], first_positive_inds[:, N_rand:]], dim=1
|
||||
)
|
||||
nonzero_inds = [[torch.nonzero(vis_g[b, :, i]) for i in range(N)] for b in range(B)]
|
||||
|
||||
for b in range(B):
|
||||
rand_vis_inds = torch.cat(
|
||||
[
|
||||
nonzero_row[torch.randint(len(nonzero_row), size=(1,))]
|
||||
for nonzero_row in nonzero_inds[b]
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
first_positive_inds[b] = torch.cat(
|
||||
[rand_vis_inds[:, :N_rand], first_positive_inds[b : b + 1, N_rand:]], dim=1
|
||||
)
|
||||
|
||||
ind_array_ = torch.arange(T, device=device)
|
||||
ind_array_ = ind_array_[None, :, None].repeat(B, 1, N)
|
||||
assert torch.allclose(
|
||||
vis_g[ind_array_ == first_positive_inds[:, None, :]],
|
||||
torch.ones_like(vis_g),
|
||||
)
|
||||
assert torch.allclose(
|
||||
vis_g[ind_array_ == rand_vis_inds[:, None, :]], torch.ones_like(vis_g)
|
||||
)
|
||||
|
||||
gather = torch.gather(
|
||||
trajs_g, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 2)
|
||||
torch.ones(1, device=device),
|
||||
)
|
||||
gather = torch.gather(trajs_g, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, D))
|
||||
xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1)
|
||||
|
||||
queries = torch.cat([first_positive_inds[:, :, None], xys], dim=2)
|
||||
queries = torch.cat([first_positive_inds[:, :, None], xys[:, :, :2]], dim=2)
|
||||
|
||||
predictions, __, visibility, train_data = model(
|
||||
rgbs=rgbs, queries=queries, iters=args.train_iters, is_train=True
|
||||
predictions, visibility, train_data = model(
|
||||
video=video, queries=queries, iters=args.train_iters, is_train=True
|
||||
)
|
||||
vis_predictions, coord_predictions, wind_inds, sort_inds = train_data
|
||||
|
||||
trajs_g = trajs_g[:, :, sort_inds]
|
||||
vis_g = vis_g[:, :, sort_inds]
|
||||
valids = valids[:, :, sort_inds]
|
||||
coord_predictions, vis_predictions, valid_mask = train_data
|
||||
|
||||
vis_gts = []
|
||||
traj_gts = []
|
||||
valids_gts = []
|
||||
|
||||
for i, wind_idx in enumerate(wind_inds):
|
||||
ind = i * (args.sliding_window_len // 2)
|
||||
|
||||
vis_gts.append(vis_g[:, ind : ind + args.sliding_window_len, :wind_idx])
|
||||
traj_gts.append(trajs_g[:, ind : ind + args.sliding_window_len, :wind_idx])
|
||||
valids_gts.append(valids[:, ind : ind + args.sliding_window_len, :wind_idx])
|
||||
|
||||
S = args.sliding_window_len
|
||||
for ind in range(0, args.sequence_len - S // 2, S // 2):
|
||||
vis_gts.append(vis_g[:, ind : ind + S])
|
||||
traj_gts.append(trajs_g[:, ind : ind + S])
|
||||
valids_gts.append(valids[:, ind : ind + S] * valid_mask[:, ind : ind + S])
|
||||
|
||||
seq_loss = sequence_loss(coord_predictions, traj_gts, vis_gts, valids_gts, 0.8)
|
||||
vis_loss = balanced_ce_loss(vis_predictions, vis_gts, valids_gts)
|
||||
|
||||
@ -131,9 +136,17 @@ def forward_batch(batch, model, args, loss_fn=None, writer=None, step=0):
|
||||
def run_test_eval(evaluator, model, dataloaders, writer, step):
|
||||
model.eval()
|
||||
for ds_name, dataloader in dataloaders:
|
||||
visualize_every = 1
|
||||
grid_size = 5
|
||||
if ds_name == "dynamic_replica":
|
||||
visualize_every = 8
|
||||
grid_size = 0
|
||||
elif "tapvid" in ds_name:
|
||||
visualize_every = 5
|
||||
|
||||
predictor = EvaluationPredictor(
|
||||
model.module.module,
|
||||
grid_size=6,
|
||||
grid_size=grid_size,
|
||||
local_grid_size=0,
|
||||
single_point=False,
|
||||
n_iters=6,
|
||||
@ -148,37 +161,23 @@ def run_test_eval(evaluator, model, dataloaders, writer, step):
|
||||
train_mode=True,
|
||||
writer=writer,
|
||||
step=step,
|
||||
visualize_every=visualize_every,
|
||||
)
|
||||
|
||||
if ds_name == "badja" or ds_name == "fastcapture" or ("kubric" in ds_name):
|
||||
|
||||
metrics = {
|
||||
**{
|
||||
f"{ds_name}_avg": np.mean(
|
||||
[v for k, v in metrics.items() if "accuracy" not in k]
|
||||
)
|
||||
},
|
||||
**{
|
||||
f"{ds_name}_avg_accuracy": np.mean(
|
||||
[v for k, v in metrics.items() if "accuracy" in k]
|
||||
)
|
||||
},
|
||||
}
|
||||
print("avg", np.mean([v for v in metrics.values()]))
|
||||
if ds_name == "dynamic_replica" or ds_name == "kubric":
|
||||
metrics = {f"{ds_name}_avg_{k}": v for k, v in metrics["avg"].items()}
|
||||
|
||||
if "tapvid" in ds_name:
|
||||
metrics = {
|
||||
f"{ds_name}_avg_OA": metrics["avg"]["occlusion_accuracy"] * 100,
|
||||
f"{ds_name}_avg_delta": metrics["avg"]["average_pts_within_thresh"]
|
||||
* 100,
|
||||
f"{ds_name}_avg_Jaccard": metrics["avg"]["average_jaccard"] * 100,
|
||||
f"{ds_name}_avg_OA": metrics["avg"]["occlusion_accuracy"],
|
||||
f"{ds_name}_avg_delta": metrics["avg"]["average_pts_within_thresh"],
|
||||
f"{ds_name}_avg_Jaccard": metrics["avg"]["average_jaccard"],
|
||||
}
|
||||
|
||||
writer.add_scalars(f"Eval", metrics, step)
|
||||
writer.add_scalars(f"Eval_{ds_name}", metrics, step)
|
||||
|
||||
|
||||
class Logger:
|
||||
|
||||
SUM_FREQ = 100
|
||||
|
||||
def __init__(self, model, scheduler):
|
||||
@ -190,24 +189,19 @@ class Logger:
|
||||
|
||||
def _print_training_status(self):
|
||||
metrics_data = [
|
||||
self.running_loss[k] / Logger.SUM_FREQ
|
||||
for k in sorted(self.running_loss.keys())
|
||||
self.running_loss[k] / Logger.SUM_FREQ for k in sorted(self.running_loss.keys())
|
||||
]
|
||||
training_str = "[{:6d}] ".format(self.total_steps + 1)
|
||||
metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data)
|
||||
|
||||
# print the training status
|
||||
logging.info(
|
||||
f"Training Metrics ({self.total_steps}): {training_str + metrics_str}"
|
||||
)
|
||||
logging.info(f"Training Metrics ({self.total_steps}): {training_str + metrics_str}")
|
||||
|
||||
if self.writer is None:
|
||||
self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs"))
|
||||
|
||||
for k in self.running_loss:
|
||||
self.writer.add_scalar(
|
||||
k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps
|
||||
)
|
||||
self.writer.add_scalar(k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps)
|
||||
self.running_loss[k] = 0.0
|
||||
|
||||
def push(self, metrics, task):
|
||||
@ -249,79 +243,56 @@ class Lite(LightningLite):
|
||||
seed_everything(0)
|
||||
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = torch.initial_seed() % 2 ** 32
|
||||
worker_seed = torch.initial_seed() % 2**32
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
g = torch.Generator()
|
||||
g.manual_seed(0)
|
||||
if self.global_rank == 0:
|
||||
eval_dataloaders = []
|
||||
if "dynamic_replica" in args.eval_datasets:
|
||||
eval_dataset = DynamicReplicaDataset(
|
||||
sample_len=60, only_first_n_samples=1, rgbd_input=False
|
||||
)
|
||||
eval_dataloader_dr = torch.utils.data.DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=1,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
eval_dataloaders.append(("dynamic_replica", eval_dataloader_dr))
|
||||
|
||||
eval_dataloaders = []
|
||||
if "badja" in args.eval_datasets:
|
||||
eval_dataset = BadjaDataset(
|
||||
data_root=os.path.join(args.dataset_root, "BADJA"),
|
||||
max_seq_len=args.eval_max_seq_len,
|
||||
dataset_resolution=args.crop_size,
|
||||
if "tapvid_davis_first" in args.eval_datasets:
|
||||
data_root = os.path.join(args.dataset_root, "tapvid/tapvid_davis/tapvid_davis.pkl")
|
||||
eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root)
|
||||
eval_dataloader_tapvid_davis = torch.utils.data.DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=1,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
eval_dataloaders.append(("tapvid_davis", eval_dataloader_tapvid_davis))
|
||||
|
||||
evaluator = Evaluator(args.ckpt_path)
|
||||
|
||||
visualizer = Visualizer(
|
||||
save_dir=args.ckpt_path,
|
||||
pad_value=80,
|
||||
fps=1,
|
||||
show_first_frame=0,
|
||||
tracks_leave_trace=0,
|
||||
)
|
||||
eval_dataloader_badja = torch.utils.data.DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
eval_dataloaders.append(("badja", eval_dataloader_badja))
|
||||
|
||||
if "fastcapture" in args.eval_datasets:
|
||||
eval_dataset = FastCaptureDataset(
|
||||
data_root=os.path.join(args.dataset_root, "fastcapture"),
|
||||
max_seq_len=min(100, args.eval_max_seq_len),
|
||||
max_num_points=40,
|
||||
dataset_resolution=args.crop_size,
|
||||
)
|
||||
eval_dataloader_fastcapture = torch.utils.data.DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=1,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
eval_dataloaders.append(("fastcapture", eval_dataloader_fastcapture))
|
||||
|
||||
if "tapvid_davis_first" in args.eval_datasets:
|
||||
data_root = os.path.join(args.dataset_root, "tapvid_davis/tapvid_davis.pkl")
|
||||
eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root)
|
||||
eval_dataloader_tapvid_davis = torch.utils.data.DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=1,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
eval_dataloaders.append(("tapvid_davis", eval_dataloader_tapvid_davis))
|
||||
|
||||
evaluator = Evaluator(args.ckpt_path)
|
||||
|
||||
visualizer = Visualizer(
|
||||
save_dir=args.ckpt_path,
|
||||
pad_value=80,
|
||||
fps=1,
|
||||
show_first_frame=0,
|
||||
tracks_leave_trace=0,
|
||||
)
|
||||
|
||||
loss_fn = None
|
||||
|
||||
if args.model_name == "cotracker":
|
||||
|
||||
model = CoTracker(
|
||||
model = CoTracker2(
|
||||
stride=args.model_stride,
|
||||
S=args.sliding_window_len,
|
||||
window_len=args.sliding_window_len,
|
||||
add_space_attn=not args.remove_space_attn,
|
||||
num_heads=args.updateformer_num_heads,
|
||||
hidden_size=args.updateformer_hidden_size,
|
||||
space_depth=args.updateformer_space_depth,
|
||||
time_depth=args.updateformer_time_depth,
|
||||
num_virtual_tracks=args.num_virtual_tracks,
|
||||
model_resolution=args.crop_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Model {args.model_name} doesn't exist")
|
||||
@ -332,7 +303,7 @@ class Lite(LightningLite):
|
||||
model.cuda()
|
||||
|
||||
train_dataset = kubric_movif_dataset.KubricMovifDataset(
|
||||
data_root=os.path.join(args.dataset_root, "kubric_movi_f"),
|
||||
data_root=os.path.join(args.dataset_root, "kubric", "kubric_movi_f_tracks"),
|
||||
crop_size=args.crop_size,
|
||||
seq_len=args.sequence_len,
|
||||
traj_per_sample=args.traj_per_sample,
|
||||
@ -357,7 +328,8 @@ class Lite(LightningLite):
|
||||
optimizer, scheduler = fetch_optimizer(args, model)
|
||||
|
||||
total_steps = 0
|
||||
logger = Logger(model, scheduler)
|
||||
if self.global_rank == 0:
|
||||
logger = Logger(model, scheduler)
|
||||
|
||||
folder_ckpts = [
|
||||
f
|
||||
@ -383,9 +355,7 @@ class Lite(LightningLite):
|
||||
logging.info(f"Load total_steps {total_steps}")
|
||||
|
||||
elif args.restore_ckpt is not None:
|
||||
assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith(
|
||||
".pt"
|
||||
)
|
||||
assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith(".pt")
|
||||
logging.info("Loading checkpoint...")
|
||||
|
||||
strict = True
|
||||
@ -394,9 +364,7 @@ class Lite(LightningLite):
|
||||
state_dict = state_dict["model"]
|
||||
|
||||
if list(state_dict.keys())[0].startswith("module."):
|
||||
state_dict = {
|
||||
k.replace("module.", ""): v for k, v in state_dict.items()
|
||||
}
|
||||
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
logging.info(f"Done loading checkpoint")
|
||||
@ -424,33 +392,22 @@ class Lite(LightningLite):
|
||||
|
||||
assert model.training
|
||||
|
||||
output = forward_batch(
|
||||
batch,
|
||||
model,
|
||||
args,
|
||||
loss_fn=loss_fn,
|
||||
writer=logger.writer,
|
||||
step=total_steps,
|
||||
)
|
||||
output = forward_batch(batch, model, args)
|
||||
|
||||
loss = 0
|
||||
for k, v in output.items():
|
||||
if "loss" in v:
|
||||
loss += v["loss"]
|
||||
logger.writer.add_scalar(
|
||||
f"live_{k}_loss", v["loss"].item(), total_steps
|
||||
)
|
||||
if "metrics" in v:
|
||||
logger.push(v["metrics"], k)
|
||||
|
||||
if self.global_rank == 0:
|
||||
if total_steps % save_freq == save_freq - 1:
|
||||
if args.model_name == "motion_diffuser":
|
||||
pred_coords = model.module.module.forward_batch_test(
|
||||
batch, interp_shape=args.crop_size
|
||||
for k, v in output.items():
|
||||
if "loss" in v:
|
||||
logger.writer.add_scalar(
|
||||
f"live_{k}_loss", v["loss"].item(), total_steps
|
||||
)
|
||||
|
||||
output["flow"] = {"predictions": pred_coords[0].detach()}
|
||||
if "metrics" in v:
|
||||
logger.push(v["metrics"], k)
|
||||
if total_steps % save_freq == save_freq - 1:
|
||||
visualizer.visualize(
|
||||
video=batch.video.clone(),
|
||||
tracks=batch.trajectory.clone(),
|
||||
@ -468,9 +425,7 @@ class Lite(LightningLite):
|
||||
)
|
||||
|
||||
if len(output) > 1:
|
||||
logger.writer.add_scalar(
|
||||
f"live_total_loss", loss.item(), total_steps
|
||||
)
|
||||
logger.writer.add_scalar(f"live_total_loss", loss.item(), total_steps)
|
||||
logger.writer.add_scalar(
|
||||
f"learning_rate", optimizer.param_groups[0]["lr"], total_steps
|
||||
)
|
||||
@ -492,9 +447,7 @@ class Lite(LightningLite):
|
||||
total_steps == 1 and args.validate_at_start
|
||||
):
|
||||
if (epoch + 1) % args.save_every_n_epoch == 0:
|
||||
ckpt_iter = "0" * (6 - len(str(total_steps))) + str(
|
||||
total_steps
|
||||
)
|
||||
ckpt_iter = "0" * (6 - len(str(total_steps))) + str(total_steps)
|
||||
save_path = Path(
|
||||
f"{args.ckpt_path}/model_{args.model_name}_{ckpt_iter}.pth"
|
||||
)
|
||||
@ -526,16 +479,18 @@ class Lite(LightningLite):
|
||||
if total_steps > args.num_steps:
|
||||
should_keep_training = False
|
||||
break
|
||||
if self.global_rank == 0:
|
||||
print("FINISHED TRAINING")
|
||||
|
||||
print("FINISHED TRAINING")
|
||||
|
||||
PATH = f"{args.ckpt_path}/{args.model_name}_final.pth"
|
||||
torch.save(model.module.module.state_dict(), PATH)
|
||||
run_test_eval(evaluator, model, eval_dataloaders, logger.writer, total_steps)
|
||||
logger.close()
|
||||
PATH = f"{args.ckpt_path}/{args.model_name}_final.pth"
|
||||
torch.save(model.module.module.state_dict(), PATH)
|
||||
run_test_eval(evaluator, model, eval_dataloaders, logger.writer, total_steps)
|
||||
logger.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
signal.signal(signal.SIGUSR1, sig_handler)
|
||||
signal.signal(signal.SIGTERM, term_handler)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_name", default="cotracker", help="model name")
|
||||
parser.add_argument("--restore_ckpt", help="path to restore a checkpoint")
|
||||
@ -543,17 +498,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=4, help="batch size used during training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, default=6, help="number of dataloader workers"
|
||||
)
|
||||
parser.add_argument("--num_nodes", type=int, default=1)
|
||||
parser.add_argument("--num_workers", type=int, default=10, help="number of dataloader workers")
|
||||
|
||||
parser.add_argument(
|
||||
"--mixed_precision", action="store_true", help="use mixed precision"
|
||||
)
|
||||
parser.add_argument("--mixed_precision", action="store_true", help="use mixed precision")
|
||||
parser.add_argument("--lr", type=float, default=0.0005, help="max learning rate.")
|
||||
parser.add_argument(
|
||||
"--wdecay", type=float, default=0.00001, help="Weight decay in optimizer."
|
||||
)
|
||||
parser.add_argument("--wdecay", type=float, default=0.00001, help="Weight decay in optimizer.")
|
||||
parser.add_argument(
|
||||
"--num_steps", type=int, default=200000, help="length of training schedule."
|
||||
)
|
||||
@ -596,13 +546,11 @@ if __name__ == "__main__":
|
||||
default=4,
|
||||
help="number of updates to the disparity field in each forward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sequence_len", type=int, default=8, help="train sequence length"
|
||||
)
|
||||
parser.add_argument("--sequence_len", type=int, default=8, help="train sequence length")
|
||||
parser.add_argument(
|
||||
"--eval_datasets",
|
||||
nargs="+",
|
||||
default=["things", "badja"],
|
||||
default=["tapvid_davis_first"],
|
||||
help="what datasets to use for evaluation",
|
||||
)
|
||||
|
||||
@ -611,6 +559,12 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="remove space attention from CoTracker",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_virtual_tracks",
|
||||
type=int,
|
||||
default=None,
|
||||
help="stride of the CoTracker feature network",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dont_use_augs",
|
||||
action="store_true",
|
||||
@ -627,30 +581,6 @@ if __name__ == "__main__":
|
||||
default=8,
|
||||
help="length of the CoTracker sliding window",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--updateformer_hidden_size",
|
||||
type=int,
|
||||
default=384,
|
||||
help="hidden dimension of the CoTracker transformer model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--updateformer_num_heads",
|
||||
type=int,
|
||||
default=8,
|
||||
help="number of heads of the CoTracker transformer model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--updateformer_space_depth",
|
||||
type=int,
|
||||
default=12,
|
||||
help="number of group attention layers in the CoTracker transformer model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--updateformer_time_depth",
|
||||
type=int,
|
||||
default=12,
|
||||
help="number of time attention layers in the CoTracker transformer model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_stride",
|
||||
type=int,
|
||||
@ -680,9 +610,9 @@ if __name__ == "__main__":
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
|
||||
Lite(
|
||||
strategy=DDPStrategy(find_unused_parameters=True),
|
||||
strategy=DDPStrategy(find_unused_parameters=False),
|
||||
devices="auto",
|
||||
accelerator="gpu",
|
||||
precision=32,
|
||||
# num_nodes=4,
|
||||
num_nodes=args.num_nodes,
|
||||
).run(args)
|
||||
|
Loading…
Reference in New Issue
Block a user