Initial commit

This commit is contained in:
nikitakaraevv 2023-07-17 17:49:06 -07:00
commit 6d62d873fa
41 changed files with 5989 additions and 0 deletions

80
CODE_OF_CONDUCT.md Normal file
View File

@ -0,0 +1,80 @@
# Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.
This Code of Conduct also applies outside the project spaces when there is a
reasonable belief that an individual's behavior may have a negative impact on
the project or its community.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at <opensource-conduct@fb.com>. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq

28
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,28 @@
# CoTracker
We want to make contributing to this project as easy and transparent as possible.
## Pull Requests
We actively welcome your pull requests.
1. Fork the repo and create your branch from `main`.
2. If you've changed APIs, update the documentation.
3. Make sure your code lints.
4. If you haven't already, complete the Contributor License Agreement ("CLA").
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Meta's open source projects.
Complete your CLA here: <https://code.facebook.com/cla>
## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.
Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.
## License
By contributing to CoTracker, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.

399
LICENSE.md Normal file
View File

@ -0,0 +1,399 @@
Attribution-NonCommercial 4.0 International
=======================================================================
Creative Commons Corporation ("Creative Commons") is not a law firm and
does not provide legal services or legal advice. Distribution of
Creative Commons public licenses does not create a lawyer-client or
other relationship. Creative Commons makes its licenses and related
information available on an "as-is" basis. Creative Commons gives no
warranties regarding its licenses, any material licensed under their
terms and conditions, or any related information. Creative Commons
disclaims all liability for damages resulting from their use to the
fullest extent possible.
Using Creative Commons Public Licenses
Creative Commons public licenses provide a standard set of terms and
conditions that creators and other rights holders may use to share
original works of authorship and other material subject to copyright
and certain other rights specified in the public license below. The
following considerations are for informational purposes only, are not
exhaustive, and do not form part of our licenses.
Considerations for licensors: Our public licenses are
intended for use by those authorized to give the public
permission to use material in ways otherwise restricted by
copyright and certain other rights. Our licenses are
irrevocable. Licensors should read and understand the terms
and conditions of the license they choose before applying it.
Licensors should also secure all rights necessary before
applying our licenses so that the public can reuse the
material as expected. Licensors should clearly mark any
material not subject to the license. This includes other CC-
licensed material, or material used under an exception or
limitation to copyright. More considerations for licensors:
wiki.creativecommons.org/Considerations_for_licensors
Considerations for the public: By using one of our public
licenses, a licensor grants the public permission to use the
licensed material under specified terms and conditions. If
the licensor's permission is not necessary for any reason--for
example, because of any applicable exception or limitation to
copyright--then that use is not regulated by the license. Our
licenses grant only permissions under copyright and certain
other rights that a licensor has authority to grant. Use of
the licensed material may still be restricted for other
reasons, including because others have copyright or other
rights in the material. A licensor may make special requests,
such as asking that all changes be marked or described.
Although not required by our licenses, you are encouraged to
respect those requests where reasonable. More_considerations
for the public:
wiki.creativecommons.org/Considerations_for_licensees
=======================================================================
Creative Commons Attribution-NonCommercial 4.0 International Public
License
By exercising the Licensed Rights (defined below), You accept and agree
to be bound by the terms and conditions of this Creative Commons
Attribution-NonCommercial 4.0 International Public License ("Public
License"). To the extent this Public License may be interpreted as a
contract, You are granted the Licensed Rights in consideration of Your
acceptance of these terms and conditions, and the Licensor grants You
such rights in consideration of benefits the Licensor receives from
making the Licensed Material available under these terms and
conditions.
Section 1 -- Definitions.
a. Adapted Material means material subject to Copyright and Similar
Rights that is derived from or based upon the Licensed Material
and in which the Licensed Material is translated, altered,
arranged, transformed, or otherwise modified in a manner requiring
permission under the Copyright and Similar Rights held by the
Licensor. For purposes of this Public License, where the Licensed
Material is a musical work, performance, or sound recording,
Adapted Material is always produced where the Licensed Material is
synched in timed relation with a moving image.
b. Adapter's License means the license You apply to Your Copyright
and Similar Rights in Your contributions to Adapted Material in
accordance with the terms and conditions of this Public License.
c. Copyright and Similar Rights means copyright and/or similar rights
closely related to copyright including, without limitation,
performance, broadcast, sound recording, and Sui Generis Database
Rights, without regard to how the rights are labeled or
categorized. For purposes of this Public License, the rights
specified in Section 2(b)(1)-(2) are not Copyright and Similar
Rights.
d. Effective Technological Measures means those measures that, in the
absence of proper authority, may not be circumvented under laws
fulfilling obligations under Article 11 of the WIPO Copyright
Treaty adopted on December 20, 1996, and/or similar international
agreements.
e. Exceptions and Limitations means fair use, fair dealing, and/or
any other exception or limitation to Copyright and Similar Rights
that applies to Your use of the Licensed Material.
f. Licensed Material means the artistic or literary work, database,
or other material to which the Licensor applied this Public
License.
g. Licensed Rights means the rights granted to You subject to the
terms and conditions of this Public License, which are limited to
all Copyright and Similar Rights that apply to Your use of the
Licensed Material and that the Licensor has authority to license.
h. Licensor means the individual(s) or entity(ies) granting rights
under this Public License.
i. NonCommercial means not primarily intended for or directed towards
commercial advantage or monetary compensation. For purposes of
this Public License, the exchange of the Licensed Material for
other material subject to Copyright and Similar Rights by digital
file-sharing or similar means is NonCommercial provided there is
no payment of monetary compensation in connection with the
exchange.
j. Share means to provide material to the public by any means or
process that requires permission under the Licensed Rights, such
as reproduction, public display, public performance, distribution,
dissemination, communication, or importation, and to make material
available to the public including in ways that members of the
public may access the material from a place and at a time
individually chosen by them.
k. Sui Generis Database Rights means rights other than copyright
resulting from Directive 96/9/EC of the European Parliament and of
the Council of 11 March 1996 on the legal protection of databases,
as amended and/or succeeded, as well as other essentially
equivalent rights anywhere in the world.
l. You means the individual or entity exercising the Licensed Rights
under this Public License. Your has a corresponding meaning.
Section 2 -- Scope.
a. License grant.
1. Subject to the terms and conditions of this Public License,
the Licensor hereby grants You a worldwide, royalty-free,
non-sublicensable, non-exclusive, irrevocable license to
exercise the Licensed Rights in the Licensed Material to:
a. reproduce and Share the Licensed Material, in whole or
in part, for NonCommercial purposes only; and
b. produce, reproduce, and Share Adapted Material for
NonCommercial purposes only.
2. Exceptions and Limitations. For the avoidance of doubt, where
Exceptions and Limitations apply to Your use, this Public
License does not apply, and You do not need to comply with
its terms and conditions.
3. Term. The term of this Public License is specified in Section
6(a).
4. Media and formats; technical modifications allowed. The
Licensor authorizes You to exercise the Licensed Rights in
all media and formats whether now known or hereafter created,
and to make technical modifications necessary to do so. The
Licensor waives and/or agrees not to assert any right or
authority to forbid You from making technical modifications
necessary to exercise the Licensed Rights, including
technical modifications necessary to circumvent Effective
Technological Measures. For purposes of this Public License,
simply making modifications authorized by this Section 2(a)
(4) never produces Adapted Material.
5. Downstream recipients.
a. Offer from the Licensor -- Licensed Material. Every
recipient of the Licensed Material automatically
receives an offer from the Licensor to exercise the
Licensed Rights under the terms and conditions of this
Public License.
b. No downstream restrictions. You may not offer or impose
any additional or different terms or conditions on, or
apply any Effective Technological Measures to, the
Licensed Material if doing so restricts exercise of the
Licensed Rights by any recipient of the Licensed
Material.
6. No endorsement. Nothing in this Public License constitutes or
may be construed as permission to assert or imply that You
are, or that Your use of the Licensed Material is, connected
with, or sponsored, endorsed, or granted official status by,
the Licensor or others designated to receive attribution as
provided in Section 3(a)(1)(A)(i).
b. Other rights.
1. Moral rights, such as the right of integrity, are not
licensed under this Public License, nor are publicity,
privacy, and/or other similar personality rights; however, to
the extent possible, the Licensor waives and/or agrees not to
assert any such rights held by the Licensor to the limited
extent necessary to allow You to exercise the Licensed
Rights, but not otherwise.
2. Patent and trademark rights are not licensed under this
Public License.
3. To the extent possible, the Licensor waives any right to
collect royalties from You for the exercise of the Licensed
Rights, whether directly or through a collecting society
under any voluntary or waivable statutory or compulsory
licensing scheme. In all other cases the Licensor expressly
reserves any right to collect such royalties, including when
the Licensed Material is used other than for NonCommercial
purposes.
Section 3 -- License Conditions.
Your exercise of the Licensed Rights is expressly made subject to the
following conditions.
a. Attribution.
1. If You Share the Licensed Material (including in modified
form), You must:
a. retain the following if it is supplied by the Licensor
with the Licensed Material:
i. identification of the creator(s) of the Licensed
Material and any others designated to receive
attribution, in any reasonable manner requested by
the Licensor (including by pseudonym if
designated);
ii. a copyright notice;
iii. a notice that refers to this Public License;
iv. a notice that refers to the disclaimer of
warranties;
v. a URI or hyperlink to the Licensed Material to the
extent reasonably practicable;
b. indicate if You modified the Licensed Material and
retain an indication of any previous modifications; and
c. indicate the Licensed Material is licensed under this
Public License, and include the text of, or the URI or
hyperlink to, this Public License.
2. You may satisfy the conditions in Section 3(a)(1) in any
reasonable manner based on the medium, means, and context in
which You Share the Licensed Material. For example, it may be
reasonable to satisfy the conditions by providing a URI or
hyperlink to a resource that includes the required
information.
3. If requested by the Licensor, You must remove any of the
information required by Section 3(a)(1)(A) to the extent
reasonably practicable.
4. If You Share Adapted Material You produce, the Adapter's
License You apply must not prevent recipients of the Adapted
Material from complying with this Public License.
Section 4 -- Sui Generis Database Rights.
Where the Licensed Rights include Sui Generis Database Rights that
apply to Your use of the Licensed Material:
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
to extract, reuse, reproduce, and Share all or a substantial
portion of the contents of the database for NonCommercial purposes
only;
b. if You include all or a substantial portion of the database
contents in a database in which You have Sui Generis Database
Rights, then the database in which You have Sui Generis Database
Rights (but not its individual contents) is Adapted Material; and
c. You must comply with the conditions in Section 3(a) if You Share
all or a substantial portion of the contents of the database.
For the avoidance of doubt, this Section 4 supplements and does not
replace Your obligations under this Public License where the Licensed
Rights include other Copyright and Similar Rights.
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
c. The disclaimer of warranties and limitation of liability provided
above shall be interpreted in a manner that, to the extent
possible, most closely approximates an absolute disclaimer and
waiver of all liability.
Section 6 -- Term and Termination.
a. This Public License applies for the term of the Copyright and
Similar Rights licensed here. However, if You fail to comply with
this Public License, then Your rights under this Public License
terminate automatically.
b. Where Your right to use the Licensed Material has terminated under
Section 6(a), it reinstates:
1. automatically as of the date the violation is cured, provided
it is cured within 30 days of Your discovery of the
violation; or
2. upon express reinstatement by the Licensor.
For the avoidance of doubt, this Section 6(b) does not affect any
right the Licensor may have to seek remedies for Your violations
of this Public License.
c. For the avoidance of doubt, the Licensor may also offer the
Licensed Material under separate terms or conditions or stop
distributing the Licensed Material at any time; however, doing so
will not terminate this Public License.
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
License.
Section 7 -- Other Terms and Conditions.
a. The Licensor shall not be bound by any additional or different
terms or conditions communicated by You unless expressly agreed.
b. Any arrangements, understandings, or agreements regarding the
Licensed Material not stated herein are separate from and
independent of the terms and conditions of this Public License.
Section 8 -- Interpretation.
a. For the avoidance of doubt, this Public License does not, and
shall not be interpreted to, reduce, limit, restrict, or impose
conditions on any use of the Licensed Material that could lawfully
be made without permission under this Public License.
b. To the extent possible, if any provision of this Public License is
deemed unenforceable, it shall be automatically reformed to the
minimum extent necessary to make it enforceable. If the provision
cannot be reformed, it shall be severed from this Public License
without affecting the enforceability of the remaining terms and
conditions.
c. No term or condition of this Public License will be waived and no
failure to comply consented to unless expressly agreed to by the
Licensor.
d. Nothing in this Public License constitutes or may be interpreted
as a limitation upon, or waiver of, any privileges and immunities
that apply to the Licensor or You, including from the legal
processes of any jurisdiction or authority.
=======================================================================
Creative Commons is not a party to its public
licenses. Notwithstanding, Creative Commons may elect to apply one of
its public licenses to material it publishes and in those instances
will be considered the “Licensor.” The text of the Creative Commons
public licenses is dedicated to the public domain under the CC0 Public
Domain Dedication. Except for the limited purpose of indicating that
material is shared under a Creative Commons public license or as
otherwise permitted by the Creative Commons policies published at
creativecommons.org/policies, Creative Commons does not authorize the
use of the trademark "Creative Commons" or any other trademark or logo
of Creative Commons without its prior written consent including,
without limitation, in connection with any unauthorized modifications
to any of its public licenses or any other arrangements,
understandings, or agreements concerning use of licensed material. For
the avoidance of doubt, this paragraph does not form part of the
public licenses.
Creative Commons may be contacted at creativecommons.org.

94
README.md Normal file
View File

@ -0,0 +1,94 @@
# CoTracker: It is Better to Track Together
**[Meta AI Research, FAIR](https://ai.facebook.com/research/)**; **[University of Oxford, VGG](https://www.robots.ox.ac.uk/~vgg/)**
[Nikita Karaev](https://nikitakaraevv.github.io/), [Ignacio Rocco](https://www.irocco.info/), [Benjamin Graham](https://ai.facebook.com/people/benjamin-graham/), [Natalia Neverova](https://nneverova.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/)
[[`Paper`]()] [[`Project`](https://co-tracker.github.io/)] [[`BibTeX`](#citing-cotracker)]
![bmx-bumps](./assets/bmx-bumps.gif)
**CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow.
CoTracker can track:
- **Every pixel** within a video
- Points sampled on a regular grid on any video frame
- Manually selected points
Try these tracking modes for yourself with our [Colab demo](https://github.com/facebookresearch/co-tracker/notebooks/demo.ipynb).
## Installation Instructions
Ensure you have both PyTorch and TorchVision installed on your system. Follow the instructions [here](https://pytorch.org/get-started/locally/) for the installation. We strongly recommend installing both PyTorch and TorchVision with CUDA support.
## Steps to Install CoTracker and its dependencies:
```
git clone https://github.com/facebookresearch/co-tracker
cd co-tracker
pip install -e .
pip install opencv-python einops timm matplotlib moviepy flow_vis
```
## Model Weights Download:
```
mkdir checkpoints
cd checkpoints
wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth
wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_12.pth
wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_8_wind_16.pth
cd ..
```
## Running the Demo:
Try our [Colab demo](https://github.com/facebookresearch/co-tracker/notebooks/demo.ipynb) or run a local demo with 10*10 points sampled on a grid on the first frame of a video:
```
python demo.py --grid_size 10
```
## Evaluation
To reproduce the results presented in the paper, download the following datasets:
- [TAP-Vid](https://github.com/deepmind/tapnet)
- [BADJA](https://github.com/benjiebob/BADJA)
- [ZJU-Mocap (FastCapture)](https://arxiv.org/abs/2303.11898)
And install the necessary dependencies:
```
pip install hydra-core==1.1.0 mediapy tensorboard
```
Then, execute the following command to evaluate on BADJA:
```
python ./cotracker/evaluation/evaluate.py --config-name eval_badja exp_dir=./eval_outputs dataset_root=your/badja/path
```
## Training
To train the CoTracker as described in our paper, you first need to generate annotations for [Google Kubric](https://github.com/google-research/kubric) MOVI-f dataset. Instructions for annotation generation can be found [here](https://github.com/deepmind/tapnet).
Once you have the annotated dataset, you need to make sure you followed the steps for evaluation setup and install the training dependencies:
```
pip install pytorch_lightning==1.6.0
```
launch training on Kubric. Our model was trained using 32 GPUs, and you can adjust the parameters to best suit your hardware setup.
```
python train.py --batch_size 1 --num_workers 28 \
--num_steps 50000 --ckpt_path ./ --model_name cotracker \
--save_freq 200 --sequence_len 24 --eval_datasets tapvid_davis_first badja \
--traj_per_sample 256 --sliding_window_len 8 --updateformer_space_depth 6 --updateformer_time_depth 6 \
--save_every_n_epoch 10 --evaluate_every_n_epoch 10 --model_stride 4
```
## License
The majority of CoTracker is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Particle Video Revisited is licensed under the MIT license, TAP-Vid is licensed under the Apache 2.0 license.
## Citing CoTracker
If you find our repository useful, please consider giving it a star ⭐ and citing our paper in your work:
```
@article{karaev2023cotracker,
title={CoTracker: It is Better to Track Together},
author={Nikita Karaev and Ignacio Rocco and Benjamin Graham and Natalia Neverova and Andrea Vedaldi and Christian Rupprecht},
journal={arxiv},
year={2023}
}
```

BIN
assets/apple.mp4 Normal file

Binary file not shown.

BIN
assets/apple_mask.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

BIN
assets/bmx-bumps.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.3 MiB

5
cotracker/__init__.py Normal file
View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,390 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import numpy as np
import os
import json
import imageio
import cv2
from enum import Enum
from cotracker.datasets.utils import CoTrackerData, resize_sample
IGNORE_ANIMALS = [
# "bear.json",
# "camel.json",
"cat_jump.json"
# "cows.json",
# "dog.json",
# "dog-agility.json",
# "horsejump-high.json",
# "horsejump-low.json",
# "impala0.json",
# "rs_dog.json"
"tiger.json"
]
class SMALJointCatalog(Enum):
# body_0 = 0
# body_1 = 1
# body_2 = 2
# body_3 = 3
# body_4 = 4
# body_5 = 5
# body_6 = 6
# upper_right_0 = 7
upper_right_1 = 8
upper_right_2 = 9
upper_right_3 = 10
# upper_left_0 = 11
upper_left_1 = 12
upper_left_2 = 13
upper_left_3 = 14
neck_lower = 15
# neck_upper = 16
# lower_right_0 = 17
lower_right_1 = 18
lower_right_2 = 19
lower_right_3 = 20
# lower_left_0 = 21
lower_left_1 = 22
lower_left_2 = 23
lower_left_3 = 24
tail_0 = 25
# tail_1 = 26
# tail_2 = 27
tail_3 = 28
# tail_4 = 29
# tail_5 = 30
tail_6 = 31
jaw = 32
nose = 33 # ADDED JOINT FOR VERTEX 1863
# chin = 34 # ADDED JOINT FOR VERTEX 26
right_ear = 35 # ADDED JOINT FOR VERTEX 149
left_ear = 36 # ADDED JOINT FOR VERTEX 2124
class SMALJointInfo:
def __init__(self):
# These are the
self.annotated_classes = np.array(
[
8,
9,
10, # upper_right
12,
13,
14, # upper_left
15, # neck
18,
19,
20, # lower_right
22,
23,
24, # lower_left
25,
28,
31, # tail
32,
33, # head
35, # right_ear
36,
]
) # left_ear
self.annotated_markers = np.array(
[
cv2.MARKER_CROSS,
cv2.MARKER_STAR,
cv2.MARKER_TRIANGLE_DOWN,
cv2.MARKER_CROSS,
cv2.MARKER_STAR,
cv2.MARKER_TRIANGLE_DOWN,
cv2.MARKER_CROSS,
cv2.MARKER_CROSS,
cv2.MARKER_STAR,
cv2.MARKER_TRIANGLE_DOWN,
cv2.MARKER_CROSS,
cv2.MARKER_STAR,
cv2.MARKER_TRIANGLE_DOWN,
cv2.MARKER_CROSS,
cv2.MARKER_STAR,
cv2.MARKER_TRIANGLE_DOWN,
cv2.MARKER_CROSS,
cv2.MARKER_STAR,
cv2.MARKER_CROSS,
cv2.MARKER_CROSS,
]
)
self.joint_regions = np.array(
[
0,
0,
0,
0,
0,
0,
0,
1,
1,
1,
1,
2,
2,
2,
2,
3,
3,
4,
4,
4,
4,
5,
5,
5,
5,
6,
6,
6,
6,
6,
6,
6,
7,
7,
7,
8,
9,
]
)
self.annotated_joint_region = self.joint_regions[self.annotated_classes]
self.region_colors = np.array(
[
[250, 190, 190], # body, light pink
[60, 180, 75], # upper_right, green
[230, 25, 75], # upper_left, red
[128, 0, 0], # neck, maroon
[0, 130, 200], # lower_right, blue
[255, 255, 25], # lower_left, yellow
[240, 50, 230], # tail, majenta
[245, 130, 48], # jaw / nose / chin, orange
[29, 98, 115], # right_ear, turquoise
[255, 153, 204],
]
) # left_ear, pink
self.joint_colors = np.array(self.region_colors)[self.annotated_joint_region]
class BADJAData:
def __init__(self, data_root, complete=False):
annotations_path = os.path.join(data_root, "joint_annotations")
self.animal_dict = {}
self.animal_count = 0
self.smal_joint_info = SMALJointInfo()
for __, animal_json in enumerate(sorted(os.listdir(annotations_path))):
if animal_json not in IGNORE_ANIMALS:
json_path = os.path.join(annotations_path, animal_json)
with open(json_path) as json_data:
animal_joint_data = json.load(json_data)
filenames = []
segnames = []
joints = []
visible = []
first_path = animal_joint_data[0]["segmentation_path"]
last_path = animal_joint_data[-1]["segmentation_path"]
first_frame = first_path.split("/")[-1]
last_frame = last_path.split("/")[-1]
if not "extra_videos" in first_path:
animal = first_path.split("/")[-2]
first_frame_int = int(first_frame.split(".")[0])
last_frame_int = int(last_frame.split(".")[0])
for fr in range(first_frame_int, last_frame_int + 1):
ref_file_name = os.path.join(
data_root,
"DAVIS/JPEGImages/Full-Resolution/%s/%05d.jpg"
% (animal, fr),
)
ref_seg_name = os.path.join(
data_root,
"DAVIS/Annotations/Full-Resolution/%s/%05d.png"
% (animal, fr),
)
foundit = False
for ind, image_annotation in enumerate(animal_joint_data):
file_name = os.path.join(
data_root, image_annotation["image_path"]
)
seg_name = os.path.join(
data_root, image_annotation["segmentation_path"]
)
if file_name == ref_file_name:
foundit = True
label_ind = ind
if foundit:
image_annotation = animal_joint_data[label_ind]
file_name = os.path.join(
data_root, image_annotation["image_path"]
)
seg_name = os.path.join(
data_root, image_annotation["segmentation_path"]
)
joint = np.array(image_annotation["joints"])
vis = np.array(image_annotation["visibility"])
else:
file_name = ref_file_name
seg_name = ref_seg_name
joint = None
vis = None
filenames.append(file_name)
segnames.append(seg_name)
joints.append(joint)
visible.append(vis)
if len(filenames):
self.animal_dict[self.animal_count] = (
filenames,
segnames,
joints,
visible,
)
self.animal_count += 1
print("Loaded BADJA dataset")
def get_loader(self):
for __ in range(int(1e6)):
animal_id = np.random.choice(len(self.animal_dict.keys()))
filenames, segnames, joints, visible = self.animal_dict[animal_id]
image_id = np.random.randint(0, len(filenames))
seg_file = segnames[image_id]
image_file = filenames[image_id]
joints = joints[image_id].copy()
joints = joints[self.smal_joint_info.annotated_classes]
visible = visible[image_id][self.smal_joint_info.annotated_classes]
rgb_img = imageio.imread(image_file) # , mode='RGB')
sil_img = imageio.imread(seg_file) # , mode='RGB')
rgb_h, rgb_w, _ = rgb_img.shape
sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST)
yield rgb_img, sil_img, joints, visible, image_file
def get_video(self, animal_id):
filenames, segnames, joint, visible = self.animal_dict[animal_id]
rgbs = []
segs = []
joints = []
visibles = []
for s in range(len(filenames)):
image_file = filenames[s]
rgb_img = imageio.imread(image_file) # , mode='RGB')
rgb_h, rgb_w, _ = rgb_img.shape
seg_file = segnames[s]
sil_img = imageio.imread(seg_file) # , mode='RGB')
sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST)
jo = joint[s]
if jo is not None:
joi = joint[s].copy()
joi = joi[self.smal_joint_info.annotated_classes]
vis = visible[s][self.smal_joint_info.annotated_classes]
else:
joi = None
vis = None
rgbs.append(rgb_img)
segs.append(sil_img)
joints.append(joi)
visibles.append(vis)
return rgbs, segs, joints, visibles, filenames[0]
class BadjaDataset(torch.utils.data.Dataset):
def __init__(
self, data_root, max_seq_len=1000, dataset_resolution=(384, 512)
):
self.data_root = data_root
self.badja_data = BADJAData(data_root)
self.max_seq_len = max_seq_len
self.dataset_resolution = dataset_resolution
print(
"found %d unique videos in %s"
% (self.badja_data.animal_count, self.data_root)
)
def __getitem__(self, index):
rgbs, segs, joints, visibles, filename = self.badja_data.get_video(index)
S = len(rgbs)
H, W, __ = rgbs[0].shape
H, W, __ = segs[0].shape
N, __ = joints[0].shape
# let's eliminate the Nones
# note the first one is guaranteed present
for s in range(1, S):
if joints[s] is None:
joints[s] = np.zeros_like(joints[0])
visibles[s] = np.zeros_like(visibles[0])
# eliminate the mystery dim
segs = [seg[:, :, 0] for seg in segs]
rgbs = np.stack(rgbs, 0)
segs = np.stack(segs, 0)
trajs = np.stack(joints, 0)
visibles = np.stack(visibles, 0)
rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0, 3, 1, 2).float()
segs = torch.from_numpy(segs).reshape(S, 1, H, W).float()
trajs = torch.from_numpy(trajs).reshape(S, N, 2).float()
visibles = torch.from_numpy(visibles).reshape(S, N)
rgbs = rgbs[: self.max_seq_len]
segs = segs[: self.max_seq_len]
trajs = trajs[: self.max_seq_len]
visibles = visibles[: self.max_seq_len]
# apparently the coords are in yx order
trajs = torch.flip(trajs, [2])
if "extra_videos" in filename:
seq_name = filename.split("/")[-3]
else:
seq_name = filename.split("/")[-2]
rgbs, trajs, segs = resize_sample(rgbs, trajs, segs, self.dataset_resolution)
return CoTrackerData(rgbs, segs, trajs, visibles, seq_name=seq_name)
def __len__(self):
return self.badja_data.animal_count

View File

@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
# from PIL import Image
import imageio
import numpy as np
from cotracker.datasets.utils import CoTrackerData, resize_sample
class FastCaptureDataset(torch.utils.data.Dataset):
def __init__(
self,
data_root,
max_seq_len=50,
max_num_points=20,
dataset_resolution=(384, 512),
):
self.data_root = data_root
self.seq_names = os.listdir(os.path.join(data_root, "renders_local_rm"))
self.pth_dir = os.path.join(data_root, "zju_tracking")
self.max_seq_len = max_seq_len
self.max_num_points = max_num_points
self.dataset_resolution = dataset_resolution
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
def __getitem__(self, index):
seq_name = self.seq_names[index]
spath = os.path.join(self.data_root, "renders_local_rm", seq_name)
pthpath = os.path.join(self.pth_dir, seq_name + ".pth")
rgbs = []
img_paths = sorted(os.listdir(spath))
for i, img_path in enumerate(img_paths):
if i < self.max_seq_len:
rgbs.append(imageio.imread(os.path.join(spath, img_path)))
annot_dict = torch.load(pthpath)
traj_2d = annot_dict["traj_2d"][:, :, : self.max_seq_len]
visibility = annot_dict["visibility"][:, : self.max_seq_len]
S = len(rgbs)
H, W, __ = rgbs[0].shape
*_, S = traj_2d.shape
visibile_pts_first_frame_inds = (visibility[:, 0] > 0).nonzero(as_tuple=False)[
:, 0
]
torch.manual_seed(0)
point_inds = torch.randperm(len(visibile_pts_first_frame_inds))[
: self.max_num_points
]
visible_inds_sampled = visibile_pts_first_frame_inds[point_inds]
rgbs = np.stack(rgbs, 0)
rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0, 3, 1, 2).float()
segs = torch.ones(S, 1, H, W).float()
trajs = traj_2d[visible_inds_sampled].permute(2, 0, 1).float()
visibles = visibility[visible_inds_sampled].permute(1, 0)
rgbs, trajs, segs = resize_sample(rgbs, trajs, segs, self.dataset_resolution)
return CoTrackerData(rgbs, segs, trajs, visibles, seq_name=seq_name)
def __len__(self):
return len(self.seq_names)

View File

@ -0,0 +1,494 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
import imageio
import numpy as np
from cotracker.datasets.utils import CoTrackerData
from torchvision.transforms import ColorJitter, GaussianBlur
from PIL import Image
import cv2
class CoTrackerDataset(torch.utils.data.Dataset):
def __init__(
self,
data_root,
crop_size=(384, 512),
seq_len=24,
traj_per_sample=768,
sample_vis_1st_frame=False,
use_augs=False,
):
super(CoTrackerDataset, self).__init__()
np.random.seed(0)
torch.manual_seed(0)
self.data_root = data_root
self.seq_len = seq_len
self.traj_per_sample = traj_per_sample
self.sample_vis_1st_frame = sample_vis_1st_frame
self.use_augs = use_augs
self.crop_size = crop_size
# photometric augmentation
self.photo_aug = ColorJitter(
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14
)
self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
self.blur_aug_prob = 0.25
self.color_aug_prob = 0.25
# occlusion augmentation
self.eraser_aug_prob = 0.5
self.eraser_bounds = [2, 100]
self.eraser_max = 10
# occlusion augmentation
self.replace_aug_prob = 0.5
self.replace_bounds = [2, 100]
self.replace_max = 10
# spatial augmentations
self.pad_bounds = [0, 100]
self.crop_size = crop_size
self.resize_lim = [0.25, 2.0] # sample resizes from here
self.resize_delta = 0.2
self.max_crop_offset = 50
self.do_flip = True
self.h_flip_prob = 0.5
self.v_flip_prob = 0.5
def getitem_helper(self, index):
return NotImplementedError
def __getitem__(self, index):
gotit = False
sample, gotit = self.getitem_helper(index)
if not gotit:
print("warning: sampling failed")
# fake sample, so we can still collate
sample = CoTrackerData(
video=torch.zeros(
(self.seq_len, 3, self.crop_size[0], self.crop_size[1])
),
segmentation=torch.zeros(
(self.seq_len, 1, self.crop_size[0], self.crop_size[1])
),
trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)),
visibility=torch.zeros((self.seq_len, self.traj_per_sample)),
valid=torch.zeros((self.seq_len, self.traj_per_sample)),
)
return sample, gotit
def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True):
T, N, _ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
if eraser:
############ eraser transform (per image after the first) ############
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
for i in range(1, S):
if np.random.rand() < self.eraser_aug_prob:
for _ in range(
np.random.randint(1, self.eraser_max + 1)
): # number of times to occlude
xc = np.random.randint(0, W)
yc = np.random.randint(0, H)
dx = np.random.randint(
self.eraser_bounds[0], self.eraser_bounds[1]
)
dy = np.random.randint(
self.eraser_bounds[0], self.eraser_bounds[1]
)
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
mean_color = np.mean(
rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0
)
rgbs[i][y0:y1, x0:x1, :] = mean_color
occ_inds = np.logical_and(
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
)
visibles[i, occ_inds] = 0
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
if replace:
rgbs_alt = [
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
for rgb in rgbs
]
rgbs_alt = [
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
for rgb in rgbs_alt
]
############ replace transform (per image after the first) ############
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt]
for i in range(1, S):
if np.random.rand() < self.replace_aug_prob:
for _ in range(
np.random.randint(1, self.replace_max + 1)
): # number of times to occlude
xc = np.random.randint(0, W)
yc = np.random.randint(0, H)
dx = np.random.randint(
self.replace_bounds[0], self.replace_bounds[1]
)
dy = np.random.randint(
self.replace_bounds[0], self.replace_bounds[1]
)
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
wid = x1 - x0
hei = y1 - y0
y00 = np.random.randint(0, H - hei)
x00 = np.random.randint(0, W - wid)
fr = np.random.randint(0, S)
rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :]
rgbs[i][y0:y1, x0:x1, :] = rep
occ_inds = np.logical_and(
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
)
visibles[i, occ_inds] = 0
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
############ photometric augmentation ############
if np.random.rand() < self.color_aug_prob:
# random per-frame amount of aug
rgbs = [
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
for rgb in rgbs
]
if np.random.rand() < self.blur_aug_prob:
# random per-frame amount of blur
rgbs = [
np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8)
for rgb in rgbs
]
return rgbs, trajs, visibles
def add_spatial_augs(self, rgbs, trajs, visibles):
T, N, __ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
############ spatial transform ############
# padding
pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
rgbs = [
np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs
]
trajs[:, :, 0] += pad_x0
trajs[:, :, 1] += pad_y0
H, W = rgbs[0].shape[:2]
# scaling + stretching
scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1])
scale_x = scale
scale_y = scale
H_new = H
W_new = W
scale_delta_x = 0.0
scale_delta_y = 0.0
rgbs_scaled = []
for s in range(S):
if s == 1:
scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta)
scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta)
elif s > 1:
scale_delta_x = (
scale_delta_x * 0.8
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
)
scale_delta_y = (
scale_delta_y * 0.8
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
)
scale_x = scale_x + scale_delta_x
scale_y = scale_y + scale_delta_y
# bring h/w closer
scale_xy = (scale_x + scale_y) * 0.5
scale_x = scale_x * 0.5 + scale_xy * 0.5
scale_y = scale_y * 0.5 + scale_xy * 0.5
# don't get too crazy
scale_x = np.clip(scale_x, 0.2, 2.0)
scale_y = np.clip(scale_y, 0.2, 2.0)
H_new = int(H * scale_y)
W_new = int(W * scale_x)
# make it at least slightly bigger than the crop area,
# so that the random cropping can add diversity
H_new = np.clip(H_new, self.crop_size[0] + 10, None)
W_new = np.clip(W_new, self.crop_size[1] + 10, None)
# recompute scale in case we clipped
scale_x = W_new / float(W)
scale_y = H_new / float(H)
rgbs_scaled.append(
cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR)
)
trajs[s, :, 0] *= scale_x
trajs[s, :, 1] *= scale_y
rgbs = rgbs_scaled
ok_inds = visibles[0, :] > 0
vis_trajs = trajs[:, ok_inds] # S,?,2
if vis_trajs.shape[1] > 0:
mid_x = np.mean(vis_trajs[0, :, 0])
mid_y = np.mean(vis_trajs[0, :, 1])
else:
mid_y = self.crop_size[0]
mid_x = self.crop_size[1]
x0 = int(mid_x - self.crop_size[1] // 2)
y0 = int(mid_y - self.crop_size[0] // 2)
offset_x = 0
offset_y = 0
for s in range(S):
# on each frame, shift a bit more
if s == 1:
offset_x = np.random.randint(
-self.max_crop_offset, self.max_crop_offset
)
offset_y = np.random.randint(
-self.max_crop_offset, self.max_crop_offset
)
elif s > 1:
offset_x = int(
offset_x * 0.8
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1)
* 0.2
)
offset_y = int(
offset_y * 0.8
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1)
* 0.2
)
x0 = x0 + offset_x
y0 = y0 + offset_y
H_new, W_new = rgbs[s].shape[:2]
if H_new == self.crop_size[0]:
y0 = 0
else:
y0 = min(max(0, y0), H_new - self.crop_size[0] - 1)
if W_new == self.crop_size[1]:
x0 = 0
else:
x0 = min(max(0, x0), W_new - self.crop_size[1] - 1)
rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
trajs[s, :, 0] -= x0
trajs[s, :, 1] -= y0
H_new = self.crop_size[0]
W_new = self.crop_size[1]
# flip
h_flipped = False
v_flipped = False
if self.do_flip:
# h flip
if np.random.rand() < self.h_flip_prob:
h_flipped = True
rgbs = [rgb[:, ::-1] for rgb in rgbs]
# v flip
if np.random.rand() < self.v_flip_prob:
v_flipped = True
rgbs = [rgb[::-1] for rgb in rgbs]
if h_flipped:
trajs[:, :, 0] = W_new - trajs[:, :, 0]
if v_flipped:
trajs[:, :, 1] = H_new - trajs[:, :, 1]
return rgbs, trajs
def crop(self, rgbs, trajs):
T, N, _ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
############ spatial transform ############
H_new = H
W_new = W
# simple random crop
y0 = (
0
if self.crop_size[0] >= H_new
else np.random.randint(0, H_new - self.crop_size[0])
)
x0 = (
0
if self.crop_size[1] >= W_new
else np.random.randint(0, W_new - self.crop_size[1])
)
rgbs = [
rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
for rgb in rgbs
]
trajs[:, :, 0] -= x0
trajs[:, :, 1] -= y0
return rgbs, trajs
class KubricMovifDataset(CoTrackerDataset):
def __init__(
self,
data_root,
crop_size=(384, 512),
seq_len=24,
traj_per_sample=768,
sample_vis_1st_frame=False,
use_augs=False,
):
super(KubricMovifDataset, self).__init__(
data_root=data_root,
crop_size=crop_size,
seq_len=seq_len,
traj_per_sample=traj_per_sample,
sample_vis_1st_frame=sample_vis_1st_frame,
use_augs=use_augs,
)
self.pad_bounds = [0, 25]
self.resize_lim = [0.75, 1.25] # sample resizes from here
self.resize_delta = 0.05
self.max_crop_offset = 15
self.seq_names = [
fname
for fname in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, fname))
]
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
def getitem_helper(self, index):
gotit = True
seq_name = self.seq_names[index]
npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy")
rgb_path = os.path.join(self.data_root, seq_name, "frames")
img_paths = sorted(os.listdir(rgb_path))
rgbs = []
for i, img_path in enumerate(img_paths):
rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path)))
rgbs = np.stack(rgbs)
annot_dict = np.load(npy_path, allow_pickle=True).item()
traj_2d = annot_dict["coords"]
visibility = annot_dict["visibility"]
# random crop
assert self.seq_len <= len(rgbs)
if self.seq_len < len(rgbs):
start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0]
rgbs = rgbs[start_ind : start_ind + self.seq_len]
traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len]
visibility = visibility[:, start_ind : start_ind + self.seq_len]
traj_2d = np.transpose(traj_2d, (1, 0, 2))
visibility = np.transpose(np.logical_not(visibility), (1, 0))
if self.use_augs:
rgbs, traj_2d, visibility = self.add_photometric_augs(
rgbs, traj_2d, visibility
)
rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility)
else:
rgbs, traj_2d = self.crop(rgbs, traj_2d)
visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False
visibility[traj_2d[:, :, 0] < 0] = False
visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False
visibility[traj_2d[:, :, 1] < 0] = False
visibility = torch.from_numpy(visibility)
traj_2d = torch.from_numpy(traj_2d)
visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0]
if self.sample_vis_1st_frame:
visibile_pts_inds = visibile_pts_first_frame_inds
else:
visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(
as_tuple=False
)[:, 0]
visibile_pts_inds = torch.cat(
(visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
)
point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample]
if len(point_inds) < self.traj_per_sample:
gotit = False
visible_inds_sampled = visibile_pts_inds[point_inds]
trajs = traj_2d[:, visible_inds_sampled].float()
visibles = visibility[:, visible_inds_sampled]
valids = torch.ones((self.seq_len, self.traj_per_sample))
rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float()
segs = torch.ones((self.seq_len, 1, self.crop_size[0], self.crop_size[1]))
sample = CoTrackerData(
video=rgbs,
segmentation=segs,
trajectory=trajs,
visibility=visibles,
valid=valids,
seq_name=seq_name,
)
return sample, gotit
def __len__(self):
return len(self.seq_names)

View File

@ -0,0 +1,218 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import io
import glob
import torch
import pickle
import numpy as np
import mediapy as media
from PIL import Image
from typing import Mapping, Tuple, Union
from cotracker.datasets.utils import CoTrackerData
DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]
def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
"""Resize a video to output_size."""
# If you have a GPU, consider replacing this with a GPU-enabled resize op,
# such as a jitted jax.image.resize. It will make things faster.
return media.resize_video(video, output_size)
def sample_queries_first(
target_occluded: np.ndarray,
target_points: np.ndarray,
frames: np.ndarray,
) -> Mapping[str, np.ndarray]:
"""Package a set of frames and tracks for use in TAPNet evaluations.
Given a set of frames and tracks with no query points, use the first
visible point in each track as the query.
Args:
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
where True indicates occluded.
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
is [x,y] scaled between 0 and 1.
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
-1 and 1.
Returns:
A dict with the keys:
video: Video tensor of shape [1, n_frames, height, width, 3]
query_points: Query points of shape [1, n_queries, 3] where
each point is [t, y, x] scaled to the range [-1, 1]
target_points: Target points of shape [1, n_queries, n_frames, 2] where
each point is [x, y] scaled to the range [-1, 1]
"""
valid = np.sum(~target_occluded, axis=1) > 0
target_points = target_points[valid, :]
target_occluded = target_occluded[valid, :]
query_points = []
for i in range(target_points.shape[0]):
index = np.where(target_occluded[i] == 0)[0][0]
x, y = target_points[i, index, 0], target_points[i, index, 1]
query_points.append(np.array([index, y, x])) # [t, y, x]
query_points = np.stack(query_points, axis=0)
return {
"video": frames[np.newaxis, ...],
"query_points": query_points[np.newaxis, ...],
"target_points": target_points[np.newaxis, ...],
"occluded": target_occluded[np.newaxis, ...],
}
def sample_queries_strided(
target_occluded: np.ndarray,
target_points: np.ndarray,
frames: np.ndarray,
query_stride: int = 5,
) -> Mapping[str, np.ndarray]:
"""Package a set of frames and tracks for use in TAPNet evaluations.
Given a set of frames and tracks with no query points, sample queries
strided every query_stride frames, ignoring points that are not visible
at the selected frames.
Args:
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
where True indicates occluded.
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
is [x,y] scaled between 0 and 1.
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
-1 and 1.
query_stride: When sampling query points, search for un-occluded points
every query_stride frames and convert each one into a query.
Returns:
A dict with the keys:
video: Video tensor of shape [1, n_frames, height, width, 3]. The video
has floats scaled to the range [-1, 1].
query_points: Query points of shape [1, n_queries, 3] where
each point is [t, y, x] scaled to the range [-1, 1].
target_points: Target points of shape [1, n_queries, n_frames, 2] where
each point is [x, y] scaled to the range [-1, 1].
trackgroup: Index of the original track that each query point was
sampled from. This is useful for visualization.
"""
tracks = []
occs = []
queries = []
trackgroups = []
total = 0
trackgroup = np.arange(target_occluded.shape[0])
for i in range(0, target_occluded.shape[1], query_stride):
mask = target_occluded[:, i] == 0
query = np.stack(
[
i * np.ones(target_occluded.shape[0:1]),
target_points[:, i, 1],
target_points[:, i, 0],
],
axis=-1,
)
queries.append(query[mask])
tracks.append(target_points[mask])
occs.append(target_occluded[mask])
trackgroups.append(trackgroup[mask])
total += np.array(np.sum(target_occluded[:, i] == 0))
return {
"video": frames[np.newaxis, ...],
"query_points": np.concatenate(queries, axis=0)[np.newaxis, ...],
"target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...],
"occluded": np.concatenate(occs, axis=0)[np.newaxis, ...],
"trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
}
class TapVidDataset(torch.utils.data.Dataset):
def __init__(
self,
data_root,
dataset_type="davis",
resize_to_256=True,
queried_first=True,
):
self.dataset_type = dataset_type
self.resize_to_256 = resize_to_256
self.queried_first = queried_first
if self.dataset_type == "kinetics":
all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl"))
points_dataset = []
for pickle_path in all_paths:
with open(pickle_path, "rb") as f:
data = pickle.load(f)
points_dataset = points_dataset + data
self.points_dataset = points_dataset
else:
with open(data_root, "rb") as f:
self.points_dataset = pickle.load(f)
if self.dataset_type == "davis":
self.video_names = list(self.points_dataset.keys())
print("found %d unique videos in %s" % (len(self.points_dataset), data_root))
def __getitem__(self, index):
if self.dataset_type == "davis":
video_name = self.video_names[index]
else:
video_name = index
video = self.points_dataset[video_name]
frames = video["video"]
if isinstance(frames[0], bytes):
# TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
def decode(frame):
byteio = io.BytesIO(frame)
img = Image.open(byteio)
return np.array(img)
frames = np.array([decode(frame) for frame in frames])
target_points = self.points_dataset[video_name]["points"]
if self.resize_to_256:
frames = resize_video(frames, [256, 256])
target_points *= np.array([256, 256])
else:
target_points *= np.array([frames.shape[2], frames.shape[1]])
T, H, W, C = frames.shape
N, T, D = target_points.shape
target_occ = self.points_dataset[video_name]["occluded"]
if self.queried_first:
converted = sample_queries_first(target_occ, target_points, frames)
else:
converted = sample_queries_strided(target_occ, target_points, frames)
assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
trajs = (
torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float()
) # T, N, D
rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
segs = torch.ones(T, 1, H, W).float()
visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[
0
].permute(
1, 0
) # T, N
query_points = torch.from_numpy(converted["query_points"])[0] # T, N
return CoTrackerData(
rgbs,
segs,
trajs,
visibles,
seq_name=str(video_name),
query_points=query_points,
)
def __len__(self):
return len(self.points_dataset)

114
cotracker/datasets/utils.py Normal file
View File

@ -0,0 +1,114 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import dataclasses
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Any, Optional
@dataclass(eq=False)
class CoTrackerData:
"""
Dataclass for storing video tracks data.
"""
video: torch.Tensor # B, S, C, H, W
segmentation: torch.Tensor # B, S, 1, H, W
trajectory: torch.Tensor # B, S, N, 2
visibility: torch.Tensor # B, S, N
# optional data
valid: Optional[torch.Tensor] = None # B, S, N
seq_name: Optional[str] = None
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
def collate_fn(batch):
"""
Collate function for video tracks data.
"""
video = torch.stack([b.video for b in batch], dim=0)
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
trajectory = torch.stack([b.trajectory for b in batch], dim=0)
visibility = torch.stack([b.visibility for b in batch], dim=0)
query_points = None
if batch[0].query_points is not None:
query_points = torch.stack([b.query_points for b in batch], dim=0)
seq_name = [b.seq_name for b in batch]
return CoTrackerData(
video,
segmentation,
trajectory,
visibility,
seq_name=seq_name,
query_points=query_points,
)
def collate_fn_train(batch):
"""
Collate function for video tracks data during training.
"""
gotit = [gotit for _, gotit in batch]
video = torch.stack([b.video for b, _ in batch], dim=0)
segmentation = torch.stack([b.segmentation for b, _ in batch], dim=0)
trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
valid = torch.stack([b.valid for b, _ in batch], dim=0)
seq_name = [b.seq_name for b, _ in batch]
return (
CoTrackerData(video, segmentation, trajectory, visibility, valid, seq_name),
gotit,
)
def try_to_cuda(t: Any) -> Any:
"""
Try to move the input variable `t` to a cuda device.
Args:
t: Input.
Returns:
t_cuda: `t` moved to a cuda device, if supported.
"""
try:
t = t.float().cuda()
except AttributeError:
pass
return t
def dataclass_to_cuda_(obj):
"""
Move all contents of a dataclass to cuda inplace if supported.
Args:
batch: Input dataclass.
Returns:
batch_cuda: `batch` moved to a cuda device, if supported.
"""
for f in dataclasses.fields(obj):
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
return obj
def resize_sample(rgbs, trajs_g, segs, interp_shape):
S, C, H, W = rgbs.shape
S, N, D = trajs_g.shape
assert D == 2
rgbs = F.interpolate(rgbs, interp_shape, mode="bilinear")
segs = F.interpolate(segs, interp_shape, mode="nearest")
trajs_g[:, :, 0] *= interp_shape[1] / W
trajs_g[:, :, 1] *= interp_shape[0] / H
return rgbs, trajs_g, segs

View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,6 @@
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: badja

View File

@ -0,0 +1,6 @@
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: fastcapture

View File

@ -0,0 +1,6 @@
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: tapvid_davis_first

View File

@ -0,0 +1,6 @@
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: tapvid_davis_strided

View File

@ -0,0 +1,6 @@
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: tapvid_kinetics_first

View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
from typing import Iterable, Mapping, Tuple, Union
def compute_tapvid_metrics(
query_points: np.ndarray,
gt_occluded: np.ndarray,
gt_tracks: np.ndarray,
pred_occluded: np.ndarray,
pred_tracks: np.ndarray,
query_mode: str,
) -> Mapping[str, np.ndarray]:
"""Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)
See the TAP-Vid paper for details on the metric computation. All inputs are
given in raster coordinates. The first three arguments should be the direct
outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
The paper metrics assume these are scaled relative to 256x256 images.
pred_occluded and pred_tracks are your algorithm's predictions.
This function takes a batch of inputs, and computes metrics separately for
each video. The metrics for the full benchmark are a simple mean of the
metrics across the full set of videos. These numbers are between 0 and 1,
but the paper multiplies them by 100 to ease reading.
Args:
query_points: The query points, an in the format [t, y, x]. Its size is
[b, n, 3], where b is the batch size and n is the number of queries
gt_occluded: A boolean array of shape [b, n, t], where t is the number
of frames. True indicates that the point is occluded.
gt_tracks: The target points, of shape [b, n, t, 2]. Each point is
in the format [x, y]
pred_occluded: A boolean array of predicted occlusions, in the same
format as gt_occluded.
pred_tracks: An array of track predictions from your algorithm, in the
same format as gt_tracks.
query_mode: Either 'first' or 'strided', depending on how queries are
sampled. If 'first', we assume the prior knowledge that all points
before the query point are occluded, and these are removed from the
evaluation.
Returns:
A dict with the following keys:
occlusion_accuracy: Accuracy at predicting occlusion.
pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
predicted to be within the given pixel threshold, ignoring occlusion
prediction.
jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
threshold
average_pts_within_thresh: average across pts_within_{x}
average_jaccard: average across jaccard_{x}
"""
metrics = {}
# Don't evaluate the query point. Numpy doesn't have one_hot, so we
# replicate it by indexing into an identity matrix.
one_hot_eye = np.eye(gt_tracks.shape[2])
query_frame = query_points[..., 0]
query_frame = np.round(query_frame).astype(np.int32)
evaluation_points = one_hot_eye[query_frame] == 0
# If we're using the first point on the track as a query, don't evaluate the
# other points.
if query_mode == "first":
for i in range(gt_occluded.shape[0]):
index = np.where(gt_occluded[i] == 0)[0][0]
evaluation_points[i, :index] = False
elif query_mode != "strided":
raise ValueError("Unknown query mode " + query_mode)
# Occlusion accuracy is simply how often the predicted occlusion equals the
# ground truth.
occ_acc = (
np.sum(
np.equal(pred_occluded, gt_occluded) & evaluation_points,
axis=(1, 2),
)
/ np.sum(evaluation_points)
)
metrics["occlusion_accuracy"] = occ_acc
# Next, convert the predictions and ground truth positions into pixel
# coordinates.
visible = np.logical_not(gt_occluded)
pred_visible = np.logical_not(pred_occluded)
all_frac_within = []
all_jaccard = []
for thresh in [1, 2, 4, 8, 16]:
# True positives are points that are within the threshold and where both
# the prediction and the ground truth are listed as visible.
within_dist = (
np.sum(
np.square(pred_tracks - gt_tracks),
axis=-1,
)
< np.square(thresh)
)
is_correct = np.logical_and(within_dist, visible)
# Compute the frac_within_threshold, which is the fraction of points
# within the threshold among points that are visible in the ground truth,
# ignoring whether they're predicted to be visible.
count_correct = np.sum(
is_correct & evaluation_points,
axis=(1, 2),
)
count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))
frac_correct = count_correct / count_visible_points
metrics["pts_within_" + str(thresh)] = frac_correct
all_frac_within.append(frac_correct)
true_positives = np.sum(
is_correct & pred_visible & evaluation_points, axis=(1, 2)
)
# The denominator of the jaccard metric is the true positives plus
# false positives plus false negatives. However, note that true positives
# plus false negatives is simply the number of points in the ground truth
# which is easier to compute than trying to compute all three quantities.
# Thus we just add the number of points in the ground truth to the number
# of false positives.
#
# False positives are simply points that are predicted to be visible,
# but the ground truth is not visible or too far from the prediction.
gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))
false_positives = (~visible) & pred_visible
false_positives = false_positives | ((~within_dist) & pred_visible)
false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))
jaccard = true_positives / (gt_positives + false_positives)
metrics["jaccard_" + str(thresh)] = jaccard
all_jaccard.append(jaccard)
metrics["average_jaccard"] = np.mean(
np.stack(all_jaccard, axis=1),
axis=1,
)
metrics["average_pts_within_thresh"] = np.mean(
np.stack(all_frac_within, axis=1),
axis=1,
)
return metrics

View File

@ -0,0 +1,252 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
import os
from typing import Optional
import torch
from tqdm import tqdm
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from cotracker.datasets.utils import dataclass_to_cuda_
from cotracker.utils.visualizer import Visualizer
from cotracker.models.core.model_utils import reduce_masked_mean
from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics
import logging
class Evaluator:
"""
A class defining the CoTracker evaluator.
"""
def __init__(self, exp_dir) -> None:
# Visualization
self.exp_dir = exp_dir
os.makedirs(exp_dir, exist_ok=True)
self.visualization_filepaths = defaultdict(lambda: defaultdict(list))
self.visualize_dir = os.path.join(exp_dir, "visualisations")
def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name):
if isinstance(pred_trajectory, tuple):
pred_trajectory, pred_visibility = pred_trajectory
else:
pred_visibility = None
if dataset_name == "badja":
sample.segmentation = (sample.segmentation > 0).float()
*_, N, _ = sample.trajectory.shape
accs = []
accs_3px = []
for s1 in range(1, sample.video.shape[1]): # target frame
for n in range(N):
vis = sample.visibility[0, s1, n]
if vis > 0:
coord_e = pred_trajectory[0, s1, n] # 2
coord_g = sample.trajectory[0, s1, n] # 2
dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0))
area = torch.sum(sample.segmentation[0, s1])
# print_('0.2*sqrt(area)', 0.2*torch.sqrt(area))
thr = 0.2 * torch.sqrt(area)
# correct =
accs.append((dist < thr).float())
# print('thr',thr)
accs_3px.append((dist < 3.0).float())
res = torch.mean(torch.stack(accs)) * 100.0
res_3px = torch.mean(torch.stack(accs_3px)) * 100.0
metrics[sample.seq_name[0]] = res.item()
metrics[sample.seq_name[0] + "_accuracy"] = res_3px.item()
print(metrics)
print(
"avg", np.mean([v for k, v in metrics.items() if "accuracy" not in k])
)
print(
"avg acc 3px",
np.mean([v for k, v in metrics.items() if "accuracy" in k]),
)
elif dataset_name == "fastcapture" or ("kubric" in dataset_name):
*_, N, _ = sample.trajectory.shape
accs = []
for s1 in range(1, sample.video.shape[1]): # target frame
for n in range(N):
vis = sample.visibility[0, s1, n]
if vis > 0:
coord_e = pred_trajectory[0, s1, n] # 2
coord_g = sample.trajectory[0, s1, n] # 2
dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0))
thr = 3
correct = (dist < thr).float()
accs.append(correct)
res = torch.mean(torch.stack(accs)) * 100.0
metrics[sample.seq_name[0] + "_accuracy"] = res.item()
print(metrics)
print("avg", np.mean([v for v in metrics.values()]))
elif "tapvid" in dataset_name:
B, T, N, D = sample.trajectory.shape
traj = sample.trajectory.clone()
thr = 0.9
if pred_visibility is None:
logging.warning("visibility is NONE")
pred_visibility = torch.zeros_like(sample.visibility)
if not pred_visibility.dtype == torch.bool:
pred_visibility = pred_visibility > thr
# pred_trajectory
query_points = sample.query_points.clone().cpu().numpy()
pred_visibility = pred_visibility[:, :, :N]
pred_trajectory = pred_trajectory[:, :, :N]
gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy()
gt_occluded = (
torch.logical_not(sample.visibility.clone().permute(0, 2, 1))
.cpu()
.numpy()
)
pred_occluded = (
torch.logical_not(pred_visibility.clone().permute(0, 2, 1))
.cpu()
.numpy()
)
pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy()
out_metrics = compute_tapvid_metrics(
query_points,
gt_occluded,
gt_tracks,
pred_occluded,
pred_tracks,
query_mode="strided" if "strided" in dataset_name else "first",
)
metrics[sample.seq_name[0]] = out_metrics
for metric_name in out_metrics.keys():
if "avg" not in metrics:
metrics["avg"] = {}
metrics["avg"][metric_name] = np.mean(
[v[metric_name] for k, v in metrics.items() if k != "avg"]
)
logging.info(f"Metrics: {out_metrics}")
logging.info(f"avg: {metrics['avg']}")
print("metrics", out_metrics)
print("avg", metrics["avg"])
else:
rgbs = sample.video
trajs_g = sample.trajectory
valids = sample.valid
vis_g = sample.visibility
B, S, C, H, W = rgbs.shape
assert C == 3
B, S, N, D = trajs_g.shape
assert torch.sum(valids) == B * S * N
vis_g = (torch.sum(vis_g, dim=1, keepdim=True) >= 4).float().repeat(1, S, 1)
ate = torch.norm(pred_trajectory - trajs_g, dim=-1) # B, S, N
metrics["things_all"] = reduce_masked_mean(ate, valids).item()
metrics["things_vis"] = reduce_masked_mean(ate, valids * vis_g).item()
metrics["things_occ"] = reduce_masked_mean(
ate, valids * (1.0 - vis_g)
).item()
@torch.no_grad()
def evaluate_sequence(
self,
model,
test_dataloader: torch.utils.data.DataLoader,
dataset_name: str,
train_mode=False,
writer: Optional[SummaryWriter] = None,
step: Optional[int] = 0,
):
metrics = {}
vis = Visualizer(
save_dir=self.exp_dir,
fps=7,
)
for ind, sample in enumerate(tqdm(test_dataloader)):
if isinstance(sample, tuple):
sample, gotit = sample
if not all(gotit):
print("batch is None")
continue
dataclass_to_cuda_(sample)
if (
not train_mode
and hasattr(model, "sequence_len")
and (sample.visibility[:, : model.sequence_len].sum() == 0)
):
print(f"skipping batch {ind}")
continue
if "tapvid" in dataset_name:
queries = sample.query_points.clone().float()
queries = torch.stack(
[
queries[:, :, 0],
queries[:, :, 2],
queries[:, :, 1],
],
dim=2,
)
else:
queries = torch.cat(
[
torch.zeros_like(sample.trajectory[:, 0, :, :1]),
sample.trajectory[:, 0],
],
dim=2,
)
pred_tracks = model(sample.video, queries)
if "strided" in dataset_name:
inv_video = sample.video.flip(1).clone()
inv_queries = queries.clone()
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
pred_trj, pred_vsb = pred_tracks
inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries)
inv_pred_trj = inv_pred_trj.flip(1)
inv_pred_vsb = inv_pred_vsb.flip(1)
mask = pred_trj == 0
pred_trj[mask] = inv_pred_trj[mask]
pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]]
pred_tracks = pred_trj, pred_vsb
if dataset_name == "badja" or dataset_name == "fastcapture":
seq_name = sample.seq_name[0]
else:
seq_name = str(ind)
vis.visualize(
sample.video,
pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
filename=dataset_name + "_" + seq_name,
writer=writer,
step=step,
)
self.compute_metrics(metrics, sample, pred_tracks, dataset_name)
return metrics

View File

@ -0,0 +1,179 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
from dataclasses import dataclass, field
import hydra
import numpy as np
import torch
from omegaconf import OmegaConf
from cotracker.datasets.badja_dataset import BadjaDataset
from cotracker.datasets.fast_capture_dataset import FastCaptureDataset
from cotracker.datasets.tap_vid_datasets import TapVidDataset
from cotracker.datasets.utils import collate_fn
from cotracker.models.evaluation_predictor import EvaluationPredictor
from cotracker.evaluation.core.evaluator import Evaluator
from cotracker.models.build_cotracker import (
build_cotracker,
)
@dataclass(eq=False)
class DefaultConfig:
# Directory where all outputs of the experiment will be saved.
exp_dir: str = "./outputs"
# Name of the dataset to be used for the evaluation.
dataset_name: str = "badja"
# The root directory of the dataset.
dataset_root: str = "./"
# Path to the pre-trained model checkpoint to be used for the evaluation.
# The default value is the path to a specific CoTracker model checkpoint.
# Other available options are commented.
checkpoint: str = "./checkpoints/cotracker_stride_4_wind_8.pth"
# cotracker_stride_4_wind_12
# cotracker_stride_8_wind_16
# EvaluationPredictor parameters
# The size (N) of the support grid used in the predictor.
# The total number of points is (N*N).
grid_size: int = 6
# The size (N) of the local support grid.
local_grid_size: int = 6
# A flag indicating whether to evaluate one ground truth point at a time.
single_point: bool = True
# The number of iterative updates for each sliding window.
n_iters: int = 6
seed: int = 0
gpu_idx: int = 0
# Override hydra's working directory to current working dir,
# also disable storing the .hydra logs:
hydra: dict = field(
default_factory=lambda: {
"run": {"dir": "."},
"output_subdir": None,
}
)
def run_eval(cfg: DefaultConfig):
"""
The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration.
Args:
cfg (DefaultConfig): An instance of DefaultConfig class which includes:
- exp_dir (str): The directory path for the experiment.
- dataset_name (str): The name of the dataset to be used.
- dataset_root (str): The root directory of the dataset.
- checkpoint (str): The path to the CoTracker model's checkpoint.
- single_point (bool): A flag indicating whether to evaluate one ground truth point at a time.
- n_iters (int): The number of iterative updates for each sliding window.
- seed (int): The seed for setting the random state for reproducibility.
- gpu_idx (int): The index of the GPU to be used.
"""
# Creating the experiment directory if it doesn't exist
os.makedirs(cfg.exp_dir, exist_ok=True)
# Saving the experiment configuration to a .yaml file in the experiment directory
cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
with open(cfg_file, "w") as f:
OmegaConf.save(config=cfg, f=f)
evaluator = Evaluator(cfg.exp_dir)
cotracker_model = build_cotracker(cfg.checkpoint)
# Creating the EvaluationPredictor object
predictor = EvaluationPredictor(
cotracker_model,
grid_size=cfg.grid_size,
local_grid_size=cfg.local_grid_size,
single_point=cfg.single_point,
n_iters=cfg.n_iters,
)
# Setting the random seeds
torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)
# Constructing the specified dataset
curr_collate_fn = collate_fn
if cfg.dataset_name == "badja":
test_dataset = BadjaDataset(data_root=os.path.join(cfg.dataset_root, "BADJA"))
elif cfg.dataset_name == "fastcapture":
test_dataset = FastCaptureDataset(
data_root=os.path.join(cfg.dataset_root, "fastcapture"),
max_seq_len=100,
max_num_points=20,
)
elif "tapvid" in cfg.dataset_name:
dataset_type = cfg.dataset_name.split("_")[1]
if dataset_type == "davis":
data_root = os.path.join(cfg.dataset_root, "/tapvid_davis/tapvid_davis.pkl")
elif dataset_type == "kinetics":
data_root = os.path.join(
cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics"
)
test_dataset = TapVidDataset(
dataset_type=dataset_type,
data_root=data_root,
queried_first=not "strided" in cfg.dataset_name,
)
# Creating the DataLoader object
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=14,
collate_fn=curr_collate_fn,
)
# Timing and conducting the evaluation
import time
start = time.time()
evaluate_result = evaluator.evaluate_sequence(
predictor,
test_dataloader,
dataset_name=cfg.dataset_name,
)
end = time.time()
print(end - start)
# Saving the evaluation results to a .json file
if not "tapvid" in cfg.dataset_name:
print("evaluate_result", evaluate_result)
else:
evaluate_result = evaluate_result["avg"]
result_file = os.path.join(cfg.exp_dir, f"result_eval_.json")
evaluate_result["time"] = end - start
print(f"Dumping eval results to {result_file}.")
with open(result_file, "w") as f:
json.dump(evaluate_result, f)
cs = hydra.core.config_store.ConfigStore.instance()
cs.store(name="default_config_eval", node=DefaultConfig)
@hydra.main(config_path="./configs/", config_name="default_config_eval")
def evaluate(cfg: DefaultConfig) -> None:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
run_eval(cfg)
if __name__ == "__main__":
evaluate()

View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from cotracker.models.core.cotracker.cotracker import CoTracker
def build_cotracker(
checkpoint: str,
):
model_name = checkpoint.split("/")[-1].split(".")[0]
if model_name == "cotracker_stride_4_wind_8":
return build_cotracker_stride_4_wind_8(checkpoint=checkpoint)
elif model_name == "cotracker_stride_4_wind_12":
return build_cotracker_stride_4_wind_12(checkpoint=checkpoint)
elif model_name == "cotracker_stride_8_wind_16":
return build_cotracker_stride_8_wind_16(checkpoint=checkpoint)
else:
raise ValueError(f"Unknown model name {model_name}")
# model used to produce the results in the paper
def build_cotracker_stride_4_wind_8(checkpoint=None):
return _build_cotracker(
stride=4,
sequence_len=8,
checkpoint=checkpoint,
)
def build_cotracker_stride_4_wind_12(checkpoint=None):
return _build_cotracker(
stride=4,
sequence_len=12,
checkpoint=checkpoint,
)
# the fastest model
def build_cotracker_stride_8_wind_16(checkpoint=None):
return _build_cotracker(
stride=8,
sequence_len=16,
checkpoint=checkpoint,
)
def _build_cotracker(
stride,
sequence_len,
checkpoint=None,
):
cotracker = CoTracker(
stride=stride,
S=sequence_len,
add_space_attn=True,
space_depth=6,
time_depth=6,
)
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f, map_location="cpu")
if "model" in state_dict:
state_dict = state_dict["model"]
cotracker.load_state_dict(state_dict)
return cotracker

View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,400 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from timm.models.vision_transformer import Attention, Mlp
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes,
planes,
kernel_size=3,
padding=1,
stride=stride,
padding_mode="zeros",
)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BasicEncoder(nn.Module):
def __init__(
self, input_dim=3, output_dim=128, stride=8, norm_fn="batch", dropout=0.0
):
super(BasicEncoder, self).__init__()
self.stride = stride
self.norm_fn = norm_fn
self.in_planes = 64
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes)
self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2)
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(self.in_planes)
self.norm2 = nn.BatchNorm2d(output_dim * 2)
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(self.in_planes)
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(
input_dim,
self.in_planes,
kernel_size=7,
stride=2,
padding=3,
padding_mode="zeros",
)
self.relu1 = nn.ReLU(inplace=True)
self.shallow = False
if self.shallow:
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=2)
self.conv2 = nn.Conv2d(128 + 96 + 64, output_dim, kernel_size=1)
else:
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=2)
self.layer4 = self._make_layer(128, stride=2)
self.conv2 = nn.Conv2d(
128 + 128 + 96 + 64,
output_dim * 2,
kernel_size=3,
padding=1,
padding_mode="zeros",
)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
_, _, H, W = x.shape
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
if self.shallow:
a = self.layer1(x)
b = self.layer2(a)
c = self.layer3(b)
a = F.interpolate(
a,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
b = F.interpolate(
b,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
c = F.interpolate(
c,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
x = self.conv2(torch.cat([a, b, c], dim=1))
else:
a = self.layer1(x)
b = self.layer2(a)
c = self.layer3(b)
d = self.layer4(c)
a = F.interpolate(
a,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
b = F.interpolate(
b,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
c = F.interpolate(
c,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
d = F.interpolate(
d,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
x = self.conv2(torch.cat([a, b, c, d], dim=1))
x = self.norm2(x)
x = self.relu2(x)
x = self.conv3(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
return x
class AttnBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = Attention(
hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=approx_gelu,
drop=0,
)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
"""Wrapper for grid_sample, uses pixel coordinates"""
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1, 1], dim=-1)
# go to 0,1 then 0,2 then -1,1
xgrid = 2 * xgrid / (W - 1) - 1
ygrid = 2 * ygrid / (H - 1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
class CorrBlock:
def __init__(self, fmaps, num_levels=4, radius=4):
B, S, C, H, W = fmaps.shape
self.S, self.C, self.H, self.W = S, C, H, W
self.num_levels = num_levels
self.radius = radius
self.fmaps_pyramid = []
self.fmaps_pyramid.append(fmaps)
for i in range(self.num_levels - 1):
fmaps_ = fmaps.reshape(B * S, C, H, W)
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
_, _, H, W = fmaps_.shape
fmaps = fmaps_.reshape(B, S, C, H, W)
self.fmaps_pyramid.append(fmaps)
def sample(self, coords):
r = self.radius
B, S, N, D = coords.shape
assert D == 2
H, W = self.H, self.W
out_pyramid = []
for i in range(self.num_levels):
corrs = self.corrs_pyramid[i] # B, S, N, H, W
_, _, _, H, W = corrs.shape
dx = torch.linspace(-r, r, 2 * r + 1)
dy = torch.linspace(-r, r, 2 * r + 1)
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
coords.device
)
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords_lvl = centroid_lvl + delta_lvl
corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl)
corrs = corrs.view(B, S, N, -1)
out_pyramid.append(corrs)
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
return out.contiguous().float()
def corr(self, targets):
B, S, N, C = targets.shape
assert C == self.C
assert S == self.S
fmap1 = targets
self.corrs_pyramid = []
for fmaps in self.fmaps_pyramid:
_, _, _, H, W = fmaps.shape
fmap2s = fmaps.view(B, S, C, H * W)
corrs = torch.matmul(fmap1, fmap2s)
corrs = corrs.view(B, S, N, H, W)
corrs = corrs / torch.sqrt(torch.tensor(C).float())
self.corrs_pyramid.append(corrs)
class UpdateFormer(nn.Module):
"""
Transformer model that updates track estimates.
"""
def __init__(
self,
space_depth=12,
time_depth=12,
input_dim=320,
hidden_size=384,
num_heads=8,
output_dim=130,
mlp_ratio=4.0,
add_space_attn=True,
):
super().__init__()
self.out_channels = 2
self.num_heads = num_heads
self.hidden_size = hidden_size
self.add_space_attn = add_space_attn
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
self.time_blocks = nn.ModuleList(
[
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
for _ in range(time_depth)
]
)
if add_space_attn:
self.space_blocks = nn.ModuleList(
[
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
for _ in range(space_depth)
]
)
assert len(self.time_blocks) >= len(self.space_blocks)
self.initialize_weights()
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def forward(self, input_tensor):
x = self.input_transform(input_tensor)
j = 0
for i in range(len(self.time_blocks)):
B, N, T, _ = x.shape
x_time = rearrange(x, "b n t c -> (b n) t c", b=B, t=T, n=N)
x_time = self.time_blocks[i](x_time)
x = rearrange(x_time, "(b n) t c -> b n t c ", b=B, t=T, n=N)
if self.add_space_attn and (
i % (len(self.time_blocks) // len(self.space_blocks)) == 0
):
x_space = rearrange(x, "b n t c -> (b t) n c ", b=B, t=T, n=N)
x_space = self.space_blocks[j](x_space)
x = rearrange(x_space, "(b t) n c -> b n t c ", b=B, t=T, n=N)
j += 1
flow = self.flow_head(x)
return flow

View File

@ -0,0 +1,351 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from einops import rearrange
from cotracker.models.core.cotracker.blocks import (
BasicEncoder,
CorrBlock,
UpdateFormer,
)
from cotracker.models.core.model_utils import meshgrid2d, bilinear_sample2d, smart_cat
from cotracker.models.core.embeddings import (
get_2d_embedding,
get_1d_sincos_pos_embed_from_grid,
get_2d_sincos_pos_embed,
)
torch.manual_seed(0)
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)):
if grid_size == 1:
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[
None, None
].cuda()
grid_y, grid_x = meshgrid2d(
1, grid_size, grid_size, stack=False, norm=False, device="cuda"
)
step = interp_shape[1] // 64
if grid_center[0] != 0 or grid_center[1] != 0:
grid_y = grid_y - grid_size / 2.0
grid_x = grid_x - grid_size / 2.0
grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * (
interp_shape[0] - step * 2
)
grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * (
interp_shape[1] - step * 2
)
grid_y = grid_y + grid_center[0]
grid_x = grid_x + grid_center[1]
xy = torch.stack([grid_x, grid_y], dim=-1).cuda()
return xy
def sample_pos_embed(grid_size, embed_dim, coords):
pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim, grid_size=grid_size)
pos_embed = (
torch.from_numpy(pos_embed)
.reshape(grid_size[0], grid_size[1], embed_dim)
.float()
.unsqueeze(0)
.to(coords.device)
)
sampled_pos_embed = bilinear_sample2d(
pos_embed.permute(0, 3, 1, 2), coords[:, 0, :, 0], coords[:, 0, :, 1]
)
return sampled_pos_embed
class CoTracker(nn.Module):
def __init__(
self,
S=8,
stride=8,
add_space_attn=True,
num_heads=8,
hidden_size=384,
space_depth=12,
time_depth=12,
):
super(CoTracker, self).__init__()
self.S = S
self.stride = stride
self.hidden_dim = 256
self.latent_dim = latent_dim = 128
self.corr_levels = 4
self.corr_radius = 3
self.add_space_attn = add_space_attn
self.fnet = BasicEncoder(
output_dim=self.latent_dim, norm_fn="instance", dropout=0, stride=stride
)
self.updateformer = UpdateFormer(
space_depth=space_depth,
time_depth=time_depth,
input_dim=456,
hidden_size=hidden_size,
num_heads=num_heads,
output_dim=latent_dim + 2,
mlp_ratio=4.0,
add_space_attn=add_space_attn,
)
self.norm = nn.GroupNorm(1, self.latent_dim)
self.ffeat_updater = nn.Sequential(
nn.Linear(self.latent_dim, self.latent_dim),
nn.GELU(),
)
self.vis_predictor = nn.Sequential(
nn.Linear(self.latent_dim, 1),
)
def forward_iteration(
self,
fmaps,
coords_init,
feat_init=None,
vis_init=None,
track_mask=None,
iters=4,
):
B, S_init, N, D = coords_init.shape
assert D == 2
assert B == 1
B, S, __, H8, W8 = fmaps.shape
device = fmaps.device
if S_init < S:
coords = torch.cat(
[coords_init, coords_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
)
vis_init = torch.cat(
[vis_init, vis_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
)
else:
coords = coords_init.clone()
fcorr_fn = CorrBlock(
fmaps, num_levels=self.corr_levels, radius=self.corr_radius
)
ffeats = feat_init.clone()
times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1)
pos_embed = sample_pos_embed(
grid_size=(H8, W8),
embed_dim=456,
coords=coords,
)
pos_embed = rearrange(pos_embed, "b e n -> (b n) e").unsqueeze(1)
times_embed = (
torch.from_numpy(get_1d_sincos_pos_embed_from_grid(456, times_[0]))[None]
.repeat(B, 1, 1)
.float()
.to(device)
)
coord_predictions = []
for __ in range(iters):
coords = coords.detach()
fcorr_fn.corr(ffeats)
fcorrs = fcorr_fn.sample(coords) # B, S, N, LRR
LRR = fcorrs.shape[3]
fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR)
flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
flows_cat = get_2d_embedding(flows_, 64, cat_coords=True)
ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
if track_mask.shape[1] < vis_init.shape[1]:
track_mask = torch.cat(
[
track_mask,
torch.zeros_like(track_mask[:, 0]).repeat(
1, vis_init.shape[1] - track_mask.shape[1], 1, 1
),
],
dim=1,
)
concat = (
torch.cat([track_mask, vis_init], dim=2)
.permute(0, 2, 1, 3)
.reshape(B * N, S, 2)
)
transformer_input = torch.cat([flows_cat, fcorrs_, ffeats_, concat], dim=2)
x = transformer_input + pos_embed + times_embed
x = rearrange(x, "(b n) t d -> b n t d", b=B)
delta = self.updateformer(x)
delta = rearrange(delta, " b n t d -> (b n) t d")
delta_coords_ = delta[:, :, :2]
delta_feats_ = delta[:, :, 2:]
delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N * S, self.latent_dim)
ffeats_ = self.ffeat_updater(self.norm(delta_feats_)) + ffeats_
ffeats = ffeats_.reshape(B, N, S, self.latent_dim).permute(
0, 2, 1, 3
) # B,S,N,C
coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
coord_predictions.append(coords * self.stride)
vis_e = self.vis_predictor(ffeats.reshape(B * S * N, self.latent_dim)).reshape(
B, S, N
)
return coord_predictions, vis_e, feat_init
def forward(self, rgbs, queries, iters=4, feat_init=None, is_train=False):
B, T, C, H, W = rgbs.shape
B, N, __ = queries.shape
device = rgbs.device
assert B == 1
# INIT for the first sequence
# We want to sort points by the first frame they are visible to add them to the tensor of tracked points consequtively
first_positive_inds = queries[:, :, 0].long()
__, sort_inds = torch.sort(first_positive_inds[0], dim=0, descending=False)
inv_sort_inds = torch.argsort(sort_inds, dim=0)
first_positive_sorted_inds = first_positive_inds[0][sort_inds]
assert torch.allclose(
first_positive_inds[0], first_positive_inds[0][sort_inds][inv_sort_inds]
)
coords_init = queries[:, :, 1:].reshape(B, 1, N, 2).repeat(
1, self.S, 1, 1
) / float(self.stride)
rgbs = 2 * (rgbs / 255.0) - 1.0
traj_e = torch.zeros((B, T, N, 2), device=device)
vis_e = torch.zeros((B, T, N), device=device)
ind_array = torch.arange(T, device=device)
ind_array = ind_array[None, :, None].repeat(B, 1, N)
track_mask = (ind_array >= first_positive_inds[:, None, :]).unsqueeze(-1)
# these are logits, so we initialize visibility with something that would give a value close to 1 after softmax
vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10
ind = 0
track_mask_ = track_mask[:, :, sort_inds].clone()
coords_init_ = coords_init[:, :, sort_inds].clone()
vis_init_ = vis_init[:, :, sort_inds].clone()
prev_wind_idx = 0
fmaps_ = None
vis_predictions = []
coord_predictions = []
wind_inds = []
while ind < T - self.S // 2:
rgbs_seq = rgbs[:, ind : ind + self.S]
S = S_local = rgbs_seq.shape[1]
if S < self.S:
rgbs_seq = torch.cat(
[rgbs_seq, rgbs_seq[:, -1, None].repeat(1, self.S - S, 1, 1, 1)],
dim=1,
)
S = rgbs_seq.shape[1]
rgbs_ = rgbs_seq.reshape(B * S, C, H, W)
if fmaps_ is None:
fmaps_ = self.fnet(rgbs_)
else:
fmaps_ = torch.cat(
[fmaps_[self.S // 2 :], self.fnet(rgbs_[self.S // 2 :])], dim=0
)
fmaps = fmaps_.reshape(
B, S, self.latent_dim, H // self.stride, W // self.stride
)
curr_wind_points = torch.nonzero(first_positive_sorted_inds < ind + self.S)
if curr_wind_points.shape[0] == 0:
ind = ind + self.S // 2
continue
wind_idx = curr_wind_points[-1] + 1
if wind_idx - prev_wind_idx > 0:
fmaps_sample = fmaps[
:, first_positive_sorted_inds[prev_wind_idx:wind_idx] - ind
]
feat_init_ = bilinear_sample2d(
fmaps_sample,
coords_init_[:, 0, prev_wind_idx:wind_idx, 0],
coords_init_[:, 0, prev_wind_idx:wind_idx, 1],
).permute(0, 2, 1)
feat_init_ = feat_init_.unsqueeze(1).repeat(1, self.S, 1, 1)
feat_init = smart_cat(feat_init, feat_init_, dim=2)
if prev_wind_idx > 0:
new_coords = coords[-1][:, self.S // 2 :] / float(self.stride)
coords_init_[:, : self.S // 2, :prev_wind_idx] = new_coords
coords_init_[:, self.S // 2 :, :prev_wind_idx] = new_coords[
:, -1
].repeat(1, self.S // 2, 1, 1)
new_vis = vis[:, self.S // 2 :].unsqueeze(-1)
vis_init_[:, : self.S // 2, :prev_wind_idx] = new_vis
vis_init_[:, self.S // 2 :, :prev_wind_idx] = new_vis[:, -1].repeat(
1, self.S // 2, 1, 1
)
coords, vis, __ = self.forward_iteration(
fmaps=fmaps,
coords_init=coords_init_[:, :, :wind_idx],
feat_init=feat_init[:, :, :wind_idx],
vis_init=vis_init_[:, :, :wind_idx],
track_mask=track_mask_[:, ind : ind + self.S, :wind_idx],
iters=iters,
)
if is_train:
vis_predictions.append(torch.sigmoid(vis[:, :S_local]))
coord_predictions.append([coord[:, :S_local] for coord in coords])
wind_inds.append(wind_idx)
traj_e[:, ind : ind + self.S, :wind_idx] = coords[-1][:, :S_local]
vis_e[:, ind : ind + self.S, :wind_idx] = vis[:, :S_local]
track_mask_[:, : ind + self.S, :wind_idx] = 0.0
ind = ind + self.S // 2
prev_wind_idx = wind_idx
traj_e = traj_e[:, :, inv_sort_inds]
vis_e = vis_e[:, :, inv_sort_inds]
vis_e = torch.sigmoid(vis_e)
train_data = (
(vis_predictions, coord_predictions, wind_inds, sort_inds)
if is_train
else None
)
return traj_e, feat_init, vis_e, train_data

View File

@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from cotracker.models.core.model_utils import reduce_masked_mean
EPS = 1e-6
def balanced_ce_loss(pred, gt, valid=None):
total_balanced_loss = 0.0
for j in range(len(gt)):
B, S, N = gt[j].shape
# pred and gt are the same shape
for (a, b) in zip(pred[j].size(), gt[j].size()):
assert a == b # some shape mismatch!
# if valid is not None:
for (a, b) in zip(pred[j].size(), valid[j].size()):
assert a == b # some shape mismatch!
pos = (gt[j] > 0.95).float()
neg = (gt[j] < 0.05).float()
label = pos * 2.0 - 1.0
a = -label * pred[j]
b = F.relu(a)
loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
pos_loss = reduce_masked_mean(loss, pos * valid[j])
neg_loss = reduce_masked_mean(loss, neg * valid[j])
balanced_loss = pos_loss + neg_loss
total_balanced_loss += balanced_loss / float(N)
return total_balanced_loss
def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8):
"""Loss function defined over sequence of flow predictions"""
total_flow_loss = 0.0
for j in range(len(flow_gt)):
B, S, N, D = flow_gt[j].shape
assert D == 2
B, S1, N = vis[j].shape
B, S2, N = valids[j].shape
assert S == S1
assert S == S2
n_predictions = len(flow_preds[j])
flow_loss = 0.0
for i in range(n_predictions):
i_weight = gamma ** (n_predictions - i - 1)
flow_pred = flow_preds[j][i]
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
i_loss = torch.mean(i_loss, dim=3) # B, S, N
flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j])
flow_loss = flow_loss / n_predictions
total_flow_loss += flow_loss / float(N)
return total_flow_loss

View File

@ -0,0 +1,154 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import numpy as np
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, tuple):
grid_size_h, grid_size_w = grid_size
else:
grid_size_h = grid_size_w = grid_size
grid_h = np.arange(grid_size_h, dtype=np.float32)
grid_w = np.arange(grid_size_w, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate(
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def get_2d_embedding(xy, C, cat_coords=True):
B, N, D = xy.shape
assert D == 2
x = xy[:, :, 0:1]
y = xy[:, :, 1:2]
div_term = (
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
).reshape(1, 1, int(C / 2))
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
pe_x[:, :, 0::2] = torch.sin(x * div_term)
pe_x[:, :, 1::2] = torch.cos(x * div_term)
pe_y[:, :, 0::2] = torch.sin(y * div_term)
pe_y[:, :, 1::2] = torch.cos(y * div_term)
pe = torch.cat([pe_x, pe_y], dim=2) # B, N, C*3
if cat_coords:
pe = torch.cat([xy, pe], dim=2) # B, N, C*3+3
return pe
def get_3d_embedding(xyz, C, cat_coords=True):
B, N, D = xyz.shape
assert D == 3
x = xyz[:, :, 0:1]
y = xyz[:, :, 1:2]
z = xyz[:, :, 2:3]
div_term = (
torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C)
).reshape(1, 1, int(C / 2))
pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
pe_x[:, :, 0::2] = torch.sin(x * div_term)
pe_x[:, :, 1::2] = torch.cos(x * div_term)
pe_y[:, :, 0::2] = torch.sin(y * div_term)
pe_y[:, :, 1::2] = torch.cos(y * div_term)
pe_z[:, :, 0::2] = torch.sin(z * div_term)
pe_z[:, :, 1::2] = torch.cos(z * div_term)
pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3
if cat_coords:
pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3
return pe
def get_4d_embedding(xyzw, C, cat_coords=True):
B, N, D = xyzw.shape
assert D == 4
x = xyzw[:, :, 0:1]
y = xyzw[:, :, 1:2]
z = xyzw[:, :, 2:3]
w = xyzw[:, :, 3:4]
div_term = (
torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C)
).reshape(1, 1, int(C / 2))
pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
pe_x[:, :, 0::2] = torch.sin(x * div_term)
pe_x[:, :, 1::2] = torch.cos(x * div_term)
pe_y[:, :, 0::2] = torch.sin(y * div_term)
pe_y[:, :, 1::2] = torch.cos(y * div_term)
pe_z[:, :, 0::2] = torch.sin(z * div_term)
pe_z[:, :, 1::2] = torch.cos(z * div_term)
pe_w[:, :, 0::2] = torch.sin(w * div_term)
pe_w[:, :, 1::2] = torch.cos(w * div_term)
pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2) # B, N, C*3
if cat_coords:
pe = torch.cat([pe, xyzw], dim=2) # B, N, C*3+3
return pe

View File

@ -0,0 +1,169 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
EPS = 1e-6
def smart_cat(tensor1, tensor2, dim):
if tensor1 is None:
return tensor2
return torch.cat([tensor1, tensor2], dim=dim)
def normalize_single(d):
# d is a whatever shape torch tensor
dmin = torch.min(d)
dmax = torch.max(d)
d = (d - dmin) / (EPS + (dmax - dmin))
return d
def normalize(d):
# d is B x whatever. normalize within each element of the batch
out = torch.zeros(d.size())
if d.is_cuda:
out = out.cuda()
B = list(d.size())[0]
for b in list(range(B)):
out[b] = normalize_single(d[b])
return out
def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"):
# returns a meshgrid sized B x Y x X
grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))
grid_y = torch.reshape(grid_y, [1, Y, 1])
grid_y = grid_y.repeat(B, 1, X)
grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device))
grid_x = torch.reshape(grid_x, [1, 1, X])
grid_x = grid_x.repeat(B, Y, 1)
if stack:
# note we stack in xy order
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
grid = torch.stack([grid_x, grid_y], dim=-1)
return grid
else:
return grid_y, grid_x
def reduce_masked_mean(x, mask, dim=None, keepdim=False):
# x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
# returns shape-1
# axis can be a list of axes
for (a, b) in zip(x.size(), mask.size()):
assert a == b # some shape mismatch!
prod = x * mask
if dim is None:
numer = torch.sum(prod)
denom = EPS + torch.sum(mask)
else:
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim)
mean = numer / denom
return mean
def bilinear_sample2d(im, x, y, return_inbounds=False):
# x and y are each B, N
# output is B, C, N
if len(im.shape) == 5:
B, N, C, H, W = list(im.shape)
else:
B, C, H, W = list(im.shape)
N = list(x.shape)[1]
x = x.float()
y = y.float()
H_f = torch.tensor(H, dtype=torch.float32)
W_f = torch.tensor(W, dtype=torch.float32)
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
max_y = (H_f - 1).int()
max_x = (W_f - 1).int()
x0 = torch.floor(x).int()
x1 = x0 + 1
y0 = torch.floor(y).int()
y1 = y0 + 1
x0_clip = torch.clamp(x0, 0, max_x)
x1_clip = torch.clamp(x1, 0, max_x)
y0_clip = torch.clamp(y0, 0, max_y)
y1_clip = torch.clamp(y1, 0, max_y)
dim2 = W
dim1 = W * H
base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1
base = torch.reshape(base, [B, 1]).repeat([1, N])
base_y0 = base + y0_clip * dim2
base_y1 = base + y1_clip * dim2
idx_y0_x0 = base_y0 + x0_clip
idx_y0_x1 = base_y0 + x1_clip
idx_y1_x0 = base_y1 + x0_clip
idx_y1_x1 = base_y1 + x1_clip
# use the indices to lookup pixels in the flat image
# im is B x C x H x W
# move C out to last dim
if len(im.shape) == 5:
im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C)
i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute(
0, 2, 1
)
i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute(
0, 2, 1
)
i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute(
0, 2, 1
)
i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute(
0, 2, 1
)
else:
im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C)
i_y0_x0 = im_flat[idx_y0_x0.long()]
i_y0_x1 = im_flat[idx_y0_x1.long()]
i_y1_x0 = im_flat[idx_y1_x0.long()]
i_y1_x1 = im_flat[idx_y1_x1.long()]
# Finally calculate interpolated values.
x0_f = x0.float()
x1_f = x1.float()
y0_f = y0.float()
y1_f = y1.float()
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
output = (
w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
)
# output is B*N x C
output = output.view(B, -1, C)
output = output.permute(0, 2, 1)
# output is B x C x N
if return_inbounds:
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
inbounds = (x_valid & y_valid).float()
inbounds = inbounds.reshape(
B, N
) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
return output, inbounds
return output # B, C, N

View File

@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from typing import Tuple
from cotracker.models.core.cotracker.cotracker import CoTracker, get_points_on_a_grid
class EvaluationPredictor(torch.nn.Module):
def __init__(
self,
cotracker_model: CoTracker,
interp_shape: Tuple[int, int] = (384, 512),
grid_size: int = 6,
local_grid_size: int = 6,
single_point: bool = True,
n_iters: int = 6,
) -> None:
super(EvaluationPredictor, self).__init__()
self.grid_size = grid_size
self.local_grid_size = local_grid_size
self.single_point = single_point
self.interp_shape = interp_shape
self.n_iters = n_iters
self.model = cotracker_model
self.model.to("cuda")
self.model.eval()
def forward(self, video, queries):
queries = queries.clone().cuda()
B, T, C, H, W = video.shape
B, N, D = queries.shape
assert D == 3
assert B == 1
rgbs = video.reshape(B * T, C, H, W)
rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear")
rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]).cuda()
queries[:, :, 1] *= self.interp_shape[1] / W
queries[:, :, 2] *= self.interp_shape[0] / H
if self.single_point:
traj_e = torch.zeros((B, T, N, 2)).cuda()
vis_e = torch.zeros((B, T, N)).cuda()
for pind in range((N)):
query = queries[:, pind : pind + 1]
t = query[0, 0, 0].long()
traj_e_pind, vis_e_pind = self._process_one_point(rgbs, query)
traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
else:
if self.grid_size > 0:
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() #
queries = torch.cat([queries, xy], dim=1) #
traj_e, __, vis_e, __ = self.model(
rgbs=rgbs,
queries=queries,
iters=self.n_iters,
)
traj_e[:, :, :, 0] *= W / float(self.interp_shape[1])
traj_e[:, :, :, 1] *= H / float(self.interp_shape[0])
return traj_e, vis_e
def _process_one_point(self, rgbs, query):
t = query[0, 0, 0].long()
device = rgbs.device
if self.local_grid_size > 0:
xy_target = get_points_on_a_grid(
self.local_grid_size,
(50, 50),
[query[0, 0, 2], query[0, 0, 1]],
)
xy_target = torch.cat(
[torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2
) #
query = torch.cat([query, xy_target], dim=1).to(device) #
if self.grid_size > 0:
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() #
query = torch.cat([query, xy], dim=1).to(device) #
# crop the video to start from the queried frame
query[0, 0, 0] = 0
traj_e_pind, __, vis_e_pind, __ = self.model(
rgbs=rgbs[:, t:], queries=query, iters=self.n_iters
)
return traj_e_pind, vis_e_pind

178
cotracker/predictor.py Normal file
View File

@ -0,0 +1,178 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from tqdm import tqdm
from cotracker.models.core.cotracker.cotracker import get_points_on_a_grid
from cotracker.models.core.model_utils import smart_cat
from cotracker.models.build_cotracker import (
build_cotracker,
)
class CoTrackerPredictor(torch.nn.Module):
def __init__(
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth"
):
super().__init__()
self.interp_shape = (384, 512)
self.support_grid_size = 6
model = build_cotracker(checkpoint)
self.model = model
self.model.to("cuda")
self.model.eval()
@torch.no_grad()
def forward(
self,
video, # (1, T, 3, H, W)
# input prompt types:
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
# *backward_tracking=True* will compute tracks in both directions.
# - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates.
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
queries: torch.Tensor = None,
segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
grid_size: int = 0,
grid_query_frame: int = 0, # only for dense and regular grid tracks
backward_tracking: bool = False,
):
if queries is None and grid_size == 0:
tracks, visibilities = self._compute_dense_tracks(
video,
grid_query_frame=grid_query_frame,
backward_tracking=backward_tracking,
)
else:
tracks, visibilities = self._compute_sparse_tracks(
video,
queries,
segm_mask,
grid_size,
add_support_grid=(grid_size == 0 or segm_mask is not None),
grid_query_frame=grid_query_frame,
backward_tracking=backward_tracking,
)
return tracks, visibilities
def _compute_dense_tracks(
self, video, grid_query_frame, grid_size=50, backward_tracking=False
):
*_, H, W = video.shape
grid_step = W // grid_size
grid_width = W // grid_step
grid_height = H // grid_step
tracks = visibilities = None
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to("cuda")
grid_pts[0, :, 0] = grid_query_frame
for offset in tqdm(range(grid_step * grid_step)):
ox = offset % grid_step
oy = offset // grid_step
grid_pts[0, :, 1] = (
torch.arange(grid_width).repeat(grid_height) * grid_step + ox
)
grid_pts[0, :, 2] = (
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
)
tracks_step, visibilities_step = self._compute_sparse_tracks(
video=video,
queries=grid_pts,
backward_tracking=backward_tracking,
)
tracks = smart_cat(tracks, tracks_step, dim=2)
visibilities = smart_cat(visibilities, visibilities_step, dim=2)
return tracks, visibilities
def _compute_sparse_tracks(
self,
video,
queries,
segm_mask=None,
grid_size=0,
add_support_grid=False,
grid_query_frame=0,
backward_tracking=False,
):
B, T, C, H, W = video.shape
assert B == 1
video = video.reshape(B * T, C, H, W)
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").cuda()
video = video.reshape(
B, T, 3, self.interp_shape[0], self.interp_shape[1]
).cuda()
if queries is not None:
queries = queries.clone()
B, N, D = queries.shape
assert D == 3
queries[:, :, 1] *= self.interp_shape[1] / W
queries[:, :, 2] *= self.interp_shape[0] / H
elif grid_size > 0:
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape)
if segm_mask is not None:
segm_mask = F.interpolate(
segm_mask, tuple(self.interp_shape), mode="nearest"
)
point_mask = segm_mask[0, 0][
(grid_pts[0, :, 1]).round().long().cpu(),
(grid_pts[0, :, 0]).round().long().cpu(),
].bool()
grid_pts = grid_pts[:, point_mask]
queries = torch.cat(
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
dim=2,
)
if add_support_grid:
grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape)
grid_pts = torch.cat(
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
)
queries = torch.cat([queries, grid_pts], dim=1)
tracks, __, visibilities, __ = self.model(rgbs=video, queries=queries, iters=6)
if backward_tracking:
tracks, visibilities = self._compute_backward_tracks(
video, queries, tracks, visibilities
)
if add_support_grid:
queries[:, -self.support_grid_size ** 2 :, 0] = T - 1
if add_support_grid:
tracks = tracks[:, :, : -self.support_grid_size ** 2]
visibilities = visibilities[:, :, : -self.support_grid_size ** 2]
thr = 0.9
visibilities = visibilities > thr
tracks[:, :, :, 0] *= W / float(self.interp_shape[1])
tracks[:, :, :, 1] *= H / float(self.interp_shape[0])
return tracks, visibilities
def _compute_backward_tracks(self, video, queries, tracks, visibilities):
inv_video = video.flip(1).clone()
inv_queries = queries.clone()
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
inv_tracks, __, inv_visibilities, __ = self.model(
rgbs=inv_video, queries=inv_queries, iters=6
)
inv_tracks = inv_tracks.flip(1)
inv_visibilities = inv_visibilities.flip(1)
mask = tracks == 0
tracks[mask] = inv_tracks[mask]
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
return tracks, visibilities

View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,291 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
import cv2
import torch
import flow_vis
from matplotlib import cm
import torch.nn.functional as F
import torchvision.transforms as transforms
from moviepy.editor import ImageSequenceClip
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
class Visualizer:
def __init__(
self,
save_dir: str = "./results",
grayscale: bool = False,
pad_value: int = 0,
fps: int = 10,
mode: str = "rainbow", # 'cool', 'optical_flow'
linewidth: int = 2,
show_first_frame: int = 10,
tracks_leave_trace: int = 0, # -1 for infinite
):
self.mode = mode
self.save_dir = save_dir
if mode == "rainbow":
self.color_map = cm.get_cmap("gist_rainbow")
elif mode == "cool":
self.color_map = cm.get_cmap(mode)
self.show_first_frame = show_first_frame
self.grayscale = grayscale
self.tracks_leave_trace = tracks_leave_trace
self.pad_value = pad_value
self.linewidth = linewidth
self.fps = fps
def visualize(
self,
video: torch.Tensor, # (B,T,C,H,W)
tracks: torch.Tensor, # (B,T,N,2)
gt_tracks: torch.Tensor = None, # (B,T,N,2)
segm_mask: torch.Tensor = None, # (B,1,H,W)
filename: str = "video",
writer: SummaryWriter = None,
step: int = 0,
query_frame: int = 0,
save_video: bool = True,
compensate_for_camera_motion: bool = False,
):
if compensate_for_camera_motion:
assert segm_mask is not None
if segm_mask is not None:
coords = tracks[0, query_frame].round().long()
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
video = F.pad(
video,
(self.pad_value, self.pad_value, self.pad_value, self.pad_value),
"constant",
255,
)
tracks = tracks + self.pad_value
if self.grayscale:
transform = transforms.Grayscale()
video = transform(video)
video = video.repeat(1, 1, 3, 1, 1)
res_video = self.draw_tracks_on_video(
video=video,
tracks=tracks,
segm_mask=segm_mask,
gt_tracks=gt_tracks,
query_frame=query_frame,
compensate_for_camera_motion=compensate_for_camera_motion,
)
if save_video:
self.save_video(res_video, filename=filename, writer=writer, step=step)
return res_video
def save_video(self, video, filename, writer=None, step=0):
if writer is not None:
writer.add_video(
f"{filename}_pred_track",
video.to(torch.uint8),
global_step=step,
fps=self.fps,
)
else:
os.makedirs(self.save_dir, exist_ok=True)
wide_list = list(video.unbind(1))
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps)
# Write the video file
save_path = os.path.join(self.save_dir, f"{filename}_pred_track.mp4")
clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None)
print(f"Video saved to {save_path}")
def draw_tracks_on_video(
self,
video: torch.Tensor,
tracks: torch.Tensor,
segm_mask: torch.Tensor = None,
gt_tracks=None,
query_frame: int = 0,
compensate_for_camera_motion=False,
):
B, T, C, H, W = video.shape
_, _, N, D = tracks.shape
assert D == 2
assert C == 3
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
if gt_tracks is not None:
gt_tracks = gt_tracks[0].detach().cpu().numpy()
res_video = []
# process input video
for rgb in video:
res_video.append(rgb.copy())
vector_colors = np.zeros((T, N, 3))
if self.mode == "optical_flow":
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
elif segm_mask is None:
if self.mode == "rainbow":
y_min, y_max = (
tracks[query_frame, :, 1].min(),
tracks[query_frame, :, 1].max(),
)
norm = plt.Normalize(y_min, y_max)
for n in range(N):
color = self.color_map(norm(tracks[query_frame, n, 1]))
color = np.array(color[:3])[None] * 255
vector_colors[:, n] = np.repeat(color, T, axis=0)
else:
# color changes with time
for t in range(T):
color = np.array(self.color_map(t / T)[:3])[None] * 255
vector_colors[t] = np.repeat(color, N, axis=0)
else:
if self.mode == "rainbow":
vector_colors[:, segm_mask <= 0, :] = 255
y_min, y_max = (
tracks[0, segm_mask > 0, 1].min(),
tracks[0, segm_mask > 0, 1].max(),
)
norm = plt.Normalize(y_min, y_max)
for n in range(N):
if segm_mask[n] > 0:
color = self.color_map(norm(tracks[0, n, 1]))
color = np.array(color[:3])[None] * 255
vector_colors[:, n] = np.repeat(color, T, axis=0)
else:
# color changes with segm class
segm_mask = segm_mask.cpu()
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
vector_colors = np.repeat(color[None], T, axis=0)
# draw tracks
if self.tracks_leave_trace != 0:
for t in range(1, T):
first_ind = (
max(0, t - self.tracks_leave_trace)
if self.tracks_leave_trace >= 0
else 0
)
curr_tracks = tracks[first_ind : t + 1]
curr_colors = vector_colors[first_ind : t + 1]
if compensate_for_camera_motion:
diff = (
tracks[first_ind : t + 1, segm_mask <= 0]
- tracks[t : t + 1, segm_mask <= 0]
).mean(1)[:, None]
curr_tracks = curr_tracks - diff
curr_tracks = curr_tracks[:, segm_mask > 0]
curr_colors = curr_colors[:, segm_mask > 0]
res_video[t] = self._draw_pred_tracks(
res_video[t],
curr_tracks,
curr_colors,
)
if gt_tracks is not None:
res_video[t] = self._draw_gt_tracks(
res_video[t], gt_tracks[first_ind : t + 1]
)
# draw points
for t in range(T):
for i in range(N):
coord = (tracks[t, i, 0], tracks[t, i, 1])
if coord[0] != 0 and coord[1] != 0:
if not compensate_for_camera_motion or (
compensate_for_camera_motion and segm_mask[i] > 0
):
cv2.circle(
res_video[t],
coord,
int(self.linewidth * 2),
vector_colors[t, i].tolist(),
-1,
)
# construct the final rgb sequence
if self.show_first_frame > 0:
res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
def _draw_pred_tracks(
self,
rgb: np.ndarray, # H x W x 3
tracks: np.ndarray, # T x 2
vector_colors: np.ndarray,
alpha: float = 0.5,
):
T, N, _ = tracks.shape
for s in range(T - 1):
vector_color = vector_colors[s]
original = rgb.copy()
alpha = (s / T) ** 2
for i in range(N):
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
if coord_y[0] != 0 and coord_y[1] != 0:
cv2.line(
rgb,
coord_y,
coord_x,
vector_color[i].tolist(),
self.linewidth,
cv2.LINE_AA,
)
if self.tracks_leave_trace > 0:
rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0)
return rgb
def _draw_gt_tracks(
self,
rgb: np.ndarray, # H x W x 3,
gt_tracks: np.ndarray, # T x 2
):
T, N, _ = gt_tracks.shape
color = np.array((211.0, 0.0, 0.0))
for t in range(T):
for i in range(N):
gt_tracks = gt_tracks[t][i]
# draw a red cross
if gt_tracks[0] > 0 and gt_tracks[1] > 0:
length = self.linewidth * 3
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
cv2.line(
rgb,
coord_y,
coord_x,
color,
self.linewidth,
cv2.LINE_AA,
)
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
cv2.line(
rgb,
coord_y,
coord_x,
color,
self.linewidth,
cv2.LINE_AA,
)
return rgb

71
demo.py Normal file
View File

@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
import argparse
import numpy as np
from torchvision.io import read_video
from PIL import Image
from cotracker.utils.visualizer import Visualizer
from cotracker.predictor import CoTrackerPredictor
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--video_path",
default="./assets/apple.mp4",
help="path to a video",
)
parser.add_argument(
"--mask_path",
default="./assets/apple_mask.png",
help="path to a segmentation mask",
)
parser.add_argument(
"--checkpoint",
default="./checkpoints/cotracker_stride_4_wind_8.pth",
help="cotracker model",
)
parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size")
parser.add_argument(
"--grid_query_frame",
type=int,
default=0,
help="Compute dense and grid tracks starting from this frame ",
)
parser.add_argument(
"--backward_tracking",
action="store_true",
help="Compute tracks in both directions, not only forward",
)
args = parser.parse_args()
# load the input video frame by frame
video = read_video(args.video_path)
video = video[0].permute(0, 3, 1, 2)[None].float()
segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
segm_mask = torch.from_numpy(segm_mask)[None, None]
model = CoTrackerPredictor(checkpoint=args.checkpoint)
pred_tracks, pred_visibility = model(
video,
grid_size=args.grid_size,
grid_query_frame=args.grid_query_frame,
backward_tracking=args.backward_tracking,
# segm_mask=segm_mask
)
print("computed")
# save a video with predicted tracks
seq_name = args.video_path.split("/")[-1]
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
vis.visualize(video, pred_tracks, query_frame=args.grid_query_frame)

924
notebooks/demo.ipynb Normal file

File diff suppressed because one or more lines are too long

18
setup.py Normal file
View File

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from setuptools import find_packages, setup
setup(
name="cotracker",
version="1.0",
install_requires=[],
packages=find_packages(exclude="notebooks"),
extras_require={
"all": ["matplotlib", "opencv-python"],
"dev": ["flake8", "black"],
},
)

665
train.py Normal file
View File

@ -0,0 +1,665 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import random
import torch
import signal
import socket
import sys
import json
import numpy as np
import argparse
import logging
from pathlib import Path
from tqdm import tqdm
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning.lite import LightningLite
from cotracker.models.evaluation_predictor import EvaluationPredictor
from cotracker.models.core.cotracker.cotracker import CoTracker
from cotracker.utils.visualizer import Visualizer
from cotracker.datasets.tap_vid_datasets import TapVidDataset
from cotracker.datasets.badja_dataset import BadjaDataset
from cotracker.datasets.fast_capture_dataset import FastCaptureDataset
from cotracker.evaluation.core.evaluator import Evaluator
from cotracker.datasets import kubric_movif_dataset
from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_cuda_
from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss
# define the handler function
# for training on a slurm cluster
def sig_handler(signum, frame):
print("caught signal", signum)
print(socket.gethostname(), "USR1 signal caught.")
# do other stuff to cleanup here
print("requeuing job " + os.environ["SLURM_JOB_ID"])
os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
sys.exit(-1)
def term_handler(signum, frame):
print("bypassing sigterm", flush=True)
def fetch_optimizer(args, model):
"""Create the optimizer and learning rate scheduler"""
optimizer = optim.AdamW(
model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8
)
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
args.lr,
args.num_steps + 100,
pct_start=0.05,
cycle_momentum=False,
anneal_strategy="linear",
)
return optimizer, scheduler
def forward_batch(batch, model, args, loss_fn=None, writer=None, step=0):
rgbs = batch.video
trajs_g = batch.trajectory
vis_g = batch.visibility
valids = batch.valid
B, T, C, H, W = rgbs.shape
assert C == 3
B, T, N, D = trajs_g.shape
device = rgbs.device
__, first_positive_inds = torch.max(vis_g, dim=1)
# We want to make sure that during training the model sees visible points
# that it does not need to track just yet: they are visible but queried from a later frame
N_rand = N // 4
# inds of visible points in the 1st frame
nonzero_inds = [torch.nonzero(vis_g[0, :, i]) for i in range(N)]
rand_vis_inds = torch.cat(
[
nonzero_row[torch.randint(len(nonzero_row), size=(1,))]
for nonzero_row in nonzero_inds
],
dim=1,
)
first_positive_inds = torch.cat(
[rand_vis_inds[:, :N_rand], first_positive_inds[:, N_rand:]], dim=1
)
ind_array_ = torch.arange(T, device=device)
ind_array_ = ind_array_[None, :, None].repeat(B, 1, N)
assert torch.allclose(
vis_g[ind_array_ == first_positive_inds[:, None, :]],
torch.ones_like(vis_g),
)
assert torch.allclose(
vis_g[ind_array_ == rand_vis_inds[:, None, :]], torch.ones_like(vis_g)
)
gather = torch.gather(
trajs_g, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 2)
)
xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1)
queries = torch.cat([first_positive_inds[:, :, None], xys], dim=2)
predictions, __, visibility, train_data = model(
rgbs=rgbs, queries=queries, iters=args.train_iters, is_train=True
)
vis_predictions, coord_predictions, wind_inds, sort_inds = train_data
trajs_g = trajs_g[:, :, sort_inds]
vis_g = vis_g[:, :, sort_inds]
valids = valids[:, :, sort_inds]
vis_gts = []
traj_gts = []
valids_gts = []
for i, wind_idx in enumerate(wind_inds):
ind = i * (args.sliding_window_len // 2)
vis_gts.append(vis_g[:, ind : ind + args.sliding_window_len, :wind_idx])
traj_gts.append(trajs_g[:, ind : ind + args.sliding_window_len, :wind_idx])
valids_gts.append(valids[:, ind : ind + args.sliding_window_len, :wind_idx])
seq_loss = sequence_loss(coord_predictions, traj_gts, vis_gts, valids_gts, 0.8)
vis_loss = balanced_ce_loss(vis_predictions, vis_gts, valids_gts)
output = {"flow": {"predictions": predictions[0].detach()}}
output["flow"]["loss"] = seq_loss.mean()
output["visibility"] = {
"loss": vis_loss.mean() * 10.0,
"predictions": visibility[0].detach(),
}
return output
def run_test_eval(evaluator, model, dataloaders, writer, step):
model.eval()
for ds_name, dataloader in dataloaders:
predictor = EvaluationPredictor(
model.module.module,
grid_size=6,
local_grid_size=0,
single_point=False,
n_iters=6,
)
metrics = evaluator.evaluate_sequence(
model=predictor,
test_dataloader=dataloader,
dataset_name=ds_name,
train_mode=True,
writer=writer,
step=step,
)
if ds_name == "badja" or ds_name == "fastcapture" or ("kubric" in ds_name):
metrics = {
**{
f"{ds_name}_avg": np.mean(
[v for k, v in metrics.items() if "accuracy" not in k]
)
},
**{
f"{ds_name}_avg_accuracy": np.mean(
[v for k, v in metrics.items() if "accuracy" in k]
)
},
}
print("avg", np.mean([v for v in metrics.values()]))
if "tapvid" in ds_name:
metrics = {
f"{ds_name}_avg_OA": metrics["avg"]["occlusion_accuracy"] * 100,
f"{ds_name}_avg_delta": metrics["avg"]["average_pts_within_thresh"]
* 100,
f"{ds_name}_avg_Jaccard": metrics["avg"]["average_jaccard"] * 100,
}
writer.add_scalars(f"Eval", metrics, step)
class Logger:
SUM_FREQ = 100
def __init__(self, model, scheduler):
self.model = model
self.scheduler = scheduler
self.total_steps = 0
self.running_loss = {}
self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs"))
def _print_training_status(self):
metrics_data = [
self.running_loss[k] / Logger.SUM_FREQ
for k in sorted(self.running_loss.keys())
]
training_str = "[{:6d}] ".format(self.total_steps + 1)
metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data)
# print the training status
logging.info(
f"Training Metrics ({self.total_steps}): {training_str + metrics_str}"
)
if self.writer is None:
self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs"))
for k in self.running_loss:
self.writer.add_scalar(
k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps
)
self.running_loss[k] = 0.0
def push(self, metrics, task):
self.total_steps += 1
for key in metrics:
task_key = str(key) + "_" + task
if task_key not in self.running_loss:
self.running_loss[task_key] = 0.0
self.running_loss[task_key] += metrics[key]
if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ - 1:
self._print_training_status()
self.running_loss = {}
def write_dict(self, results):
if self.writer is None:
self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs"))
for key in results:
self.writer.add_scalar(key, results[key], self.total_steps)
def close(self):
self.writer.close()
class Lite(LightningLite):
def run(self, args):
def seed_everything(seed: int):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(0)
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2 ** 32
np.random.seed(worker_seed)
random.seed(worker_seed)
g = torch.Generator()
g.manual_seed(0)
eval_dataloaders = []
if "badja" in args.eval_datasets:
eval_dataset = BadjaDataset(
data_root=os.path.join(args.dataset_root, "BADJA"),
max_seq_len=args.eval_max_seq_len,
dataset_resolution=args.crop_size,
)
eval_dataloader_badja = torch.utils.data.DataLoader(
eval_dataset,
batch_size=1,
shuffle=False,
num_workers=8,
collate_fn=collate_fn,
)
eval_dataloaders.append(("badja", eval_dataloader_badja))
if "fastcapture" in args.eval_datasets:
eval_dataset = FastCaptureDataset(
data_root=os.path.join(args.dataset_root, "fastcapture"),
max_seq_len=min(100, args.eval_max_seq_len),
max_num_points=40,
dataset_resolution=args.crop_size,
)
eval_dataloader_fastcapture = torch.utils.data.DataLoader(
eval_dataset,
batch_size=1,
shuffle=False,
num_workers=1,
collate_fn=collate_fn,
)
eval_dataloaders.append(("fastcapture", eval_dataloader_fastcapture))
if "tapvid_davis_first" in args.eval_datasets:
data_root = os.path.join(
args.dataset_root, "/tapvid_davis/tapvid_davis.pkl"
)
eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root)
eval_dataloader_tapvid_davis = torch.utils.data.DataLoader(
eval_dataset,
batch_size=1,
shuffle=False,
num_workers=1,
collate_fn=collate_fn,
)
eval_dataloaders.append(("tapvid_davis", eval_dataloader_tapvid_davis))
evaluator = Evaluator(args.ckpt_path)
visualizer = Visualizer(
save_dir=args.ckpt_path,
pad_value=80,
fps=1,
show_first_frame=0,
tracks_leave_trace=0,
)
loss_fn = None
if args.model_name == "cotracker":
model = CoTracker(
stride=args.model_stride,
S=args.sliding_window_len,
add_space_attn=not args.remove_space_attn,
num_heads=args.updateformer_num_heads,
hidden_size=args.updateformer_hidden_size,
space_depth=args.updateformer_space_depth,
time_depth=args.updateformer_time_depth,
)
else:
raise ValueError(f"Model {args.model_name} doesn't exist")
with open(args.ckpt_path + "/meta.json", "w") as file:
json.dump(vars(args), file, sort_keys=True, indent=4)
model.cuda()
train_dataset = kubric_movif_dataset.KubricMovifDataset(
data_root=os.path.join(args.dataset_root, "kubric_movi_f"),
crop_size=args.crop_size,
seq_len=args.sequence_len,
traj_per_sample=args.traj_per_sample,
sample_vis_1st_frame=args.sample_vis_1st_frame,
use_augs=not args.dont_use_augs,
)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
worker_init_fn=seed_worker,
generator=g,
pin_memory=True,
collate_fn=collate_fn_train,
drop_last=True,
)
train_loader = self.setup_dataloaders(train_loader, move_to_device=False)
print("LEN TRAIN LOADER", len(train_loader))
optimizer, scheduler = fetch_optimizer(args, model)
total_steps = 0
logger = Logger(model, scheduler)
folder_ckpts = [
f
for f in os.listdir(args.ckpt_path)
if not os.path.isdir(f) and f.endswith(".pth") and not "final" in f
]
if len(folder_ckpts) > 0:
ckpt_path = sorted(folder_ckpts)[-1]
ckpt = self.load(os.path.join(args.ckpt_path, ckpt_path))
logging.info(f"Loading checkpoint {ckpt_path}")
if "model" in ckpt:
model.load_state_dict(ckpt["model"])
else:
model.load_state_dict(ckpt)
if "optimizer" in ckpt:
logging.info("Load optimizer")
optimizer.load_state_dict(ckpt["optimizer"])
if "scheduler" in ckpt:
logging.info("Load scheduler")
scheduler.load_state_dict(ckpt["scheduler"])
if "total_steps" in ckpt:
total_steps = ckpt["total_steps"]
logging.info(f"Load total_steps {total_steps}")
elif args.restore_ckpt is not None:
assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith(
".pt"
)
logging.info("Loading checkpoint...")
strict = True
state_dict = self.load(args.restore_ckpt)
if "model" in state_dict:
state_dict = state_dict["model"]
if list(state_dict.keys())[0].startswith("module."):
state_dict = {
k.replace("module.", ""): v for k, v in state_dict.items()
}
model.load_state_dict(state_dict, strict=strict)
logging.info(f"Done loading checkpoint")
model, optimizer = self.setup(model, optimizer, move_to_device=False)
# model.cuda()
model.train()
save_freq = args.save_freq
scaler = GradScaler(enabled=args.mixed_precision)
should_keep_training = True
global_batch_num = 0
epoch = -1
while should_keep_training:
epoch += 1
for i_batch, batch in enumerate(tqdm(train_loader)):
batch, gotit = batch
if not all(gotit):
print("batch is None")
continue
dataclass_to_cuda_(batch)
optimizer.zero_grad()
assert model.training
output = forward_batch(
batch,
model,
args,
loss_fn=loss_fn,
writer=logger.writer,
step=total_steps,
)
loss = 0
for k, v in output.items():
if "loss" in v:
loss += v["loss"]
logger.writer.add_scalar(
f"live_{k}_loss", v["loss"].item(), total_steps
)
if "metrics" in v:
logger.push(v["metrics"], k)
if self.global_rank == 0:
if total_steps % save_freq == save_freq - 1:
if args.model_name == "motion_diffuser":
pred_coords = model.module.module.forward_batch_test(
batch, interp_shape=args.crop_size
)
output["flow"] = {"predictions": pred_coords[0].detach()}
visualizer.visualize(
video=batch.video.clone(),
tracks=batch.trajectory.clone(),
filename="train_gt_traj",
writer=logger.writer,
step=total_steps,
)
visualizer.visualize(
video=batch.video.clone(),
tracks=output["flow"]["predictions"][None],
filename="train_pred_traj",
writer=logger.writer,
step=total_steps,
)
if len(output) > 1:
logger.writer.add_scalar(
f"live_total_loss", loss.item(), total_steps
)
logger.writer.add_scalar(
f"learning_rate", optimizer.param_groups[0]["lr"], total_steps
)
global_batch_num += 1
self.barrier()
self.backward(scaler.scale(loss))
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0)
scaler.step(optimizer)
scheduler.step()
scaler.update()
total_steps += 1
if self.global_rank == 0:
if (i_batch >= len(train_loader) - 1) or (
total_steps == 1 and args.validate_at_start
):
if (epoch + 1) % args.save_every_n_epoch == 0:
ckpt_iter = "0" * (6 - len(str(total_steps))) + str(
total_steps
)
save_path = Path(
f"{args.ckpt_path}/model_{args.model_name}_{ckpt_iter}.pth"
)
save_dict = {
"model": model.module.module.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"total_steps": total_steps,
}
logging.info(f"Saving file {save_path}")
self.save(save_dict, save_path)
if (epoch + 1) % args.evaluate_every_n_epoch == 0 or (
args.validate_at_start and epoch == 0
):
run_test_eval(
evaluator,
model,
eval_dataloaders,
logger.writer,
total_steps,
)
model.train()
torch.cuda.empty_cache()
self.barrier()
if total_steps > args.num_steps:
should_keep_training = False
break
print("FINISHED TRAINING")
PATH = f"{args.ckpt_path}/{args.model_name}_final.pth"
torch.save(model.module.module.state_dict(), PATH)
run_test_eval(evaluator, model, eval_dataloaders, logger.writer, total_steps)
logger.close()
if __name__ == "__main__":
signal.signal(signal.SIGUSR1, sig_handler)
signal.signal(signal.SIGTERM, term_handler)
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="cotracker", help="model name")
parser.add_argument("--restore_ckpt", help="restore checkpoint")
parser.add_argument("--ckpt_path", help="restore checkpoint")
parser.add_argument(
"--batch_size", type=int, default=4, help="batch size used during training."
)
parser.add_argument(
"--num_workers", type=int, default=6, help="left right consistency loss"
)
parser.add_argument(
"--mixed_precision", action="store_true", help="use mixed precision"
)
parser.add_argument("--lr", type=float, default=0.0005, help="max learning rate.")
parser.add_argument(
"--wdecay", type=float, default=0.00001, help="Weight decay in optimizer."
)
parser.add_argument(
"--num_steps", type=int, default=200000, help="length of training schedule."
)
parser.add_argument(
"--evaluate_every_n_epoch",
type=int,
default=1,
help="number of flow-field updates during validation forward pass",
)
parser.add_argument(
"--save_every_n_epoch",
type=int,
default=1,
help="number of flow-field updates during validation forward pass",
)
parser.add_argument(
"--validate_at_start", action="store_true", help="use mixed precision"
)
parser.add_argument("--save_freq", type=int, default=100, help="save_freq")
parser.add_argument("--traj_per_sample", type=int, default=768, help="save_freq")
parser.add_argument("--dataset_root", type=str, help="path lo all the datasets")
parser.add_argument(
"--train_iters",
type=int,
default=4,
help="number of updates to the disparity field in each forward pass.",
)
parser.add_argument(
"--sequence_len", type=int, default=8, help="train sequence length"
)
parser.add_argument(
"--eval_datasets",
nargs="+",
default=["things", "badja", "fastcapture"],
help="eval datasets.",
)
parser.add_argument(
"--remove_space_attn", action="store_true", help="use mixed precision"
)
parser.add_argument(
"--dont_use_augs", action="store_true", help="use mixed precision"
)
parser.add_argument(
"--sample_vis_1st_frame", action="store_true", help="use mixed precision"
)
parser.add_argument(
"--sliding_window_len", type=int, default=8, help="use mixed precision"
)
parser.add_argument(
"--updateformer_hidden_size", type=int, default=384, help="use mixed precision"
)
parser.add_argument(
"--updateformer_num_heads", type=int, default=8, help="use mixed precision"
)
parser.add_argument(
"--updateformer_space_depth", type=int, default=12, help="use mixed precision"
)
parser.add_argument(
"--updateformer_time_depth", type=int, default=12, help="use mixed precision"
)
parser.add_argument(
"--model_stride", type=int, default=8, help="use mixed precision"
)
parser.add_argument(
"--crop_size",
type=int,
nargs="+",
default=[384, 512],
help="use mixed precision",
)
parser.add_argument(
"--eval_max_seq_len", type=int, default=1000, help="use mixed precision"
)
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
)
Path(args.ckpt_path).mkdir(exist_ok=True, parents=True)
from pytorch_lightning.strategies import DDPStrategy
Lite(
strategy=DDPStrategy(find_unused_parameters=True),
devices="auto",
accelerator="gpu",
precision=32,
num_nodes=4,
).run(args)