Initial commit
This commit is contained in:
commit
6d62d873fa
80
CODE_OF_CONDUCT.md
Normal file
80
CODE_OF_CONDUCT.md
Normal file
@ -0,0 +1,80 @@
|
||||
# Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
In the interest of fostering an open and welcoming environment, we as
|
||||
contributors and maintainers pledge to make participation in our project and
|
||||
our community a harassment-free experience for everyone, regardless of age, body
|
||||
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
||||
level of experience, education, socio-economic status, nationality, personal
|
||||
appearance, race, religion, or sexual identity and orientation.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to creating a positive environment
|
||||
include:
|
||||
|
||||
* Using welcoming and inclusive language
|
||||
* Being respectful of differing viewpoints and experiences
|
||||
* Gracefully accepting constructive criticism
|
||||
* Focusing on what is best for the community
|
||||
* Showing empathy towards other community members
|
||||
|
||||
Examples of unacceptable behavior by participants include:
|
||||
|
||||
* The use of sexualized language or imagery and unwelcome sexual attention or
|
||||
advances
|
||||
* Trolling, insulting/derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or electronic
|
||||
address, without explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Our Responsibilities
|
||||
|
||||
Project maintainers are responsible for clarifying the standards of acceptable
|
||||
behavior and are expected to take appropriate and fair corrective action in
|
||||
response to any instances of unacceptable behavior.
|
||||
|
||||
Project maintainers have the right and responsibility to remove, edit, or
|
||||
reject comments, commits, code, wiki edits, issues, and other contributions
|
||||
that are not aligned to this Code of Conduct, or to ban temporarily or
|
||||
permanently any contributor for other behaviors that they deem inappropriate,
|
||||
threatening, offensive, or harmful.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all project spaces, and it also applies when
|
||||
an individual is representing the project or its community in public spaces.
|
||||
Examples of representing a project or community include using an official
|
||||
project e-mail address, posting via an official social media account, or acting
|
||||
as an appointed representative at an online or offline event. Representation of
|
||||
a project may be further defined and clarified by project maintainers.
|
||||
|
||||
This Code of Conduct also applies outside the project spaces when there is a
|
||||
reasonable belief that an individual's behavior may have a negative impact on
|
||||
the project or its community.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported by contacting the project team at <opensource-conduct@fb.com>. All
|
||||
complaints will be reviewed and investigated and will result in a response that
|
||||
is deemed necessary and appropriate to the circumstances. The project team is
|
||||
obligated to maintain confidentiality with regard to the reporter of an incident.
|
||||
Further details of specific enforcement policies may be posted separately.
|
||||
|
||||
Project maintainers who do not follow or enforce the Code of Conduct in good
|
||||
faith may face temporary or permanent repercussions as determined by other
|
||||
members of the project's leadership.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
||||
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
|
||||
For answers to common questions about this code of conduct, see
|
||||
https://www.contributor-covenant.org/faq
|
28
CONTRIBUTING.md
Normal file
28
CONTRIBUTING.md
Normal file
@ -0,0 +1,28 @@
|
||||
# CoTracker
|
||||
We want to make contributing to this project as easy and transparent as possible.
|
||||
|
||||
## Pull Requests
|
||||
We actively welcome your pull requests.
|
||||
|
||||
1. Fork the repo and create your branch from `main`.
|
||||
2. If you've changed APIs, update the documentation.
|
||||
3. Make sure your code lints.
|
||||
4. If you haven't already, complete the Contributor License Agreement ("CLA").
|
||||
|
||||
## Contributor License Agreement ("CLA")
|
||||
In order to accept your pull request, we need you to submit a CLA. You only need
|
||||
to do this once to work on any of Meta's open source projects.
|
||||
|
||||
Complete your CLA here: <https://code.facebook.com/cla>
|
||||
|
||||
## Issues
|
||||
We use GitHub issues to track public bugs. Please ensure your description is
|
||||
clear and has sufficient instructions to be able to reproduce the issue.
|
||||
|
||||
Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
||||
disclosure of security bugs. In those cases, please go through the process
|
||||
outlined on that page and do not file a public issue.
|
||||
|
||||
## License
|
||||
By contributing to CoTracker, you agree that your contributions will be licensed
|
||||
under the LICENSE file in the root directory of this source tree.
|
399
LICENSE.md
Normal file
399
LICENSE.md
Normal file
@ -0,0 +1,399 @@
|
||||
Attribution-NonCommercial 4.0 International
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
||||
does not provide legal services or legal advice. Distribution of
|
||||
Creative Commons public licenses does not create a lawyer-client or
|
||||
other relationship. Creative Commons makes its licenses and related
|
||||
information available on an "as-is" basis. Creative Commons gives no
|
||||
warranties regarding its licenses, any material licensed under their
|
||||
terms and conditions, or any related information. Creative Commons
|
||||
disclaims all liability for damages resulting from their use to the
|
||||
fullest extent possible.
|
||||
|
||||
Using Creative Commons Public Licenses
|
||||
|
||||
Creative Commons public licenses provide a standard set of terms and
|
||||
conditions that creators and other rights holders may use to share
|
||||
original works of authorship and other material subject to copyright
|
||||
and certain other rights specified in the public license below. The
|
||||
following considerations are for informational purposes only, are not
|
||||
exhaustive, and do not form part of our licenses.
|
||||
|
||||
Considerations for licensors: Our public licenses are
|
||||
intended for use by those authorized to give the public
|
||||
permission to use material in ways otherwise restricted by
|
||||
copyright and certain other rights. Our licenses are
|
||||
irrevocable. Licensors should read and understand the terms
|
||||
and conditions of the license they choose before applying it.
|
||||
Licensors should also secure all rights necessary before
|
||||
applying our licenses so that the public can reuse the
|
||||
material as expected. Licensors should clearly mark any
|
||||
material not subject to the license. This includes other CC-
|
||||
licensed material, or material used under an exception or
|
||||
limitation to copyright. More considerations for licensors:
|
||||
wiki.creativecommons.org/Considerations_for_licensors
|
||||
|
||||
Considerations for the public: By using one of our public
|
||||
licenses, a licensor grants the public permission to use the
|
||||
licensed material under specified terms and conditions. If
|
||||
the licensor's permission is not necessary for any reason--for
|
||||
example, because of any applicable exception or limitation to
|
||||
copyright--then that use is not regulated by the license. Our
|
||||
licenses grant only permissions under copyright and certain
|
||||
other rights that a licensor has authority to grant. Use of
|
||||
the licensed material may still be restricted for other
|
||||
reasons, including because others have copyright or other
|
||||
rights in the material. A licensor may make special requests,
|
||||
such as asking that all changes be marked or described.
|
||||
Although not required by our licenses, you are encouraged to
|
||||
respect those requests where reasonable. More_considerations
|
||||
for the public:
|
||||
wiki.creativecommons.org/Considerations_for_licensees
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons Attribution-NonCommercial 4.0 International Public
|
||||
License
|
||||
|
||||
By exercising the Licensed Rights (defined below), You accept and agree
|
||||
to be bound by the terms and conditions of this Creative Commons
|
||||
Attribution-NonCommercial 4.0 International Public License ("Public
|
||||
License"). To the extent this Public License may be interpreted as a
|
||||
contract, You are granted the Licensed Rights in consideration of Your
|
||||
acceptance of these terms and conditions, and the Licensor grants You
|
||||
such rights in consideration of benefits the Licensor receives from
|
||||
making the Licensed Material available under these terms and
|
||||
conditions.
|
||||
|
||||
Section 1 -- Definitions.
|
||||
|
||||
a. Adapted Material means material subject to Copyright and Similar
|
||||
Rights that is derived from or based upon the Licensed Material
|
||||
and in which the Licensed Material is translated, altered,
|
||||
arranged, transformed, or otherwise modified in a manner requiring
|
||||
permission under the Copyright and Similar Rights held by the
|
||||
Licensor. For purposes of this Public License, where the Licensed
|
||||
Material is a musical work, performance, or sound recording,
|
||||
Adapted Material is always produced where the Licensed Material is
|
||||
synched in timed relation with a moving image.
|
||||
|
||||
b. Adapter's License means the license You apply to Your Copyright
|
||||
and Similar Rights in Your contributions to Adapted Material in
|
||||
accordance with the terms and conditions of this Public License.
|
||||
|
||||
c. Copyright and Similar Rights means copyright and/or similar rights
|
||||
closely related to copyright including, without limitation,
|
||||
performance, broadcast, sound recording, and Sui Generis Database
|
||||
Rights, without regard to how the rights are labeled or
|
||||
categorized. For purposes of this Public License, the rights
|
||||
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
||||
Rights.
|
||||
d. Effective Technological Measures means those measures that, in the
|
||||
absence of proper authority, may not be circumvented under laws
|
||||
fulfilling obligations under Article 11 of the WIPO Copyright
|
||||
Treaty adopted on December 20, 1996, and/or similar international
|
||||
agreements.
|
||||
|
||||
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
||||
any other exception or limitation to Copyright and Similar Rights
|
||||
that applies to Your use of the Licensed Material.
|
||||
|
||||
f. Licensed Material means the artistic or literary work, database,
|
||||
or other material to which the Licensor applied this Public
|
||||
License.
|
||||
|
||||
g. Licensed Rights means the rights granted to You subject to the
|
||||
terms and conditions of this Public License, which are limited to
|
||||
all Copyright and Similar Rights that apply to Your use of the
|
||||
Licensed Material and that the Licensor has authority to license.
|
||||
|
||||
h. Licensor means the individual(s) or entity(ies) granting rights
|
||||
under this Public License.
|
||||
|
||||
i. NonCommercial means not primarily intended for or directed towards
|
||||
commercial advantage or monetary compensation. For purposes of
|
||||
this Public License, the exchange of the Licensed Material for
|
||||
other material subject to Copyright and Similar Rights by digital
|
||||
file-sharing or similar means is NonCommercial provided there is
|
||||
no payment of monetary compensation in connection with the
|
||||
exchange.
|
||||
|
||||
j. Share means to provide material to the public by any means or
|
||||
process that requires permission under the Licensed Rights, such
|
||||
as reproduction, public display, public performance, distribution,
|
||||
dissemination, communication, or importation, and to make material
|
||||
available to the public including in ways that members of the
|
||||
public may access the material from a place and at a time
|
||||
individually chosen by them.
|
||||
|
||||
k. Sui Generis Database Rights means rights other than copyright
|
||||
resulting from Directive 96/9/EC of the European Parliament and of
|
||||
the Council of 11 March 1996 on the legal protection of databases,
|
||||
as amended and/or succeeded, as well as other essentially
|
||||
equivalent rights anywhere in the world.
|
||||
|
||||
l. You means the individual or entity exercising the Licensed Rights
|
||||
under this Public License. Your has a corresponding meaning.
|
||||
|
||||
Section 2 -- Scope.
|
||||
|
||||
a. License grant.
|
||||
|
||||
1. Subject to the terms and conditions of this Public License,
|
||||
the Licensor hereby grants You a worldwide, royalty-free,
|
||||
non-sublicensable, non-exclusive, irrevocable license to
|
||||
exercise the Licensed Rights in the Licensed Material to:
|
||||
|
||||
a. reproduce and Share the Licensed Material, in whole or
|
||||
in part, for NonCommercial purposes only; and
|
||||
|
||||
b. produce, reproduce, and Share Adapted Material for
|
||||
NonCommercial purposes only.
|
||||
|
||||
2. Exceptions and Limitations. For the avoidance of doubt, where
|
||||
Exceptions and Limitations apply to Your use, this Public
|
||||
License does not apply, and You do not need to comply with
|
||||
its terms and conditions.
|
||||
|
||||
3. Term. The term of this Public License is specified in Section
|
||||
6(a).
|
||||
|
||||
4. Media and formats; technical modifications allowed. The
|
||||
Licensor authorizes You to exercise the Licensed Rights in
|
||||
all media and formats whether now known or hereafter created,
|
||||
and to make technical modifications necessary to do so. The
|
||||
Licensor waives and/or agrees not to assert any right or
|
||||
authority to forbid You from making technical modifications
|
||||
necessary to exercise the Licensed Rights, including
|
||||
technical modifications necessary to circumvent Effective
|
||||
Technological Measures. For purposes of this Public License,
|
||||
simply making modifications authorized by this Section 2(a)
|
||||
(4) never produces Adapted Material.
|
||||
|
||||
5. Downstream recipients.
|
||||
|
||||
a. Offer from the Licensor -- Licensed Material. Every
|
||||
recipient of the Licensed Material automatically
|
||||
receives an offer from the Licensor to exercise the
|
||||
Licensed Rights under the terms and conditions of this
|
||||
Public License.
|
||||
|
||||
b. No downstream restrictions. You may not offer or impose
|
||||
any additional or different terms or conditions on, or
|
||||
apply any Effective Technological Measures to, the
|
||||
Licensed Material if doing so restricts exercise of the
|
||||
Licensed Rights by any recipient of the Licensed
|
||||
Material.
|
||||
|
||||
6. No endorsement. Nothing in this Public License constitutes or
|
||||
may be construed as permission to assert or imply that You
|
||||
are, or that Your use of the Licensed Material is, connected
|
||||
with, or sponsored, endorsed, or granted official status by,
|
||||
the Licensor or others designated to receive attribution as
|
||||
provided in Section 3(a)(1)(A)(i).
|
||||
|
||||
b. Other rights.
|
||||
|
||||
1. Moral rights, such as the right of integrity, are not
|
||||
licensed under this Public License, nor are publicity,
|
||||
privacy, and/or other similar personality rights; however, to
|
||||
the extent possible, the Licensor waives and/or agrees not to
|
||||
assert any such rights held by the Licensor to the limited
|
||||
extent necessary to allow You to exercise the Licensed
|
||||
Rights, but not otherwise.
|
||||
|
||||
2. Patent and trademark rights are not licensed under this
|
||||
Public License.
|
||||
|
||||
3. To the extent possible, the Licensor waives any right to
|
||||
collect royalties from You for the exercise of the Licensed
|
||||
Rights, whether directly or through a collecting society
|
||||
under any voluntary or waivable statutory or compulsory
|
||||
licensing scheme. In all other cases the Licensor expressly
|
||||
reserves any right to collect such royalties, including when
|
||||
the Licensed Material is used other than for NonCommercial
|
||||
purposes.
|
||||
|
||||
Section 3 -- License Conditions.
|
||||
|
||||
Your exercise of the Licensed Rights is expressly made subject to the
|
||||
following conditions.
|
||||
|
||||
a. Attribution.
|
||||
|
||||
1. If You Share the Licensed Material (including in modified
|
||||
form), You must:
|
||||
|
||||
a. retain the following if it is supplied by the Licensor
|
||||
with the Licensed Material:
|
||||
|
||||
i. identification of the creator(s) of the Licensed
|
||||
Material and any others designated to receive
|
||||
attribution, in any reasonable manner requested by
|
||||
the Licensor (including by pseudonym if
|
||||
designated);
|
||||
|
||||
ii. a copyright notice;
|
||||
|
||||
iii. a notice that refers to this Public License;
|
||||
|
||||
iv. a notice that refers to the disclaimer of
|
||||
warranties;
|
||||
|
||||
v. a URI or hyperlink to the Licensed Material to the
|
||||
extent reasonably practicable;
|
||||
|
||||
b. indicate if You modified the Licensed Material and
|
||||
retain an indication of any previous modifications; and
|
||||
|
||||
c. indicate the Licensed Material is licensed under this
|
||||
Public License, and include the text of, or the URI or
|
||||
hyperlink to, this Public License.
|
||||
|
||||
2. You may satisfy the conditions in Section 3(a)(1) in any
|
||||
reasonable manner based on the medium, means, and context in
|
||||
which You Share the Licensed Material. For example, it may be
|
||||
reasonable to satisfy the conditions by providing a URI or
|
||||
hyperlink to a resource that includes the required
|
||||
information.
|
||||
|
||||
3. If requested by the Licensor, You must remove any of the
|
||||
information required by Section 3(a)(1)(A) to the extent
|
||||
reasonably practicable.
|
||||
|
||||
4. If You Share Adapted Material You produce, the Adapter's
|
||||
License You apply must not prevent recipients of the Adapted
|
||||
Material from complying with this Public License.
|
||||
|
||||
Section 4 -- Sui Generis Database Rights.
|
||||
|
||||
Where the Licensed Rights include Sui Generis Database Rights that
|
||||
apply to Your use of the Licensed Material:
|
||||
|
||||
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
||||
to extract, reuse, reproduce, and Share all or a substantial
|
||||
portion of the contents of the database for NonCommercial purposes
|
||||
only;
|
||||
|
||||
b. if You include all or a substantial portion of the database
|
||||
contents in a database in which You have Sui Generis Database
|
||||
Rights, then the database in which You have Sui Generis Database
|
||||
Rights (but not its individual contents) is Adapted Material; and
|
||||
|
||||
c. You must comply with the conditions in Section 3(a) if You Share
|
||||
all or a substantial portion of the contents of the database.
|
||||
|
||||
For the avoidance of doubt, this Section 4 supplements and does not
|
||||
replace Your obligations under this Public License where the Licensed
|
||||
Rights include other Copyright and Similar Rights.
|
||||
|
||||
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
||||
|
||||
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
||||
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
||||
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
||||
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
||||
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
||||
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
||||
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
||||
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
||||
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
||||
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
||||
|
||||
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
||||
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
||||
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
||||
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
||||
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
||||
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
||||
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
||||
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
||||
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
||||
|
||||
c. The disclaimer of warranties and limitation of liability provided
|
||||
above shall be interpreted in a manner that, to the extent
|
||||
possible, most closely approximates an absolute disclaimer and
|
||||
waiver of all liability.
|
||||
|
||||
Section 6 -- Term and Termination.
|
||||
|
||||
a. This Public License applies for the term of the Copyright and
|
||||
Similar Rights licensed here. However, if You fail to comply with
|
||||
this Public License, then Your rights under this Public License
|
||||
terminate automatically.
|
||||
|
||||
b. Where Your right to use the Licensed Material has terminated under
|
||||
Section 6(a), it reinstates:
|
||||
|
||||
1. automatically as of the date the violation is cured, provided
|
||||
it is cured within 30 days of Your discovery of the
|
||||
violation; or
|
||||
|
||||
2. upon express reinstatement by the Licensor.
|
||||
|
||||
For the avoidance of doubt, this Section 6(b) does not affect any
|
||||
right the Licensor may have to seek remedies for Your violations
|
||||
of this Public License.
|
||||
|
||||
c. For the avoidance of doubt, the Licensor may also offer the
|
||||
Licensed Material under separate terms or conditions or stop
|
||||
distributing the Licensed Material at any time; however, doing so
|
||||
will not terminate this Public License.
|
||||
|
||||
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
||||
License.
|
||||
|
||||
Section 7 -- Other Terms and Conditions.
|
||||
|
||||
a. The Licensor shall not be bound by any additional or different
|
||||
terms or conditions communicated by You unless expressly agreed.
|
||||
|
||||
b. Any arrangements, understandings, or agreements regarding the
|
||||
Licensed Material not stated herein are separate from and
|
||||
independent of the terms and conditions of this Public License.
|
||||
|
||||
Section 8 -- Interpretation.
|
||||
|
||||
a. For the avoidance of doubt, this Public License does not, and
|
||||
shall not be interpreted to, reduce, limit, restrict, or impose
|
||||
conditions on any use of the Licensed Material that could lawfully
|
||||
be made without permission under this Public License.
|
||||
|
||||
b. To the extent possible, if any provision of this Public License is
|
||||
deemed unenforceable, it shall be automatically reformed to the
|
||||
minimum extent necessary to make it enforceable. If the provision
|
||||
cannot be reformed, it shall be severed from this Public License
|
||||
without affecting the enforceability of the remaining terms and
|
||||
conditions.
|
||||
|
||||
c. No term or condition of this Public License will be waived and no
|
||||
failure to comply consented to unless expressly agreed to by the
|
||||
Licensor.
|
||||
|
||||
d. Nothing in this Public License constitutes or may be interpreted
|
||||
as a limitation upon, or waiver of, any privileges and immunities
|
||||
that apply to the Licensor or You, including from the legal
|
||||
processes of any jurisdiction or authority.
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons is not a party to its public
|
||||
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
||||
its public licenses to material it publishes and in those instances
|
||||
will be considered the “Licensor.” The text of the Creative Commons
|
||||
public licenses is dedicated to the public domain under the CC0 Public
|
||||
Domain Dedication. Except for the limited purpose of indicating that
|
||||
material is shared under a Creative Commons public license or as
|
||||
otherwise permitted by the Creative Commons policies published at
|
||||
creativecommons.org/policies, Creative Commons does not authorize the
|
||||
use of the trademark "Creative Commons" or any other trademark or logo
|
||||
of Creative Commons without its prior written consent including,
|
||||
without limitation, in connection with any unauthorized modifications
|
||||
to any of its public licenses or any other arrangements,
|
||||
understandings, or agreements concerning use of licensed material. For
|
||||
the avoidance of doubt, this paragraph does not form part of the
|
||||
public licenses.
|
||||
|
||||
Creative Commons may be contacted at creativecommons.org.
|
94
README.md
Normal file
94
README.md
Normal file
@ -0,0 +1,94 @@
|
||||
# CoTracker: It is Better to Track Together
|
||||
|
||||
**[Meta AI Research, FAIR](https://ai.facebook.com/research/)**; **[University of Oxford, VGG](https://www.robots.ox.ac.uk/~vgg/)**
|
||||
|
||||
[Nikita Karaev](https://nikitakaraevv.github.io/), [Ignacio Rocco](https://www.irocco.info/), [Benjamin Graham](https://ai.facebook.com/people/benjamin-graham/), [Natalia Neverova](https://nneverova.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/)
|
||||
|
||||
[[`Paper`]()] [[`Project`](https://co-tracker.github.io/)] [[`BibTeX`](#citing-cotracker)]
|
||||
|
||||

|
||||
|
||||
**CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow.
|
||||
|
||||
CoTracker can track:
|
||||
- **Every pixel** within a video
|
||||
- Points sampled on a regular grid on any video frame
|
||||
- Manually selected points
|
||||
|
||||
Try these tracking modes for yourself with our [Colab demo](https://github.com/facebookresearch/co-tracker/notebooks/demo.ipynb).
|
||||
|
||||
|
||||
|
||||
## Installation Instructions
|
||||
Ensure you have both PyTorch and TorchVision installed on your system. Follow the instructions [here](https://pytorch.org/get-started/locally/) for the installation. We strongly recommend installing both PyTorch and TorchVision with CUDA support.
|
||||
|
||||
## Steps to Install CoTracker and its dependencies:
|
||||
```
|
||||
git clone https://github.com/facebookresearch/co-tracker
|
||||
cd co-tracker
|
||||
pip install -e .
|
||||
pip install opencv-python einops timm matplotlib moviepy flow_vis
|
||||
```
|
||||
|
||||
|
||||
## Model Weights Download:
|
||||
```
|
||||
mkdir checkpoints
|
||||
cd checkpoints
|
||||
wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth
|
||||
wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_12.pth
|
||||
wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_8_wind_16.pth
|
||||
cd ..
|
||||
```
|
||||
|
||||
|
||||
## Running the Demo:
|
||||
Try our [Colab demo](https://github.com/facebookresearch/co-tracker/notebooks/demo.ipynb) or run a local demo with 10*10 points sampled on a grid on the first frame of a video:
|
||||
```
|
||||
python demo.py --grid_size 10
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
To reproduce the results presented in the paper, download the following datasets:
|
||||
- [TAP-Vid](https://github.com/deepmind/tapnet)
|
||||
- [BADJA](https://github.com/benjiebob/BADJA)
|
||||
- [ZJU-Mocap (FastCapture)](https://arxiv.org/abs/2303.11898)
|
||||
|
||||
And install the necessary dependencies:
|
||||
```
|
||||
pip install hydra-core==1.1.0 mediapy tensorboard
|
||||
```
|
||||
Then, execute the following command to evaluate on BADJA:
|
||||
```
|
||||
python ./cotracker/evaluation/evaluate.py --config-name eval_badja exp_dir=./eval_outputs dataset_root=your/badja/path
|
||||
```
|
||||
|
||||
## Training
|
||||
To train the CoTracker as described in our paper, you first need to generate annotations for [Google Kubric](https://github.com/google-research/kubric) MOVI-f dataset. Instructions for annotation generation can be found [here](https://github.com/deepmind/tapnet).
|
||||
|
||||
Once you have the annotated dataset, you need to make sure you followed the steps for evaluation setup and install the training dependencies:
|
||||
```
|
||||
pip install pytorch_lightning==1.6.0
|
||||
```
|
||||
launch training on Kubric. Our model was trained using 32 GPUs, and you can adjust the parameters to best suit your hardware setup.
|
||||
```
|
||||
python train.py --batch_size 1 --num_workers 28 \
|
||||
--num_steps 50000 --ckpt_path ./ --model_name cotracker \
|
||||
--save_freq 200 --sequence_len 24 --eval_datasets tapvid_davis_first badja \
|
||||
--traj_per_sample 256 --sliding_window_len 8 --updateformer_space_depth 6 --updateformer_time_depth 6 \
|
||||
--save_every_n_epoch 10 --evaluate_every_n_epoch 10 --model_stride 4
|
||||
```
|
||||
|
||||
## License
|
||||
The majority of CoTracker is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Particle Video Revisited is licensed under the MIT license, TAP-Vid is licensed under the Apache 2.0 license.
|
||||
|
||||
## Citing CoTracker
|
||||
If you find our repository useful, please consider giving it a star ⭐ and citing our paper in your work:
|
||||
```
|
||||
@article{karaev2023cotracker,
|
||||
title={CoTracker: It is Better to Track Together},
|
||||
author={Nikita Karaev and Ignacio Rocco and Benjamin Graham and Natalia Neverova and Andrea Vedaldi and Christian Rupprecht},
|
||||
journal={arxiv},
|
||||
year={2023}
|
||||
}
|
||||
```
|
BIN
assets/apple.mp4
Normal file
BIN
assets/apple.mp4
Normal file
Binary file not shown.
BIN
assets/apple_mask.png
Normal file
BIN
assets/apple_mask.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 14 KiB |
BIN
assets/bmx-bumps.gif
Normal file
BIN
assets/bmx-bumps.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 7.3 MiB |
5
cotracker/__init__.py
Normal file
5
cotracker/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
5
cotracker/datasets/__init__.py
Normal file
5
cotracker/datasets/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
390
cotracker/datasets/badja_dataset.py
Normal file
390
cotracker/datasets/badja_dataset.py
Normal file
@ -0,0 +1,390 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
import json
|
||||
import imageio
|
||||
import cv2
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from cotracker.datasets.utils import CoTrackerData, resize_sample
|
||||
|
||||
IGNORE_ANIMALS = [
|
||||
# "bear.json",
|
||||
# "camel.json",
|
||||
"cat_jump.json"
|
||||
# "cows.json",
|
||||
# "dog.json",
|
||||
# "dog-agility.json",
|
||||
# "horsejump-high.json",
|
||||
# "horsejump-low.json",
|
||||
# "impala0.json",
|
||||
# "rs_dog.json"
|
||||
"tiger.json"
|
||||
]
|
||||
|
||||
|
||||
class SMALJointCatalog(Enum):
|
||||
# body_0 = 0
|
||||
# body_1 = 1
|
||||
# body_2 = 2
|
||||
# body_3 = 3
|
||||
# body_4 = 4
|
||||
# body_5 = 5
|
||||
# body_6 = 6
|
||||
# upper_right_0 = 7
|
||||
upper_right_1 = 8
|
||||
upper_right_2 = 9
|
||||
upper_right_3 = 10
|
||||
# upper_left_0 = 11
|
||||
upper_left_1 = 12
|
||||
upper_left_2 = 13
|
||||
upper_left_3 = 14
|
||||
neck_lower = 15
|
||||
# neck_upper = 16
|
||||
# lower_right_0 = 17
|
||||
lower_right_1 = 18
|
||||
lower_right_2 = 19
|
||||
lower_right_3 = 20
|
||||
# lower_left_0 = 21
|
||||
lower_left_1 = 22
|
||||
lower_left_2 = 23
|
||||
lower_left_3 = 24
|
||||
tail_0 = 25
|
||||
# tail_1 = 26
|
||||
# tail_2 = 27
|
||||
tail_3 = 28
|
||||
# tail_4 = 29
|
||||
# tail_5 = 30
|
||||
tail_6 = 31
|
||||
jaw = 32
|
||||
nose = 33 # ADDED JOINT FOR VERTEX 1863
|
||||
# chin = 34 # ADDED JOINT FOR VERTEX 26
|
||||
right_ear = 35 # ADDED JOINT FOR VERTEX 149
|
||||
left_ear = 36 # ADDED JOINT FOR VERTEX 2124
|
||||
|
||||
|
||||
class SMALJointInfo:
|
||||
def __init__(self):
|
||||
# These are the
|
||||
self.annotated_classes = np.array(
|
||||
[
|
||||
8,
|
||||
9,
|
||||
10, # upper_right
|
||||
12,
|
||||
13,
|
||||
14, # upper_left
|
||||
15, # neck
|
||||
18,
|
||||
19,
|
||||
20, # lower_right
|
||||
22,
|
||||
23,
|
||||
24, # lower_left
|
||||
25,
|
||||
28,
|
||||
31, # tail
|
||||
32,
|
||||
33, # head
|
||||
35, # right_ear
|
||||
36,
|
||||
]
|
||||
) # left_ear
|
||||
|
||||
self.annotated_markers = np.array(
|
||||
[
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_STAR,
|
||||
cv2.MARKER_TRIANGLE_DOWN,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_STAR,
|
||||
cv2.MARKER_TRIANGLE_DOWN,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_STAR,
|
||||
cv2.MARKER_TRIANGLE_DOWN,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_STAR,
|
||||
cv2.MARKER_TRIANGLE_DOWN,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_STAR,
|
||||
cv2.MARKER_TRIANGLE_DOWN,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_STAR,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_CROSS,
|
||||
]
|
||||
)
|
||||
|
||||
self.joint_regions = np.array(
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
3,
|
||||
3,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
5,
|
||||
5,
|
||||
5,
|
||||
5,
|
||||
6,
|
||||
6,
|
||||
6,
|
||||
6,
|
||||
6,
|
||||
6,
|
||||
6,
|
||||
7,
|
||||
7,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
]
|
||||
)
|
||||
|
||||
self.annotated_joint_region = self.joint_regions[self.annotated_classes]
|
||||
self.region_colors = np.array(
|
||||
[
|
||||
[250, 190, 190], # body, light pink
|
||||
[60, 180, 75], # upper_right, green
|
||||
[230, 25, 75], # upper_left, red
|
||||
[128, 0, 0], # neck, maroon
|
||||
[0, 130, 200], # lower_right, blue
|
||||
[255, 255, 25], # lower_left, yellow
|
||||
[240, 50, 230], # tail, majenta
|
||||
[245, 130, 48], # jaw / nose / chin, orange
|
||||
[29, 98, 115], # right_ear, turquoise
|
||||
[255, 153, 204],
|
||||
]
|
||||
) # left_ear, pink
|
||||
|
||||
self.joint_colors = np.array(self.region_colors)[self.annotated_joint_region]
|
||||
|
||||
|
||||
class BADJAData:
|
||||
def __init__(self, data_root, complete=False):
|
||||
annotations_path = os.path.join(data_root, "joint_annotations")
|
||||
|
||||
self.animal_dict = {}
|
||||
self.animal_count = 0
|
||||
self.smal_joint_info = SMALJointInfo()
|
||||
for __, animal_json in enumerate(sorted(os.listdir(annotations_path))):
|
||||
if animal_json not in IGNORE_ANIMALS:
|
||||
json_path = os.path.join(annotations_path, animal_json)
|
||||
with open(json_path) as json_data:
|
||||
animal_joint_data = json.load(json_data)
|
||||
|
||||
filenames = []
|
||||
segnames = []
|
||||
joints = []
|
||||
visible = []
|
||||
|
||||
first_path = animal_joint_data[0]["segmentation_path"]
|
||||
last_path = animal_joint_data[-1]["segmentation_path"]
|
||||
first_frame = first_path.split("/")[-1]
|
||||
last_frame = last_path.split("/")[-1]
|
||||
|
||||
if not "extra_videos" in first_path:
|
||||
animal = first_path.split("/")[-2]
|
||||
|
||||
first_frame_int = int(first_frame.split(".")[0])
|
||||
last_frame_int = int(last_frame.split(".")[0])
|
||||
|
||||
for fr in range(first_frame_int, last_frame_int + 1):
|
||||
ref_file_name = os.path.join(
|
||||
data_root,
|
||||
"DAVIS/JPEGImages/Full-Resolution/%s/%05d.jpg"
|
||||
% (animal, fr),
|
||||
)
|
||||
ref_seg_name = os.path.join(
|
||||
data_root,
|
||||
"DAVIS/Annotations/Full-Resolution/%s/%05d.png"
|
||||
% (animal, fr),
|
||||
)
|
||||
|
||||
foundit = False
|
||||
for ind, image_annotation in enumerate(animal_joint_data):
|
||||
file_name = os.path.join(
|
||||
data_root, image_annotation["image_path"]
|
||||
)
|
||||
seg_name = os.path.join(
|
||||
data_root, image_annotation["segmentation_path"]
|
||||
)
|
||||
|
||||
if file_name == ref_file_name:
|
||||
foundit = True
|
||||
label_ind = ind
|
||||
|
||||
if foundit:
|
||||
image_annotation = animal_joint_data[label_ind]
|
||||
file_name = os.path.join(
|
||||
data_root, image_annotation["image_path"]
|
||||
)
|
||||
seg_name = os.path.join(
|
||||
data_root, image_annotation["segmentation_path"]
|
||||
)
|
||||
joint = np.array(image_annotation["joints"])
|
||||
vis = np.array(image_annotation["visibility"])
|
||||
else:
|
||||
file_name = ref_file_name
|
||||
seg_name = ref_seg_name
|
||||
joint = None
|
||||
vis = None
|
||||
|
||||
filenames.append(file_name)
|
||||
segnames.append(seg_name)
|
||||
joints.append(joint)
|
||||
visible.append(vis)
|
||||
|
||||
if len(filenames):
|
||||
self.animal_dict[self.animal_count] = (
|
||||
filenames,
|
||||
segnames,
|
||||
joints,
|
||||
visible,
|
||||
)
|
||||
self.animal_count += 1
|
||||
print("Loaded BADJA dataset")
|
||||
|
||||
def get_loader(self):
|
||||
for __ in range(int(1e6)):
|
||||
animal_id = np.random.choice(len(self.animal_dict.keys()))
|
||||
filenames, segnames, joints, visible = self.animal_dict[animal_id]
|
||||
|
||||
image_id = np.random.randint(0, len(filenames))
|
||||
|
||||
seg_file = segnames[image_id]
|
||||
image_file = filenames[image_id]
|
||||
|
||||
joints = joints[image_id].copy()
|
||||
joints = joints[self.smal_joint_info.annotated_classes]
|
||||
visible = visible[image_id][self.smal_joint_info.annotated_classes]
|
||||
|
||||
rgb_img = imageio.imread(image_file) # , mode='RGB')
|
||||
sil_img = imageio.imread(seg_file) # , mode='RGB')
|
||||
|
||||
rgb_h, rgb_w, _ = rgb_img.shape
|
||||
sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST)
|
||||
|
||||
yield rgb_img, sil_img, joints, visible, image_file
|
||||
|
||||
def get_video(self, animal_id):
|
||||
filenames, segnames, joint, visible = self.animal_dict[animal_id]
|
||||
|
||||
rgbs = []
|
||||
segs = []
|
||||
joints = []
|
||||
visibles = []
|
||||
|
||||
for s in range(len(filenames)):
|
||||
image_file = filenames[s]
|
||||
rgb_img = imageio.imread(image_file) # , mode='RGB')
|
||||
rgb_h, rgb_w, _ = rgb_img.shape
|
||||
|
||||
seg_file = segnames[s]
|
||||
sil_img = imageio.imread(seg_file) # , mode='RGB')
|
||||
sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST)
|
||||
|
||||
jo = joint[s]
|
||||
|
||||
if jo is not None:
|
||||
joi = joint[s].copy()
|
||||
joi = joi[self.smal_joint_info.annotated_classes]
|
||||
vis = visible[s][self.smal_joint_info.annotated_classes]
|
||||
else:
|
||||
joi = None
|
||||
vis = None
|
||||
|
||||
rgbs.append(rgb_img)
|
||||
segs.append(sil_img)
|
||||
joints.append(joi)
|
||||
visibles.append(vis)
|
||||
|
||||
return rgbs, segs, joints, visibles, filenames[0]
|
||||
|
||||
|
||||
class BadjaDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self, data_root, max_seq_len=1000, dataset_resolution=(384, 512)
|
||||
):
|
||||
|
||||
self.data_root = data_root
|
||||
self.badja_data = BADJAData(data_root)
|
||||
self.max_seq_len = max_seq_len
|
||||
self.dataset_resolution = dataset_resolution
|
||||
print(
|
||||
"found %d unique videos in %s"
|
||||
% (self.badja_data.animal_count, self.data_root)
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
rgbs, segs, joints, visibles, filename = self.badja_data.get_video(index)
|
||||
S = len(rgbs)
|
||||
H, W, __ = rgbs[0].shape
|
||||
H, W, __ = segs[0].shape
|
||||
|
||||
N, __ = joints[0].shape
|
||||
|
||||
# let's eliminate the Nones
|
||||
# note the first one is guaranteed present
|
||||
for s in range(1, S):
|
||||
if joints[s] is None:
|
||||
joints[s] = np.zeros_like(joints[0])
|
||||
visibles[s] = np.zeros_like(visibles[0])
|
||||
|
||||
# eliminate the mystery dim
|
||||
segs = [seg[:, :, 0] for seg in segs]
|
||||
|
||||
rgbs = np.stack(rgbs, 0)
|
||||
segs = np.stack(segs, 0)
|
||||
trajs = np.stack(joints, 0)
|
||||
visibles = np.stack(visibles, 0)
|
||||
|
||||
rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0, 3, 1, 2).float()
|
||||
segs = torch.from_numpy(segs).reshape(S, 1, H, W).float()
|
||||
trajs = torch.from_numpy(trajs).reshape(S, N, 2).float()
|
||||
visibles = torch.from_numpy(visibles).reshape(S, N)
|
||||
|
||||
rgbs = rgbs[: self.max_seq_len]
|
||||
segs = segs[: self.max_seq_len]
|
||||
trajs = trajs[: self.max_seq_len]
|
||||
visibles = visibles[: self.max_seq_len]
|
||||
# apparently the coords are in yx order
|
||||
trajs = torch.flip(trajs, [2])
|
||||
|
||||
if "extra_videos" in filename:
|
||||
seq_name = filename.split("/")[-3]
|
||||
else:
|
||||
seq_name = filename.split("/")[-2]
|
||||
|
||||
rgbs, trajs, segs = resize_sample(rgbs, trajs, segs, self.dataset_resolution)
|
||||
|
||||
return CoTrackerData(rgbs, segs, trajs, visibles, seq_name=seq_name)
|
||||
|
||||
def __len__(self):
|
||||
return self.badja_data.animal_count
|
72
cotracker/datasets/fast_capture_dataset.py
Normal file
72
cotracker/datasets/fast_capture_dataset.py
Normal file
@ -0,0 +1,72 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import torch
|
||||
|
||||
# from PIL import Image
|
||||
import imageio
|
||||
import numpy as np
|
||||
from cotracker.datasets.utils import CoTrackerData, resize_sample
|
||||
|
||||
|
||||
class FastCaptureDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
max_seq_len=50,
|
||||
max_num_points=20,
|
||||
dataset_resolution=(384, 512),
|
||||
):
|
||||
|
||||
self.data_root = data_root
|
||||
self.seq_names = os.listdir(os.path.join(data_root, "renders_local_rm"))
|
||||
self.pth_dir = os.path.join(data_root, "zju_tracking")
|
||||
self.max_seq_len = max_seq_len
|
||||
self.max_num_points = max_num_points
|
||||
self.dataset_resolution = dataset_resolution
|
||||
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
|
||||
|
||||
def __getitem__(self, index):
|
||||
seq_name = self.seq_names[index]
|
||||
spath = os.path.join(self.data_root, "renders_local_rm", seq_name)
|
||||
pthpath = os.path.join(self.pth_dir, seq_name + ".pth")
|
||||
|
||||
rgbs = []
|
||||
img_paths = sorted(os.listdir(spath))
|
||||
for i, img_path in enumerate(img_paths):
|
||||
if i < self.max_seq_len:
|
||||
rgbs.append(imageio.imread(os.path.join(spath, img_path)))
|
||||
|
||||
annot_dict = torch.load(pthpath)
|
||||
traj_2d = annot_dict["traj_2d"][:, :, : self.max_seq_len]
|
||||
visibility = annot_dict["visibility"][:, : self.max_seq_len]
|
||||
|
||||
S = len(rgbs)
|
||||
H, W, __ = rgbs[0].shape
|
||||
*_, S = traj_2d.shape
|
||||
visibile_pts_first_frame_inds = (visibility[:, 0] > 0).nonzero(as_tuple=False)[
|
||||
:, 0
|
||||
]
|
||||
torch.manual_seed(0)
|
||||
point_inds = torch.randperm(len(visibile_pts_first_frame_inds))[
|
||||
: self.max_num_points
|
||||
]
|
||||
visible_inds_sampled = visibile_pts_first_frame_inds[point_inds]
|
||||
|
||||
rgbs = np.stack(rgbs, 0)
|
||||
rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0, 3, 1, 2).float()
|
||||
|
||||
segs = torch.ones(S, 1, H, W).float()
|
||||
trajs = traj_2d[visible_inds_sampled].permute(2, 0, 1).float()
|
||||
visibles = visibility[visible_inds_sampled].permute(1, 0)
|
||||
|
||||
rgbs, trajs, segs = resize_sample(rgbs, trajs, segs, self.dataset_resolution)
|
||||
|
||||
return CoTrackerData(rgbs, segs, trajs, visibles, seq_name=seq_name)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.seq_names)
|
494
cotracker/datasets/kubric_movif_dataset.py
Normal file
494
cotracker/datasets/kubric_movif_dataset.py
Normal file
@ -0,0 +1,494 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import torch
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
|
||||
from cotracker.datasets.utils import CoTrackerData
|
||||
from torchvision.transforms import ColorJitter, GaussianBlur
|
||||
from PIL import Image
|
||||
import cv2
|
||||
|
||||
|
||||
class CoTrackerDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
crop_size=(384, 512),
|
||||
seq_len=24,
|
||||
traj_per_sample=768,
|
||||
sample_vis_1st_frame=False,
|
||||
use_augs=False,
|
||||
):
|
||||
super(CoTrackerDataset, self).__init__()
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
self.data_root = data_root
|
||||
self.seq_len = seq_len
|
||||
self.traj_per_sample = traj_per_sample
|
||||
self.sample_vis_1st_frame = sample_vis_1st_frame
|
||||
self.use_augs = use_augs
|
||||
self.crop_size = crop_size
|
||||
|
||||
# photometric augmentation
|
||||
self.photo_aug = ColorJitter(
|
||||
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14
|
||||
)
|
||||
self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
|
||||
|
||||
self.blur_aug_prob = 0.25
|
||||
self.color_aug_prob = 0.25
|
||||
|
||||
# occlusion augmentation
|
||||
self.eraser_aug_prob = 0.5
|
||||
self.eraser_bounds = [2, 100]
|
||||
self.eraser_max = 10
|
||||
|
||||
# occlusion augmentation
|
||||
self.replace_aug_prob = 0.5
|
||||
self.replace_bounds = [2, 100]
|
||||
self.replace_max = 10
|
||||
|
||||
# spatial augmentations
|
||||
self.pad_bounds = [0, 100]
|
||||
self.crop_size = crop_size
|
||||
self.resize_lim = [0.25, 2.0] # sample resizes from here
|
||||
self.resize_delta = 0.2
|
||||
self.max_crop_offset = 50
|
||||
|
||||
self.do_flip = True
|
||||
self.h_flip_prob = 0.5
|
||||
self.v_flip_prob = 0.5
|
||||
|
||||
def getitem_helper(self, index):
|
||||
return NotImplementedError
|
||||
|
||||
def __getitem__(self, index):
|
||||
gotit = False
|
||||
|
||||
sample, gotit = self.getitem_helper(index)
|
||||
if not gotit:
|
||||
print("warning: sampling failed")
|
||||
# fake sample, so we can still collate
|
||||
sample = CoTrackerData(
|
||||
video=torch.zeros(
|
||||
(self.seq_len, 3, self.crop_size[0], self.crop_size[1])
|
||||
),
|
||||
segmentation=torch.zeros(
|
||||
(self.seq_len, 1, self.crop_size[0], self.crop_size[1])
|
||||
),
|
||||
trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)),
|
||||
visibility=torch.zeros((self.seq_len, self.traj_per_sample)),
|
||||
valid=torch.zeros((self.seq_len, self.traj_per_sample)),
|
||||
)
|
||||
|
||||
return sample, gotit
|
||||
|
||||
def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True):
|
||||
T, N, _ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
if eraser:
|
||||
############ eraser transform (per image after the first) ############
|
||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||
for i in range(1, S):
|
||||
if np.random.rand() < self.eraser_aug_prob:
|
||||
for _ in range(
|
||||
np.random.randint(1, self.eraser_max + 1)
|
||||
): # number of times to occlude
|
||||
|
||||
xc = np.random.randint(0, W)
|
||||
yc = np.random.randint(0, H)
|
||||
dx = np.random.randint(
|
||||
self.eraser_bounds[0], self.eraser_bounds[1]
|
||||
)
|
||||
dy = np.random.randint(
|
||||
self.eraser_bounds[0], self.eraser_bounds[1]
|
||||
)
|
||||
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
|
||||
mean_color = np.mean(
|
||||
rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0
|
||||
)
|
||||
rgbs[i][y0:y1, x0:x1, :] = mean_color
|
||||
|
||||
occ_inds = np.logical_and(
|
||||
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
|
||||
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
|
||||
)
|
||||
visibles[i, occ_inds] = 0
|
||||
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
||||
|
||||
if replace:
|
||||
|
||||
rgbs_alt = [
|
||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
|
||||
for rgb in rgbs
|
||||
]
|
||||
rgbs_alt = [
|
||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
|
||||
for rgb in rgbs_alt
|
||||
]
|
||||
|
||||
############ replace transform (per image after the first) ############
|
||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||
rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt]
|
||||
for i in range(1, S):
|
||||
if np.random.rand() < self.replace_aug_prob:
|
||||
for _ in range(
|
||||
np.random.randint(1, self.replace_max + 1)
|
||||
): # number of times to occlude
|
||||
xc = np.random.randint(0, W)
|
||||
yc = np.random.randint(0, H)
|
||||
dx = np.random.randint(
|
||||
self.replace_bounds[0], self.replace_bounds[1]
|
||||
)
|
||||
dy = np.random.randint(
|
||||
self.replace_bounds[0], self.replace_bounds[1]
|
||||
)
|
||||
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
|
||||
wid = x1 - x0
|
||||
hei = y1 - y0
|
||||
y00 = np.random.randint(0, H - hei)
|
||||
x00 = np.random.randint(0, W - wid)
|
||||
fr = np.random.randint(0, S)
|
||||
rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :]
|
||||
rgbs[i][y0:y1, x0:x1, :] = rep
|
||||
|
||||
occ_inds = np.logical_and(
|
||||
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
|
||||
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
|
||||
)
|
||||
visibles[i, occ_inds] = 0
|
||||
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
||||
|
||||
############ photometric augmentation ############
|
||||
if np.random.rand() < self.color_aug_prob:
|
||||
# random per-frame amount of aug
|
||||
rgbs = [
|
||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
|
||||
for rgb in rgbs
|
||||
]
|
||||
|
||||
if np.random.rand() < self.blur_aug_prob:
|
||||
# random per-frame amount of blur
|
||||
rgbs = [
|
||||
np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8)
|
||||
for rgb in rgbs
|
||||
]
|
||||
|
||||
return rgbs, trajs, visibles
|
||||
|
||||
def add_spatial_augs(self, rgbs, trajs, visibles):
|
||||
T, N, __ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||
|
||||
############ spatial transform ############
|
||||
|
||||
# padding
|
||||
pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
|
||||
rgbs = [
|
||||
np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs
|
||||
]
|
||||
trajs[:, :, 0] += pad_x0
|
||||
trajs[:, :, 1] += pad_y0
|
||||
H, W = rgbs[0].shape[:2]
|
||||
|
||||
# scaling + stretching
|
||||
scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1])
|
||||
scale_x = scale
|
||||
scale_y = scale
|
||||
H_new = H
|
||||
W_new = W
|
||||
|
||||
scale_delta_x = 0.0
|
||||
scale_delta_y = 0.0
|
||||
|
||||
rgbs_scaled = []
|
||||
for s in range(S):
|
||||
if s == 1:
|
||||
scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta)
|
||||
scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta)
|
||||
elif s > 1:
|
||||
scale_delta_x = (
|
||||
scale_delta_x * 0.8
|
||||
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
|
||||
)
|
||||
scale_delta_y = (
|
||||
scale_delta_y * 0.8
|
||||
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
|
||||
)
|
||||
scale_x = scale_x + scale_delta_x
|
||||
scale_y = scale_y + scale_delta_y
|
||||
|
||||
# bring h/w closer
|
||||
scale_xy = (scale_x + scale_y) * 0.5
|
||||
scale_x = scale_x * 0.5 + scale_xy * 0.5
|
||||
scale_y = scale_y * 0.5 + scale_xy * 0.5
|
||||
|
||||
# don't get too crazy
|
||||
scale_x = np.clip(scale_x, 0.2, 2.0)
|
||||
scale_y = np.clip(scale_y, 0.2, 2.0)
|
||||
|
||||
H_new = int(H * scale_y)
|
||||
W_new = int(W * scale_x)
|
||||
|
||||
# make it at least slightly bigger than the crop area,
|
||||
# so that the random cropping can add diversity
|
||||
H_new = np.clip(H_new, self.crop_size[0] + 10, None)
|
||||
W_new = np.clip(W_new, self.crop_size[1] + 10, None)
|
||||
# recompute scale in case we clipped
|
||||
scale_x = W_new / float(W)
|
||||
scale_y = H_new / float(H)
|
||||
|
||||
rgbs_scaled.append(
|
||||
cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR)
|
||||
)
|
||||
trajs[s, :, 0] *= scale_x
|
||||
trajs[s, :, 1] *= scale_y
|
||||
rgbs = rgbs_scaled
|
||||
|
||||
ok_inds = visibles[0, :] > 0
|
||||
vis_trajs = trajs[:, ok_inds] # S,?,2
|
||||
|
||||
if vis_trajs.shape[1] > 0:
|
||||
mid_x = np.mean(vis_trajs[0, :, 0])
|
||||
mid_y = np.mean(vis_trajs[0, :, 1])
|
||||
else:
|
||||
mid_y = self.crop_size[0]
|
||||
mid_x = self.crop_size[1]
|
||||
|
||||
x0 = int(mid_x - self.crop_size[1] // 2)
|
||||
y0 = int(mid_y - self.crop_size[0] // 2)
|
||||
|
||||
offset_x = 0
|
||||
offset_y = 0
|
||||
|
||||
for s in range(S):
|
||||
# on each frame, shift a bit more
|
||||
if s == 1:
|
||||
offset_x = np.random.randint(
|
||||
-self.max_crop_offset, self.max_crop_offset
|
||||
)
|
||||
offset_y = np.random.randint(
|
||||
-self.max_crop_offset, self.max_crop_offset
|
||||
)
|
||||
elif s > 1:
|
||||
offset_x = int(
|
||||
offset_x * 0.8
|
||||
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1)
|
||||
* 0.2
|
||||
)
|
||||
offset_y = int(
|
||||
offset_y * 0.8
|
||||
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1)
|
||||
* 0.2
|
||||
)
|
||||
x0 = x0 + offset_x
|
||||
y0 = y0 + offset_y
|
||||
|
||||
H_new, W_new = rgbs[s].shape[:2]
|
||||
if H_new == self.crop_size[0]:
|
||||
y0 = 0
|
||||
else:
|
||||
y0 = min(max(0, y0), H_new - self.crop_size[0] - 1)
|
||||
|
||||
if W_new == self.crop_size[1]:
|
||||
x0 = 0
|
||||
else:
|
||||
x0 = min(max(0, x0), W_new - self.crop_size[1] - 1)
|
||||
|
||||
rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
|
||||
trajs[s, :, 0] -= x0
|
||||
trajs[s, :, 1] -= y0
|
||||
|
||||
H_new = self.crop_size[0]
|
||||
W_new = self.crop_size[1]
|
||||
|
||||
# flip
|
||||
h_flipped = False
|
||||
v_flipped = False
|
||||
if self.do_flip:
|
||||
# h flip
|
||||
if np.random.rand() < self.h_flip_prob:
|
||||
h_flipped = True
|
||||
rgbs = [rgb[:, ::-1] for rgb in rgbs]
|
||||
# v flip
|
||||
if np.random.rand() < self.v_flip_prob:
|
||||
v_flipped = True
|
||||
rgbs = [rgb[::-1] for rgb in rgbs]
|
||||
if h_flipped:
|
||||
trajs[:, :, 0] = W_new - trajs[:, :, 0]
|
||||
if v_flipped:
|
||||
trajs[:, :, 1] = H_new - trajs[:, :, 1]
|
||||
|
||||
return rgbs, trajs
|
||||
|
||||
def crop(self, rgbs, trajs):
|
||||
T, N, _ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
############ spatial transform ############
|
||||
|
||||
H_new = H
|
||||
W_new = W
|
||||
|
||||
# simple random crop
|
||||
y0 = (
|
||||
0
|
||||
if self.crop_size[0] >= H_new
|
||||
else np.random.randint(0, H_new - self.crop_size[0])
|
||||
)
|
||||
x0 = (
|
||||
0
|
||||
if self.crop_size[1] >= W_new
|
||||
else np.random.randint(0, W_new - self.crop_size[1])
|
||||
)
|
||||
rgbs = [
|
||||
rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
|
||||
for rgb in rgbs
|
||||
]
|
||||
|
||||
trajs[:, :, 0] -= x0
|
||||
trajs[:, :, 1] -= y0
|
||||
|
||||
return rgbs, trajs
|
||||
|
||||
|
||||
class KubricMovifDataset(CoTrackerDataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
crop_size=(384, 512),
|
||||
seq_len=24,
|
||||
traj_per_sample=768,
|
||||
sample_vis_1st_frame=False,
|
||||
use_augs=False,
|
||||
):
|
||||
super(KubricMovifDataset, self).__init__(
|
||||
data_root=data_root,
|
||||
crop_size=crop_size,
|
||||
seq_len=seq_len,
|
||||
traj_per_sample=traj_per_sample,
|
||||
sample_vis_1st_frame=sample_vis_1st_frame,
|
||||
use_augs=use_augs,
|
||||
)
|
||||
|
||||
self.pad_bounds = [0, 25]
|
||||
self.resize_lim = [0.75, 1.25] # sample resizes from here
|
||||
self.resize_delta = 0.05
|
||||
self.max_crop_offset = 15
|
||||
self.seq_names = [
|
||||
fname
|
||||
for fname in os.listdir(data_root)
|
||||
if os.path.isdir(os.path.join(data_root, fname))
|
||||
]
|
||||
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
|
||||
|
||||
def getitem_helper(self, index):
|
||||
gotit = True
|
||||
seq_name = self.seq_names[index]
|
||||
|
||||
npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy")
|
||||
rgb_path = os.path.join(self.data_root, seq_name, "frames")
|
||||
|
||||
img_paths = sorted(os.listdir(rgb_path))
|
||||
rgbs = []
|
||||
for i, img_path in enumerate(img_paths):
|
||||
rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path)))
|
||||
|
||||
rgbs = np.stack(rgbs)
|
||||
annot_dict = np.load(npy_path, allow_pickle=True).item()
|
||||
traj_2d = annot_dict["coords"]
|
||||
visibility = annot_dict["visibility"]
|
||||
|
||||
# random crop
|
||||
assert self.seq_len <= len(rgbs)
|
||||
if self.seq_len < len(rgbs):
|
||||
start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0]
|
||||
|
||||
rgbs = rgbs[start_ind : start_ind + self.seq_len]
|
||||
traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len]
|
||||
visibility = visibility[:, start_ind : start_ind + self.seq_len]
|
||||
|
||||
traj_2d = np.transpose(traj_2d, (1, 0, 2))
|
||||
visibility = np.transpose(np.logical_not(visibility), (1, 0))
|
||||
if self.use_augs:
|
||||
rgbs, traj_2d, visibility = self.add_photometric_augs(
|
||||
rgbs, traj_2d, visibility
|
||||
)
|
||||
rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility)
|
||||
else:
|
||||
rgbs, traj_2d = self.crop(rgbs, traj_2d)
|
||||
|
||||
visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False
|
||||
visibility[traj_2d[:, :, 0] < 0] = False
|
||||
visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False
|
||||
visibility[traj_2d[:, :, 1] < 0] = False
|
||||
|
||||
visibility = torch.from_numpy(visibility)
|
||||
traj_2d = torch.from_numpy(traj_2d)
|
||||
|
||||
visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0]
|
||||
|
||||
if self.sample_vis_1st_frame:
|
||||
visibile_pts_inds = visibile_pts_first_frame_inds
|
||||
else:
|
||||
visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(
|
||||
as_tuple=False
|
||||
)[:, 0]
|
||||
visibile_pts_inds = torch.cat(
|
||||
(visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
|
||||
)
|
||||
point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample]
|
||||
if len(point_inds) < self.traj_per_sample:
|
||||
gotit = False
|
||||
|
||||
visible_inds_sampled = visibile_pts_inds[point_inds]
|
||||
|
||||
trajs = traj_2d[:, visible_inds_sampled].float()
|
||||
visibles = visibility[:, visible_inds_sampled]
|
||||
valids = torch.ones((self.seq_len, self.traj_per_sample))
|
||||
|
||||
rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float()
|
||||
segs = torch.ones((self.seq_len, 1, self.crop_size[0], self.crop_size[1]))
|
||||
sample = CoTrackerData(
|
||||
video=rgbs,
|
||||
segmentation=segs,
|
||||
trajectory=trajs,
|
||||
visibility=visibles,
|
||||
valid=valids,
|
||||
seq_name=seq_name,
|
||||
)
|
||||
return sample, gotit
|
||||
|
||||
def __len__(self):
|
||||
return len(self.seq_names)
|
218
cotracker/datasets/tap_vid_datasets.py
Normal file
218
cotracker/datasets/tap_vid_datasets.py
Normal file
@ -0,0 +1,218 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import io
|
||||
import glob
|
||||
import torch
|
||||
import pickle
|
||||
import numpy as np
|
||||
import mediapy as media
|
||||
|
||||
from PIL import Image
|
||||
from typing import Mapping, Tuple, Union
|
||||
|
||||
from cotracker.datasets.utils import CoTrackerData
|
||||
|
||||
DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]
|
||||
|
||||
|
||||
def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
||||
"""Resize a video to output_size."""
|
||||
# If you have a GPU, consider replacing this with a GPU-enabled resize op,
|
||||
# such as a jitted jax.image.resize. It will make things faster.
|
||||
return media.resize_video(video, output_size)
|
||||
|
||||
|
||||
def sample_queries_first(
|
||||
target_occluded: np.ndarray,
|
||||
target_points: np.ndarray,
|
||||
frames: np.ndarray,
|
||||
) -> Mapping[str, np.ndarray]:
|
||||
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
||||
Given a set of frames and tracks with no query points, use the first
|
||||
visible point in each track as the query.
|
||||
Args:
|
||||
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
||||
where True indicates occluded.
|
||||
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
||||
is [x,y] scaled between 0 and 1.
|
||||
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
||||
-1 and 1.
|
||||
Returns:
|
||||
A dict with the keys:
|
||||
video: Video tensor of shape [1, n_frames, height, width, 3]
|
||||
query_points: Query points of shape [1, n_queries, 3] where
|
||||
each point is [t, y, x] scaled to the range [-1, 1]
|
||||
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
||||
each point is [x, y] scaled to the range [-1, 1]
|
||||
"""
|
||||
valid = np.sum(~target_occluded, axis=1) > 0
|
||||
target_points = target_points[valid, :]
|
||||
target_occluded = target_occluded[valid, :]
|
||||
|
||||
query_points = []
|
||||
for i in range(target_points.shape[0]):
|
||||
index = np.where(target_occluded[i] == 0)[0][0]
|
||||
x, y = target_points[i, index, 0], target_points[i, index, 1]
|
||||
query_points.append(np.array([index, y, x])) # [t, y, x]
|
||||
query_points = np.stack(query_points, axis=0)
|
||||
|
||||
return {
|
||||
"video": frames[np.newaxis, ...],
|
||||
"query_points": query_points[np.newaxis, ...],
|
||||
"target_points": target_points[np.newaxis, ...],
|
||||
"occluded": target_occluded[np.newaxis, ...],
|
||||
}
|
||||
|
||||
|
||||
def sample_queries_strided(
|
||||
target_occluded: np.ndarray,
|
||||
target_points: np.ndarray,
|
||||
frames: np.ndarray,
|
||||
query_stride: int = 5,
|
||||
) -> Mapping[str, np.ndarray]:
|
||||
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
||||
|
||||
Given a set of frames and tracks with no query points, sample queries
|
||||
strided every query_stride frames, ignoring points that are not visible
|
||||
at the selected frames.
|
||||
|
||||
Args:
|
||||
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
||||
where True indicates occluded.
|
||||
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
||||
is [x,y] scaled between 0 and 1.
|
||||
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
||||
-1 and 1.
|
||||
query_stride: When sampling query points, search for un-occluded points
|
||||
every query_stride frames and convert each one into a query.
|
||||
|
||||
Returns:
|
||||
A dict with the keys:
|
||||
video: Video tensor of shape [1, n_frames, height, width, 3]. The video
|
||||
has floats scaled to the range [-1, 1].
|
||||
query_points: Query points of shape [1, n_queries, 3] where
|
||||
each point is [t, y, x] scaled to the range [-1, 1].
|
||||
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
||||
each point is [x, y] scaled to the range [-1, 1].
|
||||
trackgroup: Index of the original track that each query point was
|
||||
sampled from. This is useful for visualization.
|
||||
"""
|
||||
tracks = []
|
||||
occs = []
|
||||
queries = []
|
||||
trackgroups = []
|
||||
total = 0
|
||||
trackgroup = np.arange(target_occluded.shape[0])
|
||||
for i in range(0, target_occluded.shape[1], query_stride):
|
||||
mask = target_occluded[:, i] == 0
|
||||
query = np.stack(
|
||||
[
|
||||
i * np.ones(target_occluded.shape[0:1]),
|
||||
target_points[:, i, 1],
|
||||
target_points[:, i, 0],
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
queries.append(query[mask])
|
||||
tracks.append(target_points[mask])
|
||||
occs.append(target_occluded[mask])
|
||||
trackgroups.append(trackgroup[mask])
|
||||
total += np.array(np.sum(target_occluded[:, i] == 0))
|
||||
|
||||
return {
|
||||
"video": frames[np.newaxis, ...],
|
||||
"query_points": np.concatenate(queries, axis=0)[np.newaxis, ...],
|
||||
"target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...],
|
||||
"occluded": np.concatenate(occs, axis=0)[np.newaxis, ...],
|
||||
"trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
|
||||
}
|
||||
|
||||
|
||||
class TapVidDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
dataset_type="davis",
|
||||
resize_to_256=True,
|
||||
queried_first=True,
|
||||
):
|
||||
self.dataset_type = dataset_type
|
||||
self.resize_to_256 = resize_to_256
|
||||
self.queried_first = queried_first
|
||||
if self.dataset_type == "kinetics":
|
||||
all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl"))
|
||||
points_dataset = []
|
||||
for pickle_path in all_paths:
|
||||
with open(pickle_path, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
points_dataset = points_dataset + data
|
||||
self.points_dataset = points_dataset
|
||||
else:
|
||||
with open(data_root, "rb") as f:
|
||||
self.points_dataset = pickle.load(f)
|
||||
if self.dataset_type == "davis":
|
||||
self.video_names = list(self.points_dataset.keys())
|
||||
print("found %d unique videos in %s" % (len(self.points_dataset), data_root))
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.dataset_type == "davis":
|
||||
video_name = self.video_names[index]
|
||||
else:
|
||||
video_name = index
|
||||
video = self.points_dataset[video_name]
|
||||
frames = video["video"]
|
||||
|
||||
if isinstance(frames[0], bytes):
|
||||
# TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
|
||||
def decode(frame):
|
||||
byteio = io.BytesIO(frame)
|
||||
img = Image.open(byteio)
|
||||
return np.array(img)
|
||||
|
||||
frames = np.array([decode(frame) for frame in frames])
|
||||
|
||||
target_points = self.points_dataset[video_name]["points"]
|
||||
if self.resize_to_256:
|
||||
frames = resize_video(frames, [256, 256])
|
||||
target_points *= np.array([256, 256])
|
||||
else:
|
||||
target_points *= np.array([frames.shape[2], frames.shape[1]])
|
||||
|
||||
T, H, W, C = frames.shape
|
||||
N, T, D = target_points.shape
|
||||
|
||||
target_occ = self.points_dataset[video_name]["occluded"]
|
||||
if self.queried_first:
|
||||
converted = sample_queries_first(target_occ, target_points, frames)
|
||||
else:
|
||||
converted = sample_queries_strided(target_occ, target_points, frames)
|
||||
assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
|
||||
|
||||
trajs = (
|
||||
torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float()
|
||||
) # T, N, D
|
||||
|
||||
rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
|
||||
segs = torch.ones(T, 1, H, W).float()
|
||||
visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[
|
||||
0
|
||||
].permute(
|
||||
1, 0
|
||||
) # T, N
|
||||
query_points = torch.from_numpy(converted["query_points"])[0] # T, N
|
||||
return CoTrackerData(
|
||||
rgbs,
|
||||
segs,
|
||||
trajs,
|
||||
visibles,
|
||||
seq_name=str(video_name),
|
||||
query_points=query_points,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.points_dataset)
|
114
cotracker/datasets/utils.py
Normal file
114
cotracker/datasets/utils.py
Normal file
@ -0,0 +1,114 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
import dataclasses
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class CoTrackerData:
|
||||
"""
|
||||
Dataclass for storing video tracks data.
|
||||
"""
|
||||
|
||||
video: torch.Tensor # B, S, C, H, W
|
||||
segmentation: torch.Tensor # B, S, 1, H, W
|
||||
trajectory: torch.Tensor # B, S, N, 2
|
||||
visibility: torch.Tensor # B, S, N
|
||||
# optional data
|
||||
valid: Optional[torch.Tensor] = None # B, S, N
|
||||
seq_name: Optional[str] = None
|
||||
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
"""
|
||||
Collate function for video tracks data.
|
||||
"""
|
||||
video = torch.stack([b.video for b in batch], dim=0)
|
||||
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
|
||||
trajectory = torch.stack([b.trajectory for b in batch], dim=0)
|
||||
visibility = torch.stack([b.visibility for b in batch], dim=0)
|
||||
query_points = None
|
||||
if batch[0].query_points is not None:
|
||||
query_points = torch.stack([b.query_points for b in batch], dim=0)
|
||||
seq_name = [b.seq_name for b in batch]
|
||||
|
||||
return CoTrackerData(
|
||||
video,
|
||||
segmentation,
|
||||
trajectory,
|
||||
visibility,
|
||||
seq_name=seq_name,
|
||||
query_points=query_points,
|
||||
)
|
||||
|
||||
|
||||
def collate_fn_train(batch):
|
||||
"""
|
||||
Collate function for video tracks data during training.
|
||||
"""
|
||||
gotit = [gotit for _, gotit in batch]
|
||||
video = torch.stack([b.video for b, _ in batch], dim=0)
|
||||
segmentation = torch.stack([b.segmentation for b, _ in batch], dim=0)
|
||||
trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
|
||||
visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
|
||||
valid = torch.stack([b.valid for b, _ in batch], dim=0)
|
||||
seq_name = [b.seq_name for b, _ in batch]
|
||||
return (
|
||||
CoTrackerData(video, segmentation, trajectory, visibility, valid, seq_name),
|
||||
gotit,
|
||||
)
|
||||
|
||||
|
||||
def try_to_cuda(t: Any) -> Any:
|
||||
"""
|
||||
Try to move the input variable `t` to a cuda device.
|
||||
|
||||
Args:
|
||||
t: Input.
|
||||
|
||||
Returns:
|
||||
t_cuda: `t` moved to a cuda device, if supported.
|
||||
"""
|
||||
try:
|
||||
t = t.float().cuda()
|
||||
except AttributeError:
|
||||
pass
|
||||
return t
|
||||
|
||||
|
||||
def dataclass_to_cuda_(obj):
|
||||
"""
|
||||
Move all contents of a dataclass to cuda inplace if supported.
|
||||
|
||||
Args:
|
||||
batch: Input dataclass.
|
||||
|
||||
Returns:
|
||||
batch_cuda: `batch` moved to a cuda device, if supported.
|
||||
"""
|
||||
for f in dataclasses.fields(obj):
|
||||
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
|
||||
return obj
|
||||
|
||||
|
||||
def resize_sample(rgbs, trajs_g, segs, interp_shape):
|
||||
S, C, H, W = rgbs.shape
|
||||
S, N, D = trajs_g.shape
|
||||
|
||||
assert D == 2
|
||||
|
||||
rgbs = F.interpolate(rgbs, interp_shape, mode="bilinear")
|
||||
segs = F.interpolate(segs, interp_shape, mode="nearest")
|
||||
|
||||
trajs_g[:, :, 0] *= interp_shape[1] / W
|
||||
trajs_g[:, :, 1] *= interp_shape[0] / H
|
||||
return rgbs, trajs_g, segs
|
5
cotracker/evaluation/__init__.py
Normal file
5
cotracker/evaluation/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
6
cotracker/evaluation/configs/eval_badja.yaml
Normal file
6
cotracker/evaluation/configs/eval_badja.yaml
Normal file
@ -0,0 +1,6 @@
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: badja
|
||||
|
||||
|
6
cotracker/evaluation/configs/eval_fastcapture.yaml
Normal file
6
cotracker/evaluation/configs/eval_fastcapture.yaml
Normal file
@ -0,0 +1,6 @@
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: fastcapture
|
||||
|
||||
|
@ -0,0 +1,6 @@
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: tapvid_davis_first
|
||||
|
||||
|
@ -0,0 +1,6 @@
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: tapvid_davis_strided
|
||||
|
||||
|
@ -0,0 +1,6 @@
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: tapvid_kinetics_first
|
||||
|
||||
|
5
cotracker/evaluation/core/__init__.py
Normal file
5
cotracker/evaluation/core/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
144
cotracker/evaluation/core/eval_utils.py
Normal file
144
cotracker/evaluation/core/eval_utils.py
Normal file
@ -0,0 +1,144 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from typing import Iterable, Mapping, Tuple, Union
|
||||
|
||||
|
||||
def compute_tapvid_metrics(
|
||||
query_points: np.ndarray,
|
||||
gt_occluded: np.ndarray,
|
||||
gt_tracks: np.ndarray,
|
||||
pred_occluded: np.ndarray,
|
||||
pred_tracks: np.ndarray,
|
||||
query_mode: str,
|
||||
) -> Mapping[str, np.ndarray]:
|
||||
"""Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)
|
||||
See the TAP-Vid paper for details on the metric computation. All inputs are
|
||||
given in raster coordinates. The first three arguments should be the direct
|
||||
outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
|
||||
The paper metrics assume these are scaled relative to 256x256 images.
|
||||
pred_occluded and pred_tracks are your algorithm's predictions.
|
||||
This function takes a batch of inputs, and computes metrics separately for
|
||||
each video. The metrics for the full benchmark are a simple mean of the
|
||||
metrics across the full set of videos. These numbers are between 0 and 1,
|
||||
but the paper multiplies them by 100 to ease reading.
|
||||
Args:
|
||||
query_points: The query points, an in the format [t, y, x]. Its size is
|
||||
[b, n, 3], where b is the batch size and n is the number of queries
|
||||
gt_occluded: A boolean array of shape [b, n, t], where t is the number
|
||||
of frames. True indicates that the point is occluded.
|
||||
gt_tracks: The target points, of shape [b, n, t, 2]. Each point is
|
||||
in the format [x, y]
|
||||
pred_occluded: A boolean array of predicted occlusions, in the same
|
||||
format as gt_occluded.
|
||||
pred_tracks: An array of track predictions from your algorithm, in the
|
||||
same format as gt_tracks.
|
||||
query_mode: Either 'first' or 'strided', depending on how queries are
|
||||
sampled. If 'first', we assume the prior knowledge that all points
|
||||
before the query point are occluded, and these are removed from the
|
||||
evaluation.
|
||||
Returns:
|
||||
A dict with the following keys:
|
||||
occlusion_accuracy: Accuracy at predicting occlusion.
|
||||
pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
|
||||
predicted to be within the given pixel threshold, ignoring occlusion
|
||||
prediction.
|
||||
jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
|
||||
threshold
|
||||
average_pts_within_thresh: average across pts_within_{x}
|
||||
average_jaccard: average across jaccard_{x}
|
||||
"""
|
||||
|
||||
metrics = {}
|
||||
|
||||
# Don't evaluate the query point. Numpy doesn't have one_hot, so we
|
||||
# replicate it by indexing into an identity matrix.
|
||||
one_hot_eye = np.eye(gt_tracks.shape[2])
|
||||
query_frame = query_points[..., 0]
|
||||
query_frame = np.round(query_frame).astype(np.int32)
|
||||
evaluation_points = one_hot_eye[query_frame] == 0
|
||||
|
||||
# If we're using the first point on the track as a query, don't evaluate the
|
||||
# other points.
|
||||
if query_mode == "first":
|
||||
for i in range(gt_occluded.shape[0]):
|
||||
index = np.where(gt_occluded[i] == 0)[0][0]
|
||||
evaluation_points[i, :index] = False
|
||||
elif query_mode != "strided":
|
||||
raise ValueError("Unknown query mode " + query_mode)
|
||||
|
||||
# Occlusion accuracy is simply how often the predicted occlusion equals the
|
||||
# ground truth.
|
||||
occ_acc = (
|
||||
np.sum(
|
||||
np.equal(pred_occluded, gt_occluded) & evaluation_points,
|
||||
axis=(1, 2),
|
||||
)
|
||||
/ np.sum(evaluation_points)
|
||||
)
|
||||
metrics["occlusion_accuracy"] = occ_acc
|
||||
|
||||
# Next, convert the predictions and ground truth positions into pixel
|
||||
# coordinates.
|
||||
visible = np.logical_not(gt_occluded)
|
||||
pred_visible = np.logical_not(pred_occluded)
|
||||
all_frac_within = []
|
||||
all_jaccard = []
|
||||
for thresh in [1, 2, 4, 8, 16]:
|
||||
# True positives are points that are within the threshold and where both
|
||||
# the prediction and the ground truth are listed as visible.
|
||||
within_dist = (
|
||||
np.sum(
|
||||
np.square(pred_tracks - gt_tracks),
|
||||
axis=-1,
|
||||
)
|
||||
< np.square(thresh)
|
||||
)
|
||||
is_correct = np.logical_and(within_dist, visible)
|
||||
|
||||
# Compute the frac_within_threshold, which is the fraction of points
|
||||
# within the threshold among points that are visible in the ground truth,
|
||||
# ignoring whether they're predicted to be visible.
|
||||
count_correct = np.sum(
|
||||
is_correct & evaluation_points,
|
||||
axis=(1, 2),
|
||||
)
|
||||
count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))
|
||||
frac_correct = count_correct / count_visible_points
|
||||
metrics["pts_within_" + str(thresh)] = frac_correct
|
||||
all_frac_within.append(frac_correct)
|
||||
|
||||
true_positives = np.sum(
|
||||
is_correct & pred_visible & evaluation_points, axis=(1, 2)
|
||||
)
|
||||
|
||||
# The denominator of the jaccard metric is the true positives plus
|
||||
# false positives plus false negatives. However, note that true positives
|
||||
# plus false negatives is simply the number of points in the ground truth
|
||||
# which is easier to compute than trying to compute all three quantities.
|
||||
# Thus we just add the number of points in the ground truth to the number
|
||||
# of false positives.
|
||||
#
|
||||
# False positives are simply points that are predicted to be visible,
|
||||
# but the ground truth is not visible or too far from the prediction.
|
||||
gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))
|
||||
false_positives = (~visible) & pred_visible
|
||||
false_positives = false_positives | ((~within_dist) & pred_visible)
|
||||
false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))
|
||||
jaccard = true_positives / (gt_positives + false_positives)
|
||||
metrics["jaccard_" + str(thresh)] = jaccard
|
||||
all_jaccard.append(jaccard)
|
||||
metrics["average_jaccard"] = np.mean(
|
||||
np.stack(all_jaccard, axis=1),
|
||||
axis=1,
|
||||
)
|
||||
metrics["average_pts_within_thresh"] = np.mean(
|
||||
np.stack(all_frac_within, axis=1),
|
||||
axis=1,
|
||||
)
|
||||
return metrics
|
252
cotracker/evaluation/core/evaluator.py
Normal file
252
cotracker/evaluation/core/evaluator.py
Normal file
@ -0,0 +1,252 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import defaultdict
|
||||
import os
|
||||
from typing import Optional
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from cotracker.datasets.utils import dataclass_to_cuda_
|
||||
from cotracker.utils.visualizer import Visualizer
|
||||
from cotracker.models.core.model_utils import reduce_masked_mean
|
||||
from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
class Evaluator:
|
||||
"""
|
||||
A class defining the CoTracker evaluator.
|
||||
"""
|
||||
|
||||
def __init__(self, exp_dir) -> None:
|
||||
# Visualization
|
||||
self.exp_dir = exp_dir
|
||||
os.makedirs(exp_dir, exist_ok=True)
|
||||
self.visualization_filepaths = defaultdict(lambda: defaultdict(list))
|
||||
self.visualize_dir = os.path.join(exp_dir, "visualisations")
|
||||
|
||||
def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name):
|
||||
if isinstance(pred_trajectory, tuple):
|
||||
pred_trajectory, pred_visibility = pred_trajectory
|
||||
else:
|
||||
pred_visibility = None
|
||||
if dataset_name == "badja":
|
||||
sample.segmentation = (sample.segmentation > 0).float()
|
||||
*_, N, _ = sample.trajectory.shape
|
||||
accs = []
|
||||
accs_3px = []
|
||||
for s1 in range(1, sample.video.shape[1]): # target frame
|
||||
for n in range(N):
|
||||
vis = sample.visibility[0, s1, n]
|
||||
if vis > 0:
|
||||
coord_e = pred_trajectory[0, s1, n] # 2
|
||||
coord_g = sample.trajectory[0, s1, n] # 2
|
||||
dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0))
|
||||
area = torch.sum(sample.segmentation[0, s1])
|
||||
# print_('0.2*sqrt(area)', 0.2*torch.sqrt(area))
|
||||
thr = 0.2 * torch.sqrt(area)
|
||||
# correct =
|
||||
accs.append((dist < thr).float())
|
||||
# print('thr',thr)
|
||||
accs_3px.append((dist < 3.0).float())
|
||||
|
||||
res = torch.mean(torch.stack(accs)) * 100.0
|
||||
res_3px = torch.mean(torch.stack(accs_3px)) * 100.0
|
||||
metrics[sample.seq_name[0]] = res.item()
|
||||
metrics[sample.seq_name[0] + "_accuracy"] = res_3px.item()
|
||||
print(metrics)
|
||||
print(
|
||||
"avg", np.mean([v for k, v in metrics.items() if "accuracy" not in k])
|
||||
)
|
||||
print(
|
||||
"avg acc 3px",
|
||||
np.mean([v for k, v in metrics.items() if "accuracy" in k]),
|
||||
)
|
||||
elif dataset_name == "fastcapture" or ("kubric" in dataset_name):
|
||||
*_, N, _ = sample.trajectory.shape
|
||||
accs = []
|
||||
for s1 in range(1, sample.video.shape[1]): # target frame
|
||||
for n in range(N):
|
||||
vis = sample.visibility[0, s1, n]
|
||||
if vis > 0:
|
||||
coord_e = pred_trajectory[0, s1, n] # 2
|
||||
coord_g = sample.trajectory[0, s1, n] # 2
|
||||
dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0))
|
||||
thr = 3
|
||||
correct = (dist < thr).float()
|
||||
accs.append(correct)
|
||||
|
||||
res = torch.mean(torch.stack(accs)) * 100.0
|
||||
metrics[sample.seq_name[0] + "_accuracy"] = res.item()
|
||||
print(metrics)
|
||||
print("avg", np.mean([v for v in metrics.values()]))
|
||||
elif "tapvid" in dataset_name:
|
||||
B, T, N, D = sample.trajectory.shape
|
||||
traj = sample.trajectory.clone()
|
||||
thr = 0.9
|
||||
|
||||
if pred_visibility is None:
|
||||
logging.warning("visibility is NONE")
|
||||
pred_visibility = torch.zeros_like(sample.visibility)
|
||||
|
||||
if not pred_visibility.dtype == torch.bool:
|
||||
pred_visibility = pred_visibility > thr
|
||||
|
||||
# pred_trajectory
|
||||
query_points = sample.query_points.clone().cpu().numpy()
|
||||
|
||||
pred_visibility = pred_visibility[:, :, :N]
|
||||
pred_trajectory = pred_trajectory[:, :, :N]
|
||||
|
||||
gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy()
|
||||
gt_occluded = (
|
||||
torch.logical_not(sample.visibility.clone().permute(0, 2, 1))
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
pred_occluded = (
|
||||
torch.logical_not(pred_visibility.clone().permute(0, 2, 1))
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy()
|
||||
|
||||
out_metrics = compute_tapvid_metrics(
|
||||
query_points,
|
||||
gt_occluded,
|
||||
gt_tracks,
|
||||
pred_occluded,
|
||||
pred_tracks,
|
||||
query_mode="strided" if "strided" in dataset_name else "first",
|
||||
)
|
||||
|
||||
metrics[sample.seq_name[0]] = out_metrics
|
||||
for metric_name in out_metrics.keys():
|
||||
if "avg" not in metrics:
|
||||
metrics["avg"] = {}
|
||||
metrics["avg"][metric_name] = np.mean(
|
||||
[v[metric_name] for k, v in metrics.items() if k != "avg"]
|
||||
)
|
||||
|
||||
logging.info(f"Metrics: {out_metrics}")
|
||||
logging.info(f"avg: {metrics['avg']}")
|
||||
print("metrics", out_metrics)
|
||||
print("avg", metrics["avg"])
|
||||
else:
|
||||
rgbs = sample.video
|
||||
trajs_g = sample.trajectory
|
||||
valids = sample.valid
|
||||
vis_g = sample.visibility
|
||||
|
||||
B, S, C, H, W = rgbs.shape
|
||||
assert C == 3
|
||||
B, S, N, D = trajs_g.shape
|
||||
|
||||
assert torch.sum(valids) == B * S * N
|
||||
|
||||
vis_g = (torch.sum(vis_g, dim=1, keepdim=True) >= 4).float().repeat(1, S, 1)
|
||||
|
||||
ate = torch.norm(pred_trajectory - trajs_g, dim=-1) # B, S, N
|
||||
|
||||
metrics["things_all"] = reduce_masked_mean(ate, valids).item()
|
||||
metrics["things_vis"] = reduce_masked_mean(ate, valids * vis_g).item()
|
||||
metrics["things_occ"] = reduce_masked_mean(
|
||||
ate, valids * (1.0 - vis_g)
|
||||
).item()
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_sequence(
|
||||
self,
|
||||
model,
|
||||
test_dataloader: torch.utils.data.DataLoader,
|
||||
dataset_name: str,
|
||||
train_mode=False,
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
step: Optional[int] = 0,
|
||||
):
|
||||
metrics = {}
|
||||
|
||||
vis = Visualizer(
|
||||
save_dir=self.exp_dir,
|
||||
fps=7,
|
||||
)
|
||||
|
||||
for ind, sample in enumerate(tqdm(test_dataloader)):
|
||||
if isinstance(sample, tuple):
|
||||
sample, gotit = sample
|
||||
if not all(gotit):
|
||||
print("batch is None")
|
||||
continue
|
||||
dataclass_to_cuda_(sample)
|
||||
|
||||
if (
|
||||
not train_mode
|
||||
and hasattr(model, "sequence_len")
|
||||
and (sample.visibility[:, : model.sequence_len].sum() == 0)
|
||||
):
|
||||
print(f"skipping batch {ind}")
|
||||
continue
|
||||
|
||||
if "tapvid" in dataset_name:
|
||||
queries = sample.query_points.clone().float()
|
||||
|
||||
queries = torch.stack(
|
||||
[
|
||||
queries[:, :, 0],
|
||||
queries[:, :, 2],
|
||||
queries[:, :, 1],
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
else:
|
||||
queries = torch.cat(
|
||||
[
|
||||
torch.zeros_like(sample.trajectory[:, 0, :, :1]),
|
||||
sample.trajectory[:, 0],
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
|
||||
pred_tracks = model(sample.video, queries)
|
||||
if "strided" in dataset_name:
|
||||
|
||||
inv_video = sample.video.flip(1).clone()
|
||||
inv_queries = queries.clone()
|
||||
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
||||
|
||||
pred_trj, pred_vsb = pred_tracks
|
||||
inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries)
|
||||
|
||||
inv_pred_trj = inv_pred_trj.flip(1)
|
||||
inv_pred_vsb = inv_pred_vsb.flip(1)
|
||||
|
||||
mask = pred_trj == 0
|
||||
|
||||
pred_trj[mask] = inv_pred_trj[mask]
|
||||
pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]]
|
||||
|
||||
pred_tracks = pred_trj, pred_vsb
|
||||
|
||||
if dataset_name == "badja" or dataset_name == "fastcapture":
|
||||
seq_name = sample.seq_name[0]
|
||||
else:
|
||||
seq_name = str(ind)
|
||||
|
||||
vis.visualize(
|
||||
sample.video,
|
||||
pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
|
||||
filename=dataset_name + "_" + seq_name,
|
||||
writer=writer,
|
||||
step=step,
|
||||
)
|
||||
|
||||
self.compute_metrics(metrics, sample, pred_tracks, dataset_name)
|
||||
return metrics
|
179
cotracker/evaluation/evaluate.py
Normal file
179
cotracker/evaluation/evaluate.py
Normal file
@ -0,0 +1,179 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from cotracker.datasets.badja_dataset import BadjaDataset
|
||||
from cotracker.datasets.fast_capture_dataset import FastCaptureDataset
|
||||
from cotracker.datasets.tap_vid_datasets import TapVidDataset
|
||||
from cotracker.datasets.utils import collate_fn
|
||||
|
||||
from cotracker.models.evaluation_predictor import EvaluationPredictor
|
||||
|
||||
from cotracker.evaluation.core.evaluator import Evaluator
|
||||
from cotracker.models.build_cotracker import (
|
||||
build_cotracker,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class DefaultConfig:
|
||||
# Directory where all outputs of the experiment will be saved.
|
||||
exp_dir: str = "./outputs"
|
||||
|
||||
# Name of the dataset to be used for the evaluation.
|
||||
dataset_name: str = "badja"
|
||||
# The root directory of the dataset.
|
||||
dataset_root: str = "./"
|
||||
|
||||
# Path to the pre-trained model checkpoint to be used for the evaluation.
|
||||
# The default value is the path to a specific CoTracker model checkpoint.
|
||||
# Other available options are commented.
|
||||
checkpoint: str = "./checkpoints/cotracker_stride_4_wind_8.pth"
|
||||
# cotracker_stride_4_wind_12
|
||||
# cotracker_stride_8_wind_16
|
||||
|
||||
# EvaluationPredictor parameters
|
||||
# The size (N) of the support grid used in the predictor.
|
||||
# The total number of points is (N*N).
|
||||
grid_size: int = 6
|
||||
# The size (N) of the local support grid.
|
||||
local_grid_size: int = 6
|
||||
# A flag indicating whether to evaluate one ground truth point at a time.
|
||||
single_point: bool = True
|
||||
# The number of iterative updates for each sliding window.
|
||||
n_iters: int = 6
|
||||
|
||||
seed: int = 0
|
||||
gpu_idx: int = 0
|
||||
|
||||
# Override hydra's working directory to current working dir,
|
||||
# also disable storing the .hydra logs:
|
||||
hydra: dict = field(
|
||||
default_factory=lambda: {
|
||||
"run": {"dir": "."},
|
||||
"output_subdir": None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def run_eval(cfg: DefaultConfig):
|
||||
"""
|
||||
The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration.
|
||||
|
||||
Args:
|
||||
cfg (DefaultConfig): An instance of DefaultConfig class which includes:
|
||||
- exp_dir (str): The directory path for the experiment.
|
||||
- dataset_name (str): The name of the dataset to be used.
|
||||
- dataset_root (str): The root directory of the dataset.
|
||||
- checkpoint (str): The path to the CoTracker model's checkpoint.
|
||||
- single_point (bool): A flag indicating whether to evaluate one ground truth point at a time.
|
||||
- n_iters (int): The number of iterative updates for each sliding window.
|
||||
- seed (int): The seed for setting the random state for reproducibility.
|
||||
- gpu_idx (int): The index of the GPU to be used.
|
||||
"""
|
||||
# Creating the experiment directory if it doesn't exist
|
||||
os.makedirs(cfg.exp_dir, exist_ok=True)
|
||||
|
||||
# Saving the experiment configuration to a .yaml file in the experiment directory
|
||||
cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
|
||||
with open(cfg_file, "w") as f:
|
||||
OmegaConf.save(config=cfg, f=f)
|
||||
|
||||
evaluator = Evaluator(cfg.exp_dir)
|
||||
cotracker_model = build_cotracker(cfg.checkpoint)
|
||||
|
||||
# Creating the EvaluationPredictor object
|
||||
predictor = EvaluationPredictor(
|
||||
cotracker_model,
|
||||
grid_size=cfg.grid_size,
|
||||
local_grid_size=cfg.local_grid_size,
|
||||
single_point=cfg.single_point,
|
||||
n_iters=cfg.n_iters,
|
||||
)
|
||||
|
||||
# Setting the random seeds
|
||||
torch.manual_seed(cfg.seed)
|
||||
np.random.seed(cfg.seed)
|
||||
|
||||
# Constructing the specified dataset
|
||||
curr_collate_fn = collate_fn
|
||||
if cfg.dataset_name == "badja":
|
||||
test_dataset = BadjaDataset(data_root=os.path.join(cfg.dataset_root, "BADJA"))
|
||||
elif cfg.dataset_name == "fastcapture":
|
||||
test_dataset = FastCaptureDataset(
|
||||
data_root=os.path.join(cfg.dataset_root, "fastcapture"),
|
||||
max_seq_len=100,
|
||||
max_num_points=20,
|
||||
)
|
||||
elif "tapvid" in cfg.dataset_name:
|
||||
dataset_type = cfg.dataset_name.split("_")[1]
|
||||
if dataset_type == "davis":
|
||||
data_root = os.path.join(cfg.dataset_root, "/tapvid_davis/tapvid_davis.pkl")
|
||||
elif dataset_type == "kinetics":
|
||||
data_root = os.path.join(
|
||||
cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics"
|
||||
)
|
||||
test_dataset = TapVidDataset(
|
||||
dataset_type=dataset_type,
|
||||
data_root=data_root,
|
||||
queried_first=not "strided" in cfg.dataset_name,
|
||||
)
|
||||
|
||||
# Creating the DataLoader object
|
||||
test_dataloader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=14,
|
||||
collate_fn=curr_collate_fn,
|
||||
)
|
||||
|
||||
# Timing and conducting the evaluation
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
evaluate_result = evaluator.evaluate_sequence(
|
||||
predictor,
|
||||
test_dataloader,
|
||||
dataset_name=cfg.dataset_name,
|
||||
)
|
||||
end = time.time()
|
||||
print(end - start)
|
||||
|
||||
# Saving the evaluation results to a .json file
|
||||
if not "tapvid" in cfg.dataset_name:
|
||||
print("evaluate_result", evaluate_result)
|
||||
else:
|
||||
evaluate_result = evaluate_result["avg"]
|
||||
result_file = os.path.join(cfg.exp_dir, f"result_eval_.json")
|
||||
evaluate_result["time"] = end - start
|
||||
print(f"Dumping eval results to {result_file}.")
|
||||
with open(result_file, "w") as f:
|
||||
json.dump(evaluate_result, f)
|
||||
|
||||
|
||||
cs = hydra.core.config_store.ConfigStore.instance()
|
||||
cs.store(name="default_config_eval", node=DefaultConfig)
|
||||
|
||||
|
||||
@hydra.main(config_path="./configs/", config_name="default_config_eval")
|
||||
def evaluate(cfg: DefaultConfig) -> None:
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
|
||||
run_eval(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
evaluate()
|
5
cotracker/models/__init__.py
Normal file
5
cotracker/models/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
70
cotracker/models/build_cotracker.py
Normal file
70
cotracker/models/build_cotracker.py
Normal file
@ -0,0 +1,70 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from cotracker.models.core.cotracker.cotracker import CoTracker
|
||||
|
||||
|
||||
def build_cotracker(
|
||||
checkpoint: str,
|
||||
):
|
||||
model_name = checkpoint.split("/")[-1].split(".")[0]
|
||||
if model_name == "cotracker_stride_4_wind_8":
|
||||
return build_cotracker_stride_4_wind_8(checkpoint=checkpoint)
|
||||
elif model_name == "cotracker_stride_4_wind_12":
|
||||
return build_cotracker_stride_4_wind_12(checkpoint=checkpoint)
|
||||
elif model_name == "cotracker_stride_8_wind_16":
|
||||
return build_cotracker_stride_8_wind_16(checkpoint=checkpoint)
|
||||
else:
|
||||
raise ValueError(f"Unknown model name {model_name}")
|
||||
|
||||
|
||||
# model used to produce the results in the paper
|
||||
def build_cotracker_stride_4_wind_8(checkpoint=None):
|
||||
return _build_cotracker(
|
||||
stride=4,
|
||||
sequence_len=8,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_cotracker_stride_4_wind_12(checkpoint=None):
|
||||
return _build_cotracker(
|
||||
stride=4,
|
||||
sequence_len=12,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
# the fastest model
|
||||
def build_cotracker_stride_8_wind_16(checkpoint=None):
|
||||
return _build_cotracker(
|
||||
stride=8,
|
||||
sequence_len=16,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def _build_cotracker(
|
||||
stride,
|
||||
sequence_len,
|
||||
checkpoint=None,
|
||||
):
|
||||
cotracker = CoTracker(
|
||||
stride=stride,
|
||||
S=sequence_len,
|
||||
add_space_attn=True,
|
||||
space_depth=6,
|
||||
time_depth=6,
|
||||
)
|
||||
if checkpoint is not None:
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f, map_location="cpu")
|
||||
if "model" in state_dict:
|
||||
state_dict = state_dict["model"]
|
||||
cotracker.load_state_dict(state_dict)
|
||||
return cotracker
|
5
cotracker/models/core/__init__.py
Normal file
5
cotracker/models/core/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
5
cotracker/models/core/cotracker/__init__.py
Normal file
5
cotracker/models/core/cotracker/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
400
cotracker/models/core/cotracker/blocks.py
Normal file
400
cotracker/models/core/cotracker/blocks.py
Normal file
@ -0,0 +1,400 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange
|
||||
from timm.models.vision_transformer import Attention, Mlp
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=stride,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
|
||||
)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(planes)
|
||||
self.norm2 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(planes)
|
||||
self.norm2 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x + y)
|
||||
|
||||
|
||||
class BasicEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, input_dim=3, output_dim=128, stride=8, norm_fn="batch", dropout=0.0
|
||||
):
|
||||
super(BasicEncoder, self).__init__()
|
||||
self.stride = stride
|
||||
self.norm_fn = norm_fn
|
||||
self.in_planes = 64
|
||||
|
||||
if self.norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2)
|
||||
|
||||
elif self.norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(self.in_planes)
|
||||
self.norm2 = nn.BatchNorm2d(output_dim * 2)
|
||||
|
||||
elif self.norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
||||
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
||||
|
||||
elif self.norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
input_dim,
|
||||
self.in_planes,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.shallow = False
|
||||
if self.shallow:
|
||||
self.layer1 = self._make_layer(64, stride=1)
|
||||
self.layer2 = self._make_layer(96, stride=2)
|
||||
self.layer3 = self._make_layer(128, stride=2)
|
||||
self.conv2 = nn.Conv2d(128 + 96 + 64, output_dim, kernel_size=1)
|
||||
else:
|
||||
self.layer1 = self._make_layer(64, stride=1)
|
||||
self.layer2 = self._make_layer(96, stride=2)
|
||||
self.layer3 = self._make_layer(128, stride=2)
|
||||
self.layer4 = self._make_layer(128, stride=2)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
128 + 128 + 96 + 64,
|
||||
output_dim * 2,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
_, _, H, W = x.shape
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
if self.shallow:
|
||||
a = self.layer1(x)
|
||||
b = self.layer2(a)
|
||||
c = self.layer3(b)
|
||||
a = F.interpolate(
|
||||
a,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
b = F.interpolate(
|
||||
b,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
c = F.interpolate(
|
||||
c,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
x = self.conv2(torch.cat([a, b, c], dim=1))
|
||||
else:
|
||||
a = self.layer1(x)
|
||||
b = self.layer2(a)
|
||||
c = self.layer3(b)
|
||||
d = self.layer4(c)
|
||||
a = F.interpolate(
|
||||
a,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
b = F.interpolate(
|
||||
b,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
c = F.interpolate(
|
||||
c,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
d = F.interpolate(
|
||||
d,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
||||
x = self.norm2(x)
|
||||
x = self.relu2(x)
|
||||
x = self.conv3(x)
|
||||
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = Attention(
|
||||
hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=approx_gelu,
|
||||
drop=0,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.norm1(x))
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
|
||||
"""Wrapper for grid_sample, uses pixel coordinates"""
|
||||
H, W = img.shape[-2:]
|
||||
xgrid, ygrid = coords.split([1, 1], dim=-1)
|
||||
# go to 0,1 then 0,2 then -1,1
|
||||
xgrid = 2 * xgrid / (W - 1) - 1
|
||||
ygrid = 2 * ygrid / (H - 1) - 1
|
||||
|
||||
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||
img = F.grid_sample(img, grid, align_corners=True)
|
||||
|
||||
if mask:
|
||||
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||
return img, mask.float()
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class CorrBlock:
|
||||
def __init__(self, fmaps, num_levels=4, radius=4):
|
||||
B, S, C, H, W = fmaps.shape
|
||||
self.S, self.C, self.H, self.W = S, C, H, W
|
||||
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
self.fmaps_pyramid = []
|
||||
|
||||
self.fmaps_pyramid.append(fmaps)
|
||||
for i in range(self.num_levels - 1):
|
||||
fmaps_ = fmaps.reshape(B * S, C, H, W)
|
||||
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
||||
_, _, H, W = fmaps_.shape
|
||||
fmaps = fmaps_.reshape(B, S, C, H, W)
|
||||
self.fmaps_pyramid.append(fmaps)
|
||||
|
||||
def sample(self, coords):
|
||||
r = self.radius
|
||||
B, S, N, D = coords.shape
|
||||
assert D == 2
|
||||
|
||||
H, W = self.H, self.W
|
||||
out_pyramid = []
|
||||
for i in range(self.num_levels):
|
||||
corrs = self.corrs_pyramid[i] # B, S, N, H, W
|
||||
_, _, _, H, W = corrs.shape
|
||||
|
||||
dx = torch.linspace(-r, r, 2 * r + 1)
|
||||
dy = torch.linspace(-r, r, 2 * r + 1)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
|
||||
coords.device
|
||||
)
|
||||
|
||||
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
|
||||
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
||||
coords_lvl = centroid_lvl + delta_lvl
|
||||
|
||||
corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl)
|
||||
corrs = corrs.view(B, S, N, -1)
|
||||
out_pyramid.append(corrs)
|
||||
|
||||
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
||||
return out.contiguous().float()
|
||||
|
||||
def corr(self, targets):
|
||||
B, S, N, C = targets.shape
|
||||
assert C == self.C
|
||||
assert S == self.S
|
||||
|
||||
fmap1 = targets
|
||||
|
||||
self.corrs_pyramid = []
|
||||
for fmaps in self.fmaps_pyramid:
|
||||
_, _, _, H, W = fmaps.shape
|
||||
fmap2s = fmaps.view(B, S, C, H * W)
|
||||
corrs = torch.matmul(fmap1, fmap2s)
|
||||
corrs = corrs.view(B, S, N, H, W)
|
||||
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
||||
self.corrs_pyramid.append(corrs)
|
||||
|
||||
|
||||
class UpdateFormer(nn.Module):
|
||||
"""
|
||||
Transformer model that updates track estimates.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
space_depth=12,
|
||||
time_depth=12,
|
||||
input_dim=320,
|
||||
hidden_size=384,
|
||||
num_heads=8,
|
||||
output_dim=130,
|
||||
mlp_ratio=4.0,
|
||||
add_space_attn=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = 2
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.add_space_attn = add_space_attn
|
||||
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
||||
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
||||
|
||||
self.time_blocks = nn.ModuleList(
|
||||
[
|
||||
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
|
||||
for _ in range(time_depth)
|
||||
]
|
||||
)
|
||||
|
||||
if add_space_attn:
|
||||
self.space_blocks = nn.ModuleList(
|
||||
[
|
||||
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
|
||||
for _ in range(space_depth)
|
||||
]
|
||||
)
|
||||
assert len(self.time_blocks) >= len(self.space_blocks)
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
def forward(self, input_tensor):
|
||||
x = self.input_transform(input_tensor)
|
||||
|
||||
j = 0
|
||||
for i in range(len(self.time_blocks)):
|
||||
B, N, T, _ = x.shape
|
||||
x_time = rearrange(x, "b n t c -> (b n) t c", b=B, t=T, n=N)
|
||||
x_time = self.time_blocks[i](x_time)
|
||||
|
||||
x = rearrange(x_time, "(b n) t c -> b n t c ", b=B, t=T, n=N)
|
||||
if self.add_space_attn and (
|
||||
i % (len(self.time_blocks) // len(self.space_blocks)) == 0
|
||||
):
|
||||
x_space = rearrange(x, "b n t c -> (b t) n c ", b=B, t=T, n=N)
|
||||
x_space = self.space_blocks[j](x_space)
|
||||
x = rearrange(x_space, "(b t) n c -> b n t c ", b=B, t=T, n=N)
|
||||
j += 1
|
||||
|
||||
flow = self.flow_head(x)
|
||||
return flow
|
351
cotracker/models/core/cotracker/cotracker.py
Normal file
351
cotracker/models/core/cotracker/cotracker.py
Normal file
@ -0,0 +1,351 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
from cotracker.models.core.cotracker.blocks import (
|
||||
BasicEncoder,
|
||||
CorrBlock,
|
||||
UpdateFormer,
|
||||
)
|
||||
|
||||
from cotracker.models.core.model_utils import meshgrid2d, bilinear_sample2d, smart_cat
|
||||
from cotracker.models.core.embeddings import (
|
||||
get_2d_embedding,
|
||||
get_1d_sincos_pos_embed_from_grid,
|
||||
get_2d_sincos_pos_embed,
|
||||
)
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)):
|
||||
if grid_size == 1:
|
||||
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[
|
||||
None, None
|
||||
].cuda()
|
||||
|
||||
grid_y, grid_x = meshgrid2d(
|
||||
1, grid_size, grid_size, stack=False, norm=False, device="cuda"
|
||||
)
|
||||
step = interp_shape[1] // 64
|
||||
if grid_center[0] != 0 or grid_center[1] != 0:
|
||||
grid_y = grid_y - grid_size / 2.0
|
||||
grid_x = grid_x - grid_size / 2.0
|
||||
grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * (
|
||||
interp_shape[0] - step * 2
|
||||
)
|
||||
grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * (
|
||||
interp_shape[1] - step * 2
|
||||
)
|
||||
|
||||
grid_y = grid_y + grid_center[0]
|
||||
grid_x = grid_x + grid_center[1]
|
||||
xy = torch.stack([grid_x, grid_y], dim=-1).cuda()
|
||||
return xy
|
||||
|
||||
|
||||
def sample_pos_embed(grid_size, embed_dim, coords):
|
||||
pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim, grid_size=grid_size)
|
||||
pos_embed = (
|
||||
torch.from_numpy(pos_embed)
|
||||
.reshape(grid_size[0], grid_size[1], embed_dim)
|
||||
.float()
|
||||
.unsqueeze(0)
|
||||
.to(coords.device)
|
||||
)
|
||||
sampled_pos_embed = bilinear_sample2d(
|
||||
pos_embed.permute(0, 3, 1, 2), coords[:, 0, :, 0], coords[:, 0, :, 1]
|
||||
)
|
||||
return sampled_pos_embed
|
||||
|
||||
|
||||
class CoTracker(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
S=8,
|
||||
stride=8,
|
||||
add_space_attn=True,
|
||||
num_heads=8,
|
||||
hidden_size=384,
|
||||
space_depth=12,
|
||||
time_depth=12,
|
||||
):
|
||||
super(CoTracker, self).__init__()
|
||||
self.S = S
|
||||
self.stride = stride
|
||||
self.hidden_dim = 256
|
||||
self.latent_dim = latent_dim = 128
|
||||
self.corr_levels = 4
|
||||
self.corr_radius = 3
|
||||
self.add_space_attn = add_space_attn
|
||||
self.fnet = BasicEncoder(
|
||||
output_dim=self.latent_dim, norm_fn="instance", dropout=0, stride=stride
|
||||
)
|
||||
|
||||
self.updateformer = UpdateFormer(
|
||||
space_depth=space_depth,
|
||||
time_depth=time_depth,
|
||||
input_dim=456,
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
output_dim=latent_dim + 2,
|
||||
mlp_ratio=4.0,
|
||||
add_space_attn=add_space_attn,
|
||||
)
|
||||
|
||||
self.norm = nn.GroupNorm(1, self.latent_dim)
|
||||
self.ffeat_updater = nn.Sequential(
|
||||
nn.Linear(self.latent_dim, self.latent_dim),
|
||||
nn.GELU(),
|
||||
)
|
||||
self.vis_predictor = nn.Sequential(
|
||||
nn.Linear(self.latent_dim, 1),
|
||||
)
|
||||
|
||||
def forward_iteration(
|
||||
self,
|
||||
fmaps,
|
||||
coords_init,
|
||||
feat_init=None,
|
||||
vis_init=None,
|
||||
track_mask=None,
|
||||
iters=4,
|
||||
):
|
||||
B, S_init, N, D = coords_init.shape
|
||||
assert D == 2
|
||||
assert B == 1
|
||||
|
||||
B, S, __, H8, W8 = fmaps.shape
|
||||
|
||||
device = fmaps.device
|
||||
|
||||
if S_init < S:
|
||||
coords = torch.cat(
|
||||
[coords_init, coords_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
|
||||
)
|
||||
vis_init = torch.cat(
|
||||
[vis_init, vis_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
|
||||
)
|
||||
else:
|
||||
coords = coords_init.clone()
|
||||
|
||||
fcorr_fn = CorrBlock(
|
||||
fmaps, num_levels=self.corr_levels, radius=self.corr_radius
|
||||
)
|
||||
|
||||
ffeats = feat_init.clone()
|
||||
|
||||
times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1)
|
||||
|
||||
pos_embed = sample_pos_embed(
|
||||
grid_size=(H8, W8),
|
||||
embed_dim=456,
|
||||
coords=coords,
|
||||
)
|
||||
pos_embed = rearrange(pos_embed, "b e n -> (b n) e").unsqueeze(1)
|
||||
times_embed = (
|
||||
torch.from_numpy(get_1d_sincos_pos_embed_from_grid(456, times_[0]))[None]
|
||||
.repeat(B, 1, 1)
|
||||
.float()
|
||||
.to(device)
|
||||
)
|
||||
coord_predictions = []
|
||||
|
||||
for __ in range(iters):
|
||||
coords = coords.detach()
|
||||
fcorr_fn.corr(ffeats)
|
||||
|
||||
fcorrs = fcorr_fn.sample(coords) # B, S, N, LRR
|
||||
LRR = fcorrs.shape[3]
|
||||
|
||||
fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR)
|
||||
flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
|
||||
|
||||
flows_cat = get_2d_embedding(flows_, 64, cat_coords=True)
|
||||
ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
|
||||
|
||||
if track_mask.shape[1] < vis_init.shape[1]:
|
||||
track_mask = torch.cat(
|
||||
[
|
||||
track_mask,
|
||||
torch.zeros_like(track_mask[:, 0]).repeat(
|
||||
1, vis_init.shape[1] - track_mask.shape[1], 1, 1
|
||||
),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
concat = (
|
||||
torch.cat([track_mask, vis_init], dim=2)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B * N, S, 2)
|
||||
)
|
||||
|
||||
transformer_input = torch.cat([flows_cat, fcorrs_, ffeats_, concat], dim=2)
|
||||
x = transformer_input + pos_embed + times_embed
|
||||
|
||||
x = rearrange(x, "(b n) t d -> b n t d", b=B)
|
||||
|
||||
delta = self.updateformer(x)
|
||||
|
||||
delta = rearrange(delta, " b n t d -> (b n) t d")
|
||||
|
||||
delta_coords_ = delta[:, :, :2]
|
||||
delta_feats_ = delta[:, :, 2:]
|
||||
|
||||
delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
|
||||
ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N * S, self.latent_dim)
|
||||
|
||||
ffeats_ = self.ffeat_updater(self.norm(delta_feats_)) + ffeats_
|
||||
|
||||
ffeats = ffeats_.reshape(B, N, S, self.latent_dim).permute(
|
||||
0, 2, 1, 3
|
||||
) # B,S,N,C
|
||||
|
||||
coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
|
||||
coord_predictions.append(coords * self.stride)
|
||||
|
||||
vis_e = self.vis_predictor(ffeats.reshape(B * S * N, self.latent_dim)).reshape(
|
||||
B, S, N
|
||||
)
|
||||
return coord_predictions, vis_e, feat_init
|
||||
|
||||
def forward(self, rgbs, queries, iters=4, feat_init=None, is_train=False):
|
||||
B, T, C, H, W = rgbs.shape
|
||||
B, N, __ = queries.shape
|
||||
|
||||
device = rgbs.device
|
||||
assert B == 1
|
||||
# INIT for the first sequence
|
||||
# We want to sort points by the first frame they are visible to add them to the tensor of tracked points consequtively
|
||||
first_positive_inds = queries[:, :, 0].long()
|
||||
|
||||
__, sort_inds = torch.sort(first_positive_inds[0], dim=0, descending=False)
|
||||
inv_sort_inds = torch.argsort(sort_inds, dim=0)
|
||||
first_positive_sorted_inds = first_positive_inds[0][sort_inds]
|
||||
|
||||
assert torch.allclose(
|
||||
first_positive_inds[0], first_positive_inds[0][sort_inds][inv_sort_inds]
|
||||
)
|
||||
|
||||
coords_init = queries[:, :, 1:].reshape(B, 1, N, 2).repeat(
|
||||
1, self.S, 1, 1
|
||||
) / float(self.stride)
|
||||
|
||||
rgbs = 2 * (rgbs / 255.0) - 1.0
|
||||
|
||||
traj_e = torch.zeros((B, T, N, 2), device=device)
|
||||
vis_e = torch.zeros((B, T, N), device=device)
|
||||
|
||||
ind_array = torch.arange(T, device=device)
|
||||
ind_array = ind_array[None, :, None].repeat(B, 1, N)
|
||||
|
||||
track_mask = (ind_array >= first_positive_inds[:, None, :]).unsqueeze(-1)
|
||||
# these are logits, so we initialize visibility with something that would give a value close to 1 after softmax
|
||||
vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10
|
||||
|
||||
ind = 0
|
||||
|
||||
track_mask_ = track_mask[:, :, sort_inds].clone()
|
||||
coords_init_ = coords_init[:, :, sort_inds].clone()
|
||||
vis_init_ = vis_init[:, :, sort_inds].clone()
|
||||
|
||||
prev_wind_idx = 0
|
||||
fmaps_ = None
|
||||
vis_predictions = []
|
||||
coord_predictions = []
|
||||
wind_inds = []
|
||||
while ind < T - self.S // 2:
|
||||
rgbs_seq = rgbs[:, ind : ind + self.S]
|
||||
|
||||
S = S_local = rgbs_seq.shape[1]
|
||||
if S < self.S:
|
||||
rgbs_seq = torch.cat(
|
||||
[rgbs_seq, rgbs_seq[:, -1, None].repeat(1, self.S - S, 1, 1, 1)],
|
||||
dim=1,
|
||||
)
|
||||
S = rgbs_seq.shape[1]
|
||||
rgbs_ = rgbs_seq.reshape(B * S, C, H, W)
|
||||
|
||||
if fmaps_ is None:
|
||||
fmaps_ = self.fnet(rgbs_)
|
||||
else:
|
||||
fmaps_ = torch.cat(
|
||||
[fmaps_[self.S // 2 :], self.fnet(rgbs_[self.S // 2 :])], dim=0
|
||||
)
|
||||
fmaps = fmaps_.reshape(
|
||||
B, S, self.latent_dim, H // self.stride, W // self.stride
|
||||
)
|
||||
|
||||
curr_wind_points = torch.nonzero(first_positive_sorted_inds < ind + self.S)
|
||||
if curr_wind_points.shape[0] == 0:
|
||||
ind = ind + self.S // 2
|
||||
continue
|
||||
wind_idx = curr_wind_points[-1] + 1
|
||||
|
||||
if wind_idx - prev_wind_idx > 0:
|
||||
fmaps_sample = fmaps[
|
||||
:, first_positive_sorted_inds[prev_wind_idx:wind_idx] - ind
|
||||
]
|
||||
|
||||
feat_init_ = bilinear_sample2d(
|
||||
fmaps_sample,
|
||||
coords_init_[:, 0, prev_wind_idx:wind_idx, 0],
|
||||
coords_init_[:, 0, prev_wind_idx:wind_idx, 1],
|
||||
).permute(0, 2, 1)
|
||||
|
||||
feat_init_ = feat_init_.unsqueeze(1).repeat(1, self.S, 1, 1)
|
||||
feat_init = smart_cat(feat_init, feat_init_, dim=2)
|
||||
|
||||
if prev_wind_idx > 0:
|
||||
new_coords = coords[-1][:, self.S // 2 :] / float(self.stride)
|
||||
|
||||
coords_init_[:, : self.S // 2, :prev_wind_idx] = new_coords
|
||||
coords_init_[:, self.S // 2 :, :prev_wind_idx] = new_coords[
|
||||
:, -1
|
||||
].repeat(1, self.S // 2, 1, 1)
|
||||
|
||||
new_vis = vis[:, self.S // 2 :].unsqueeze(-1)
|
||||
vis_init_[:, : self.S // 2, :prev_wind_idx] = new_vis
|
||||
vis_init_[:, self.S // 2 :, :prev_wind_idx] = new_vis[:, -1].repeat(
|
||||
1, self.S // 2, 1, 1
|
||||
)
|
||||
|
||||
coords, vis, __ = self.forward_iteration(
|
||||
fmaps=fmaps,
|
||||
coords_init=coords_init_[:, :, :wind_idx],
|
||||
feat_init=feat_init[:, :, :wind_idx],
|
||||
vis_init=vis_init_[:, :, :wind_idx],
|
||||
track_mask=track_mask_[:, ind : ind + self.S, :wind_idx],
|
||||
iters=iters,
|
||||
)
|
||||
if is_train:
|
||||
vis_predictions.append(torch.sigmoid(vis[:, :S_local]))
|
||||
coord_predictions.append([coord[:, :S_local] for coord in coords])
|
||||
wind_inds.append(wind_idx)
|
||||
|
||||
traj_e[:, ind : ind + self.S, :wind_idx] = coords[-1][:, :S_local]
|
||||
vis_e[:, ind : ind + self.S, :wind_idx] = vis[:, :S_local]
|
||||
|
||||
track_mask_[:, : ind + self.S, :wind_idx] = 0.0
|
||||
ind = ind + self.S // 2
|
||||
|
||||
prev_wind_idx = wind_idx
|
||||
|
||||
traj_e = traj_e[:, :, inv_sort_inds]
|
||||
vis_e = vis_e[:, :, inv_sort_inds]
|
||||
|
||||
vis_e = torch.sigmoid(vis_e)
|
||||
|
||||
train_data = (
|
||||
(vis_predictions, coord_predictions, wind_inds, sort_inds)
|
||||
if is_train
|
||||
else None
|
||||
)
|
||||
return traj_e, feat_init, vis_e, train_data
|
61
cotracker/models/core/cotracker/losses.py
Normal file
61
cotracker/models/core/cotracker/losses.py
Normal file
@ -0,0 +1,61 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from cotracker.models.core.model_utils import reduce_masked_mean
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
|
||||
def balanced_ce_loss(pred, gt, valid=None):
|
||||
total_balanced_loss = 0.0
|
||||
for j in range(len(gt)):
|
||||
B, S, N = gt[j].shape
|
||||
# pred and gt are the same shape
|
||||
for (a, b) in zip(pred[j].size(), gt[j].size()):
|
||||
assert a == b # some shape mismatch!
|
||||
# if valid is not None:
|
||||
for (a, b) in zip(pred[j].size(), valid[j].size()):
|
||||
assert a == b # some shape mismatch!
|
||||
|
||||
pos = (gt[j] > 0.95).float()
|
||||
neg = (gt[j] < 0.05).float()
|
||||
|
||||
label = pos * 2.0 - 1.0
|
||||
a = -label * pred[j]
|
||||
b = F.relu(a)
|
||||
loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
|
||||
|
||||
pos_loss = reduce_masked_mean(loss, pos * valid[j])
|
||||
neg_loss = reduce_masked_mean(loss, neg * valid[j])
|
||||
|
||||
balanced_loss = pos_loss + neg_loss
|
||||
total_balanced_loss += balanced_loss / float(N)
|
||||
return total_balanced_loss
|
||||
|
||||
|
||||
def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8):
|
||||
"""Loss function defined over sequence of flow predictions"""
|
||||
total_flow_loss = 0.0
|
||||
for j in range(len(flow_gt)):
|
||||
B, S, N, D = flow_gt[j].shape
|
||||
assert D == 2
|
||||
B, S1, N = vis[j].shape
|
||||
B, S2, N = valids[j].shape
|
||||
assert S == S1
|
||||
assert S == S2
|
||||
n_predictions = len(flow_preds[j])
|
||||
flow_loss = 0.0
|
||||
for i in range(n_predictions):
|
||||
i_weight = gamma ** (n_predictions - i - 1)
|
||||
flow_pred = flow_preds[j][i]
|
||||
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
|
||||
i_loss = torch.mean(i_loss, dim=3) # B, S, N
|
||||
flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j])
|
||||
flow_loss = flow_loss / n_predictions
|
||||
total_flow_loss += flow_loss / float(N)
|
||||
return total_flow_loss
|
154
cotracker/models/core/embeddings.py
Normal file
154
cotracker/models/core/embeddings.py
Normal file
@ -0,0 +1,154 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
if isinstance(grid_size, tuple):
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
else:
|
||||
grid_size_h = grid_size_w = grid_size
|
||||
grid_h = np.arange(grid_size_h, dtype=np.float32)
|
||||
grid_w = np.arange(grid_size_w, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate(
|
||||
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
|
||||
)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000 ** omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_2d_embedding(xy, C, cat_coords=True):
|
||||
B, N, D = xy.shape
|
||||
assert D == 2
|
||||
|
||||
x = xy[:, :, 0:1]
|
||||
y = xy[:, :, 1:2]
|
||||
div_term = (
|
||||
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
|
||||
).reshape(1, 1, int(C / 2))
|
||||
|
||||
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
||||
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
||||
|
||||
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
||||
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
||||
|
||||
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
||||
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
||||
|
||||
pe = torch.cat([pe_x, pe_y], dim=2) # B, N, C*3
|
||||
if cat_coords:
|
||||
pe = torch.cat([xy, pe], dim=2) # B, N, C*3+3
|
||||
return pe
|
||||
|
||||
|
||||
def get_3d_embedding(xyz, C, cat_coords=True):
|
||||
B, N, D = xyz.shape
|
||||
assert D == 3
|
||||
|
||||
x = xyz[:, :, 0:1]
|
||||
y = xyz[:, :, 1:2]
|
||||
z = xyz[:, :, 2:3]
|
||||
div_term = (
|
||||
torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C)
|
||||
).reshape(1, 1, int(C / 2))
|
||||
|
||||
pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
||||
pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
||||
pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
||||
|
||||
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
||||
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
||||
|
||||
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
||||
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
||||
|
||||
pe_z[:, :, 0::2] = torch.sin(z * div_term)
|
||||
pe_z[:, :, 1::2] = torch.cos(z * div_term)
|
||||
|
||||
pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3
|
||||
if cat_coords:
|
||||
pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3
|
||||
return pe
|
||||
|
||||
|
||||
def get_4d_embedding(xyzw, C, cat_coords=True):
|
||||
B, N, D = xyzw.shape
|
||||
assert D == 4
|
||||
|
||||
x = xyzw[:, :, 0:1]
|
||||
y = xyzw[:, :, 1:2]
|
||||
z = xyzw[:, :, 2:3]
|
||||
w = xyzw[:, :, 3:4]
|
||||
div_term = (
|
||||
torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C)
|
||||
).reshape(1, 1, int(C / 2))
|
||||
|
||||
pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
||||
pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
||||
pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
||||
pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
||||
|
||||
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
||||
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
||||
|
||||
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
||||
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
||||
|
||||
pe_z[:, :, 0::2] = torch.sin(z * div_term)
|
||||
pe_z[:, :, 1::2] = torch.cos(z * div_term)
|
||||
|
||||
pe_w[:, :, 0::2] = torch.sin(w * div_term)
|
||||
pe_w[:, :, 1::2] = torch.cos(w * div_term)
|
||||
|
||||
pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2) # B, N, C*3
|
||||
if cat_coords:
|
||||
pe = torch.cat([pe, xyzw], dim=2) # B, N, C*3+3
|
||||
return pe
|
169
cotracker/models/core/model_utils.py
Normal file
169
cotracker/models/core/model_utils.py
Normal file
@ -0,0 +1,169 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
|
||||
def smart_cat(tensor1, tensor2, dim):
|
||||
if tensor1 is None:
|
||||
return tensor2
|
||||
return torch.cat([tensor1, tensor2], dim=dim)
|
||||
|
||||
|
||||
def normalize_single(d):
|
||||
# d is a whatever shape torch tensor
|
||||
dmin = torch.min(d)
|
||||
dmax = torch.max(d)
|
||||
d = (d - dmin) / (EPS + (dmax - dmin))
|
||||
return d
|
||||
|
||||
|
||||
def normalize(d):
|
||||
# d is B x whatever. normalize within each element of the batch
|
||||
out = torch.zeros(d.size())
|
||||
if d.is_cuda:
|
||||
out = out.cuda()
|
||||
B = list(d.size())[0]
|
||||
for b in list(range(B)):
|
||||
out[b] = normalize_single(d[b])
|
||||
return out
|
||||
|
||||
|
||||
def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"):
|
||||
# returns a meshgrid sized B x Y x X
|
||||
|
||||
grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))
|
||||
grid_y = torch.reshape(grid_y, [1, Y, 1])
|
||||
grid_y = grid_y.repeat(B, 1, X)
|
||||
|
||||
grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device))
|
||||
grid_x = torch.reshape(grid_x, [1, 1, X])
|
||||
grid_x = grid_x.repeat(B, Y, 1)
|
||||
|
||||
if stack:
|
||||
# note we stack in xy order
|
||||
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
|
||||
grid = torch.stack([grid_x, grid_y], dim=-1)
|
||||
return grid
|
||||
else:
|
||||
return grid_y, grid_x
|
||||
|
||||
|
||||
def reduce_masked_mean(x, mask, dim=None, keepdim=False):
|
||||
# x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
|
||||
# returns shape-1
|
||||
# axis can be a list of axes
|
||||
for (a, b) in zip(x.size(), mask.size()):
|
||||
assert a == b # some shape mismatch!
|
||||
prod = x * mask
|
||||
if dim is None:
|
||||
numer = torch.sum(prod)
|
||||
denom = EPS + torch.sum(mask)
|
||||
else:
|
||||
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
||||
denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim)
|
||||
|
||||
mean = numer / denom
|
||||
return mean
|
||||
|
||||
|
||||
def bilinear_sample2d(im, x, y, return_inbounds=False):
|
||||
# x and y are each B, N
|
||||
# output is B, C, N
|
||||
if len(im.shape) == 5:
|
||||
B, N, C, H, W = list(im.shape)
|
||||
else:
|
||||
B, C, H, W = list(im.shape)
|
||||
N = list(x.shape)[1]
|
||||
|
||||
x = x.float()
|
||||
y = y.float()
|
||||
H_f = torch.tensor(H, dtype=torch.float32)
|
||||
W_f = torch.tensor(W, dtype=torch.float32)
|
||||
|
||||
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
|
||||
|
||||
max_y = (H_f - 1).int()
|
||||
max_x = (W_f - 1).int()
|
||||
|
||||
x0 = torch.floor(x).int()
|
||||
x1 = x0 + 1
|
||||
y0 = torch.floor(y).int()
|
||||
y1 = y0 + 1
|
||||
|
||||
x0_clip = torch.clamp(x0, 0, max_x)
|
||||
x1_clip = torch.clamp(x1, 0, max_x)
|
||||
y0_clip = torch.clamp(y0, 0, max_y)
|
||||
y1_clip = torch.clamp(y1, 0, max_y)
|
||||
dim2 = W
|
||||
dim1 = W * H
|
||||
|
||||
base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1
|
||||
base = torch.reshape(base, [B, 1]).repeat([1, N])
|
||||
|
||||
base_y0 = base + y0_clip * dim2
|
||||
base_y1 = base + y1_clip * dim2
|
||||
|
||||
idx_y0_x0 = base_y0 + x0_clip
|
||||
idx_y0_x1 = base_y0 + x1_clip
|
||||
idx_y1_x0 = base_y1 + x0_clip
|
||||
idx_y1_x1 = base_y1 + x1_clip
|
||||
|
||||
# use the indices to lookup pixels in the flat image
|
||||
# im is B x C x H x W
|
||||
# move C out to last dim
|
||||
if len(im.shape) == 5:
|
||||
im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C)
|
||||
i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
else:
|
||||
im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C)
|
||||
i_y0_x0 = im_flat[idx_y0_x0.long()]
|
||||
i_y0_x1 = im_flat[idx_y0_x1.long()]
|
||||
i_y1_x0 = im_flat[idx_y1_x0.long()]
|
||||
i_y1_x1 = im_flat[idx_y1_x1.long()]
|
||||
|
||||
# Finally calculate interpolated values.
|
||||
x0_f = x0.float()
|
||||
x1_f = x1.float()
|
||||
y0_f = y0.float()
|
||||
y1_f = y1.float()
|
||||
|
||||
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
|
||||
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
|
||||
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
|
||||
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
|
||||
|
||||
output = (
|
||||
w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
|
||||
)
|
||||
# output is B*N x C
|
||||
output = output.view(B, -1, C)
|
||||
output = output.permute(0, 2, 1)
|
||||
# output is B x C x N
|
||||
|
||||
if return_inbounds:
|
||||
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
|
||||
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
|
||||
inbounds = (x_valid & y_valid).float()
|
||||
inbounds = inbounds.reshape(
|
||||
B, N
|
||||
) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
|
||||
return output, inbounds
|
||||
|
||||
return output # B, C, N
|
103
cotracker/models/evaluation_predictor.py
Normal file
103
cotracker/models/evaluation_predictor.py
Normal file
@ -0,0 +1,103 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
|
||||
from cotracker.models.core.cotracker.cotracker import CoTracker, get_points_on_a_grid
|
||||
|
||||
|
||||
class EvaluationPredictor(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cotracker_model: CoTracker,
|
||||
interp_shape: Tuple[int, int] = (384, 512),
|
||||
grid_size: int = 6,
|
||||
local_grid_size: int = 6,
|
||||
single_point: bool = True,
|
||||
n_iters: int = 6,
|
||||
) -> None:
|
||||
super(EvaluationPredictor, self).__init__()
|
||||
self.grid_size = grid_size
|
||||
self.local_grid_size = local_grid_size
|
||||
self.single_point = single_point
|
||||
self.interp_shape = interp_shape
|
||||
self.n_iters = n_iters
|
||||
|
||||
self.model = cotracker_model
|
||||
self.model.to("cuda")
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, video, queries):
|
||||
queries = queries.clone().cuda()
|
||||
B, T, C, H, W = video.shape
|
||||
B, N, D = queries.shape
|
||||
|
||||
assert D == 3
|
||||
assert B == 1
|
||||
|
||||
rgbs = video.reshape(B * T, C, H, W)
|
||||
rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear")
|
||||
rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]).cuda()
|
||||
|
||||
queries[:, :, 1] *= self.interp_shape[1] / W
|
||||
queries[:, :, 2] *= self.interp_shape[0] / H
|
||||
|
||||
if self.single_point:
|
||||
traj_e = torch.zeros((B, T, N, 2)).cuda()
|
||||
vis_e = torch.zeros((B, T, N)).cuda()
|
||||
for pind in range((N)):
|
||||
query = queries[:, pind : pind + 1]
|
||||
|
||||
t = query[0, 0, 0].long()
|
||||
|
||||
traj_e_pind, vis_e_pind = self._process_one_point(rgbs, query)
|
||||
traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
|
||||
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
|
||||
else:
|
||||
if self.grid_size > 0:
|
||||
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() #
|
||||
queries = torch.cat([queries, xy], dim=1) #
|
||||
|
||||
traj_e, __, vis_e, __ = self.model(
|
||||
rgbs=rgbs,
|
||||
queries=queries,
|
||||
iters=self.n_iters,
|
||||
)
|
||||
|
||||
traj_e[:, :, :, 0] *= W / float(self.interp_shape[1])
|
||||
traj_e[:, :, :, 1] *= H / float(self.interp_shape[0])
|
||||
return traj_e, vis_e
|
||||
|
||||
def _process_one_point(self, rgbs, query):
|
||||
t = query[0, 0, 0].long()
|
||||
|
||||
device = rgbs.device
|
||||
if self.local_grid_size > 0:
|
||||
xy_target = get_points_on_a_grid(
|
||||
self.local_grid_size,
|
||||
(50, 50),
|
||||
[query[0, 0, 2], query[0, 0, 1]],
|
||||
)
|
||||
|
||||
xy_target = torch.cat(
|
||||
[torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2
|
||||
) #
|
||||
query = torch.cat([query, xy_target], dim=1).to(device) #
|
||||
|
||||
if self.grid_size > 0:
|
||||
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() #
|
||||
query = torch.cat([query, xy], dim=1).to(device) #
|
||||
# crop the video to start from the queried frame
|
||||
query[0, 0, 0] = 0
|
||||
traj_e_pind, __, vis_e_pind, __ = self.model(
|
||||
rgbs=rgbs[:, t:], queries=query, iters=self.n_iters
|
||||
)
|
||||
|
||||
return traj_e_pind, vis_e_pind
|
178
cotracker/predictor.py
Normal file
178
cotracker/predictor.py
Normal file
@ -0,0 +1,178 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tqdm import tqdm
|
||||
from cotracker.models.core.cotracker.cotracker import get_points_on_a_grid
|
||||
from cotracker.models.core.model_utils import smart_cat
|
||||
from cotracker.models.build_cotracker import (
|
||||
build_cotracker,
|
||||
)
|
||||
|
||||
|
||||
class CoTrackerPredictor(torch.nn.Module):
|
||||
def __init__(
|
||||
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth"
|
||||
):
|
||||
super().__init__()
|
||||
self.interp_shape = (384, 512)
|
||||
self.support_grid_size = 6
|
||||
model = build_cotracker(checkpoint)
|
||||
|
||||
self.model = model
|
||||
self.model.to("cuda")
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
video, # (1, T, 3, H, W)
|
||||
# input prompt types:
|
||||
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
|
||||
# *backward_tracking=True* will compute tracks in both directions.
|
||||
# - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates.
|
||||
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
|
||||
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
|
||||
queries: torch.Tensor = None,
|
||||
segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
|
||||
grid_size: int = 0,
|
||||
grid_query_frame: int = 0, # only for dense and regular grid tracks
|
||||
backward_tracking: bool = False,
|
||||
):
|
||||
|
||||
if queries is None and grid_size == 0:
|
||||
tracks, visibilities = self._compute_dense_tracks(
|
||||
video,
|
||||
grid_query_frame=grid_query_frame,
|
||||
backward_tracking=backward_tracking,
|
||||
)
|
||||
else:
|
||||
tracks, visibilities = self._compute_sparse_tracks(
|
||||
video,
|
||||
queries,
|
||||
segm_mask,
|
||||
grid_size,
|
||||
add_support_grid=(grid_size == 0 or segm_mask is not None),
|
||||
grid_query_frame=grid_query_frame,
|
||||
backward_tracking=backward_tracking,
|
||||
)
|
||||
|
||||
return tracks, visibilities
|
||||
|
||||
def _compute_dense_tracks(
|
||||
self, video, grid_query_frame, grid_size=50, backward_tracking=False
|
||||
):
|
||||
*_, H, W = video.shape
|
||||
grid_step = W // grid_size
|
||||
grid_width = W // grid_step
|
||||
grid_height = H // grid_step
|
||||
tracks = visibilities = None
|
||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to("cuda")
|
||||
grid_pts[0, :, 0] = grid_query_frame
|
||||
for offset in tqdm(range(grid_step * grid_step)):
|
||||
ox = offset % grid_step
|
||||
oy = offset // grid_step
|
||||
grid_pts[0, :, 1] = (
|
||||
torch.arange(grid_width).repeat(grid_height) * grid_step + ox
|
||||
)
|
||||
grid_pts[0, :, 2] = (
|
||||
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
|
||||
)
|
||||
tracks_step, visibilities_step = self._compute_sparse_tracks(
|
||||
video=video,
|
||||
queries=grid_pts,
|
||||
backward_tracking=backward_tracking,
|
||||
)
|
||||
tracks = smart_cat(tracks, tracks_step, dim=2)
|
||||
visibilities = smart_cat(visibilities, visibilities_step, dim=2)
|
||||
|
||||
return tracks, visibilities
|
||||
|
||||
def _compute_sparse_tracks(
|
||||
self,
|
||||
video,
|
||||
queries,
|
||||
segm_mask=None,
|
||||
grid_size=0,
|
||||
add_support_grid=False,
|
||||
grid_query_frame=0,
|
||||
backward_tracking=False,
|
||||
):
|
||||
B, T, C, H, W = video.shape
|
||||
assert B == 1
|
||||
|
||||
video = video.reshape(B * T, C, H, W)
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").cuda()
|
||||
video = video.reshape(
|
||||
B, T, 3, self.interp_shape[0], self.interp_shape[1]
|
||||
).cuda()
|
||||
|
||||
if queries is not None:
|
||||
queries = queries.clone()
|
||||
B, N, D = queries.shape
|
||||
assert D == 3
|
||||
queries[:, :, 1] *= self.interp_shape[1] / W
|
||||
queries[:, :, 2] *= self.interp_shape[0] / H
|
||||
elif grid_size > 0:
|
||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape)
|
||||
if segm_mask is not None:
|
||||
segm_mask = F.interpolate(
|
||||
segm_mask, tuple(self.interp_shape), mode="nearest"
|
||||
)
|
||||
point_mask = segm_mask[0, 0][
|
||||
(grid_pts[0, :, 1]).round().long().cpu(),
|
||||
(grid_pts[0, :, 0]).round().long().cpu(),
|
||||
].bool()
|
||||
grid_pts = grid_pts[:, point_mask]
|
||||
|
||||
queries = torch.cat(
|
||||
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
|
||||
dim=2,
|
||||
)
|
||||
|
||||
if add_support_grid:
|
||||
grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape)
|
||||
grid_pts = torch.cat(
|
||||
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
|
||||
)
|
||||
queries = torch.cat([queries, grid_pts], dim=1)
|
||||
|
||||
tracks, __, visibilities, __ = self.model(rgbs=video, queries=queries, iters=6)
|
||||
|
||||
if backward_tracking:
|
||||
tracks, visibilities = self._compute_backward_tracks(
|
||||
video, queries, tracks, visibilities
|
||||
)
|
||||
if add_support_grid:
|
||||
queries[:, -self.support_grid_size ** 2 :, 0] = T - 1
|
||||
if add_support_grid:
|
||||
tracks = tracks[:, :, : -self.support_grid_size ** 2]
|
||||
visibilities = visibilities[:, :, : -self.support_grid_size ** 2]
|
||||
thr = 0.9
|
||||
visibilities = visibilities > thr
|
||||
tracks[:, :, :, 0] *= W / float(self.interp_shape[1])
|
||||
tracks[:, :, :, 1] *= H / float(self.interp_shape[0])
|
||||
return tracks, visibilities
|
||||
|
||||
def _compute_backward_tracks(self, video, queries, tracks, visibilities):
|
||||
inv_video = video.flip(1).clone()
|
||||
inv_queries = queries.clone()
|
||||
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
||||
|
||||
inv_tracks, __, inv_visibilities, __ = self.model(
|
||||
rgbs=inv_video, queries=inv_queries, iters=6
|
||||
)
|
||||
|
||||
inv_tracks = inv_tracks.flip(1)
|
||||
inv_visibilities = inv_visibilities.flip(1)
|
||||
|
||||
mask = tracks == 0
|
||||
|
||||
tracks[mask] = inv_tracks[mask]
|
||||
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
|
||||
return tracks, visibilities
|
5
cotracker/utils/__init__.py
Normal file
5
cotracker/utils/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
291
cotracker/utils/visualizer.py
Normal file
291
cotracker/utils/visualizer.py
Normal file
@ -0,0 +1,291 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import flow_vis
|
||||
|
||||
from matplotlib import cm
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as transforms
|
||||
from moviepy.editor import ImageSequenceClip
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class Visualizer:
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: str = "./results",
|
||||
grayscale: bool = False,
|
||||
pad_value: int = 0,
|
||||
fps: int = 10,
|
||||
mode: str = "rainbow", # 'cool', 'optical_flow'
|
||||
linewidth: int = 2,
|
||||
show_first_frame: int = 10,
|
||||
tracks_leave_trace: int = 0, # -1 for infinite
|
||||
):
|
||||
self.mode = mode
|
||||
self.save_dir = save_dir
|
||||
if mode == "rainbow":
|
||||
self.color_map = cm.get_cmap("gist_rainbow")
|
||||
elif mode == "cool":
|
||||
self.color_map = cm.get_cmap(mode)
|
||||
self.show_first_frame = show_first_frame
|
||||
self.grayscale = grayscale
|
||||
self.tracks_leave_trace = tracks_leave_trace
|
||||
self.pad_value = pad_value
|
||||
self.linewidth = linewidth
|
||||
self.fps = fps
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
video: torch.Tensor, # (B,T,C,H,W)
|
||||
tracks: torch.Tensor, # (B,T,N,2)
|
||||
gt_tracks: torch.Tensor = None, # (B,T,N,2)
|
||||
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
||||
filename: str = "video",
|
||||
writer: SummaryWriter = None,
|
||||
step: int = 0,
|
||||
query_frame: int = 0,
|
||||
save_video: bool = True,
|
||||
compensate_for_camera_motion: bool = False,
|
||||
):
|
||||
if compensate_for_camera_motion:
|
||||
assert segm_mask is not None
|
||||
if segm_mask is not None:
|
||||
coords = tracks[0, query_frame].round().long()
|
||||
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
|
||||
|
||||
video = F.pad(
|
||||
video,
|
||||
(self.pad_value, self.pad_value, self.pad_value, self.pad_value),
|
||||
"constant",
|
||||
255,
|
||||
)
|
||||
tracks = tracks + self.pad_value
|
||||
|
||||
if self.grayscale:
|
||||
transform = transforms.Grayscale()
|
||||
video = transform(video)
|
||||
video = video.repeat(1, 1, 3, 1, 1)
|
||||
|
||||
res_video = self.draw_tracks_on_video(
|
||||
video=video,
|
||||
tracks=tracks,
|
||||
segm_mask=segm_mask,
|
||||
gt_tracks=gt_tracks,
|
||||
query_frame=query_frame,
|
||||
compensate_for_camera_motion=compensate_for_camera_motion,
|
||||
)
|
||||
if save_video:
|
||||
self.save_video(res_video, filename=filename, writer=writer, step=step)
|
||||
return res_video
|
||||
|
||||
def save_video(self, video, filename, writer=None, step=0):
|
||||
if writer is not None:
|
||||
writer.add_video(
|
||||
f"{filename}_pred_track",
|
||||
video.to(torch.uint8),
|
||||
global_step=step,
|
||||
fps=self.fps,
|
||||
)
|
||||
else:
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
wide_list = list(video.unbind(1))
|
||||
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
|
||||
clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps)
|
||||
|
||||
# Write the video file
|
||||
save_path = os.path.join(self.save_dir, f"{filename}_pred_track.mp4")
|
||||
clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None)
|
||||
|
||||
print(f"Video saved to {save_path}")
|
||||
|
||||
def draw_tracks_on_video(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
tracks: torch.Tensor,
|
||||
segm_mask: torch.Tensor = None,
|
||||
gt_tracks=None,
|
||||
query_frame: int = 0,
|
||||
compensate_for_camera_motion=False,
|
||||
):
|
||||
B, T, C, H, W = video.shape
|
||||
_, _, N, D = tracks.shape
|
||||
|
||||
assert D == 2
|
||||
assert C == 3
|
||||
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
|
||||
tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
|
||||
if gt_tracks is not None:
|
||||
gt_tracks = gt_tracks[0].detach().cpu().numpy()
|
||||
|
||||
res_video = []
|
||||
|
||||
# process input video
|
||||
for rgb in video:
|
||||
res_video.append(rgb.copy())
|
||||
|
||||
vector_colors = np.zeros((T, N, 3))
|
||||
if self.mode == "optical_flow":
|
||||
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
|
||||
elif segm_mask is None:
|
||||
if self.mode == "rainbow":
|
||||
y_min, y_max = (
|
||||
tracks[query_frame, :, 1].min(),
|
||||
tracks[query_frame, :, 1].max(),
|
||||
)
|
||||
norm = plt.Normalize(y_min, y_max)
|
||||
for n in range(N):
|
||||
color = self.color_map(norm(tracks[query_frame, n, 1]))
|
||||
color = np.array(color[:3])[None] * 255
|
||||
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
||||
else:
|
||||
# color changes with time
|
||||
for t in range(T):
|
||||
color = np.array(self.color_map(t / T)[:3])[None] * 255
|
||||
vector_colors[t] = np.repeat(color, N, axis=0)
|
||||
else:
|
||||
if self.mode == "rainbow":
|
||||
vector_colors[:, segm_mask <= 0, :] = 255
|
||||
|
||||
y_min, y_max = (
|
||||
tracks[0, segm_mask > 0, 1].min(),
|
||||
tracks[0, segm_mask > 0, 1].max(),
|
||||
)
|
||||
norm = plt.Normalize(y_min, y_max)
|
||||
for n in range(N):
|
||||
if segm_mask[n] > 0:
|
||||
color = self.color_map(norm(tracks[0, n, 1]))
|
||||
color = np.array(color[:3])[None] * 255
|
||||
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
||||
|
||||
else:
|
||||
# color changes with segm class
|
||||
segm_mask = segm_mask.cpu()
|
||||
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
|
||||
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
|
||||
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
|
||||
vector_colors = np.repeat(color[None], T, axis=0)
|
||||
|
||||
# draw tracks
|
||||
if self.tracks_leave_trace != 0:
|
||||
for t in range(1, T):
|
||||
first_ind = (
|
||||
max(0, t - self.tracks_leave_trace)
|
||||
if self.tracks_leave_trace >= 0
|
||||
else 0
|
||||
)
|
||||
curr_tracks = tracks[first_ind : t + 1]
|
||||
curr_colors = vector_colors[first_ind : t + 1]
|
||||
if compensate_for_camera_motion:
|
||||
diff = (
|
||||
tracks[first_ind : t + 1, segm_mask <= 0]
|
||||
- tracks[t : t + 1, segm_mask <= 0]
|
||||
).mean(1)[:, None]
|
||||
|
||||
curr_tracks = curr_tracks - diff
|
||||
curr_tracks = curr_tracks[:, segm_mask > 0]
|
||||
curr_colors = curr_colors[:, segm_mask > 0]
|
||||
|
||||
res_video[t] = self._draw_pred_tracks(
|
||||
res_video[t],
|
||||
curr_tracks,
|
||||
curr_colors,
|
||||
)
|
||||
if gt_tracks is not None:
|
||||
res_video[t] = self._draw_gt_tracks(
|
||||
res_video[t], gt_tracks[first_ind : t + 1]
|
||||
)
|
||||
|
||||
# draw points
|
||||
for t in range(T):
|
||||
for i in range(N):
|
||||
coord = (tracks[t, i, 0], tracks[t, i, 1])
|
||||
if coord[0] != 0 and coord[1] != 0:
|
||||
if not compensate_for_camera_motion or (
|
||||
compensate_for_camera_motion and segm_mask[i] > 0
|
||||
):
|
||||
cv2.circle(
|
||||
res_video[t],
|
||||
coord,
|
||||
int(self.linewidth * 2),
|
||||
vector_colors[t, i].tolist(),
|
||||
-1,
|
||||
)
|
||||
|
||||
# construct the final rgb sequence
|
||||
if self.show_first_frame > 0:
|
||||
res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
|
||||
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
|
||||
|
||||
def _draw_pred_tracks(
|
||||
self,
|
||||
rgb: np.ndarray, # H x W x 3
|
||||
tracks: np.ndarray, # T x 2
|
||||
vector_colors: np.ndarray,
|
||||
alpha: float = 0.5,
|
||||
):
|
||||
T, N, _ = tracks.shape
|
||||
|
||||
for s in range(T - 1):
|
||||
vector_color = vector_colors[s]
|
||||
original = rgb.copy()
|
||||
alpha = (s / T) ** 2
|
||||
for i in range(N):
|
||||
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
|
||||
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
|
||||
if coord_y[0] != 0 and coord_y[1] != 0:
|
||||
cv2.line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
vector_color[i].tolist(),
|
||||
self.linewidth,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
if self.tracks_leave_trace > 0:
|
||||
rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0)
|
||||
return rgb
|
||||
|
||||
def _draw_gt_tracks(
|
||||
self,
|
||||
rgb: np.ndarray, # H x W x 3,
|
||||
gt_tracks: np.ndarray, # T x 2
|
||||
):
|
||||
T, N, _ = gt_tracks.shape
|
||||
color = np.array((211.0, 0.0, 0.0))
|
||||
|
||||
for t in range(T):
|
||||
for i in range(N):
|
||||
gt_tracks = gt_tracks[t][i]
|
||||
# draw a red cross
|
||||
if gt_tracks[0] > 0 and gt_tracks[1] > 0:
|
||||
length = self.linewidth * 3
|
||||
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
|
||||
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
|
||||
cv2.line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
color,
|
||||
self.linewidth,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
|
||||
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
|
||||
cv2.line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
color,
|
||||
self.linewidth,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
return rgb
|
71
demo.py
Normal file
71
demo.py
Normal file
@ -0,0 +1,71 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from torchvision.io import read_video
|
||||
from PIL import Image
|
||||
from cotracker.utils.visualizer import Visualizer
|
||||
from cotracker.predictor import CoTrackerPredictor
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--video_path",
|
||||
default="./assets/apple.mp4",
|
||||
help="path to a video",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask_path",
|
||||
default="./assets/apple_mask.png",
|
||||
help="path to a segmentation mask",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
default="./checkpoints/cotracker_stride_4_wind_8.pth",
|
||||
help="cotracker model",
|
||||
)
|
||||
parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size")
|
||||
parser.add_argument(
|
||||
"--grid_query_frame",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Compute dense and grid tracks starting from this frame ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--backward_tracking",
|
||||
action="store_true",
|
||||
help="Compute tracks in both directions, not only forward",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# load the input video frame by frame
|
||||
video = read_video(args.video_path)
|
||||
video = video[0].permute(0, 3, 1, 2)[None].float()
|
||||
segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
|
||||
segm_mask = torch.from_numpy(segm_mask)[None, None]
|
||||
|
||||
model = CoTrackerPredictor(checkpoint=args.checkpoint)
|
||||
|
||||
pred_tracks, pred_visibility = model(
|
||||
video,
|
||||
grid_size=args.grid_size,
|
||||
grid_query_frame=args.grid_query_frame,
|
||||
backward_tracking=args.backward_tracking,
|
||||
# segm_mask=segm_mask
|
||||
)
|
||||
print("computed")
|
||||
|
||||
# save a video with predicted tracks
|
||||
seq_name = args.video_path.split("/")[-1]
|
||||
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
|
||||
vis.visualize(video, pred_tracks, query_frame=args.grid_query_frame)
|
924
notebooks/demo.ipynb
Normal file
924
notebooks/demo.ipynb
Normal file
File diff suppressed because one or more lines are too long
18
setup.py
Normal file
18
setup.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
setup(
|
||||
name="cotracker",
|
||||
version="1.0",
|
||||
install_requires=[],
|
||||
packages=find_packages(exclude="notebooks"),
|
||||
extras_require={
|
||||
"all": ["matplotlib", "opencv-python"],
|
||||
"dev": ["flake8", "black"],
|
||||
},
|
||||
)
|
665
train.py
Normal file
665
train.py
Normal file
@ -0,0 +1,665 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import signal
|
||||
import socket
|
||||
import sys
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from pytorch_lightning.lite import LightningLite
|
||||
|
||||
from cotracker.models.evaluation_predictor import EvaluationPredictor
|
||||
from cotracker.models.core.cotracker.cotracker import CoTracker
|
||||
from cotracker.utils.visualizer import Visualizer
|
||||
from cotracker.datasets.tap_vid_datasets import TapVidDataset
|
||||
from cotracker.datasets.badja_dataset import BadjaDataset
|
||||
from cotracker.datasets.fast_capture_dataset import FastCaptureDataset
|
||||
from cotracker.evaluation.core.evaluator import Evaluator
|
||||
from cotracker.datasets import kubric_movif_dataset
|
||||
from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_cuda_
|
||||
from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss
|
||||
|
||||
|
||||
# define the handler function
|
||||
# for training on a slurm cluster
|
||||
def sig_handler(signum, frame):
|
||||
print("caught signal", signum)
|
||||
print(socket.gethostname(), "USR1 signal caught.")
|
||||
# do other stuff to cleanup here
|
||||
print("requeuing job " + os.environ["SLURM_JOB_ID"])
|
||||
os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def term_handler(signum, frame):
|
||||
print("bypassing sigterm", flush=True)
|
||||
|
||||
|
||||
def fetch_optimizer(args, model):
|
||||
"""Create the optimizer and learning rate scheduler"""
|
||||
optimizer = optim.AdamW(
|
||||
model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8
|
||||
)
|
||||
scheduler = optim.lr_scheduler.OneCycleLR(
|
||||
optimizer,
|
||||
args.lr,
|
||||
args.num_steps + 100,
|
||||
pct_start=0.05,
|
||||
cycle_momentum=False,
|
||||
anneal_strategy="linear",
|
||||
)
|
||||
|
||||
return optimizer, scheduler
|
||||
|
||||
|
||||
def forward_batch(batch, model, args, loss_fn=None, writer=None, step=0):
|
||||
rgbs = batch.video
|
||||
trajs_g = batch.trajectory
|
||||
vis_g = batch.visibility
|
||||
valids = batch.valid
|
||||
B, T, C, H, W = rgbs.shape
|
||||
assert C == 3
|
||||
B, T, N, D = trajs_g.shape
|
||||
device = rgbs.device
|
||||
|
||||
__, first_positive_inds = torch.max(vis_g, dim=1)
|
||||
# We want to make sure that during training the model sees visible points
|
||||
# that it does not need to track just yet: they are visible but queried from a later frame
|
||||
N_rand = N // 4
|
||||
# inds of visible points in the 1st frame
|
||||
nonzero_inds = [torch.nonzero(vis_g[0, :, i]) for i in range(N)]
|
||||
rand_vis_inds = torch.cat(
|
||||
[
|
||||
nonzero_row[torch.randint(len(nonzero_row), size=(1,))]
|
||||
for nonzero_row in nonzero_inds
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
first_positive_inds = torch.cat(
|
||||
[rand_vis_inds[:, :N_rand], first_positive_inds[:, N_rand:]], dim=1
|
||||
)
|
||||
ind_array_ = torch.arange(T, device=device)
|
||||
ind_array_ = ind_array_[None, :, None].repeat(B, 1, N)
|
||||
assert torch.allclose(
|
||||
vis_g[ind_array_ == first_positive_inds[:, None, :]],
|
||||
torch.ones_like(vis_g),
|
||||
)
|
||||
assert torch.allclose(
|
||||
vis_g[ind_array_ == rand_vis_inds[:, None, :]], torch.ones_like(vis_g)
|
||||
)
|
||||
|
||||
gather = torch.gather(
|
||||
trajs_g, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 2)
|
||||
)
|
||||
xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1)
|
||||
|
||||
queries = torch.cat([first_positive_inds[:, :, None], xys], dim=2)
|
||||
|
||||
predictions, __, visibility, train_data = model(
|
||||
rgbs=rgbs, queries=queries, iters=args.train_iters, is_train=True
|
||||
)
|
||||
vis_predictions, coord_predictions, wind_inds, sort_inds = train_data
|
||||
|
||||
trajs_g = trajs_g[:, :, sort_inds]
|
||||
vis_g = vis_g[:, :, sort_inds]
|
||||
valids = valids[:, :, sort_inds]
|
||||
|
||||
vis_gts = []
|
||||
traj_gts = []
|
||||
valids_gts = []
|
||||
|
||||
for i, wind_idx in enumerate(wind_inds):
|
||||
ind = i * (args.sliding_window_len // 2)
|
||||
|
||||
vis_gts.append(vis_g[:, ind : ind + args.sliding_window_len, :wind_idx])
|
||||
traj_gts.append(trajs_g[:, ind : ind + args.sliding_window_len, :wind_idx])
|
||||
valids_gts.append(valids[:, ind : ind + args.sliding_window_len, :wind_idx])
|
||||
|
||||
seq_loss = sequence_loss(coord_predictions, traj_gts, vis_gts, valids_gts, 0.8)
|
||||
vis_loss = balanced_ce_loss(vis_predictions, vis_gts, valids_gts)
|
||||
|
||||
output = {"flow": {"predictions": predictions[0].detach()}}
|
||||
output["flow"]["loss"] = seq_loss.mean()
|
||||
output["visibility"] = {
|
||||
"loss": vis_loss.mean() * 10.0,
|
||||
"predictions": visibility[0].detach(),
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
def run_test_eval(evaluator, model, dataloaders, writer, step):
|
||||
model.eval()
|
||||
for ds_name, dataloader in dataloaders:
|
||||
predictor = EvaluationPredictor(
|
||||
model.module.module,
|
||||
grid_size=6,
|
||||
local_grid_size=0,
|
||||
single_point=False,
|
||||
n_iters=6,
|
||||
)
|
||||
|
||||
metrics = evaluator.evaluate_sequence(
|
||||
model=predictor,
|
||||
test_dataloader=dataloader,
|
||||
dataset_name=ds_name,
|
||||
train_mode=True,
|
||||
writer=writer,
|
||||
step=step,
|
||||
)
|
||||
|
||||
if ds_name == "badja" or ds_name == "fastcapture" or ("kubric" in ds_name):
|
||||
|
||||
metrics = {
|
||||
**{
|
||||
f"{ds_name}_avg": np.mean(
|
||||
[v for k, v in metrics.items() if "accuracy" not in k]
|
||||
)
|
||||
},
|
||||
**{
|
||||
f"{ds_name}_avg_accuracy": np.mean(
|
||||
[v for k, v in metrics.items() if "accuracy" in k]
|
||||
)
|
||||
},
|
||||
}
|
||||
print("avg", np.mean([v for v in metrics.values()]))
|
||||
|
||||
if "tapvid" in ds_name:
|
||||
metrics = {
|
||||
f"{ds_name}_avg_OA": metrics["avg"]["occlusion_accuracy"] * 100,
|
||||
f"{ds_name}_avg_delta": metrics["avg"]["average_pts_within_thresh"]
|
||||
* 100,
|
||||
f"{ds_name}_avg_Jaccard": metrics["avg"]["average_jaccard"] * 100,
|
||||
}
|
||||
|
||||
writer.add_scalars(f"Eval", metrics, step)
|
||||
|
||||
|
||||
class Logger:
|
||||
|
||||
SUM_FREQ = 100
|
||||
|
||||
def __init__(self, model, scheduler):
|
||||
self.model = model
|
||||
self.scheduler = scheduler
|
||||
self.total_steps = 0
|
||||
self.running_loss = {}
|
||||
self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs"))
|
||||
|
||||
def _print_training_status(self):
|
||||
metrics_data = [
|
||||
self.running_loss[k] / Logger.SUM_FREQ
|
||||
for k in sorted(self.running_loss.keys())
|
||||
]
|
||||
training_str = "[{:6d}] ".format(self.total_steps + 1)
|
||||
metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data)
|
||||
|
||||
# print the training status
|
||||
logging.info(
|
||||
f"Training Metrics ({self.total_steps}): {training_str + metrics_str}"
|
||||
)
|
||||
|
||||
if self.writer is None:
|
||||
self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs"))
|
||||
|
||||
for k in self.running_loss:
|
||||
self.writer.add_scalar(
|
||||
k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps
|
||||
)
|
||||
self.running_loss[k] = 0.0
|
||||
|
||||
def push(self, metrics, task):
|
||||
self.total_steps += 1
|
||||
|
||||
for key in metrics:
|
||||
task_key = str(key) + "_" + task
|
||||
if task_key not in self.running_loss:
|
||||
self.running_loss[task_key] = 0.0
|
||||
|
||||
self.running_loss[task_key] += metrics[key]
|
||||
|
||||
if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ - 1:
|
||||
self._print_training_status()
|
||||
self.running_loss = {}
|
||||
|
||||
def write_dict(self, results):
|
||||
if self.writer is None:
|
||||
self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs"))
|
||||
|
||||
for key in results:
|
||||
self.writer.add_scalar(key, results[key], self.total_steps)
|
||||
|
||||
def close(self):
|
||||
self.writer.close()
|
||||
|
||||
|
||||
class Lite(LightningLite):
|
||||
def run(self, args):
|
||||
def seed_everything(seed: int):
|
||||
random.seed(seed)
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
seed_everything(0)
|
||||
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = torch.initial_seed() % 2 ** 32
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
g = torch.Generator()
|
||||
g.manual_seed(0)
|
||||
|
||||
eval_dataloaders = []
|
||||
if "badja" in args.eval_datasets:
|
||||
eval_dataset = BadjaDataset(
|
||||
data_root=os.path.join(args.dataset_root, "BADJA"),
|
||||
max_seq_len=args.eval_max_seq_len,
|
||||
dataset_resolution=args.crop_size,
|
||||
)
|
||||
eval_dataloader_badja = torch.utils.data.DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
eval_dataloaders.append(("badja", eval_dataloader_badja))
|
||||
|
||||
if "fastcapture" in args.eval_datasets:
|
||||
eval_dataset = FastCaptureDataset(
|
||||
data_root=os.path.join(args.dataset_root, "fastcapture"),
|
||||
max_seq_len=min(100, args.eval_max_seq_len),
|
||||
max_num_points=40,
|
||||
dataset_resolution=args.crop_size,
|
||||
)
|
||||
eval_dataloader_fastcapture = torch.utils.data.DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=1,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
eval_dataloaders.append(("fastcapture", eval_dataloader_fastcapture))
|
||||
|
||||
if "tapvid_davis_first" in args.eval_datasets:
|
||||
data_root = os.path.join(
|
||||
args.dataset_root, "/tapvid_davis/tapvid_davis.pkl"
|
||||
)
|
||||
eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root)
|
||||
eval_dataloader_tapvid_davis = torch.utils.data.DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=1,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
eval_dataloaders.append(("tapvid_davis", eval_dataloader_tapvid_davis))
|
||||
|
||||
evaluator = Evaluator(args.ckpt_path)
|
||||
|
||||
visualizer = Visualizer(
|
||||
save_dir=args.ckpt_path,
|
||||
pad_value=80,
|
||||
fps=1,
|
||||
show_first_frame=0,
|
||||
tracks_leave_trace=0,
|
||||
)
|
||||
|
||||
loss_fn = None
|
||||
|
||||
if args.model_name == "cotracker":
|
||||
|
||||
model = CoTracker(
|
||||
stride=args.model_stride,
|
||||
S=args.sliding_window_len,
|
||||
add_space_attn=not args.remove_space_attn,
|
||||
num_heads=args.updateformer_num_heads,
|
||||
hidden_size=args.updateformer_hidden_size,
|
||||
space_depth=args.updateformer_space_depth,
|
||||
time_depth=args.updateformer_time_depth,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Model {args.model_name} doesn't exist")
|
||||
|
||||
with open(args.ckpt_path + "/meta.json", "w") as file:
|
||||
json.dump(vars(args), file, sort_keys=True, indent=4)
|
||||
|
||||
model.cuda()
|
||||
|
||||
train_dataset = kubric_movif_dataset.KubricMovifDataset(
|
||||
data_root=os.path.join(args.dataset_root, "kubric_movi_f"),
|
||||
crop_size=args.crop_size,
|
||||
seq_len=args.sequence_len,
|
||||
traj_per_sample=args.traj_per_sample,
|
||||
sample_vis_1st_frame=args.sample_vis_1st_frame,
|
||||
use_augs=not args.dont_use_augs,
|
||||
)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.num_workers,
|
||||
worker_init_fn=seed_worker,
|
||||
generator=g,
|
||||
pin_memory=True,
|
||||
collate_fn=collate_fn_train,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
train_loader = self.setup_dataloaders(train_loader, move_to_device=False)
|
||||
print("LEN TRAIN LOADER", len(train_loader))
|
||||
optimizer, scheduler = fetch_optimizer(args, model)
|
||||
|
||||
total_steps = 0
|
||||
logger = Logger(model, scheduler)
|
||||
|
||||
folder_ckpts = [
|
||||
f
|
||||
for f in os.listdir(args.ckpt_path)
|
||||
if not os.path.isdir(f) and f.endswith(".pth") and not "final" in f
|
||||
]
|
||||
if len(folder_ckpts) > 0:
|
||||
ckpt_path = sorted(folder_ckpts)[-1]
|
||||
ckpt = self.load(os.path.join(args.ckpt_path, ckpt_path))
|
||||
logging.info(f"Loading checkpoint {ckpt_path}")
|
||||
if "model" in ckpt:
|
||||
model.load_state_dict(ckpt["model"])
|
||||
else:
|
||||
model.load_state_dict(ckpt)
|
||||
if "optimizer" in ckpt:
|
||||
logging.info("Load optimizer")
|
||||
optimizer.load_state_dict(ckpt["optimizer"])
|
||||
if "scheduler" in ckpt:
|
||||
logging.info("Load scheduler")
|
||||
scheduler.load_state_dict(ckpt["scheduler"])
|
||||
if "total_steps" in ckpt:
|
||||
total_steps = ckpt["total_steps"]
|
||||
logging.info(f"Load total_steps {total_steps}")
|
||||
|
||||
elif args.restore_ckpt is not None:
|
||||
assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith(
|
||||
".pt"
|
||||
)
|
||||
logging.info("Loading checkpoint...")
|
||||
|
||||
strict = True
|
||||
state_dict = self.load(args.restore_ckpt)
|
||||
if "model" in state_dict:
|
||||
state_dict = state_dict["model"]
|
||||
|
||||
if list(state_dict.keys())[0].startswith("module."):
|
||||
state_dict = {
|
||||
k.replace("module.", ""): v for k, v in state_dict.items()
|
||||
}
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
logging.info(f"Done loading checkpoint")
|
||||
model, optimizer = self.setup(model, optimizer, move_to_device=False)
|
||||
# model.cuda()
|
||||
model.train()
|
||||
|
||||
save_freq = args.save_freq
|
||||
scaler = GradScaler(enabled=args.mixed_precision)
|
||||
|
||||
should_keep_training = True
|
||||
global_batch_num = 0
|
||||
epoch = -1
|
||||
|
||||
while should_keep_training:
|
||||
epoch += 1
|
||||
for i_batch, batch in enumerate(tqdm(train_loader)):
|
||||
batch, gotit = batch
|
||||
if not all(gotit):
|
||||
print("batch is None")
|
||||
continue
|
||||
dataclass_to_cuda_(batch)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
assert model.training
|
||||
|
||||
output = forward_batch(
|
||||
batch,
|
||||
model,
|
||||
args,
|
||||
loss_fn=loss_fn,
|
||||
writer=logger.writer,
|
||||
step=total_steps,
|
||||
)
|
||||
|
||||
loss = 0
|
||||
for k, v in output.items():
|
||||
if "loss" in v:
|
||||
loss += v["loss"]
|
||||
logger.writer.add_scalar(
|
||||
f"live_{k}_loss", v["loss"].item(), total_steps
|
||||
)
|
||||
if "metrics" in v:
|
||||
logger.push(v["metrics"], k)
|
||||
|
||||
if self.global_rank == 0:
|
||||
if total_steps % save_freq == save_freq - 1:
|
||||
if args.model_name == "motion_diffuser":
|
||||
pred_coords = model.module.module.forward_batch_test(
|
||||
batch, interp_shape=args.crop_size
|
||||
)
|
||||
|
||||
output["flow"] = {"predictions": pred_coords[0].detach()}
|
||||
visualizer.visualize(
|
||||
video=batch.video.clone(),
|
||||
tracks=batch.trajectory.clone(),
|
||||
filename="train_gt_traj",
|
||||
writer=logger.writer,
|
||||
step=total_steps,
|
||||
)
|
||||
|
||||
visualizer.visualize(
|
||||
video=batch.video.clone(),
|
||||
tracks=output["flow"]["predictions"][None],
|
||||
filename="train_pred_traj",
|
||||
writer=logger.writer,
|
||||
step=total_steps,
|
||||
)
|
||||
|
||||
if len(output) > 1:
|
||||
logger.writer.add_scalar(
|
||||
f"live_total_loss", loss.item(), total_steps
|
||||
)
|
||||
logger.writer.add_scalar(
|
||||
f"learning_rate", optimizer.param_groups[0]["lr"], total_steps
|
||||
)
|
||||
global_batch_num += 1
|
||||
|
||||
self.barrier()
|
||||
|
||||
self.backward(scaler.scale(loss))
|
||||
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scheduler.step()
|
||||
scaler.update()
|
||||
total_steps += 1
|
||||
if self.global_rank == 0:
|
||||
if (i_batch >= len(train_loader) - 1) or (
|
||||
total_steps == 1 and args.validate_at_start
|
||||
):
|
||||
if (epoch + 1) % args.save_every_n_epoch == 0:
|
||||
ckpt_iter = "0" * (6 - len(str(total_steps))) + str(
|
||||
total_steps
|
||||
)
|
||||
save_path = Path(
|
||||
f"{args.ckpt_path}/model_{args.model_name}_{ckpt_iter}.pth"
|
||||
)
|
||||
|
||||
save_dict = {
|
||||
"model": model.module.module.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"scheduler": scheduler.state_dict(),
|
||||
"total_steps": total_steps,
|
||||
}
|
||||
|
||||
logging.info(f"Saving file {save_path}")
|
||||
self.save(save_dict, save_path)
|
||||
|
||||
if (epoch + 1) % args.evaluate_every_n_epoch == 0 or (
|
||||
args.validate_at_start and epoch == 0
|
||||
):
|
||||
run_test_eval(
|
||||
evaluator,
|
||||
model,
|
||||
eval_dataloaders,
|
||||
logger.writer,
|
||||
total_steps,
|
||||
)
|
||||
model.train()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.barrier()
|
||||
if total_steps > args.num_steps:
|
||||
should_keep_training = False
|
||||
break
|
||||
|
||||
print("FINISHED TRAINING")
|
||||
|
||||
PATH = f"{args.ckpt_path}/{args.model_name}_final.pth"
|
||||
torch.save(model.module.module.state_dict(), PATH)
|
||||
run_test_eval(evaluator, model, eval_dataloaders, logger.writer, total_steps)
|
||||
logger.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
signal.signal(signal.SIGUSR1, sig_handler)
|
||||
signal.signal(signal.SIGTERM, term_handler)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_name", default="cotracker", help="model name")
|
||||
parser.add_argument("--restore_ckpt", help="restore checkpoint")
|
||||
parser.add_argument("--ckpt_path", help="restore checkpoint")
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=4, help="batch size used during training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, default=6, help="left right consistency loss"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mixed_precision", action="store_true", help="use mixed precision"
|
||||
)
|
||||
parser.add_argument("--lr", type=float, default=0.0005, help="max learning rate.")
|
||||
parser.add_argument(
|
||||
"--wdecay", type=float, default=0.00001, help="Weight decay in optimizer."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_steps", type=int, default=200000, help="length of training schedule."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--evaluate_every_n_epoch",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of flow-field updates during validation forward pass",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_every_n_epoch",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of flow-field updates during validation forward pass",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validate_at_start", action="store_true", help="use mixed precision"
|
||||
)
|
||||
parser.add_argument("--save_freq", type=int, default=100, help="save_freq")
|
||||
parser.add_argument("--traj_per_sample", type=int, default=768, help="save_freq")
|
||||
parser.add_argument("--dataset_root", type=str, help="path lo all the datasets")
|
||||
|
||||
parser.add_argument(
|
||||
"--train_iters",
|
||||
type=int,
|
||||
default=4,
|
||||
help="number of updates to the disparity field in each forward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sequence_len", type=int, default=8, help="train sequence length"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_datasets",
|
||||
nargs="+",
|
||||
default=["things", "badja", "fastcapture"],
|
||||
help="eval datasets.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--remove_space_attn", action="store_true", help="use mixed precision"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dont_use_augs", action="store_true", help="use mixed precision"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_vis_1st_frame", action="store_true", help="use mixed precision"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sliding_window_len", type=int, default=8, help="use mixed precision"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--updateformer_hidden_size", type=int, default=384, help="use mixed precision"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--updateformer_num_heads", type=int, default=8, help="use mixed precision"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--updateformer_space_depth", type=int, default=12, help="use mixed precision"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--updateformer_time_depth", type=int, default=12, help="use mixed precision"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_stride", type=int, default=8, help="use mixed precision"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop_size",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[384, 512],
|
||||
help="use mixed precision",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_max_seq_len", type=int, default=1000, help="use mixed precision"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
|
||||
)
|
||||
|
||||
Path(args.ckpt_path).mkdir(exist_ok=True, parents=True)
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
|
||||
Lite(
|
||||
strategy=DDPStrategy(find_unused_parameters=True),
|
||||
devices="auto",
|
||||
accelerator="gpu",
|
||||
precision=32,
|
||||
num_nodes=4,
|
||||
).run(args)
|
Loading…
Reference in New Issue
Block a user