Initial commit
This commit is contained in:
commit
6d62d873fa
80
CODE_OF_CONDUCT.md
Normal file
80
CODE_OF_CONDUCT.md
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
# Code of Conduct
|
||||||
|
|
||||||
|
## Our Pledge
|
||||||
|
|
||||||
|
In the interest of fostering an open and welcoming environment, we as
|
||||||
|
contributors and maintainers pledge to make participation in our project and
|
||||||
|
our community a harassment-free experience for everyone, regardless of age, body
|
||||||
|
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
||||||
|
level of experience, education, socio-economic status, nationality, personal
|
||||||
|
appearance, race, religion, or sexual identity and orientation.
|
||||||
|
|
||||||
|
## Our Standards
|
||||||
|
|
||||||
|
Examples of behavior that contributes to creating a positive environment
|
||||||
|
include:
|
||||||
|
|
||||||
|
* Using welcoming and inclusive language
|
||||||
|
* Being respectful of differing viewpoints and experiences
|
||||||
|
* Gracefully accepting constructive criticism
|
||||||
|
* Focusing on what is best for the community
|
||||||
|
* Showing empathy towards other community members
|
||||||
|
|
||||||
|
Examples of unacceptable behavior by participants include:
|
||||||
|
|
||||||
|
* The use of sexualized language or imagery and unwelcome sexual attention or
|
||||||
|
advances
|
||||||
|
* Trolling, insulting/derogatory comments, and personal or political attacks
|
||||||
|
* Public or private harassment
|
||||||
|
* Publishing others' private information, such as a physical or electronic
|
||||||
|
address, without explicit permission
|
||||||
|
* Other conduct which could reasonably be considered inappropriate in a
|
||||||
|
professional setting
|
||||||
|
|
||||||
|
## Our Responsibilities
|
||||||
|
|
||||||
|
Project maintainers are responsible for clarifying the standards of acceptable
|
||||||
|
behavior and are expected to take appropriate and fair corrective action in
|
||||||
|
response to any instances of unacceptable behavior.
|
||||||
|
|
||||||
|
Project maintainers have the right and responsibility to remove, edit, or
|
||||||
|
reject comments, commits, code, wiki edits, issues, and other contributions
|
||||||
|
that are not aligned to this Code of Conduct, or to ban temporarily or
|
||||||
|
permanently any contributor for other behaviors that they deem inappropriate,
|
||||||
|
threatening, offensive, or harmful.
|
||||||
|
|
||||||
|
## Scope
|
||||||
|
|
||||||
|
This Code of Conduct applies within all project spaces, and it also applies when
|
||||||
|
an individual is representing the project or its community in public spaces.
|
||||||
|
Examples of representing a project or community include using an official
|
||||||
|
project e-mail address, posting via an official social media account, or acting
|
||||||
|
as an appointed representative at an online or offline event. Representation of
|
||||||
|
a project may be further defined and clarified by project maintainers.
|
||||||
|
|
||||||
|
This Code of Conduct also applies outside the project spaces when there is a
|
||||||
|
reasonable belief that an individual's behavior may have a negative impact on
|
||||||
|
the project or its community.
|
||||||
|
|
||||||
|
## Enforcement
|
||||||
|
|
||||||
|
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||||
|
reported by contacting the project team at <opensource-conduct@fb.com>. All
|
||||||
|
complaints will be reviewed and investigated and will result in a response that
|
||||||
|
is deemed necessary and appropriate to the circumstances. The project team is
|
||||||
|
obligated to maintain confidentiality with regard to the reporter of an incident.
|
||||||
|
Further details of specific enforcement policies may be posted separately.
|
||||||
|
|
||||||
|
Project maintainers who do not follow or enforce the Code of Conduct in good
|
||||||
|
faith may face temporary or permanent repercussions as determined by other
|
||||||
|
members of the project's leadership.
|
||||||
|
|
||||||
|
## Attribution
|
||||||
|
|
||||||
|
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
||||||
|
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
||||||
|
|
||||||
|
[homepage]: https://www.contributor-covenant.org
|
||||||
|
|
||||||
|
For answers to common questions about this code of conduct, see
|
||||||
|
https://www.contributor-covenant.org/faq
|
28
CONTRIBUTING.md
Normal file
28
CONTRIBUTING.md
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# CoTracker
|
||||||
|
We want to make contributing to this project as easy and transparent as possible.
|
||||||
|
|
||||||
|
## Pull Requests
|
||||||
|
We actively welcome your pull requests.
|
||||||
|
|
||||||
|
1. Fork the repo and create your branch from `main`.
|
||||||
|
2. If you've changed APIs, update the documentation.
|
||||||
|
3. Make sure your code lints.
|
||||||
|
4. If you haven't already, complete the Contributor License Agreement ("CLA").
|
||||||
|
|
||||||
|
## Contributor License Agreement ("CLA")
|
||||||
|
In order to accept your pull request, we need you to submit a CLA. You only need
|
||||||
|
to do this once to work on any of Meta's open source projects.
|
||||||
|
|
||||||
|
Complete your CLA here: <https://code.facebook.com/cla>
|
||||||
|
|
||||||
|
## Issues
|
||||||
|
We use GitHub issues to track public bugs. Please ensure your description is
|
||||||
|
clear and has sufficient instructions to be able to reproduce the issue.
|
||||||
|
|
||||||
|
Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
||||||
|
disclosure of security bugs. In those cases, please go through the process
|
||||||
|
outlined on that page and do not file a public issue.
|
||||||
|
|
||||||
|
## License
|
||||||
|
By contributing to CoTracker, you agree that your contributions will be licensed
|
||||||
|
under the LICENSE file in the root directory of this source tree.
|
399
LICENSE.md
Normal file
399
LICENSE.md
Normal file
@ -0,0 +1,399 @@
|
|||||||
|
Attribution-NonCommercial 4.0 International
|
||||||
|
|
||||||
|
=======================================================================
|
||||||
|
|
||||||
|
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
||||||
|
does not provide legal services or legal advice. Distribution of
|
||||||
|
Creative Commons public licenses does not create a lawyer-client or
|
||||||
|
other relationship. Creative Commons makes its licenses and related
|
||||||
|
information available on an "as-is" basis. Creative Commons gives no
|
||||||
|
warranties regarding its licenses, any material licensed under their
|
||||||
|
terms and conditions, or any related information. Creative Commons
|
||||||
|
disclaims all liability for damages resulting from their use to the
|
||||||
|
fullest extent possible.
|
||||||
|
|
||||||
|
Using Creative Commons Public Licenses
|
||||||
|
|
||||||
|
Creative Commons public licenses provide a standard set of terms and
|
||||||
|
conditions that creators and other rights holders may use to share
|
||||||
|
original works of authorship and other material subject to copyright
|
||||||
|
and certain other rights specified in the public license below. The
|
||||||
|
following considerations are for informational purposes only, are not
|
||||||
|
exhaustive, and do not form part of our licenses.
|
||||||
|
|
||||||
|
Considerations for licensors: Our public licenses are
|
||||||
|
intended for use by those authorized to give the public
|
||||||
|
permission to use material in ways otherwise restricted by
|
||||||
|
copyright and certain other rights. Our licenses are
|
||||||
|
irrevocable. Licensors should read and understand the terms
|
||||||
|
and conditions of the license they choose before applying it.
|
||||||
|
Licensors should also secure all rights necessary before
|
||||||
|
applying our licenses so that the public can reuse the
|
||||||
|
material as expected. Licensors should clearly mark any
|
||||||
|
material not subject to the license. This includes other CC-
|
||||||
|
licensed material, or material used under an exception or
|
||||||
|
limitation to copyright. More considerations for licensors:
|
||||||
|
wiki.creativecommons.org/Considerations_for_licensors
|
||||||
|
|
||||||
|
Considerations for the public: By using one of our public
|
||||||
|
licenses, a licensor grants the public permission to use the
|
||||||
|
licensed material under specified terms and conditions. If
|
||||||
|
the licensor's permission is not necessary for any reason--for
|
||||||
|
example, because of any applicable exception or limitation to
|
||||||
|
copyright--then that use is not regulated by the license. Our
|
||||||
|
licenses grant only permissions under copyright and certain
|
||||||
|
other rights that a licensor has authority to grant. Use of
|
||||||
|
the licensed material may still be restricted for other
|
||||||
|
reasons, including because others have copyright or other
|
||||||
|
rights in the material. A licensor may make special requests,
|
||||||
|
such as asking that all changes be marked or described.
|
||||||
|
Although not required by our licenses, you are encouraged to
|
||||||
|
respect those requests where reasonable. More_considerations
|
||||||
|
for the public:
|
||||||
|
wiki.creativecommons.org/Considerations_for_licensees
|
||||||
|
|
||||||
|
=======================================================================
|
||||||
|
|
||||||
|
Creative Commons Attribution-NonCommercial 4.0 International Public
|
||||||
|
License
|
||||||
|
|
||||||
|
By exercising the Licensed Rights (defined below), You accept and agree
|
||||||
|
to be bound by the terms and conditions of this Creative Commons
|
||||||
|
Attribution-NonCommercial 4.0 International Public License ("Public
|
||||||
|
License"). To the extent this Public License may be interpreted as a
|
||||||
|
contract, You are granted the Licensed Rights in consideration of Your
|
||||||
|
acceptance of these terms and conditions, and the Licensor grants You
|
||||||
|
such rights in consideration of benefits the Licensor receives from
|
||||||
|
making the Licensed Material available under these terms and
|
||||||
|
conditions.
|
||||||
|
|
||||||
|
Section 1 -- Definitions.
|
||||||
|
|
||||||
|
a. Adapted Material means material subject to Copyright and Similar
|
||||||
|
Rights that is derived from or based upon the Licensed Material
|
||||||
|
and in which the Licensed Material is translated, altered,
|
||||||
|
arranged, transformed, or otherwise modified in a manner requiring
|
||||||
|
permission under the Copyright and Similar Rights held by the
|
||||||
|
Licensor. For purposes of this Public License, where the Licensed
|
||||||
|
Material is a musical work, performance, or sound recording,
|
||||||
|
Adapted Material is always produced where the Licensed Material is
|
||||||
|
synched in timed relation with a moving image.
|
||||||
|
|
||||||
|
b. Adapter's License means the license You apply to Your Copyright
|
||||||
|
and Similar Rights in Your contributions to Adapted Material in
|
||||||
|
accordance with the terms and conditions of this Public License.
|
||||||
|
|
||||||
|
c. Copyright and Similar Rights means copyright and/or similar rights
|
||||||
|
closely related to copyright including, without limitation,
|
||||||
|
performance, broadcast, sound recording, and Sui Generis Database
|
||||||
|
Rights, without regard to how the rights are labeled or
|
||||||
|
categorized. For purposes of this Public License, the rights
|
||||||
|
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
||||||
|
Rights.
|
||||||
|
d. Effective Technological Measures means those measures that, in the
|
||||||
|
absence of proper authority, may not be circumvented under laws
|
||||||
|
fulfilling obligations under Article 11 of the WIPO Copyright
|
||||||
|
Treaty adopted on December 20, 1996, and/or similar international
|
||||||
|
agreements.
|
||||||
|
|
||||||
|
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
||||||
|
any other exception or limitation to Copyright and Similar Rights
|
||||||
|
that applies to Your use of the Licensed Material.
|
||||||
|
|
||||||
|
f. Licensed Material means the artistic or literary work, database,
|
||||||
|
or other material to which the Licensor applied this Public
|
||||||
|
License.
|
||||||
|
|
||||||
|
g. Licensed Rights means the rights granted to You subject to the
|
||||||
|
terms and conditions of this Public License, which are limited to
|
||||||
|
all Copyright and Similar Rights that apply to Your use of the
|
||||||
|
Licensed Material and that the Licensor has authority to license.
|
||||||
|
|
||||||
|
h. Licensor means the individual(s) or entity(ies) granting rights
|
||||||
|
under this Public License.
|
||||||
|
|
||||||
|
i. NonCommercial means not primarily intended for or directed towards
|
||||||
|
commercial advantage or monetary compensation. For purposes of
|
||||||
|
this Public License, the exchange of the Licensed Material for
|
||||||
|
other material subject to Copyright and Similar Rights by digital
|
||||||
|
file-sharing or similar means is NonCommercial provided there is
|
||||||
|
no payment of monetary compensation in connection with the
|
||||||
|
exchange.
|
||||||
|
|
||||||
|
j. Share means to provide material to the public by any means or
|
||||||
|
process that requires permission under the Licensed Rights, such
|
||||||
|
as reproduction, public display, public performance, distribution,
|
||||||
|
dissemination, communication, or importation, and to make material
|
||||||
|
available to the public including in ways that members of the
|
||||||
|
public may access the material from a place and at a time
|
||||||
|
individually chosen by them.
|
||||||
|
|
||||||
|
k. Sui Generis Database Rights means rights other than copyright
|
||||||
|
resulting from Directive 96/9/EC of the European Parliament and of
|
||||||
|
the Council of 11 March 1996 on the legal protection of databases,
|
||||||
|
as amended and/or succeeded, as well as other essentially
|
||||||
|
equivalent rights anywhere in the world.
|
||||||
|
|
||||||
|
l. You means the individual or entity exercising the Licensed Rights
|
||||||
|
under this Public License. Your has a corresponding meaning.
|
||||||
|
|
||||||
|
Section 2 -- Scope.
|
||||||
|
|
||||||
|
a. License grant.
|
||||||
|
|
||||||
|
1. Subject to the terms and conditions of this Public License,
|
||||||
|
the Licensor hereby grants You a worldwide, royalty-free,
|
||||||
|
non-sublicensable, non-exclusive, irrevocable license to
|
||||||
|
exercise the Licensed Rights in the Licensed Material to:
|
||||||
|
|
||||||
|
a. reproduce and Share the Licensed Material, in whole or
|
||||||
|
in part, for NonCommercial purposes only; and
|
||||||
|
|
||||||
|
b. produce, reproduce, and Share Adapted Material for
|
||||||
|
NonCommercial purposes only.
|
||||||
|
|
||||||
|
2. Exceptions and Limitations. For the avoidance of doubt, where
|
||||||
|
Exceptions and Limitations apply to Your use, this Public
|
||||||
|
License does not apply, and You do not need to comply with
|
||||||
|
its terms and conditions.
|
||||||
|
|
||||||
|
3. Term. The term of this Public License is specified in Section
|
||||||
|
6(a).
|
||||||
|
|
||||||
|
4. Media and formats; technical modifications allowed. The
|
||||||
|
Licensor authorizes You to exercise the Licensed Rights in
|
||||||
|
all media and formats whether now known or hereafter created,
|
||||||
|
and to make technical modifications necessary to do so. The
|
||||||
|
Licensor waives and/or agrees not to assert any right or
|
||||||
|
authority to forbid You from making technical modifications
|
||||||
|
necessary to exercise the Licensed Rights, including
|
||||||
|
technical modifications necessary to circumvent Effective
|
||||||
|
Technological Measures. For purposes of this Public License,
|
||||||
|
simply making modifications authorized by this Section 2(a)
|
||||||
|
(4) never produces Adapted Material.
|
||||||
|
|
||||||
|
5. Downstream recipients.
|
||||||
|
|
||||||
|
a. Offer from the Licensor -- Licensed Material. Every
|
||||||
|
recipient of the Licensed Material automatically
|
||||||
|
receives an offer from the Licensor to exercise the
|
||||||
|
Licensed Rights under the terms and conditions of this
|
||||||
|
Public License.
|
||||||
|
|
||||||
|
b. No downstream restrictions. You may not offer or impose
|
||||||
|
any additional or different terms or conditions on, or
|
||||||
|
apply any Effective Technological Measures to, the
|
||||||
|
Licensed Material if doing so restricts exercise of the
|
||||||
|
Licensed Rights by any recipient of the Licensed
|
||||||
|
Material.
|
||||||
|
|
||||||
|
6. No endorsement. Nothing in this Public License constitutes or
|
||||||
|
may be construed as permission to assert or imply that You
|
||||||
|
are, or that Your use of the Licensed Material is, connected
|
||||||
|
with, or sponsored, endorsed, or granted official status by,
|
||||||
|
the Licensor or others designated to receive attribution as
|
||||||
|
provided in Section 3(a)(1)(A)(i).
|
||||||
|
|
||||||
|
b. Other rights.
|
||||||
|
|
||||||
|
1. Moral rights, such as the right of integrity, are not
|
||||||
|
licensed under this Public License, nor are publicity,
|
||||||
|
privacy, and/or other similar personality rights; however, to
|
||||||
|
the extent possible, the Licensor waives and/or agrees not to
|
||||||
|
assert any such rights held by the Licensor to the limited
|
||||||
|
extent necessary to allow You to exercise the Licensed
|
||||||
|
Rights, but not otherwise.
|
||||||
|
|
||||||
|
2. Patent and trademark rights are not licensed under this
|
||||||
|
Public License.
|
||||||
|
|
||||||
|
3. To the extent possible, the Licensor waives any right to
|
||||||
|
collect royalties from You for the exercise of the Licensed
|
||||||
|
Rights, whether directly or through a collecting society
|
||||||
|
under any voluntary or waivable statutory or compulsory
|
||||||
|
licensing scheme. In all other cases the Licensor expressly
|
||||||
|
reserves any right to collect such royalties, including when
|
||||||
|
the Licensed Material is used other than for NonCommercial
|
||||||
|
purposes.
|
||||||
|
|
||||||
|
Section 3 -- License Conditions.
|
||||||
|
|
||||||
|
Your exercise of the Licensed Rights is expressly made subject to the
|
||||||
|
following conditions.
|
||||||
|
|
||||||
|
a. Attribution.
|
||||||
|
|
||||||
|
1. If You Share the Licensed Material (including in modified
|
||||||
|
form), You must:
|
||||||
|
|
||||||
|
a. retain the following if it is supplied by the Licensor
|
||||||
|
with the Licensed Material:
|
||||||
|
|
||||||
|
i. identification of the creator(s) of the Licensed
|
||||||
|
Material and any others designated to receive
|
||||||
|
attribution, in any reasonable manner requested by
|
||||||
|
the Licensor (including by pseudonym if
|
||||||
|
designated);
|
||||||
|
|
||||||
|
ii. a copyright notice;
|
||||||
|
|
||||||
|
iii. a notice that refers to this Public License;
|
||||||
|
|
||||||
|
iv. a notice that refers to the disclaimer of
|
||||||
|
warranties;
|
||||||
|
|
||||||
|
v. a URI or hyperlink to the Licensed Material to the
|
||||||
|
extent reasonably practicable;
|
||||||
|
|
||||||
|
b. indicate if You modified the Licensed Material and
|
||||||
|
retain an indication of any previous modifications; and
|
||||||
|
|
||||||
|
c. indicate the Licensed Material is licensed under this
|
||||||
|
Public License, and include the text of, or the URI or
|
||||||
|
hyperlink to, this Public License.
|
||||||
|
|
||||||
|
2. You may satisfy the conditions in Section 3(a)(1) in any
|
||||||
|
reasonable manner based on the medium, means, and context in
|
||||||
|
which You Share the Licensed Material. For example, it may be
|
||||||
|
reasonable to satisfy the conditions by providing a URI or
|
||||||
|
hyperlink to a resource that includes the required
|
||||||
|
information.
|
||||||
|
|
||||||
|
3. If requested by the Licensor, You must remove any of the
|
||||||
|
information required by Section 3(a)(1)(A) to the extent
|
||||||
|
reasonably practicable.
|
||||||
|
|
||||||
|
4. If You Share Adapted Material You produce, the Adapter's
|
||||||
|
License You apply must not prevent recipients of the Adapted
|
||||||
|
Material from complying with this Public License.
|
||||||
|
|
||||||
|
Section 4 -- Sui Generis Database Rights.
|
||||||
|
|
||||||
|
Where the Licensed Rights include Sui Generis Database Rights that
|
||||||
|
apply to Your use of the Licensed Material:
|
||||||
|
|
||||||
|
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
||||||
|
to extract, reuse, reproduce, and Share all or a substantial
|
||||||
|
portion of the contents of the database for NonCommercial purposes
|
||||||
|
only;
|
||||||
|
|
||||||
|
b. if You include all or a substantial portion of the database
|
||||||
|
contents in a database in which You have Sui Generis Database
|
||||||
|
Rights, then the database in which You have Sui Generis Database
|
||||||
|
Rights (but not its individual contents) is Adapted Material; and
|
||||||
|
|
||||||
|
c. You must comply with the conditions in Section 3(a) if You Share
|
||||||
|
all or a substantial portion of the contents of the database.
|
||||||
|
|
||||||
|
For the avoidance of doubt, this Section 4 supplements and does not
|
||||||
|
replace Your obligations under this Public License where the Licensed
|
||||||
|
Rights include other Copyright and Similar Rights.
|
||||||
|
|
||||||
|
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
||||||
|
|
||||||
|
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
||||||
|
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
||||||
|
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
||||||
|
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
||||||
|
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
||||||
|
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
||||||
|
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
||||||
|
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
||||||
|
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
||||||
|
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
||||||
|
|
||||||
|
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
||||||
|
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
||||||
|
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
||||||
|
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
||||||
|
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
||||||
|
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
||||||
|
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
||||||
|
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
||||||
|
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
||||||
|
|
||||||
|
c. The disclaimer of warranties and limitation of liability provided
|
||||||
|
above shall be interpreted in a manner that, to the extent
|
||||||
|
possible, most closely approximates an absolute disclaimer and
|
||||||
|
waiver of all liability.
|
||||||
|
|
||||||
|
Section 6 -- Term and Termination.
|
||||||
|
|
||||||
|
a. This Public License applies for the term of the Copyright and
|
||||||
|
Similar Rights licensed here. However, if You fail to comply with
|
||||||
|
this Public License, then Your rights under this Public License
|
||||||
|
terminate automatically.
|
||||||
|
|
||||||
|
b. Where Your right to use the Licensed Material has terminated under
|
||||||
|
Section 6(a), it reinstates:
|
||||||
|
|
||||||
|
1. automatically as of the date the violation is cured, provided
|
||||||
|
it is cured within 30 days of Your discovery of the
|
||||||
|
violation; or
|
||||||
|
|
||||||
|
2. upon express reinstatement by the Licensor.
|
||||||
|
|
||||||
|
For the avoidance of doubt, this Section 6(b) does not affect any
|
||||||
|
right the Licensor may have to seek remedies for Your violations
|
||||||
|
of this Public License.
|
||||||
|
|
||||||
|
c. For the avoidance of doubt, the Licensor may also offer the
|
||||||
|
Licensed Material under separate terms or conditions or stop
|
||||||
|
distributing the Licensed Material at any time; however, doing so
|
||||||
|
will not terminate this Public License.
|
||||||
|
|
||||||
|
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
||||||
|
License.
|
||||||
|
|
||||||
|
Section 7 -- Other Terms and Conditions.
|
||||||
|
|
||||||
|
a. The Licensor shall not be bound by any additional or different
|
||||||
|
terms or conditions communicated by You unless expressly agreed.
|
||||||
|
|
||||||
|
b. Any arrangements, understandings, or agreements regarding the
|
||||||
|
Licensed Material not stated herein are separate from and
|
||||||
|
independent of the terms and conditions of this Public License.
|
||||||
|
|
||||||
|
Section 8 -- Interpretation.
|
||||||
|
|
||||||
|
a. For the avoidance of doubt, this Public License does not, and
|
||||||
|
shall not be interpreted to, reduce, limit, restrict, or impose
|
||||||
|
conditions on any use of the Licensed Material that could lawfully
|
||||||
|
be made without permission under this Public License.
|
||||||
|
|
||||||
|
b. To the extent possible, if any provision of this Public License is
|
||||||
|
deemed unenforceable, it shall be automatically reformed to the
|
||||||
|
minimum extent necessary to make it enforceable. If the provision
|
||||||
|
cannot be reformed, it shall be severed from this Public License
|
||||||
|
without affecting the enforceability of the remaining terms and
|
||||||
|
conditions.
|
||||||
|
|
||||||
|
c. No term or condition of this Public License will be waived and no
|
||||||
|
failure to comply consented to unless expressly agreed to by the
|
||||||
|
Licensor.
|
||||||
|
|
||||||
|
d. Nothing in this Public License constitutes or may be interpreted
|
||||||
|
as a limitation upon, or waiver of, any privileges and immunities
|
||||||
|
that apply to the Licensor or You, including from the legal
|
||||||
|
processes of any jurisdiction or authority.
|
||||||
|
|
||||||
|
=======================================================================
|
||||||
|
|
||||||
|
Creative Commons is not a party to its public
|
||||||
|
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
||||||
|
its public licenses to material it publishes and in those instances
|
||||||
|
will be considered the “Licensor.” The text of the Creative Commons
|
||||||
|
public licenses is dedicated to the public domain under the CC0 Public
|
||||||
|
Domain Dedication. Except for the limited purpose of indicating that
|
||||||
|
material is shared under a Creative Commons public license or as
|
||||||
|
otherwise permitted by the Creative Commons policies published at
|
||||||
|
creativecommons.org/policies, Creative Commons does not authorize the
|
||||||
|
use of the trademark "Creative Commons" or any other trademark or logo
|
||||||
|
of Creative Commons without its prior written consent including,
|
||||||
|
without limitation, in connection with any unauthorized modifications
|
||||||
|
to any of its public licenses or any other arrangements,
|
||||||
|
understandings, or agreements concerning use of licensed material. For
|
||||||
|
the avoidance of doubt, this paragraph does not form part of the
|
||||||
|
public licenses.
|
||||||
|
|
||||||
|
Creative Commons may be contacted at creativecommons.org.
|
94
README.md
Normal file
94
README.md
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
# CoTracker: It is Better to Track Together
|
||||||
|
|
||||||
|
**[Meta AI Research, FAIR](https://ai.facebook.com/research/)**; **[University of Oxford, VGG](https://www.robots.ox.ac.uk/~vgg/)**
|
||||||
|
|
||||||
|
[Nikita Karaev](https://nikitakaraevv.github.io/), [Ignacio Rocco](https://www.irocco.info/), [Benjamin Graham](https://ai.facebook.com/people/benjamin-graham/), [Natalia Neverova](https://nneverova.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/)
|
||||||
|
|
||||||
|
[[`Paper`]()] [[`Project`](https://co-tracker.github.io/)] [[`BibTeX`](#citing-cotracker)]
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
**CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow.
|
||||||
|
|
||||||
|
CoTracker can track:
|
||||||
|
- **Every pixel** within a video
|
||||||
|
- Points sampled on a regular grid on any video frame
|
||||||
|
- Manually selected points
|
||||||
|
|
||||||
|
Try these tracking modes for yourself with our [Colab demo](https://github.com/facebookresearch/co-tracker/notebooks/demo.ipynb).
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Installation Instructions
|
||||||
|
Ensure you have both PyTorch and TorchVision installed on your system. Follow the instructions [here](https://pytorch.org/get-started/locally/) for the installation. We strongly recommend installing both PyTorch and TorchVision with CUDA support.
|
||||||
|
|
||||||
|
## Steps to Install CoTracker and its dependencies:
|
||||||
|
```
|
||||||
|
git clone https://github.com/facebookresearch/co-tracker
|
||||||
|
cd co-tracker
|
||||||
|
pip install -e .
|
||||||
|
pip install opencv-python einops timm matplotlib moviepy flow_vis
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Model Weights Download:
|
||||||
|
```
|
||||||
|
mkdir checkpoints
|
||||||
|
cd checkpoints
|
||||||
|
wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth
|
||||||
|
wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_12.pth
|
||||||
|
wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_8_wind_16.pth
|
||||||
|
cd ..
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Running the Demo:
|
||||||
|
Try our [Colab demo](https://github.com/facebookresearch/co-tracker/notebooks/demo.ipynb) or run a local demo with 10*10 points sampled on a grid on the first frame of a video:
|
||||||
|
```
|
||||||
|
python demo.py --grid_size 10
|
||||||
|
```
|
||||||
|
|
||||||
|
## Evaluation
|
||||||
|
To reproduce the results presented in the paper, download the following datasets:
|
||||||
|
- [TAP-Vid](https://github.com/deepmind/tapnet)
|
||||||
|
- [BADJA](https://github.com/benjiebob/BADJA)
|
||||||
|
- [ZJU-Mocap (FastCapture)](https://arxiv.org/abs/2303.11898)
|
||||||
|
|
||||||
|
And install the necessary dependencies:
|
||||||
|
```
|
||||||
|
pip install hydra-core==1.1.0 mediapy tensorboard
|
||||||
|
```
|
||||||
|
Then, execute the following command to evaluate on BADJA:
|
||||||
|
```
|
||||||
|
python ./cotracker/evaluation/evaluate.py --config-name eval_badja exp_dir=./eval_outputs dataset_root=your/badja/path
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training
|
||||||
|
To train the CoTracker as described in our paper, you first need to generate annotations for [Google Kubric](https://github.com/google-research/kubric) MOVI-f dataset. Instructions for annotation generation can be found [here](https://github.com/deepmind/tapnet).
|
||||||
|
|
||||||
|
Once you have the annotated dataset, you need to make sure you followed the steps for evaluation setup and install the training dependencies:
|
||||||
|
```
|
||||||
|
pip install pytorch_lightning==1.6.0
|
||||||
|
```
|
||||||
|
launch training on Kubric. Our model was trained using 32 GPUs, and you can adjust the parameters to best suit your hardware setup.
|
||||||
|
```
|
||||||
|
python train.py --batch_size 1 --num_workers 28 \
|
||||||
|
--num_steps 50000 --ckpt_path ./ --model_name cotracker \
|
||||||
|
--save_freq 200 --sequence_len 24 --eval_datasets tapvid_davis_first badja \
|
||||||
|
--traj_per_sample 256 --sliding_window_len 8 --updateformer_space_depth 6 --updateformer_time_depth 6 \
|
||||||
|
--save_every_n_epoch 10 --evaluate_every_n_epoch 10 --model_stride 4
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
The majority of CoTracker is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Particle Video Revisited is licensed under the MIT license, TAP-Vid is licensed under the Apache 2.0 license.
|
||||||
|
|
||||||
|
## Citing CoTracker
|
||||||
|
If you find our repository useful, please consider giving it a star ⭐ and citing our paper in your work:
|
||||||
|
```
|
||||||
|
@article{karaev2023cotracker,
|
||||||
|
title={CoTracker: It is Better to Track Together},
|
||||||
|
author={Nikita Karaev and Ignacio Rocco and Benjamin Graham and Natalia Neverova and Andrea Vedaldi and Christian Rupprecht},
|
||||||
|
journal={arxiv},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
```
|
BIN
assets/apple.mp4
Normal file
BIN
assets/apple.mp4
Normal file
Binary file not shown.
BIN
assets/apple_mask.png
Normal file
BIN
assets/apple_mask.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 14 KiB |
BIN
assets/bmx-bumps.gif
Normal file
BIN
assets/bmx-bumps.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 7.3 MiB |
5
cotracker/__init__.py
Normal file
5
cotracker/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
5
cotracker/datasets/__init__.py
Normal file
5
cotracker/datasets/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
390
cotracker/datasets/badja_dataset.py
Normal file
390
cotracker/datasets/badja_dataset.py
Normal file
@ -0,0 +1,390 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
import json
|
||||||
|
import imageio
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from cotracker.datasets.utils import CoTrackerData, resize_sample
|
||||||
|
|
||||||
|
IGNORE_ANIMALS = [
|
||||||
|
# "bear.json",
|
||||||
|
# "camel.json",
|
||||||
|
"cat_jump.json"
|
||||||
|
# "cows.json",
|
||||||
|
# "dog.json",
|
||||||
|
# "dog-agility.json",
|
||||||
|
# "horsejump-high.json",
|
||||||
|
# "horsejump-low.json",
|
||||||
|
# "impala0.json",
|
||||||
|
# "rs_dog.json"
|
||||||
|
"tiger.json"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SMALJointCatalog(Enum):
|
||||||
|
# body_0 = 0
|
||||||
|
# body_1 = 1
|
||||||
|
# body_2 = 2
|
||||||
|
# body_3 = 3
|
||||||
|
# body_4 = 4
|
||||||
|
# body_5 = 5
|
||||||
|
# body_6 = 6
|
||||||
|
# upper_right_0 = 7
|
||||||
|
upper_right_1 = 8
|
||||||
|
upper_right_2 = 9
|
||||||
|
upper_right_3 = 10
|
||||||
|
# upper_left_0 = 11
|
||||||
|
upper_left_1 = 12
|
||||||
|
upper_left_2 = 13
|
||||||
|
upper_left_3 = 14
|
||||||
|
neck_lower = 15
|
||||||
|
# neck_upper = 16
|
||||||
|
# lower_right_0 = 17
|
||||||
|
lower_right_1 = 18
|
||||||
|
lower_right_2 = 19
|
||||||
|
lower_right_3 = 20
|
||||||
|
# lower_left_0 = 21
|
||||||
|
lower_left_1 = 22
|
||||||
|
lower_left_2 = 23
|
||||||
|
lower_left_3 = 24
|
||||||
|
tail_0 = 25
|
||||||
|
# tail_1 = 26
|
||||||
|
# tail_2 = 27
|
||||||
|
tail_3 = 28
|
||||||
|
# tail_4 = 29
|
||||||
|
# tail_5 = 30
|
||||||
|
tail_6 = 31
|
||||||
|
jaw = 32
|
||||||
|
nose = 33 # ADDED JOINT FOR VERTEX 1863
|
||||||
|
# chin = 34 # ADDED JOINT FOR VERTEX 26
|
||||||
|
right_ear = 35 # ADDED JOINT FOR VERTEX 149
|
||||||
|
left_ear = 36 # ADDED JOINT FOR VERTEX 2124
|
||||||
|
|
||||||
|
|
||||||
|
class SMALJointInfo:
|
||||||
|
def __init__(self):
|
||||||
|
# These are the
|
||||||
|
self.annotated_classes = np.array(
|
||||||
|
[
|
||||||
|
8,
|
||||||
|
9,
|
||||||
|
10, # upper_right
|
||||||
|
12,
|
||||||
|
13,
|
||||||
|
14, # upper_left
|
||||||
|
15, # neck
|
||||||
|
18,
|
||||||
|
19,
|
||||||
|
20, # lower_right
|
||||||
|
22,
|
||||||
|
23,
|
||||||
|
24, # lower_left
|
||||||
|
25,
|
||||||
|
28,
|
||||||
|
31, # tail
|
||||||
|
32,
|
||||||
|
33, # head
|
||||||
|
35, # right_ear
|
||||||
|
36,
|
||||||
|
]
|
||||||
|
) # left_ear
|
||||||
|
|
||||||
|
self.annotated_markers = np.array(
|
||||||
|
[
|
||||||
|
cv2.MARKER_CROSS,
|
||||||
|
cv2.MARKER_STAR,
|
||||||
|
cv2.MARKER_TRIANGLE_DOWN,
|
||||||
|
cv2.MARKER_CROSS,
|
||||||
|
cv2.MARKER_STAR,
|
||||||
|
cv2.MARKER_TRIANGLE_DOWN,
|
||||||
|
cv2.MARKER_CROSS,
|
||||||
|
cv2.MARKER_CROSS,
|
||||||
|
cv2.MARKER_STAR,
|
||||||
|
cv2.MARKER_TRIANGLE_DOWN,
|
||||||
|
cv2.MARKER_CROSS,
|
||||||
|
cv2.MARKER_STAR,
|
||||||
|
cv2.MARKER_TRIANGLE_DOWN,
|
||||||
|
cv2.MARKER_CROSS,
|
||||||
|
cv2.MARKER_STAR,
|
||||||
|
cv2.MARKER_TRIANGLE_DOWN,
|
||||||
|
cv2.MARKER_CROSS,
|
||||||
|
cv2.MARKER_STAR,
|
||||||
|
cv2.MARKER_CROSS,
|
||||||
|
cv2.MARKER_CROSS,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.joint_regions = np.array(
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
4,
|
||||||
|
4,
|
||||||
|
4,
|
||||||
|
5,
|
||||||
|
5,
|
||||||
|
5,
|
||||||
|
5,
|
||||||
|
6,
|
||||||
|
6,
|
||||||
|
6,
|
||||||
|
6,
|
||||||
|
6,
|
||||||
|
6,
|
||||||
|
6,
|
||||||
|
7,
|
||||||
|
7,
|
||||||
|
7,
|
||||||
|
8,
|
||||||
|
9,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.annotated_joint_region = self.joint_regions[self.annotated_classes]
|
||||||
|
self.region_colors = np.array(
|
||||||
|
[
|
||||||
|
[250, 190, 190], # body, light pink
|
||||||
|
[60, 180, 75], # upper_right, green
|
||||||
|
[230, 25, 75], # upper_left, red
|
||||||
|
[128, 0, 0], # neck, maroon
|
||||||
|
[0, 130, 200], # lower_right, blue
|
||||||
|
[255, 255, 25], # lower_left, yellow
|
||||||
|
[240, 50, 230], # tail, majenta
|
||||||
|
[245, 130, 48], # jaw / nose / chin, orange
|
||||||
|
[29, 98, 115], # right_ear, turquoise
|
||||||
|
[255, 153, 204],
|
||||||
|
]
|
||||||
|
) # left_ear, pink
|
||||||
|
|
||||||
|
self.joint_colors = np.array(self.region_colors)[self.annotated_joint_region]
|
||||||
|
|
||||||
|
|
||||||
|
class BADJAData:
|
||||||
|
def __init__(self, data_root, complete=False):
|
||||||
|
annotations_path = os.path.join(data_root, "joint_annotations")
|
||||||
|
|
||||||
|
self.animal_dict = {}
|
||||||
|
self.animal_count = 0
|
||||||
|
self.smal_joint_info = SMALJointInfo()
|
||||||
|
for __, animal_json in enumerate(sorted(os.listdir(annotations_path))):
|
||||||
|
if animal_json not in IGNORE_ANIMALS:
|
||||||
|
json_path = os.path.join(annotations_path, animal_json)
|
||||||
|
with open(json_path) as json_data:
|
||||||
|
animal_joint_data = json.load(json_data)
|
||||||
|
|
||||||
|
filenames = []
|
||||||
|
segnames = []
|
||||||
|
joints = []
|
||||||
|
visible = []
|
||||||
|
|
||||||
|
first_path = animal_joint_data[0]["segmentation_path"]
|
||||||
|
last_path = animal_joint_data[-1]["segmentation_path"]
|
||||||
|
first_frame = first_path.split("/")[-1]
|
||||||
|
last_frame = last_path.split("/")[-1]
|
||||||
|
|
||||||
|
if not "extra_videos" in first_path:
|
||||||
|
animal = first_path.split("/")[-2]
|
||||||
|
|
||||||
|
first_frame_int = int(first_frame.split(".")[0])
|
||||||
|
last_frame_int = int(last_frame.split(".")[0])
|
||||||
|
|
||||||
|
for fr in range(first_frame_int, last_frame_int + 1):
|
||||||
|
ref_file_name = os.path.join(
|
||||||
|
data_root,
|
||||||
|
"DAVIS/JPEGImages/Full-Resolution/%s/%05d.jpg"
|
||||||
|
% (animal, fr),
|
||||||
|
)
|
||||||
|
ref_seg_name = os.path.join(
|
||||||
|
data_root,
|
||||||
|
"DAVIS/Annotations/Full-Resolution/%s/%05d.png"
|
||||||
|
% (animal, fr),
|
||||||
|
)
|
||||||
|
|
||||||
|
foundit = False
|
||||||
|
for ind, image_annotation in enumerate(animal_joint_data):
|
||||||
|
file_name = os.path.join(
|
||||||
|
data_root, image_annotation["image_path"]
|
||||||
|
)
|
||||||
|
seg_name = os.path.join(
|
||||||
|
data_root, image_annotation["segmentation_path"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if file_name == ref_file_name:
|
||||||
|
foundit = True
|
||||||
|
label_ind = ind
|
||||||
|
|
||||||
|
if foundit:
|
||||||
|
image_annotation = animal_joint_data[label_ind]
|
||||||
|
file_name = os.path.join(
|
||||||
|
data_root, image_annotation["image_path"]
|
||||||
|
)
|
||||||
|
seg_name = os.path.join(
|
||||||
|
data_root, image_annotation["segmentation_path"]
|
||||||
|
)
|
||||||
|
joint = np.array(image_annotation["joints"])
|
||||||
|
vis = np.array(image_annotation["visibility"])
|
||||||
|
else:
|
||||||
|
file_name = ref_file_name
|
||||||
|
seg_name = ref_seg_name
|
||||||
|
joint = None
|
||||||
|
vis = None
|
||||||
|
|
||||||
|
filenames.append(file_name)
|
||||||
|
segnames.append(seg_name)
|
||||||
|
joints.append(joint)
|
||||||
|
visible.append(vis)
|
||||||
|
|
||||||
|
if len(filenames):
|
||||||
|
self.animal_dict[self.animal_count] = (
|
||||||
|
filenames,
|
||||||
|
segnames,
|
||||||
|
joints,
|
||||||
|
visible,
|
||||||
|
)
|
||||||
|
self.animal_count += 1
|
||||||
|
print("Loaded BADJA dataset")
|
||||||
|
|
||||||
|
def get_loader(self):
|
||||||
|
for __ in range(int(1e6)):
|
||||||
|
animal_id = np.random.choice(len(self.animal_dict.keys()))
|
||||||
|
filenames, segnames, joints, visible = self.animal_dict[animal_id]
|
||||||
|
|
||||||
|
image_id = np.random.randint(0, len(filenames))
|
||||||
|
|
||||||
|
seg_file = segnames[image_id]
|
||||||
|
image_file = filenames[image_id]
|
||||||
|
|
||||||
|
joints = joints[image_id].copy()
|
||||||
|
joints = joints[self.smal_joint_info.annotated_classes]
|
||||||
|
visible = visible[image_id][self.smal_joint_info.annotated_classes]
|
||||||
|
|
||||||
|
rgb_img = imageio.imread(image_file) # , mode='RGB')
|
||||||
|
sil_img = imageio.imread(seg_file) # , mode='RGB')
|
||||||
|
|
||||||
|
rgb_h, rgb_w, _ = rgb_img.shape
|
||||||
|
sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST)
|
||||||
|
|
||||||
|
yield rgb_img, sil_img, joints, visible, image_file
|
||||||
|
|
||||||
|
def get_video(self, animal_id):
|
||||||
|
filenames, segnames, joint, visible = self.animal_dict[animal_id]
|
||||||
|
|
||||||
|
rgbs = []
|
||||||
|
segs = []
|
||||||
|
joints = []
|
||||||
|
visibles = []
|
||||||
|
|
||||||
|
for s in range(len(filenames)):
|
||||||
|
image_file = filenames[s]
|
||||||
|
rgb_img = imageio.imread(image_file) # , mode='RGB')
|
||||||
|
rgb_h, rgb_w, _ = rgb_img.shape
|
||||||
|
|
||||||
|
seg_file = segnames[s]
|
||||||
|
sil_img = imageio.imread(seg_file) # , mode='RGB')
|
||||||
|
sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST)
|
||||||
|
|
||||||
|
jo = joint[s]
|
||||||
|
|
||||||
|
if jo is not None:
|
||||||
|
joi = joint[s].copy()
|
||||||
|
joi = joi[self.smal_joint_info.annotated_classes]
|
||||||
|
vis = visible[s][self.smal_joint_info.annotated_classes]
|
||||||
|
else:
|
||||||
|
joi = None
|
||||||
|
vis = None
|
||||||
|
|
||||||
|
rgbs.append(rgb_img)
|
||||||
|
segs.append(sil_img)
|
||||||
|
joints.append(joi)
|
||||||
|
visibles.append(vis)
|
||||||
|
|
||||||
|
return rgbs, segs, joints, visibles, filenames[0]
|
||||||
|
|
||||||
|
|
||||||
|
class BadjaDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self, data_root, max_seq_len=1000, dataset_resolution=(384, 512)
|
||||||
|
):
|
||||||
|
|
||||||
|
self.data_root = data_root
|
||||||
|
self.badja_data = BADJAData(data_root)
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.dataset_resolution = dataset_resolution
|
||||||
|
print(
|
||||||
|
"found %d unique videos in %s"
|
||||||
|
% (self.badja_data.animal_count, self.data_root)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
|
||||||
|
rgbs, segs, joints, visibles, filename = self.badja_data.get_video(index)
|
||||||
|
S = len(rgbs)
|
||||||
|
H, W, __ = rgbs[0].shape
|
||||||
|
H, W, __ = segs[0].shape
|
||||||
|
|
||||||
|
N, __ = joints[0].shape
|
||||||
|
|
||||||
|
# let's eliminate the Nones
|
||||||
|
# note the first one is guaranteed present
|
||||||
|
for s in range(1, S):
|
||||||
|
if joints[s] is None:
|
||||||
|
joints[s] = np.zeros_like(joints[0])
|
||||||
|
visibles[s] = np.zeros_like(visibles[0])
|
||||||
|
|
||||||
|
# eliminate the mystery dim
|
||||||
|
segs = [seg[:, :, 0] for seg in segs]
|
||||||
|
|
||||||
|
rgbs = np.stack(rgbs, 0)
|
||||||
|
segs = np.stack(segs, 0)
|
||||||
|
trajs = np.stack(joints, 0)
|
||||||
|
visibles = np.stack(visibles, 0)
|
||||||
|
|
||||||
|
rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0, 3, 1, 2).float()
|
||||||
|
segs = torch.from_numpy(segs).reshape(S, 1, H, W).float()
|
||||||
|
trajs = torch.from_numpy(trajs).reshape(S, N, 2).float()
|
||||||
|
visibles = torch.from_numpy(visibles).reshape(S, N)
|
||||||
|
|
||||||
|
rgbs = rgbs[: self.max_seq_len]
|
||||||
|
segs = segs[: self.max_seq_len]
|
||||||
|
trajs = trajs[: self.max_seq_len]
|
||||||
|
visibles = visibles[: self.max_seq_len]
|
||||||
|
# apparently the coords are in yx order
|
||||||
|
trajs = torch.flip(trajs, [2])
|
||||||
|
|
||||||
|
if "extra_videos" in filename:
|
||||||
|
seq_name = filename.split("/")[-3]
|
||||||
|
else:
|
||||||
|
seq_name = filename.split("/")[-2]
|
||||||
|
|
||||||
|
rgbs, trajs, segs = resize_sample(rgbs, trajs, segs, self.dataset_resolution)
|
||||||
|
|
||||||
|
return CoTrackerData(rgbs, segs, trajs, visibles, seq_name=seq_name)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.badja_data.animal_count
|
72
cotracker/datasets/fast_capture_dataset.py
Normal file
72
cotracker/datasets/fast_capture_dataset.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# from PIL import Image
|
||||||
|
import imageio
|
||||||
|
import numpy as np
|
||||||
|
from cotracker.datasets.utils import CoTrackerData, resize_sample
|
||||||
|
|
||||||
|
|
||||||
|
class FastCaptureDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_root,
|
||||||
|
max_seq_len=50,
|
||||||
|
max_num_points=20,
|
||||||
|
dataset_resolution=(384, 512),
|
||||||
|
):
|
||||||
|
|
||||||
|
self.data_root = data_root
|
||||||
|
self.seq_names = os.listdir(os.path.join(data_root, "renders_local_rm"))
|
||||||
|
self.pth_dir = os.path.join(data_root, "zju_tracking")
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.max_num_points = max_num_points
|
||||||
|
self.dataset_resolution = dataset_resolution
|
||||||
|
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
seq_name = self.seq_names[index]
|
||||||
|
spath = os.path.join(self.data_root, "renders_local_rm", seq_name)
|
||||||
|
pthpath = os.path.join(self.pth_dir, seq_name + ".pth")
|
||||||
|
|
||||||
|
rgbs = []
|
||||||
|
img_paths = sorted(os.listdir(spath))
|
||||||
|
for i, img_path in enumerate(img_paths):
|
||||||
|
if i < self.max_seq_len:
|
||||||
|
rgbs.append(imageio.imread(os.path.join(spath, img_path)))
|
||||||
|
|
||||||
|
annot_dict = torch.load(pthpath)
|
||||||
|
traj_2d = annot_dict["traj_2d"][:, :, : self.max_seq_len]
|
||||||
|
visibility = annot_dict["visibility"][:, : self.max_seq_len]
|
||||||
|
|
||||||
|
S = len(rgbs)
|
||||||
|
H, W, __ = rgbs[0].shape
|
||||||
|
*_, S = traj_2d.shape
|
||||||
|
visibile_pts_first_frame_inds = (visibility[:, 0] > 0).nonzero(as_tuple=False)[
|
||||||
|
:, 0
|
||||||
|
]
|
||||||
|
torch.manual_seed(0)
|
||||||
|
point_inds = torch.randperm(len(visibile_pts_first_frame_inds))[
|
||||||
|
: self.max_num_points
|
||||||
|
]
|
||||||
|
visible_inds_sampled = visibile_pts_first_frame_inds[point_inds]
|
||||||
|
|
||||||
|
rgbs = np.stack(rgbs, 0)
|
||||||
|
rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0, 3, 1, 2).float()
|
||||||
|
|
||||||
|
segs = torch.ones(S, 1, H, W).float()
|
||||||
|
trajs = traj_2d[visible_inds_sampled].permute(2, 0, 1).float()
|
||||||
|
visibles = visibility[visible_inds_sampled].permute(1, 0)
|
||||||
|
|
||||||
|
rgbs, trajs, segs = resize_sample(rgbs, trajs, segs, self.dataset_resolution)
|
||||||
|
|
||||||
|
return CoTrackerData(rgbs, segs, trajs, visibles, seq_name=seq_name)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.seq_names)
|
494
cotracker/datasets/kubric_movif_dataset.py
Normal file
494
cotracker/datasets/kubric_movif_dataset.py
Normal file
@ -0,0 +1,494 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import imageio
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from cotracker.datasets.utils import CoTrackerData
|
||||||
|
from torchvision.transforms import ColorJitter, GaussianBlur
|
||||||
|
from PIL import Image
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
class CoTrackerDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_root,
|
||||||
|
crop_size=(384, 512),
|
||||||
|
seq_len=24,
|
||||||
|
traj_per_sample=768,
|
||||||
|
sample_vis_1st_frame=False,
|
||||||
|
use_augs=False,
|
||||||
|
):
|
||||||
|
super(CoTrackerDataset, self).__init__()
|
||||||
|
np.random.seed(0)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
self.data_root = data_root
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.traj_per_sample = traj_per_sample
|
||||||
|
self.sample_vis_1st_frame = sample_vis_1st_frame
|
||||||
|
self.use_augs = use_augs
|
||||||
|
self.crop_size = crop_size
|
||||||
|
|
||||||
|
# photometric augmentation
|
||||||
|
self.photo_aug = ColorJitter(
|
||||||
|
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14
|
||||||
|
)
|
||||||
|
self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
|
||||||
|
|
||||||
|
self.blur_aug_prob = 0.25
|
||||||
|
self.color_aug_prob = 0.25
|
||||||
|
|
||||||
|
# occlusion augmentation
|
||||||
|
self.eraser_aug_prob = 0.5
|
||||||
|
self.eraser_bounds = [2, 100]
|
||||||
|
self.eraser_max = 10
|
||||||
|
|
||||||
|
# occlusion augmentation
|
||||||
|
self.replace_aug_prob = 0.5
|
||||||
|
self.replace_bounds = [2, 100]
|
||||||
|
self.replace_max = 10
|
||||||
|
|
||||||
|
# spatial augmentations
|
||||||
|
self.pad_bounds = [0, 100]
|
||||||
|
self.crop_size = crop_size
|
||||||
|
self.resize_lim = [0.25, 2.0] # sample resizes from here
|
||||||
|
self.resize_delta = 0.2
|
||||||
|
self.max_crop_offset = 50
|
||||||
|
|
||||||
|
self.do_flip = True
|
||||||
|
self.h_flip_prob = 0.5
|
||||||
|
self.v_flip_prob = 0.5
|
||||||
|
|
||||||
|
def getitem_helper(self, index):
|
||||||
|
return NotImplementedError
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
gotit = False
|
||||||
|
|
||||||
|
sample, gotit = self.getitem_helper(index)
|
||||||
|
if not gotit:
|
||||||
|
print("warning: sampling failed")
|
||||||
|
# fake sample, so we can still collate
|
||||||
|
sample = CoTrackerData(
|
||||||
|
video=torch.zeros(
|
||||||
|
(self.seq_len, 3, self.crop_size[0], self.crop_size[1])
|
||||||
|
),
|
||||||
|
segmentation=torch.zeros(
|
||||||
|
(self.seq_len, 1, self.crop_size[0], self.crop_size[1])
|
||||||
|
),
|
||||||
|
trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)),
|
||||||
|
visibility=torch.zeros((self.seq_len, self.traj_per_sample)),
|
||||||
|
valid=torch.zeros((self.seq_len, self.traj_per_sample)),
|
||||||
|
)
|
||||||
|
|
||||||
|
return sample, gotit
|
||||||
|
|
||||||
|
def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True):
|
||||||
|
T, N, _ = trajs.shape
|
||||||
|
|
||||||
|
S = len(rgbs)
|
||||||
|
H, W = rgbs[0].shape[:2]
|
||||||
|
assert S == T
|
||||||
|
|
||||||
|
if eraser:
|
||||||
|
############ eraser transform (per image after the first) ############
|
||||||
|
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||||
|
for i in range(1, S):
|
||||||
|
if np.random.rand() < self.eraser_aug_prob:
|
||||||
|
for _ in range(
|
||||||
|
np.random.randint(1, self.eraser_max + 1)
|
||||||
|
): # number of times to occlude
|
||||||
|
|
||||||
|
xc = np.random.randint(0, W)
|
||||||
|
yc = np.random.randint(0, H)
|
||||||
|
dx = np.random.randint(
|
||||||
|
self.eraser_bounds[0], self.eraser_bounds[1]
|
||||||
|
)
|
||||||
|
dy = np.random.randint(
|
||||||
|
self.eraser_bounds[0], self.eraser_bounds[1]
|
||||||
|
)
|
||||||
|
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
||||||
|
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
||||||
|
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
||||||
|
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
||||||
|
|
||||||
|
mean_color = np.mean(
|
||||||
|
rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0
|
||||||
|
)
|
||||||
|
rgbs[i][y0:y1, x0:x1, :] = mean_color
|
||||||
|
|
||||||
|
occ_inds = np.logical_and(
|
||||||
|
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
|
||||||
|
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
|
||||||
|
)
|
||||||
|
visibles[i, occ_inds] = 0
|
||||||
|
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
||||||
|
|
||||||
|
if replace:
|
||||||
|
|
||||||
|
rgbs_alt = [
|
||||||
|
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
|
||||||
|
for rgb in rgbs
|
||||||
|
]
|
||||||
|
rgbs_alt = [
|
||||||
|
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
|
||||||
|
for rgb in rgbs_alt
|
||||||
|
]
|
||||||
|
|
||||||
|
############ replace transform (per image after the first) ############
|
||||||
|
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||||
|
rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt]
|
||||||
|
for i in range(1, S):
|
||||||
|
if np.random.rand() < self.replace_aug_prob:
|
||||||
|
for _ in range(
|
||||||
|
np.random.randint(1, self.replace_max + 1)
|
||||||
|
): # number of times to occlude
|
||||||
|
xc = np.random.randint(0, W)
|
||||||
|
yc = np.random.randint(0, H)
|
||||||
|
dx = np.random.randint(
|
||||||
|
self.replace_bounds[0], self.replace_bounds[1]
|
||||||
|
)
|
||||||
|
dy = np.random.randint(
|
||||||
|
self.replace_bounds[0], self.replace_bounds[1]
|
||||||
|
)
|
||||||
|
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
||||||
|
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
||||||
|
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
||||||
|
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
||||||
|
|
||||||
|
wid = x1 - x0
|
||||||
|
hei = y1 - y0
|
||||||
|
y00 = np.random.randint(0, H - hei)
|
||||||
|
x00 = np.random.randint(0, W - wid)
|
||||||
|
fr = np.random.randint(0, S)
|
||||||
|
rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :]
|
||||||
|
rgbs[i][y0:y1, x0:x1, :] = rep
|
||||||
|
|
||||||
|
occ_inds = np.logical_and(
|
||||||
|
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
|
||||||
|
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
|
||||||
|
)
|
||||||
|
visibles[i, occ_inds] = 0
|
||||||
|
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
||||||
|
|
||||||
|
############ photometric augmentation ############
|
||||||
|
if np.random.rand() < self.color_aug_prob:
|
||||||
|
# random per-frame amount of aug
|
||||||
|
rgbs = [
|
||||||
|
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
|
||||||
|
for rgb in rgbs
|
||||||
|
]
|
||||||
|
|
||||||
|
if np.random.rand() < self.blur_aug_prob:
|
||||||
|
# random per-frame amount of blur
|
||||||
|
rgbs = [
|
||||||
|
np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8)
|
||||||
|
for rgb in rgbs
|
||||||
|
]
|
||||||
|
|
||||||
|
return rgbs, trajs, visibles
|
||||||
|
|
||||||
|
def add_spatial_augs(self, rgbs, trajs, visibles):
|
||||||
|
T, N, __ = trajs.shape
|
||||||
|
|
||||||
|
S = len(rgbs)
|
||||||
|
H, W = rgbs[0].shape[:2]
|
||||||
|
assert S == T
|
||||||
|
|
||||||
|
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||||
|
|
||||||
|
############ spatial transform ############
|
||||||
|
|
||||||
|
# padding
|
||||||
|
pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||||
|
pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||||
|
pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||||
|
pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||||
|
|
||||||
|
rgbs = [
|
||||||
|
np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs
|
||||||
|
]
|
||||||
|
trajs[:, :, 0] += pad_x0
|
||||||
|
trajs[:, :, 1] += pad_y0
|
||||||
|
H, W = rgbs[0].shape[:2]
|
||||||
|
|
||||||
|
# scaling + stretching
|
||||||
|
scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1])
|
||||||
|
scale_x = scale
|
||||||
|
scale_y = scale
|
||||||
|
H_new = H
|
||||||
|
W_new = W
|
||||||
|
|
||||||
|
scale_delta_x = 0.0
|
||||||
|
scale_delta_y = 0.0
|
||||||
|
|
||||||
|
rgbs_scaled = []
|
||||||
|
for s in range(S):
|
||||||
|
if s == 1:
|
||||||
|
scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta)
|
||||||
|
scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta)
|
||||||
|
elif s > 1:
|
||||||
|
scale_delta_x = (
|
||||||
|
scale_delta_x * 0.8
|
||||||
|
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
|
||||||
|
)
|
||||||
|
scale_delta_y = (
|
||||||
|
scale_delta_y * 0.8
|
||||||
|
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
|
||||||
|
)
|
||||||
|
scale_x = scale_x + scale_delta_x
|
||||||
|
scale_y = scale_y + scale_delta_y
|
||||||
|
|
||||||
|
# bring h/w closer
|
||||||
|
scale_xy = (scale_x + scale_y) * 0.5
|
||||||
|
scale_x = scale_x * 0.5 + scale_xy * 0.5
|
||||||
|
scale_y = scale_y * 0.5 + scale_xy * 0.5
|
||||||
|
|
||||||
|
# don't get too crazy
|
||||||
|
scale_x = np.clip(scale_x, 0.2, 2.0)
|
||||||
|
scale_y = np.clip(scale_y, 0.2, 2.0)
|
||||||
|
|
||||||
|
H_new = int(H * scale_y)
|
||||||
|
W_new = int(W * scale_x)
|
||||||
|
|
||||||
|
# make it at least slightly bigger than the crop area,
|
||||||
|
# so that the random cropping can add diversity
|
||||||
|
H_new = np.clip(H_new, self.crop_size[0] + 10, None)
|
||||||
|
W_new = np.clip(W_new, self.crop_size[1] + 10, None)
|
||||||
|
# recompute scale in case we clipped
|
||||||
|
scale_x = W_new / float(W)
|
||||||
|
scale_y = H_new / float(H)
|
||||||
|
|
||||||
|
rgbs_scaled.append(
|
||||||
|
cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR)
|
||||||
|
)
|
||||||
|
trajs[s, :, 0] *= scale_x
|
||||||
|
trajs[s, :, 1] *= scale_y
|
||||||
|
rgbs = rgbs_scaled
|
||||||
|
|
||||||
|
ok_inds = visibles[0, :] > 0
|
||||||
|
vis_trajs = trajs[:, ok_inds] # S,?,2
|
||||||
|
|
||||||
|
if vis_trajs.shape[1] > 0:
|
||||||
|
mid_x = np.mean(vis_trajs[0, :, 0])
|
||||||
|
mid_y = np.mean(vis_trajs[0, :, 1])
|
||||||
|
else:
|
||||||
|
mid_y = self.crop_size[0]
|
||||||
|
mid_x = self.crop_size[1]
|
||||||
|
|
||||||
|
x0 = int(mid_x - self.crop_size[1] // 2)
|
||||||
|
y0 = int(mid_y - self.crop_size[0] // 2)
|
||||||
|
|
||||||
|
offset_x = 0
|
||||||
|
offset_y = 0
|
||||||
|
|
||||||
|
for s in range(S):
|
||||||
|
# on each frame, shift a bit more
|
||||||
|
if s == 1:
|
||||||
|
offset_x = np.random.randint(
|
||||||
|
-self.max_crop_offset, self.max_crop_offset
|
||||||
|
)
|
||||||
|
offset_y = np.random.randint(
|
||||||
|
-self.max_crop_offset, self.max_crop_offset
|
||||||
|
)
|
||||||
|
elif s > 1:
|
||||||
|
offset_x = int(
|
||||||
|
offset_x * 0.8
|
||||||
|
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1)
|
||||||
|
* 0.2
|
||||||
|
)
|
||||||
|
offset_y = int(
|
||||||
|
offset_y * 0.8
|
||||||
|
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1)
|
||||||
|
* 0.2
|
||||||
|
)
|
||||||
|
x0 = x0 + offset_x
|
||||||
|
y0 = y0 + offset_y
|
||||||
|
|
||||||
|
H_new, W_new = rgbs[s].shape[:2]
|
||||||
|
if H_new == self.crop_size[0]:
|
||||||
|
y0 = 0
|
||||||
|
else:
|
||||||
|
y0 = min(max(0, y0), H_new - self.crop_size[0] - 1)
|
||||||
|
|
||||||
|
if W_new == self.crop_size[1]:
|
||||||
|
x0 = 0
|
||||||
|
else:
|
||||||
|
x0 = min(max(0, x0), W_new - self.crop_size[1] - 1)
|
||||||
|
|
||||||
|
rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
|
||||||
|
trajs[s, :, 0] -= x0
|
||||||
|
trajs[s, :, 1] -= y0
|
||||||
|
|
||||||
|
H_new = self.crop_size[0]
|
||||||
|
W_new = self.crop_size[1]
|
||||||
|
|
||||||
|
# flip
|
||||||
|
h_flipped = False
|
||||||
|
v_flipped = False
|
||||||
|
if self.do_flip:
|
||||||
|
# h flip
|
||||||
|
if np.random.rand() < self.h_flip_prob:
|
||||||
|
h_flipped = True
|
||||||
|
rgbs = [rgb[:, ::-1] for rgb in rgbs]
|
||||||
|
# v flip
|
||||||
|
if np.random.rand() < self.v_flip_prob:
|
||||||
|
v_flipped = True
|
||||||
|
rgbs = [rgb[::-1] for rgb in rgbs]
|
||||||
|
if h_flipped:
|
||||||
|
trajs[:, :, 0] = W_new - trajs[:, :, 0]
|
||||||
|
if v_flipped:
|
||||||
|
trajs[:, :, 1] = H_new - trajs[:, :, 1]
|
||||||
|
|
||||||
|
return rgbs, trajs
|
||||||
|
|
||||||
|
def crop(self, rgbs, trajs):
|
||||||
|
T, N, _ = trajs.shape
|
||||||
|
|
||||||
|
S = len(rgbs)
|
||||||
|
H, W = rgbs[0].shape[:2]
|
||||||
|
assert S == T
|
||||||
|
|
||||||
|
############ spatial transform ############
|
||||||
|
|
||||||
|
H_new = H
|
||||||
|
W_new = W
|
||||||
|
|
||||||
|
# simple random crop
|
||||||
|
y0 = (
|
||||||
|
0
|
||||||
|
if self.crop_size[0] >= H_new
|
||||||
|
else np.random.randint(0, H_new - self.crop_size[0])
|
||||||
|
)
|
||||||
|
x0 = (
|
||||||
|
0
|
||||||
|
if self.crop_size[1] >= W_new
|
||||||
|
else np.random.randint(0, W_new - self.crop_size[1])
|
||||||
|
)
|
||||||
|
rgbs = [
|
||||||
|
rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
|
||||||
|
for rgb in rgbs
|
||||||
|
]
|
||||||
|
|
||||||
|
trajs[:, :, 0] -= x0
|
||||||
|
trajs[:, :, 1] -= y0
|
||||||
|
|
||||||
|
return rgbs, trajs
|
||||||
|
|
||||||
|
|
||||||
|
class KubricMovifDataset(CoTrackerDataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_root,
|
||||||
|
crop_size=(384, 512),
|
||||||
|
seq_len=24,
|
||||||
|
traj_per_sample=768,
|
||||||
|
sample_vis_1st_frame=False,
|
||||||
|
use_augs=False,
|
||||||
|
):
|
||||||
|
super(KubricMovifDataset, self).__init__(
|
||||||
|
data_root=data_root,
|
||||||
|
crop_size=crop_size,
|
||||||
|
seq_len=seq_len,
|
||||||
|
traj_per_sample=traj_per_sample,
|
||||||
|
sample_vis_1st_frame=sample_vis_1st_frame,
|
||||||
|
use_augs=use_augs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pad_bounds = [0, 25]
|
||||||
|
self.resize_lim = [0.75, 1.25] # sample resizes from here
|
||||||
|
self.resize_delta = 0.05
|
||||||
|
self.max_crop_offset = 15
|
||||||
|
self.seq_names = [
|
||||||
|
fname
|
||||||
|
for fname in os.listdir(data_root)
|
||||||
|
if os.path.isdir(os.path.join(data_root, fname))
|
||||||
|
]
|
||||||
|
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
|
||||||
|
|
||||||
|
def getitem_helper(self, index):
|
||||||
|
gotit = True
|
||||||
|
seq_name = self.seq_names[index]
|
||||||
|
|
||||||
|
npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy")
|
||||||
|
rgb_path = os.path.join(self.data_root, seq_name, "frames")
|
||||||
|
|
||||||
|
img_paths = sorted(os.listdir(rgb_path))
|
||||||
|
rgbs = []
|
||||||
|
for i, img_path in enumerate(img_paths):
|
||||||
|
rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path)))
|
||||||
|
|
||||||
|
rgbs = np.stack(rgbs)
|
||||||
|
annot_dict = np.load(npy_path, allow_pickle=True).item()
|
||||||
|
traj_2d = annot_dict["coords"]
|
||||||
|
visibility = annot_dict["visibility"]
|
||||||
|
|
||||||
|
# random crop
|
||||||
|
assert self.seq_len <= len(rgbs)
|
||||||
|
if self.seq_len < len(rgbs):
|
||||||
|
start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0]
|
||||||
|
|
||||||
|
rgbs = rgbs[start_ind : start_ind + self.seq_len]
|
||||||
|
traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len]
|
||||||
|
visibility = visibility[:, start_ind : start_ind + self.seq_len]
|
||||||
|
|
||||||
|
traj_2d = np.transpose(traj_2d, (1, 0, 2))
|
||||||
|
visibility = np.transpose(np.logical_not(visibility), (1, 0))
|
||||||
|
if self.use_augs:
|
||||||
|
rgbs, traj_2d, visibility = self.add_photometric_augs(
|
||||||
|
rgbs, traj_2d, visibility
|
||||||
|
)
|
||||||
|
rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility)
|
||||||
|
else:
|
||||||
|
rgbs, traj_2d = self.crop(rgbs, traj_2d)
|
||||||
|
|
||||||
|
visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False
|
||||||
|
visibility[traj_2d[:, :, 0] < 0] = False
|
||||||
|
visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False
|
||||||
|
visibility[traj_2d[:, :, 1] < 0] = False
|
||||||
|
|
||||||
|
visibility = torch.from_numpy(visibility)
|
||||||
|
traj_2d = torch.from_numpy(traj_2d)
|
||||||
|
|
||||||
|
visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0]
|
||||||
|
|
||||||
|
if self.sample_vis_1st_frame:
|
||||||
|
visibile_pts_inds = visibile_pts_first_frame_inds
|
||||||
|
else:
|
||||||
|
visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(
|
||||||
|
as_tuple=False
|
||||||
|
)[:, 0]
|
||||||
|
visibile_pts_inds = torch.cat(
|
||||||
|
(visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
|
||||||
|
)
|
||||||
|
point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample]
|
||||||
|
if len(point_inds) < self.traj_per_sample:
|
||||||
|
gotit = False
|
||||||
|
|
||||||
|
visible_inds_sampled = visibile_pts_inds[point_inds]
|
||||||
|
|
||||||
|
trajs = traj_2d[:, visible_inds_sampled].float()
|
||||||
|
visibles = visibility[:, visible_inds_sampled]
|
||||||
|
valids = torch.ones((self.seq_len, self.traj_per_sample))
|
||||||
|
|
||||||
|
rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float()
|
||||||
|
segs = torch.ones((self.seq_len, 1, self.crop_size[0], self.crop_size[1]))
|
||||||
|
sample = CoTrackerData(
|
||||||
|
video=rgbs,
|
||||||
|
segmentation=segs,
|
||||||
|
trajectory=trajs,
|
||||||
|
visibility=visibles,
|
||||||
|
valid=valids,
|
||||||
|
seq_name=seq_name,
|
||||||
|
)
|
||||||
|
return sample, gotit
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.seq_names)
|
218
cotracker/datasets/tap_vid_datasets.py
Normal file
218
cotracker/datasets/tap_vid_datasets.py
Normal file
@ -0,0 +1,218 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import glob
|
||||||
|
import torch
|
||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
import mediapy as media
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from typing import Mapping, Tuple, Union
|
||||||
|
|
||||||
|
from cotracker.datasets.utils import CoTrackerData
|
||||||
|
|
||||||
|
DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]
|
||||||
|
|
||||||
|
|
||||||
|
def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
||||||
|
"""Resize a video to output_size."""
|
||||||
|
# If you have a GPU, consider replacing this with a GPU-enabled resize op,
|
||||||
|
# such as a jitted jax.image.resize. It will make things faster.
|
||||||
|
return media.resize_video(video, output_size)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_queries_first(
|
||||||
|
target_occluded: np.ndarray,
|
||||||
|
target_points: np.ndarray,
|
||||||
|
frames: np.ndarray,
|
||||||
|
) -> Mapping[str, np.ndarray]:
|
||||||
|
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
||||||
|
Given a set of frames and tracks with no query points, use the first
|
||||||
|
visible point in each track as the query.
|
||||||
|
Args:
|
||||||
|
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
||||||
|
where True indicates occluded.
|
||||||
|
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
||||||
|
is [x,y] scaled between 0 and 1.
|
||||||
|
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
||||||
|
-1 and 1.
|
||||||
|
Returns:
|
||||||
|
A dict with the keys:
|
||||||
|
video: Video tensor of shape [1, n_frames, height, width, 3]
|
||||||
|
query_points: Query points of shape [1, n_queries, 3] where
|
||||||
|
each point is [t, y, x] scaled to the range [-1, 1]
|
||||||
|
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
||||||
|
each point is [x, y] scaled to the range [-1, 1]
|
||||||
|
"""
|
||||||
|
valid = np.sum(~target_occluded, axis=1) > 0
|
||||||
|
target_points = target_points[valid, :]
|
||||||
|
target_occluded = target_occluded[valid, :]
|
||||||
|
|
||||||
|
query_points = []
|
||||||
|
for i in range(target_points.shape[0]):
|
||||||
|
index = np.where(target_occluded[i] == 0)[0][0]
|
||||||
|
x, y = target_points[i, index, 0], target_points[i, index, 1]
|
||||||
|
query_points.append(np.array([index, y, x])) # [t, y, x]
|
||||||
|
query_points = np.stack(query_points, axis=0)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"video": frames[np.newaxis, ...],
|
||||||
|
"query_points": query_points[np.newaxis, ...],
|
||||||
|
"target_points": target_points[np.newaxis, ...],
|
||||||
|
"occluded": target_occluded[np.newaxis, ...],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def sample_queries_strided(
|
||||||
|
target_occluded: np.ndarray,
|
||||||
|
target_points: np.ndarray,
|
||||||
|
frames: np.ndarray,
|
||||||
|
query_stride: int = 5,
|
||||||
|
) -> Mapping[str, np.ndarray]:
|
||||||
|
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
||||||
|
|
||||||
|
Given a set of frames and tracks with no query points, sample queries
|
||||||
|
strided every query_stride frames, ignoring points that are not visible
|
||||||
|
at the selected frames.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
||||||
|
where True indicates occluded.
|
||||||
|
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
||||||
|
is [x,y] scaled between 0 and 1.
|
||||||
|
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
||||||
|
-1 and 1.
|
||||||
|
query_stride: When sampling query points, search for un-occluded points
|
||||||
|
every query_stride frames and convert each one into a query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict with the keys:
|
||||||
|
video: Video tensor of shape [1, n_frames, height, width, 3]. The video
|
||||||
|
has floats scaled to the range [-1, 1].
|
||||||
|
query_points: Query points of shape [1, n_queries, 3] where
|
||||||
|
each point is [t, y, x] scaled to the range [-1, 1].
|
||||||
|
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
||||||
|
each point is [x, y] scaled to the range [-1, 1].
|
||||||
|
trackgroup: Index of the original track that each query point was
|
||||||
|
sampled from. This is useful for visualization.
|
||||||
|
"""
|
||||||
|
tracks = []
|
||||||
|
occs = []
|
||||||
|
queries = []
|
||||||
|
trackgroups = []
|
||||||
|
total = 0
|
||||||
|
trackgroup = np.arange(target_occluded.shape[0])
|
||||||
|
for i in range(0, target_occluded.shape[1], query_stride):
|
||||||
|
mask = target_occluded[:, i] == 0
|
||||||
|
query = np.stack(
|
||||||
|
[
|
||||||
|
i * np.ones(target_occluded.shape[0:1]),
|
||||||
|
target_points[:, i, 1],
|
||||||
|
target_points[:, i, 0],
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
queries.append(query[mask])
|
||||||
|
tracks.append(target_points[mask])
|
||||||
|
occs.append(target_occluded[mask])
|
||||||
|
trackgroups.append(trackgroup[mask])
|
||||||
|
total += np.array(np.sum(target_occluded[:, i] == 0))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"video": frames[np.newaxis, ...],
|
||||||
|
"query_points": np.concatenate(queries, axis=0)[np.newaxis, ...],
|
||||||
|
"target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...],
|
||||||
|
"occluded": np.concatenate(occs, axis=0)[np.newaxis, ...],
|
||||||
|
"trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TapVidDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_root,
|
||||||
|
dataset_type="davis",
|
||||||
|
resize_to_256=True,
|
||||||
|
queried_first=True,
|
||||||
|
):
|
||||||
|
self.dataset_type = dataset_type
|
||||||
|
self.resize_to_256 = resize_to_256
|
||||||
|
self.queried_first = queried_first
|
||||||
|
if self.dataset_type == "kinetics":
|
||||||
|
all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl"))
|
||||||
|
points_dataset = []
|
||||||
|
for pickle_path in all_paths:
|
||||||
|
with open(pickle_path, "rb") as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
points_dataset = points_dataset + data
|
||||||
|
self.points_dataset = points_dataset
|
||||||
|
else:
|
||||||
|
with open(data_root, "rb") as f:
|
||||||
|
self.points_dataset = pickle.load(f)
|
||||||
|
if self.dataset_type == "davis":
|
||||||
|
self.video_names = list(self.points_dataset.keys())
|
||||||
|
print("found %d unique videos in %s" % (len(self.points_dataset), data_root))
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
if self.dataset_type == "davis":
|
||||||
|
video_name = self.video_names[index]
|
||||||
|
else:
|
||||||
|
video_name = index
|
||||||
|
video = self.points_dataset[video_name]
|
||||||
|
frames = video["video"]
|
||||||
|
|
||||||
|
if isinstance(frames[0], bytes):
|
||||||
|
# TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
|
||||||
|
def decode(frame):
|
||||||
|
byteio = io.BytesIO(frame)
|
||||||
|
img = Image.open(byteio)
|
||||||
|
return np.array(img)
|
||||||
|
|
||||||
|
frames = np.array([decode(frame) for frame in frames])
|
||||||
|
|
||||||
|
target_points = self.points_dataset[video_name]["points"]
|
||||||
|
if self.resize_to_256:
|
||||||
|
frames = resize_video(frames, [256, 256])
|
||||||
|
target_points *= np.array([256, 256])
|
||||||
|
else:
|
||||||
|
target_points *= np.array([frames.shape[2], frames.shape[1]])
|
||||||
|
|
||||||
|
T, H, W, C = frames.shape
|
||||||
|
N, T, D = target_points.shape
|
||||||
|
|
||||||
|
target_occ = self.points_dataset[video_name]["occluded"]
|
||||||
|
if self.queried_first:
|
||||||
|
converted = sample_queries_first(target_occ, target_points, frames)
|
||||||
|
else:
|
||||||
|
converted = sample_queries_strided(target_occ, target_points, frames)
|
||||||
|
assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
|
||||||
|
|
||||||
|
trajs = (
|
||||||
|
torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float()
|
||||||
|
) # T, N, D
|
||||||
|
|
||||||
|
rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
|
||||||
|
segs = torch.ones(T, 1, H, W).float()
|
||||||
|
visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[
|
||||||
|
0
|
||||||
|
].permute(
|
||||||
|
1, 0
|
||||||
|
) # T, N
|
||||||
|
query_points = torch.from_numpy(converted["query_points"])[0] # T, N
|
||||||
|
return CoTrackerData(
|
||||||
|
rgbs,
|
||||||
|
segs,
|
||||||
|
trajs,
|
||||||
|
visibles,
|
||||||
|
seq_name=str(video_name),
|
||||||
|
query_points=query_points,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.points_dataset)
|
114
cotracker/datasets/utils.py
Normal file
114
cotracker/datasets/utils.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import dataclasses
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class CoTrackerData:
|
||||||
|
"""
|
||||||
|
Dataclass for storing video tracks data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
video: torch.Tensor # B, S, C, H, W
|
||||||
|
segmentation: torch.Tensor # B, S, 1, H, W
|
||||||
|
trajectory: torch.Tensor # B, S, N, 2
|
||||||
|
visibility: torch.Tensor # B, S, N
|
||||||
|
# optional data
|
||||||
|
valid: Optional[torch.Tensor] = None # B, S, N
|
||||||
|
seq_name: Optional[str] = None
|
||||||
|
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(batch):
|
||||||
|
"""
|
||||||
|
Collate function for video tracks data.
|
||||||
|
"""
|
||||||
|
video = torch.stack([b.video for b in batch], dim=0)
|
||||||
|
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
|
||||||
|
trajectory = torch.stack([b.trajectory for b in batch], dim=0)
|
||||||
|
visibility = torch.stack([b.visibility for b in batch], dim=0)
|
||||||
|
query_points = None
|
||||||
|
if batch[0].query_points is not None:
|
||||||
|
query_points = torch.stack([b.query_points for b in batch], dim=0)
|
||||||
|
seq_name = [b.seq_name for b in batch]
|
||||||
|
|
||||||
|
return CoTrackerData(
|
||||||
|
video,
|
||||||
|
segmentation,
|
||||||
|
trajectory,
|
||||||
|
visibility,
|
||||||
|
seq_name=seq_name,
|
||||||
|
query_points=query_points,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn_train(batch):
|
||||||
|
"""
|
||||||
|
Collate function for video tracks data during training.
|
||||||
|
"""
|
||||||
|
gotit = [gotit for _, gotit in batch]
|
||||||
|
video = torch.stack([b.video for b, _ in batch], dim=0)
|
||||||
|
segmentation = torch.stack([b.segmentation for b, _ in batch], dim=0)
|
||||||
|
trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
|
||||||
|
visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
|
||||||
|
valid = torch.stack([b.valid for b, _ in batch], dim=0)
|
||||||
|
seq_name = [b.seq_name for b, _ in batch]
|
||||||
|
return (
|
||||||
|
CoTrackerData(video, segmentation, trajectory, visibility, valid, seq_name),
|
||||||
|
gotit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def try_to_cuda(t: Any) -> Any:
|
||||||
|
"""
|
||||||
|
Try to move the input variable `t` to a cuda device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t: Input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
t_cuda: `t` moved to a cuda device, if supported.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
t = t.float().cuda()
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def dataclass_to_cuda_(obj):
|
||||||
|
"""
|
||||||
|
Move all contents of a dataclass to cuda inplace if supported.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Input dataclass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
batch_cuda: `batch` moved to a cuda device, if supported.
|
||||||
|
"""
|
||||||
|
for f in dataclasses.fields(obj):
|
||||||
|
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def resize_sample(rgbs, trajs_g, segs, interp_shape):
|
||||||
|
S, C, H, W = rgbs.shape
|
||||||
|
S, N, D = trajs_g.shape
|
||||||
|
|
||||||
|
assert D == 2
|
||||||
|
|
||||||
|
rgbs = F.interpolate(rgbs, interp_shape, mode="bilinear")
|
||||||
|
segs = F.interpolate(segs, interp_shape, mode="nearest")
|
||||||
|
|
||||||
|
trajs_g[:, :, 0] *= interp_shape[1] / W
|
||||||
|
trajs_g[:, :, 1] *= interp_shape[0] / H
|
||||||
|
return rgbs, trajs_g, segs
|
5
cotracker/evaluation/__init__.py
Normal file
5
cotracker/evaluation/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
6
cotracker/evaluation/configs/eval_badja.yaml
Normal file
6
cotracker/evaluation/configs/eval_badja.yaml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
defaults:
|
||||||
|
- default_config_eval
|
||||||
|
exp_dir: ./outputs/cotracker
|
||||||
|
dataset_name: badja
|
||||||
|
|
||||||
|
|
6
cotracker/evaluation/configs/eval_fastcapture.yaml
Normal file
6
cotracker/evaluation/configs/eval_fastcapture.yaml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
defaults:
|
||||||
|
- default_config_eval
|
||||||
|
exp_dir: ./outputs/cotracker
|
||||||
|
dataset_name: fastcapture
|
||||||
|
|
||||||
|
|
@ -0,0 +1,6 @@
|
|||||||
|
defaults:
|
||||||
|
- default_config_eval
|
||||||
|
exp_dir: ./outputs/cotracker
|
||||||
|
dataset_name: tapvid_davis_first
|
||||||
|
|
||||||
|
|
@ -0,0 +1,6 @@
|
|||||||
|
defaults:
|
||||||
|
- default_config_eval
|
||||||
|
exp_dir: ./outputs/cotracker
|
||||||
|
dataset_name: tapvid_davis_strided
|
||||||
|
|
||||||
|
|
@ -0,0 +1,6 @@
|
|||||||
|
defaults:
|
||||||
|
- default_config_eval
|
||||||
|
exp_dir: ./outputs/cotracker
|
||||||
|
dataset_name: tapvid_kinetics_first
|
||||||
|
|
||||||
|
|
5
cotracker/evaluation/core/__init__.py
Normal file
5
cotracker/evaluation/core/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
144
cotracker/evaluation/core/eval_utils.py
Normal file
144
cotracker/evaluation/core/eval_utils.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from typing import Iterable, Mapping, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
|
def compute_tapvid_metrics(
|
||||||
|
query_points: np.ndarray,
|
||||||
|
gt_occluded: np.ndarray,
|
||||||
|
gt_tracks: np.ndarray,
|
||||||
|
pred_occluded: np.ndarray,
|
||||||
|
pred_tracks: np.ndarray,
|
||||||
|
query_mode: str,
|
||||||
|
) -> Mapping[str, np.ndarray]:
|
||||||
|
"""Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)
|
||||||
|
See the TAP-Vid paper for details on the metric computation. All inputs are
|
||||||
|
given in raster coordinates. The first three arguments should be the direct
|
||||||
|
outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
|
||||||
|
The paper metrics assume these are scaled relative to 256x256 images.
|
||||||
|
pred_occluded and pred_tracks are your algorithm's predictions.
|
||||||
|
This function takes a batch of inputs, and computes metrics separately for
|
||||||
|
each video. The metrics for the full benchmark are a simple mean of the
|
||||||
|
metrics across the full set of videos. These numbers are between 0 and 1,
|
||||||
|
but the paper multiplies them by 100 to ease reading.
|
||||||
|
Args:
|
||||||
|
query_points: The query points, an in the format [t, y, x]. Its size is
|
||||||
|
[b, n, 3], where b is the batch size and n is the number of queries
|
||||||
|
gt_occluded: A boolean array of shape [b, n, t], where t is the number
|
||||||
|
of frames. True indicates that the point is occluded.
|
||||||
|
gt_tracks: The target points, of shape [b, n, t, 2]. Each point is
|
||||||
|
in the format [x, y]
|
||||||
|
pred_occluded: A boolean array of predicted occlusions, in the same
|
||||||
|
format as gt_occluded.
|
||||||
|
pred_tracks: An array of track predictions from your algorithm, in the
|
||||||
|
same format as gt_tracks.
|
||||||
|
query_mode: Either 'first' or 'strided', depending on how queries are
|
||||||
|
sampled. If 'first', we assume the prior knowledge that all points
|
||||||
|
before the query point are occluded, and these are removed from the
|
||||||
|
evaluation.
|
||||||
|
Returns:
|
||||||
|
A dict with the following keys:
|
||||||
|
occlusion_accuracy: Accuracy at predicting occlusion.
|
||||||
|
pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
|
||||||
|
predicted to be within the given pixel threshold, ignoring occlusion
|
||||||
|
prediction.
|
||||||
|
jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
|
||||||
|
threshold
|
||||||
|
average_pts_within_thresh: average across pts_within_{x}
|
||||||
|
average_jaccard: average across jaccard_{x}
|
||||||
|
"""
|
||||||
|
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
# Don't evaluate the query point. Numpy doesn't have one_hot, so we
|
||||||
|
# replicate it by indexing into an identity matrix.
|
||||||
|
one_hot_eye = np.eye(gt_tracks.shape[2])
|
||||||
|
query_frame = query_points[..., 0]
|
||||||
|
query_frame = np.round(query_frame).astype(np.int32)
|
||||||
|
evaluation_points = one_hot_eye[query_frame] == 0
|
||||||
|
|
||||||
|
# If we're using the first point on the track as a query, don't evaluate the
|
||||||
|
# other points.
|
||||||
|
if query_mode == "first":
|
||||||
|
for i in range(gt_occluded.shape[0]):
|
||||||
|
index = np.where(gt_occluded[i] == 0)[0][0]
|
||||||
|
evaluation_points[i, :index] = False
|
||||||
|
elif query_mode != "strided":
|
||||||
|
raise ValueError("Unknown query mode " + query_mode)
|
||||||
|
|
||||||
|
# Occlusion accuracy is simply how often the predicted occlusion equals the
|
||||||
|
# ground truth.
|
||||||
|
occ_acc = (
|
||||||
|
np.sum(
|
||||||
|
np.equal(pred_occluded, gt_occluded) & evaluation_points,
|
||||||
|
axis=(1, 2),
|
||||||
|
)
|
||||||
|
/ np.sum(evaluation_points)
|
||||||
|
)
|
||||||
|
metrics["occlusion_accuracy"] = occ_acc
|
||||||
|
|
||||||
|
# Next, convert the predictions and ground truth positions into pixel
|
||||||
|
# coordinates.
|
||||||
|
visible = np.logical_not(gt_occluded)
|
||||||
|
pred_visible = np.logical_not(pred_occluded)
|
||||||
|
all_frac_within = []
|
||||||
|
all_jaccard = []
|
||||||
|
for thresh in [1, 2, 4, 8, 16]:
|
||||||
|
# True positives are points that are within the threshold and where both
|
||||||
|
# the prediction and the ground truth are listed as visible.
|
||||||
|
within_dist = (
|
||||||
|
np.sum(
|
||||||
|
np.square(pred_tracks - gt_tracks),
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
< np.square(thresh)
|
||||||
|
)
|
||||||
|
is_correct = np.logical_and(within_dist, visible)
|
||||||
|
|
||||||
|
# Compute the frac_within_threshold, which is the fraction of points
|
||||||
|
# within the threshold among points that are visible in the ground truth,
|
||||||
|
# ignoring whether they're predicted to be visible.
|
||||||
|
count_correct = np.sum(
|
||||||
|
is_correct & evaluation_points,
|
||||||
|
axis=(1, 2),
|
||||||
|
)
|
||||||
|
count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))
|
||||||
|
frac_correct = count_correct / count_visible_points
|
||||||
|
metrics["pts_within_" + str(thresh)] = frac_correct
|
||||||
|
all_frac_within.append(frac_correct)
|
||||||
|
|
||||||
|
true_positives = np.sum(
|
||||||
|
is_correct & pred_visible & evaluation_points, axis=(1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
# The denominator of the jaccard metric is the true positives plus
|
||||||
|
# false positives plus false negatives. However, note that true positives
|
||||||
|
# plus false negatives is simply the number of points in the ground truth
|
||||||
|
# which is easier to compute than trying to compute all three quantities.
|
||||||
|
# Thus we just add the number of points in the ground truth to the number
|
||||||
|
# of false positives.
|
||||||
|
#
|
||||||
|
# False positives are simply points that are predicted to be visible,
|
||||||
|
# but the ground truth is not visible or too far from the prediction.
|
||||||
|
gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))
|
||||||
|
false_positives = (~visible) & pred_visible
|
||||||
|
false_positives = false_positives | ((~within_dist) & pred_visible)
|
||||||
|
false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))
|
||||||
|
jaccard = true_positives / (gt_positives + false_positives)
|
||||||
|
metrics["jaccard_" + str(thresh)] = jaccard
|
||||||
|
all_jaccard.append(jaccard)
|
||||||
|
metrics["average_jaccard"] = np.mean(
|
||||||
|
np.stack(all_jaccard, axis=1),
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
metrics["average_pts_within_thresh"] = np.mean(
|
||||||
|
np.stack(all_frac_within, axis=1),
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
return metrics
|
252
cotracker/evaluation/core/evaluator.py
Normal file
252
cotracker/evaluation/core/evaluator.py
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from cotracker.datasets.utils import dataclass_to_cuda_
|
||||||
|
from cotracker.utils.visualizer import Visualizer
|
||||||
|
from cotracker.models.core.model_utils import reduce_masked_mean
|
||||||
|
from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
class Evaluator:
|
||||||
|
"""
|
||||||
|
A class defining the CoTracker evaluator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, exp_dir) -> None:
|
||||||
|
# Visualization
|
||||||
|
self.exp_dir = exp_dir
|
||||||
|
os.makedirs(exp_dir, exist_ok=True)
|
||||||
|
self.visualization_filepaths = defaultdict(lambda: defaultdict(list))
|
||||||
|
self.visualize_dir = os.path.join(exp_dir, "visualisations")
|
||||||
|
|
||||||
|
def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name):
|
||||||
|
if isinstance(pred_trajectory, tuple):
|
||||||
|
pred_trajectory, pred_visibility = pred_trajectory
|
||||||
|
else:
|
||||||
|
pred_visibility = None
|
||||||
|
if dataset_name == "badja":
|
||||||
|
sample.segmentation = (sample.segmentation > 0).float()
|
||||||
|
*_, N, _ = sample.trajectory.shape
|
||||||
|
accs = []
|
||||||
|
accs_3px = []
|
||||||
|
for s1 in range(1, sample.video.shape[1]): # target frame
|
||||||
|
for n in range(N):
|
||||||
|
vis = sample.visibility[0, s1, n]
|
||||||
|
if vis > 0:
|
||||||
|
coord_e = pred_trajectory[0, s1, n] # 2
|
||||||
|
coord_g = sample.trajectory[0, s1, n] # 2
|
||||||
|
dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0))
|
||||||
|
area = torch.sum(sample.segmentation[0, s1])
|
||||||
|
# print_('0.2*sqrt(area)', 0.2*torch.sqrt(area))
|
||||||
|
thr = 0.2 * torch.sqrt(area)
|
||||||
|
# correct =
|
||||||
|
accs.append((dist < thr).float())
|
||||||
|
# print('thr',thr)
|
||||||
|
accs_3px.append((dist < 3.0).float())
|
||||||
|
|
||||||
|
res = torch.mean(torch.stack(accs)) * 100.0
|
||||||
|
res_3px = torch.mean(torch.stack(accs_3px)) * 100.0
|
||||||
|
metrics[sample.seq_name[0]] = res.item()
|
||||||
|
metrics[sample.seq_name[0] + "_accuracy"] = res_3px.item()
|
||||||
|
print(metrics)
|
||||||
|
print(
|
||||||
|
"avg", np.mean([v for k, v in metrics.items() if "accuracy" not in k])
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"avg acc 3px",
|
||||||
|
np.mean([v for k, v in metrics.items() if "accuracy" in k]),
|
||||||
|
)
|
||||||
|
elif dataset_name == "fastcapture" or ("kubric" in dataset_name):
|
||||||
|
*_, N, _ = sample.trajectory.shape
|
||||||
|
accs = []
|
||||||
|
for s1 in range(1, sample.video.shape[1]): # target frame
|
||||||
|
for n in range(N):
|
||||||
|
vis = sample.visibility[0, s1, n]
|
||||||
|
if vis > 0:
|
||||||
|
coord_e = pred_trajectory[0, s1, n] # 2
|
||||||
|
coord_g = sample.trajectory[0, s1, n] # 2
|
||||||
|
dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0))
|
||||||
|
thr = 3
|
||||||
|
correct = (dist < thr).float()
|
||||||
|
accs.append(correct)
|
||||||
|
|
||||||
|
res = torch.mean(torch.stack(accs)) * 100.0
|
||||||
|
metrics[sample.seq_name[0] + "_accuracy"] = res.item()
|
||||||
|
print(metrics)
|
||||||
|
print("avg", np.mean([v for v in metrics.values()]))
|
||||||
|
elif "tapvid" in dataset_name:
|
||||||
|
B, T, N, D = sample.trajectory.shape
|
||||||
|
traj = sample.trajectory.clone()
|
||||||
|
thr = 0.9
|
||||||
|
|
||||||
|
if pred_visibility is None:
|
||||||
|
logging.warning("visibility is NONE")
|
||||||
|
pred_visibility = torch.zeros_like(sample.visibility)
|
||||||
|
|
||||||
|
if not pred_visibility.dtype == torch.bool:
|
||||||
|
pred_visibility = pred_visibility > thr
|
||||||
|
|
||||||
|
# pred_trajectory
|
||||||
|
query_points = sample.query_points.clone().cpu().numpy()
|
||||||
|
|
||||||
|
pred_visibility = pred_visibility[:, :, :N]
|
||||||
|
pred_trajectory = pred_trajectory[:, :, :N]
|
||||||
|
|
||||||
|
gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy()
|
||||||
|
gt_occluded = (
|
||||||
|
torch.logical_not(sample.visibility.clone().permute(0, 2, 1))
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
|
||||||
|
pred_occluded = (
|
||||||
|
torch.logical_not(pred_visibility.clone().permute(0, 2, 1))
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy()
|
||||||
|
|
||||||
|
out_metrics = compute_tapvid_metrics(
|
||||||
|
query_points,
|
||||||
|
gt_occluded,
|
||||||
|
gt_tracks,
|
||||||
|
pred_occluded,
|
||||||
|
pred_tracks,
|
||||||
|
query_mode="strided" if "strided" in dataset_name else "first",
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics[sample.seq_name[0]] = out_metrics
|
||||||
|
for metric_name in out_metrics.keys():
|
||||||
|
if "avg" not in metrics:
|
||||||
|
metrics["avg"] = {}
|
||||||
|
metrics["avg"][metric_name] = np.mean(
|
||||||
|
[v[metric_name] for k, v in metrics.items() if k != "avg"]
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Metrics: {out_metrics}")
|
||||||
|
logging.info(f"avg: {metrics['avg']}")
|
||||||
|
print("metrics", out_metrics)
|
||||||
|
print("avg", metrics["avg"])
|
||||||
|
else:
|
||||||
|
rgbs = sample.video
|
||||||
|
trajs_g = sample.trajectory
|
||||||
|
valids = sample.valid
|
||||||
|
vis_g = sample.visibility
|
||||||
|
|
||||||
|
B, S, C, H, W = rgbs.shape
|
||||||
|
assert C == 3
|
||||||
|
B, S, N, D = trajs_g.shape
|
||||||
|
|
||||||
|
assert torch.sum(valids) == B * S * N
|
||||||
|
|
||||||
|
vis_g = (torch.sum(vis_g, dim=1, keepdim=True) >= 4).float().repeat(1, S, 1)
|
||||||
|
|
||||||
|
ate = torch.norm(pred_trajectory - trajs_g, dim=-1) # B, S, N
|
||||||
|
|
||||||
|
metrics["things_all"] = reduce_masked_mean(ate, valids).item()
|
||||||
|
metrics["things_vis"] = reduce_masked_mean(ate, valids * vis_g).item()
|
||||||
|
metrics["things_occ"] = reduce_masked_mean(
|
||||||
|
ate, valids * (1.0 - vis_g)
|
||||||
|
).item()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def evaluate_sequence(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
test_dataloader: torch.utils.data.DataLoader,
|
||||||
|
dataset_name: str,
|
||||||
|
train_mode=False,
|
||||||
|
writer: Optional[SummaryWriter] = None,
|
||||||
|
step: Optional[int] = 0,
|
||||||
|
):
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
vis = Visualizer(
|
||||||
|
save_dir=self.exp_dir,
|
||||||
|
fps=7,
|
||||||
|
)
|
||||||
|
|
||||||
|
for ind, sample in enumerate(tqdm(test_dataloader)):
|
||||||
|
if isinstance(sample, tuple):
|
||||||
|
sample, gotit = sample
|
||||||
|
if not all(gotit):
|
||||||
|
print("batch is None")
|
||||||
|
continue
|
||||||
|
dataclass_to_cuda_(sample)
|
||||||
|
|
||||||
|
if (
|
||||||
|
not train_mode
|
||||||
|
and hasattr(model, "sequence_len")
|
||||||
|
and (sample.visibility[:, : model.sequence_len].sum() == 0)
|
||||||
|
):
|
||||||
|
print(f"skipping batch {ind}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "tapvid" in dataset_name:
|
||||||
|
queries = sample.query_points.clone().float()
|
||||||
|
|
||||||
|
queries = torch.stack(
|
||||||
|
[
|
||||||
|
queries[:, :, 0],
|
||||||
|
queries[:, :, 2],
|
||||||
|
queries[:, :, 1],
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
queries = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros_like(sample.trajectory[:, 0, :, :1]),
|
||||||
|
sample.trajectory[:, 0],
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
pred_tracks = model(sample.video, queries)
|
||||||
|
if "strided" in dataset_name:
|
||||||
|
|
||||||
|
inv_video = sample.video.flip(1).clone()
|
||||||
|
inv_queries = queries.clone()
|
||||||
|
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
||||||
|
|
||||||
|
pred_trj, pred_vsb = pred_tracks
|
||||||
|
inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries)
|
||||||
|
|
||||||
|
inv_pred_trj = inv_pred_trj.flip(1)
|
||||||
|
inv_pred_vsb = inv_pred_vsb.flip(1)
|
||||||
|
|
||||||
|
mask = pred_trj == 0
|
||||||
|
|
||||||
|
pred_trj[mask] = inv_pred_trj[mask]
|
||||||
|
pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]]
|
||||||
|
|
||||||
|
pred_tracks = pred_trj, pred_vsb
|
||||||
|
|
||||||
|
if dataset_name == "badja" or dataset_name == "fastcapture":
|
||||||
|
seq_name = sample.seq_name[0]
|
||||||
|
else:
|
||||||
|
seq_name = str(ind)
|
||||||
|
|
||||||
|
vis.visualize(
|
||||||
|
sample.video,
|
||||||
|
pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
|
||||||
|
filename=dataset_name + "_" + seq_name,
|
||||||
|
writer=writer,
|
||||||
|
step=step,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.compute_metrics(metrics, sample, pred_tracks, dataset_name)
|
||||||
|
return metrics
|
179
cotracker/evaluation/evaluate.py
Normal file
179
cotracker/evaluation/evaluate.py
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import hydra
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from cotracker.datasets.badja_dataset import BadjaDataset
|
||||||
|
from cotracker.datasets.fast_capture_dataset import FastCaptureDataset
|
||||||
|
from cotracker.datasets.tap_vid_datasets import TapVidDataset
|
||||||
|
from cotracker.datasets.utils import collate_fn
|
||||||
|
|
||||||
|
from cotracker.models.evaluation_predictor import EvaluationPredictor
|
||||||
|
|
||||||
|
from cotracker.evaluation.core.evaluator import Evaluator
|
||||||
|
from cotracker.models.build_cotracker import (
|
||||||
|
build_cotracker,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class DefaultConfig:
|
||||||
|
# Directory where all outputs of the experiment will be saved.
|
||||||
|
exp_dir: str = "./outputs"
|
||||||
|
|
||||||
|
# Name of the dataset to be used for the evaluation.
|
||||||
|
dataset_name: str = "badja"
|
||||||
|
# The root directory of the dataset.
|
||||||
|
dataset_root: str = "./"
|
||||||
|
|
||||||
|
# Path to the pre-trained model checkpoint to be used for the evaluation.
|
||||||
|
# The default value is the path to a specific CoTracker model checkpoint.
|
||||||
|
# Other available options are commented.
|
||||||
|
checkpoint: str = "./checkpoints/cotracker_stride_4_wind_8.pth"
|
||||||
|
# cotracker_stride_4_wind_12
|
||||||
|
# cotracker_stride_8_wind_16
|
||||||
|
|
||||||
|
# EvaluationPredictor parameters
|
||||||
|
# The size (N) of the support grid used in the predictor.
|
||||||
|
# The total number of points is (N*N).
|
||||||
|
grid_size: int = 6
|
||||||
|
# The size (N) of the local support grid.
|
||||||
|
local_grid_size: int = 6
|
||||||
|
# A flag indicating whether to evaluate one ground truth point at a time.
|
||||||
|
single_point: bool = True
|
||||||
|
# The number of iterative updates for each sliding window.
|
||||||
|
n_iters: int = 6
|
||||||
|
|
||||||
|
seed: int = 0
|
||||||
|
gpu_idx: int = 0
|
||||||
|
|
||||||
|
# Override hydra's working directory to current working dir,
|
||||||
|
# also disable storing the .hydra logs:
|
||||||
|
hydra: dict = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"run": {"dir": "."},
|
||||||
|
"output_subdir": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_eval(cfg: DefaultConfig):
|
||||||
|
"""
|
||||||
|
The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (DefaultConfig): An instance of DefaultConfig class which includes:
|
||||||
|
- exp_dir (str): The directory path for the experiment.
|
||||||
|
- dataset_name (str): The name of the dataset to be used.
|
||||||
|
- dataset_root (str): The root directory of the dataset.
|
||||||
|
- checkpoint (str): The path to the CoTracker model's checkpoint.
|
||||||
|
- single_point (bool): A flag indicating whether to evaluate one ground truth point at a time.
|
||||||
|
- n_iters (int): The number of iterative updates for each sliding window.
|
||||||
|
- seed (int): The seed for setting the random state for reproducibility.
|
||||||
|
- gpu_idx (int): The index of the GPU to be used.
|
||||||
|
"""
|
||||||
|
# Creating the experiment directory if it doesn't exist
|
||||||
|
os.makedirs(cfg.exp_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Saving the experiment configuration to a .yaml file in the experiment directory
|
||||||
|
cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
|
||||||
|
with open(cfg_file, "w") as f:
|
||||||
|
OmegaConf.save(config=cfg, f=f)
|
||||||
|
|
||||||
|
evaluator = Evaluator(cfg.exp_dir)
|
||||||
|
cotracker_model = build_cotracker(cfg.checkpoint)
|
||||||
|
|
||||||
|
# Creating the EvaluationPredictor object
|
||||||
|
predictor = EvaluationPredictor(
|
||||||
|
cotracker_model,
|
||||||
|
grid_size=cfg.grid_size,
|
||||||
|
local_grid_size=cfg.local_grid_size,
|
||||||
|
single_point=cfg.single_point,
|
||||||
|
n_iters=cfg.n_iters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setting the random seeds
|
||||||
|
torch.manual_seed(cfg.seed)
|
||||||
|
np.random.seed(cfg.seed)
|
||||||
|
|
||||||
|
# Constructing the specified dataset
|
||||||
|
curr_collate_fn = collate_fn
|
||||||
|
if cfg.dataset_name == "badja":
|
||||||
|
test_dataset = BadjaDataset(data_root=os.path.join(cfg.dataset_root, "BADJA"))
|
||||||
|
elif cfg.dataset_name == "fastcapture":
|
||||||
|
test_dataset = FastCaptureDataset(
|
||||||
|
data_root=os.path.join(cfg.dataset_root, "fastcapture"),
|
||||||
|
max_seq_len=100,
|
||||||
|
max_num_points=20,
|
||||||
|
)
|
||||||
|
elif "tapvid" in cfg.dataset_name:
|
||||||
|
dataset_type = cfg.dataset_name.split("_")[1]
|
||||||
|
if dataset_type == "davis":
|
||||||
|
data_root = os.path.join(cfg.dataset_root, "/tapvid_davis/tapvid_davis.pkl")
|
||||||
|
elif dataset_type == "kinetics":
|
||||||
|
data_root = os.path.join(
|
||||||
|
cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics"
|
||||||
|
)
|
||||||
|
test_dataset = TapVidDataset(
|
||||||
|
dataset_type=dataset_type,
|
||||||
|
data_root=data_root,
|
||||||
|
queried_first=not "strided" in cfg.dataset_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Creating the DataLoader object
|
||||||
|
test_dataloader = torch.utils.data.DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=14,
|
||||||
|
collate_fn=curr_collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Timing and conducting the evaluation
|
||||||
|
import time
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
evaluate_result = evaluator.evaluate_sequence(
|
||||||
|
predictor,
|
||||||
|
test_dataloader,
|
||||||
|
dataset_name=cfg.dataset_name,
|
||||||
|
)
|
||||||
|
end = time.time()
|
||||||
|
print(end - start)
|
||||||
|
|
||||||
|
# Saving the evaluation results to a .json file
|
||||||
|
if not "tapvid" in cfg.dataset_name:
|
||||||
|
print("evaluate_result", evaluate_result)
|
||||||
|
else:
|
||||||
|
evaluate_result = evaluate_result["avg"]
|
||||||
|
result_file = os.path.join(cfg.exp_dir, f"result_eval_.json")
|
||||||
|
evaluate_result["time"] = end - start
|
||||||
|
print(f"Dumping eval results to {result_file}.")
|
||||||
|
with open(result_file, "w") as f:
|
||||||
|
json.dump(evaluate_result, f)
|
||||||
|
|
||||||
|
|
||||||
|
cs = hydra.core.config_store.ConfigStore.instance()
|
||||||
|
cs.store(name="default_config_eval", node=DefaultConfig)
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(config_path="./configs/", config_name="default_config_eval")
|
||||||
|
def evaluate(cfg: DefaultConfig) -> None:
|
||||||
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
|
||||||
|
run_eval(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
evaluate()
|
5
cotracker/models/__init__.py
Normal file
5
cotracker/models/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
70
cotracker/models/build_cotracker.py
Normal file
70
cotracker/models/build_cotracker.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from cotracker.models.core.cotracker.cotracker import CoTracker
|
||||||
|
|
||||||
|
|
||||||
|
def build_cotracker(
|
||||||
|
checkpoint: str,
|
||||||
|
):
|
||||||
|
model_name = checkpoint.split("/")[-1].split(".")[0]
|
||||||
|
if model_name == "cotracker_stride_4_wind_8":
|
||||||
|
return build_cotracker_stride_4_wind_8(checkpoint=checkpoint)
|
||||||
|
elif model_name == "cotracker_stride_4_wind_12":
|
||||||
|
return build_cotracker_stride_4_wind_12(checkpoint=checkpoint)
|
||||||
|
elif model_name == "cotracker_stride_8_wind_16":
|
||||||
|
return build_cotracker_stride_8_wind_16(checkpoint=checkpoint)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model name {model_name}")
|
||||||
|
|
||||||
|
|
||||||
|
# model used to produce the results in the paper
|
||||||
|
def build_cotracker_stride_4_wind_8(checkpoint=None):
|
||||||
|
return _build_cotracker(
|
||||||
|
stride=4,
|
||||||
|
sequence_len=8,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_cotracker_stride_4_wind_12(checkpoint=None):
|
||||||
|
return _build_cotracker(
|
||||||
|
stride=4,
|
||||||
|
sequence_len=12,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# the fastest model
|
||||||
|
def build_cotracker_stride_8_wind_16(checkpoint=None):
|
||||||
|
return _build_cotracker(
|
||||||
|
stride=8,
|
||||||
|
sequence_len=16,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_cotracker(
|
||||||
|
stride,
|
||||||
|
sequence_len,
|
||||||
|
checkpoint=None,
|
||||||
|
):
|
||||||
|
cotracker = CoTracker(
|
||||||
|
stride=stride,
|
||||||
|
S=sequence_len,
|
||||||
|
add_space_attn=True,
|
||||||
|
space_depth=6,
|
||||||
|
time_depth=6,
|
||||||
|
)
|
||||||
|
if checkpoint is not None:
|
||||||
|
with open(checkpoint, "rb") as f:
|
||||||
|
state_dict = torch.load(f, map_location="cpu")
|
||||||
|
if "model" in state_dict:
|
||||||
|
state_dict = state_dict["model"]
|
||||||
|
cotracker.load_state_dict(state_dict)
|
||||||
|
return cotracker
|
5
cotracker/models/core/__init__.py
Normal file
5
cotracker/models/core/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
5
cotracker/models/core/cotracker/__init__.py
Normal file
5
cotracker/models/core/cotracker/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
400
cotracker/models/core/cotracker/blocks.py
Normal file
400
cotracker/models/core/cotracker/blocks.py
Normal file
@ -0,0 +1,400 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
from timm.models.vision_transformer import Attention, Mlp
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||||
|
super(ResidualBlock, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
in_planes,
|
||||||
|
planes,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
stride=stride,
|
||||||
|
padding_mode="zeros",
|
||||||
|
)
|
||||||
|
self.conv2 = nn.Conv2d(
|
||||||
|
planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
|
||||||
|
)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
num_groups = planes // 8
|
||||||
|
|
||||||
|
if norm_fn == "group":
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
|
||||||
|
elif norm_fn == "batch":
|
||||||
|
self.norm1 = nn.BatchNorm2d(planes)
|
||||||
|
self.norm2 = nn.BatchNorm2d(planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm3 = nn.BatchNorm2d(planes)
|
||||||
|
|
||||||
|
elif norm_fn == "instance":
|
||||||
|
self.norm1 = nn.InstanceNorm2d(planes)
|
||||||
|
self.norm2 = nn.InstanceNorm2d(planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm3 = nn.InstanceNorm2d(planes)
|
||||||
|
|
||||||
|
elif norm_fn == "none":
|
||||||
|
self.norm1 = nn.Sequential()
|
||||||
|
self.norm2 = nn.Sequential()
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm3 = nn.Sequential()
|
||||||
|
|
||||||
|
if stride == 1:
|
||||||
|
self.downsample = None
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = x
|
||||||
|
y = self.relu(self.norm1(self.conv1(y)))
|
||||||
|
y = self.relu(self.norm2(self.conv2(y)))
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
x = self.downsample(x)
|
||||||
|
|
||||||
|
return self.relu(x + y)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, input_dim=3, output_dim=128, stride=8, norm_fn="batch", dropout=0.0
|
||||||
|
):
|
||||||
|
super(BasicEncoder, self).__init__()
|
||||||
|
self.stride = stride
|
||||||
|
self.norm_fn = norm_fn
|
||||||
|
self.in_planes = 64
|
||||||
|
|
||||||
|
if self.norm_fn == "group":
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes)
|
||||||
|
self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2)
|
||||||
|
|
||||||
|
elif self.norm_fn == "batch":
|
||||||
|
self.norm1 = nn.BatchNorm2d(self.in_planes)
|
||||||
|
self.norm2 = nn.BatchNorm2d(output_dim * 2)
|
||||||
|
|
||||||
|
elif self.norm_fn == "instance":
|
||||||
|
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
||||||
|
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
||||||
|
|
||||||
|
elif self.norm_fn == "none":
|
||||||
|
self.norm1 = nn.Sequential()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
input_dim,
|
||||||
|
self.in_planes,
|
||||||
|
kernel_size=7,
|
||||||
|
stride=2,
|
||||||
|
padding=3,
|
||||||
|
padding_mode="zeros",
|
||||||
|
)
|
||||||
|
self.relu1 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.shallow = False
|
||||||
|
if self.shallow:
|
||||||
|
self.layer1 = self._make_layer(64, stride=1)
|
||||||
|
self.layer2 = self._make_layer(96, stride=2)
|
||||||
|
self.layer3 = self._make_layer(128, stride=2)
|
||||||
|
self.conv2 = nn.Conv2d(128 + 96 + 64, output_dim, kernel_size=1)
|
||||||
|
else:
|
||||||
|
self.layer1 = self._make_layer(64, stride=1)
|
||||||
|
self.layer2 = self._make_layer(96, stride=2)
|
||||||
|
self.layer3 = self._make_layer(128, stride=2)
|
||||||
|
self.layer4 = self._make_layer(128, stride=2)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(
|
||||||
|
128 + 128 + 96 + 64,
|
||||||
|
output_dim * 2,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
padding_mode="zeros",
|
||||||
|
)
|
||||||
|
self.relu2 = nn.ReLU(inplace=True)
|
||||||
|
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
||||||
|
|
||||||
|
self.dropout = None
|
||||||
|
if dropout > 0:
|
||||||
|
self.dropout = nn.Dropout2d(p=dropout)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||||
|
if m.weight is not None:
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def _make_layer(self, dim, stride=1):
|
||||||
|
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||||
|
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||||
|
layers = (layer1, layer2)
|
||||||
|
|
||||||
|
self.in_planes = dim
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
_, _, H, W = x.shape
|
||||||
|
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.relu1(x)
|
||||||
|
|
||||||
|
if self.shallow:
|
||||||
|
a = self.layer1(x)
|
||||||
|
b = self.layer2(a)
|
||||||
|
c = self.layer3(b)
|
||||||
|
a = F.interpolate(
|
||||||
|
a,
|
||||||
|
(H // self.stride, W // self.stride),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
b = F.interpolate(
|
||||||
|
b,
|
||||||
|
(H // self.stride, W // self.stride),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
c = F.interpolate(
|
||||||
|
c,
|
||||||
|
(H // self.stride, W // self.stride),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
x = self.conv2(torch.cat([a, b, c], dim=1))
|
||||||
|
else:
|
||||||
|
a = self.layer1(x)
|
||||||
|
b = self.layer2(a)
|
||||||
|
c = self.layer3(b)
|
||||||
|
d = self.layer4(c)
|
||||||
|
a = F.interpolate(
|
||||||
|
a,
|
||||||
|
(H // self.stride, W // self.stride),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
b = F.interpolate(
|
||||||
|
b,
|
||||||
|
(H // self.stride, W // self.stride),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
c = F.interpolate(
|
||||||
|
c,
|
||||||
|
(H // self.stride, W // self.stride),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
d = F.interpolate(
|
||||||
|
d,
|
||||||
|
(H // self.stride, W // self.stride),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.relu2(x)
|
||||||
|
x = self.conv3(x)
|
||||||
|
|
||||||
|
if self.training and self.dropout is not None:
|
||||||
|
x = self.dropout(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttnBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.attn = Attention(
|
||||||
|
hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||||
|
self.mlp = Mlp(
|
||||||
|
in_features=hidden_size,
|
||||||
|
hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=approx_gelu,
|
||||||
|
drop=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x + self.attn(self.norm1(x))
|
||||||
|
x = x + self.mlp(self.norm2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
|
||||||
|
"""Wrapper for grid_sample, uses pixel coordinates"""
|
||||||
|
H, W = img.shape[-2:]
|
||||||
|
xgrid, ygrid = coords.split([1, 1], dim=-1)
|
||||||
|
# go to 0,1 then 0,2 then -1,1
|
||||||
|
xgrid = 2 * xgrid / (W - 1) - 1
|
||||||
|
ygrid = 2 * ygrid / (H - 1) - 1
|
||||||
|
|
||||||
|
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||||
|
img = F.grid_sample(img, grid, align_corners=True)
|
||||||
|
|
||||||
|
if mask:
|
||||||
|
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||||
|
return img, mask.float()
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class CorrBlock:
|
||||||
|
def __init__(self, fmaps, num_levels=4, radius=4):
|
||||||
|
B, S, C, H, W = fmaps.shape
|
||||||
|
self.S, self.C, self.H, self.W = S, C, H, W
|
||||||
|
|
||||||
|
self.num_levels = num_levels
|
||||||
|
self.radius = radius
|
||||||
|
self.fmaps_pyramid = []
|
||||||
|
|
||||||
|
self.fmaps_pyramid.append(fmaps)
|
||||||
|
for i in range(self.num_levels - 1):
|
||||||
|
fmaps_ = fmaps.reshape(B * S, C, H, W)
|
||||||
|
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
||||||
|
_, _, H, W = fmaps_.shape
|
||||||
|
fmaps = fmaps_.reshape(B, S, C, H, W)
|
||||||
|
self.fmaps_pyramid.append(fmaps)
|
||||||
|
|
||||||
|
def sample(self, coords):
|
||||||
|
r = self.radius
|
||||||
|
B, S, N, D = coords.shape
|
||||||
|
assert D == 2
|
||||||
|
|
||||||
|
H, W = self.H, self.W
|
||||||
|
out_pyramid = []
|
||||||
|
for i in range(self.num_levels):
|
||||||
|
corrs = self.corrs_pyramid[i] # B, S, N, H, W
|
||||||
|
_, _, _, H, W = corrs.shape
|
||||||
|
|
||||||
|
dx = torch.linspace(-r, r, 2 * r + 1)
|
||||||
|
dy = torch.linspace(-r, r, 2 * r + 1)
|
||||||
|
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
|
||||||
|
coords.device
|
||||||
|
)
|
||||||
|
|
||||||
|
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
|
||||||
|
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
||||||
|
coords_lvl = centroid_lvl + delta_lvl
|
||||||
|
|
||||||
|
corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl)
|
||||||
|
corrs = corrs.view(B, S, N, -1)
|
||||||
|
out_pyramid.append(corrs)
|
||||||
|
|
||||||
|
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
||||||
|
return out.contiguous().float()
|
||||||
|
|
||||||
|
def corr(self, targets):
|
||||||
|
B, S, N, C = targets.shape
|
||||||
|
assert C == self.C
|
||||||
|
assert S == self.S
|
||||||
|
|
||||||
|
fmap1 = targets
|
||||||
|
|
||||||
|
self.corrs_pyramid = []
|
||||||
|
for fmaps in self.fmaps_pyramid:
|
||||||
|
_, _, _, H, W = fmaps.shape
|
||||||
|
fmap2s = fmaps.view(B, S, C, H * W)
|
||||||
|
corrs = torch.matmul(fmap1, fmap2s)
|
||||||
|
corrs = corrs.view(B, S, N, H, W)
|
||||||
|
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
||||||
|
self.corrs_pyramid.append(corrs)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateFormer(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer model that updates track estimates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
space_depth=12,
|
||||||
|
time_depth=12,
|
||||||
|
input_dim=320,
|
||||||
|
hidden_size=384,
|
||||||
|
num_heads=8,
|
||||||
|
output_dim=130,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
add_space_attn=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.out_channels = 2
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.add_space_attn = add_space_attn
|
||||||
|
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
||||||
|
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
||||||
|
|
||||||
|
self.time_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
|
||||||
|
for _ in range(time_depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if add_space_attn:
|
||||||
|
self.space_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
|
||||||
|
for _ in range(space_depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert len(self.time_blocks) >= len(self.space_blocks)
|
||||||
|
self.initialize_weights()
|
||||||
|
|
||||||
|
def initialize_weights(self):
|
||||||
|
def _basic_init(module):
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
torch.nn.init.xavier_uniform_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.constant_(module.bias, 0)
|
||||||
|
|
||||||
|
self.apply(_basic_init)
|
||||||
|
|
||||||
|
def forward(self, input_tensor):
|
||||||
|
x = self.input_transform(input_tensor)
|
||||||
|
|
||||||
|
j = 0
|
||||||
|
for i in range(len(self.time_blocks)):
|
||||||
|
B, N, T, _ = x.shape
|
||||||
|
x_time = rearrange(x, "b n t c -> (b n) t c", b=B, t=T, n=N)
|
||||||
|
x_time = self.time_blocks[i](x_time)
|
||||||
|
|
||||||
|
x = rearrange(x_time, "(b n) t c -> b n t c ", b=B, t=T, n=N)
|
||||||
|
if self.add_space_attn and (
|
||||||
|
i % (len(self.time_blocks) // len(self.space_blocks)) == 0
|
||||||
|
):
|
||||||
|
x_space = rearrange(x, "b n t c -> (b t) n c ", b=B, t=T, n=N)
|
||||||
|
x_space = self.space_blocks[j](x_space)
|
||||||
|
x = rearrange(x_space, "(b t) n c -> b n t c ", b=B, t=T, n=N)
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
flow = self.flow_head(x)
|
||||||
|
return flow
|
351
cotracker/models/core/cotracker/cotracker.py
Normal file
351
cotracker/models/core/cotracker/cotracker.py
Normal file
@ -0,0 +1,351 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from cotracker.models.core.cotracker.blocks import (
|
||||||
|
BasicEncoder,
|
||||||
|
CorrBlock,
|
||||||
|
UpdateFormer,
|
||||||
|
)
|
||||||
|
|
||||||
|
from cotracker.models.core.model_utils import meshgrid2d, bilinear_sample2d, smart_cat
|
||||||
|
from cotracker.models.core.embeddings import (
|
||||||
|
get_2d_embedding,
|
||||||
|
get_1d_sincos_pos_embed_from_grid,
|
||||||
|
get_2d_sincos_pos_embed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
|
||||||
|
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)):
|
||||||
|
if grid_size == 1:
|
||||||
|
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[
|
||||||
|
None, None
|
||||||
|
].cuda()
|
||||||
|
|
||||||
|
grid_y, grid_x = meshgrid2d(
|
||||||
|
1, grid_size, grid_size, stack=False, norm=False, device="cuda"
|
||||||
|
)
|
||||||
|
step = interp_shape[1] // 64
|
||||||
|
if grid_center[0] != 0 or grid_center[1] != 0:
|
||||||
|
grid_y = grid_y - grid_size / 2.0
|
||||||
|
grid_x = grid_x - grid_size / 2.0
|
||||||
|
grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * (
|
||||||
|
interp_shape[0] - step * 2
|
||||||
|
)
|
||||||
|
grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * (
|
||||||
|
interp_shape[1] - step * 2
|
||||||
|
)
|
||||||
|
|
||||||
|
grid_y = grid_y + grid_center[0]
|
||||||
|
grid_x = grid_x + grid_center[1]
|
||||||
|
xy = torch.stack([grid_x, grid_y], dim=-1).cuda()
|
||||||
|
return xy
|
||||||
|
|
||||||
|
|
||||||
|
def sample_pos_embed(grid_size, embed_dim, coords):
|
||||||
|
pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim, grid_size=grid_size)
|
||||||
|
pos_embed = (
|
||||||
|
torch.from_numpy(pos_embed)
|
||||||
|
.reshape(grid_size[0], grid_size[1], embed_dim)
|
||||||
|
.float()
|
||||||
|
.unsqueeze(0)
|
||||||
|
.to(coords.device)
|
||||||
|
)
|
||||||
|
sampled_pos_embed = bilinear_sample2d(
|
||||||
|
pos_embed.permute(0, 3, 1, 2), coords[:, 0, :, 0], coords[:, 0, :, 1]
|
||||||
|
)
|
||||||
|
return sampled_pos_embed
|
||||||
|
|
||||||
|
|
||||||
|
class CoTracker(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
S=8,
|
||||||
|
stride=8,
|
||||||
|
add_space_attn=True,
|
||||||
|
num_heads=8,
|
||||||
|
hidden_size=384,
|
||||||
|
space_depth=12,
|
||||||
|
time_depth=12,
|
||||||
|
):
|
||||||
|
super(CoTracker, self).__init__()
|
||||||
|
self.S = S
|
||||||
|
self.stride = stride
|
||||||
|
self.hidden_dim = 256
|
||||||
|
self.latent_dim = latent_dim = 128
|
||||||
|
self.corr_levels = 4
|
||||||
|
self.corr_radius = 3
|
||||||
|
self.add_space_attn = add_space_attn
|
||||||
|
self.fnet = BasicEncoder(
|
||||||
|
output_dim=self.latent_dim, norm_fn="instance", dropout=0, stride=stride
|
||||||
|
)
|
||||||
|
|
||||||
|
self.updateformer = UpdateFormer(
|
||||||
|
space_depth=space_depth,
|
||||||
|
time_depth=time_depth,
|
||||||
|
input_dim=456,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
output_dim=latent_dim + 2,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
add_space_attn=add_space_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm = nn.GroupNorm(1, self.latent_dim)
|
||||||
|
self.ffeat_updater = nn.Sequential(
|
||||||
|
nn.Linear(self.latent_dim, self.latent_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
)
|
||||||
|
self.vis_predictor = nn.Sequential(
|
||||||
|
nn.Linear(self.latent_dim, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_iteration(
|
||||||
|
self,
|
||||||
|
fmaps,
|
||||||
|
coords_init,
|
||||||
|
feat_init=None,
|
||||||
|
vis_init=None,
|
||||||
|
track_mask=None,
|
||||||
|
iters=4,
|
||||||
|
):
|
||||||
|
B, S_init, N, D = coords_init.shape
|
||||||
|
assert D == 2
|
||||||
|
assert B == 1
|
||||||
|
|
||||||
|
B, S, __, H8, W8 = fmaps.shape
|
||||||
|
|
||||||
|
device = fmaps.device
|
||||||
|
|
||||||
|
if S_init < S:
|
||||||
|
coords = torch.cat(
|
||||||
|
[coords_init, coords_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
|
||||||
|
)
|
||||||
|
vis_init = torch.cat(
|
||||||
|
[vis_init, vis_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
coords = coords_init.clone()
|
||||||
|
|
||||||
|
fcorr_fn = CorrBlock(
|
||||||
|
fmaps, num_levels=self.corr_levels, radius=self.corr_radius
|
||||||
|
)
|
||||||
|
|
||||||
|
ffeats = feat_init.clone()
|
||||||
|
|
||||||
|
times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1)
|
||||||
|
|
||||||
|
pos_embed = sample_pos_embed(
|
||||||
|
grid_size=(H8, W8),
|
||||||
|
embed_dim=456,
|
||||||
|
coords=coords,
|
||||||
|
)
|
||||||
|
pos_embed = rearrange(pos_embed, "b e n -> (b n) e").unsqueeze(1)
|
||||||
|
times_embed = (
|
||||||
|
torch.from_numpy(get_1d_sincos_pos_embed_from_grid(456, times_[0]))[None]
|
||||||
|
.repeat(B, 1, 1)
|
||||||
|
.float()
|
||||||
|
.to(device)
|
||||||
|
)
|
||||||
|
coord_predictions = []
|
||||||
|
|
||||||
|
for __ in range(iters):
|
||||||
|
coords = coords.detach()
|
||||||
|
fcorr_fn.corr(ffeats)
|
||||||
|
|
||||||
|
fcorrs = fcorr_fn.sample(coords) # B, S, N, LRR
|
||||||
|
LRR = fcorrs.shape[3]
|
||||||
|
|
||||||
|
fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR)
|
||||||
|
flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
|
||||||
|
|
||||||
|
flows_cat = get_2d_embedding(flows_, 64, cat_coords=True)
|
||||||
|
ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
|
||||||
|
|
||||||
|
if track_mask.shape[1] < vis_init.shape[1]:
|
||||||
|
track_mask = torch.cat(
|
||||||
|
[
|
||||||
|
track_mask,
|
||||||
|
torch.zeros_like(track_mask[:, 0]).repeat(
|
||||||
|
1, vis_init.shape[1] - track_mask.shape[1], 1, 1
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
concat = (
|
||||||
|
torch.cat([track_mask, vis_init], dim=2)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(B * N, S, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
transformer_input = torch.cat([flows_cat, fcorrs_, ffeats_, concat], dim=2)
|
||||||
|
x = transformer_input + pos_embed + times_embed
|
||||||
|
|
||||||
|
x = rearrange(x, "(b n) t d -> b n t d", b=B)
|
||||||
|
|
||||||
|
delta = self.updateformer(x)
|
||||||
|
|
||||||
|
delta = rearrange(delta, " b n t d -> (b n) t d")
|
||||||
|
|
||||||
|
delta_coords_ = delta[:, :, :2]
|
||||||
|
delta_feats_ = delta[:, :, 2:]
|
||||||
|
|
||||||
|
delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
|
||||||
|
ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N * S, self.latent_dim)
|
||||||
|
|
||||||
|
ffeats_ = self.ffeat_updater(self.norm(delta_feats_)) + ffeats_
|
||||||
|
|
||||||
|
ffeats = ffeats_.reshape(B, N, S, self.latent_dim).permute(
|
||||||
|
0, 2, 1, 3
|
||||||
|
) # B,S,N,C
|
||||||
|
|
||||||
|
coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
|
||||||
|
coord_predictions.append(coords * self.stride)
|
||||||
|
|
||||||
|
vis_e = self.vis_predictor(ffeats.reshape(B * S * N, self.latent_dim)).reshape(
|
||||||
|
B, S, N
|
||||||
|
)
|
||||||
|
return coord_predictions, vis_e, feat_init
|
||||||
|
|
||||||
|
def forward(self, rgbs, queries, iters=4, feat_init=None, is_train=False):
|
||||||
|
B, T, C, H, W = rgbs.shape
|
||||||
|
B, N, __ = queries.shape
|
||||||
|
|
||||||
|
device = rgbs.device
|
||||||
|
assert B == 1
|
||||||
|
# INIT for the first sequence
|
||||||
|
# We want to sort points by the first frame they are visible to add them to the tensor of tracked points consequtively
|
||||||
|
first_positive_inds = queries[:, :, 0].long()
|
||||||
|
|
||||||
|
__, sort_inds = torch.sort(first_positive_inds[0], dim=0, descending=False)
|
||||||
|
inv_sort_inds = torch.argsort(sort_inds, dim=0)
|
||||||
|
first_positive_sorted_inds = first_positive_inds[0][sort_inds]
|
||||||
|
|
||||||
|
assert torch.allclose(
|
||||||
|
first_positive_inds[0], first_positive_inds[0][sort_inds][inv_sort_inds]
|
||||||
|
)
|
||||||
|
|
||||||
|
coords_init = queries[:, :, 1:].reshape(B, 1, N, 2).repeat(
|
||||||
|
1, self.S, 1, 1
|
||||||
|
) / float(self.stride)
|
||||||
|
|
||||||
|
rgbs = 2 * (rgbs / 255.0) - 1.0
|
||||||
|
|
||||||
|
traj_e = torch.zeros((B, T, N, 2), device=device)
|
||||||
|
vis_e = torch.zeros((B, T, N), device=device)
|
||||||
|
|
||||||
|
ind_array = torch.arange(T, device=device)
|
||||||
|
ind_array = ind_array[None, :, None].repeat(B, 1, N)
|
||||||
|
|
||||||
|
track_mask = (ind_array >= first_positive_inds[:, None, :]).unsqueeze(-1)
|
||||||
|
# these are logits, so we initialize visibility with something that would give a value close to 1 after softmax
|
||||||
|
vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10
|
||||||
|
|
||||||
|
ind = 0
|
||||||
|
|
||||||
|
track_mask_ = track_mask[:, :, sort_inds].clone()
|
||||||
|
coords_init_ = coords_init[:, :, sort_inds].clone()
|
||||||
|
vis_init_ = vis_init[:, :, sort_inds].clone()
|
||||||
|
|
||||||
|
prev_wind_idx = 0
|
||||||
|
fmaps_ = None
|
||||||
|
vis_predictions = []
|
||||||
|
coord_predictions = []
|
||||||
|
wind_inds = []
|
||||||
|
while ind < T - self.S // 2:
|
||||||
|
rgbs_seq = rgbs[:, ind : ind + self.S]
|
||||||
|
|
||||||
|
S = S_local = rgbs_seq.shape[1]
|
||||||
|
if S < self.S:
|
||||||
|
rgbs_seq = torch.cat(
|
||||||
|
[rgbs_seq, rgbs_seq[:, -1, None].repeat(1, self.S - S, 1, 1, 1)],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
S = rgbs_seq.shape[1]
|
||||||
|
rgbs_ = rgbs_seq.reshape(B * S, C, H, W)
|
||||||
|
|
||||||
|
if fmaps_ is None:
|
||||||
|
fmaps_ = self.fnet(rgbs_)
|
||||||
|
else:
|
||||||
|
fmaps_ = torch.cat(
|
||||||
|
[fmaps_[self.S // 2 :], self.fnet(rgbs_[self.S // 2 :])], dim=0
|
||||||
|
)
|
||||||
|
fmaps = fmaps_.reshape(
|
||||||
|
B, S, self.latent_dim, H // self.stride, W // self.stride
|
||||||
|
)
|
||||||
|
|
||||||
|
curr_wind_points = torch.nonzero(first_positive_sorted_inds < ind + self.S)
|
||||||
|
if curr_wind_points.shape[0] == 0:
|
||||||
|
ind = ind + self.S // 2
|
||||||
|
continue
|
||||||
|
wind_idx = curr_wind_points[-1] + 1
|
||||||
|
|
||||||
|
if wind_idx - prev_wind_idx > 0:
|
||||||
|
fmaps_sample = fmaps[
|
||||||
|
:, first_positive_sorted_inds[prev_wind_idx:wind_idx] - ind
|
||||||
|
]
|
||||||
|
|
||||||
|
feat_init_ = bilinear_sample2d(
|
||||||
|
fmaps_sample,
|
||||||
|
coords_init_[:, 0, prev_wind_idx:wind_idx, 0],
|
||||||
|
coords_init_[:, 0, prev_wind_idx:wind_idx, 1],
|
||||||
|
).permute(0, 2, 1)
|
||||||
|
|
||||||
|
feat_init_ = feat_init_.unsqueeze(1).repeat(1, self.S, 1, 1)
|
||||||
|
feat_init = smart_cat(feat_init, feat_init_, dim=2)
|
||||||
|
|
||||||
|
if prev_wind_idx > 0:
|
||||||
|
new_coords = coords[-1][:, self.S // 2 :] / float(self.stride)
|
||||||
|
|
||||||
|
coords_init_[:, : self.S // 2, :prev_wind_idx] = new_coords
|
||||||
|
coords_init_[:, self.S // 2 :, :prev_wind_idx] = new_coords[
|
||||||
|
:, -1
|
||||||
|
].repeat(1, self.S // 2, 1, 1)
|
||||||
|
|
||||||
|
new_vis = vis[:, self.S // 2 :].unsqueeze(-1)
|
||||||
|
vis_init_[:, : self.S // 2, :prev_wind_idx] = new_vis
|
||||||
|
vis_init_[:, self.S // 2 :, :prev_wind_idx] = new_vis[:, -1].repeat(
|
||||||
|
1, self.S // 2, 1, 1
|
||||||
|
)
|
||||||
|
|
||||||
|
coords, vis, __ = self.forward_iteration(
|
||||||
|
fmaps=fmaps,
|
||||||
|
coords_init=coords_init_[:, :, :wind_idx],
|
||||||
|
feat_init=feat_init[:, :, :wind_idx],
|
||||||
|
vis_init=vis_init_[:, :, :wind_idx],
|
||||||
|
track_mask=track_mask_[:, ind : ind + self.S, :wind_idx],
|
||||||
|
iters=iters,
|
||||||
|
)
|
||||||
|
if is_train:
|
||||||
|
vis_predictions.append(torch.sigmoid(vis[:, :S_local]))
|
||||||
|
coord_predictions.append([coord[:, :S_local] for coord in coords])
|
||||||
|
wind_inds.append(wind_idx)
|
||||||
|
|
||||||
|
traj_e[:, ind : ind + self.S, :wind_idx] = coords[-1][:, :S_local]
|
||||||
|
vis_e[:, ind : ind + self.S, :wind_idx] = vis[:, :S_local]
|
||||||
|
|
||||||
|
track_mask_[:, : ind + self.S, :wind_idx] = 0.0
|
||||||
|
ind = ind + self.S // 2
|
||||||
|
|
||||||
|
prev_wind_idx = wind_idx
|
||||||
|
|
||||||
|
traj_e = traj_e[:, :, inv_sort_inds]
|
||||||
|
vis_e = vis_e[:, :, inv_sort_inds]
|
||||||
|
|
||||||
|
vis_e = torch.sigmoid(vis_e)
|
||||||
|
|
||||||
|
train_data = (
|
||||||
|
(vis_predictions, coord_predictions, wind_inds, sort_inds)
|
||||||
|
if is_train
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return traj_e, feat_init, vis_e, train_data
|
61
cotracker/models/core/cotracker/losses.py
Normal file
61
cotracker/models/core/cotracker/losses.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from cotracker.models.core.model_utils import reduce_masked_mean
|
||||||
|
|
||||||
|
EPS = 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def balanced_ce_loss(pred, gt, valid=None):
|
||||||
|
total_balanced_loss = 0.0
|
||||||
|
for j in range(len(gt)):
|
||||||
|
B, S, N = gt[j].shape
|
||||||
|
# pred and gt are the same shape
|
||||||
|
for (a, b) in zip(pred[j].size(), gt[j].size()):
|
||||||
|
assert a == b # some shape mismatch!
|
||||||
|
# if valid is not None:
|
||||||
|
for (a, b) in zip(pred[j].size(), valid[j].size()):
|
||||||
|
assert a == b # some shape mismatch!
|
||||||
|
|
||||||
|
pos = (gt[j] > 0.95).float()
|
||||||
|
neg = (gt[j] < 0.05).float()
|
||||||
|
|
||||||
|
label = pos * 2.0 - 1.0
|
||||||
|
a = -label * pred[j]
|
||||||
|
b = F.relu(a)
|
||||||
|
loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
|
||||||
|
|
||||||
|
pos_loss = reduce_masked_mean(loss, pos * valid[j])
|
||||||
|
neg_loss = reduce_masked_mean(loss, neg * valid[j])
|
||||||
|
|
||||||
|
balanced_loss = pos_loss + neg_loss
|
||||||
|
total_balanced_loss += balanced_loss / float(N)
|
||||||
|
return total_balanced_loss
|
||||||
|
|
||||||
|
|
||||||
|
def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8):
|
||||||
|
"""Loss function defined over sequence of flow predictions"""
|
||||||
|
total_flow_loss = 0.0
|
||||||
|
for j in range(len(flow_gt)):
|
||||||
|
B, S, N, D = flow_gt[j].shape
|
||||||
|
assert D == 2
|
||||||
|
B, S1, N = vis[j].shape
|
||||||
|
B, S2, N = valids[j].shape
|
||||||
|
assert S == S1
|
||||||
|
assert S == S2
|
||||||
|
n_predictions = len(flow_preds[j])
|
||||||
|
flow_loss = 0.0
|
||||||
|
for i in range(n_predictions):
|
||||||
|
i_weight = gamma ** (n_predictions - i - 1)
|
||||||
|
flow_pred = flow_preds[j][i]
|
||||||
|
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
|
||||||
|
i_loss = torch.mean(i_loss, dim=3) # B, S, N
|
||||||
|
flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j])
|
||||||
|
flow_loss = flow_loss / n_predictions
|
||||||
|
total_flow_loss += flow_loss / float(N)
|
||||||
|
return total_flow_loss
|
154
cotracker/models/core/embeddings.py
Normal file
154
cotracker/models/core/embeddings.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
||||||
|
"""
|
||||||
|
grid_size: int of the grid height and width
|
||||||
|
return:
|
||||||
|
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||||
|
"""
|
||||||
|
if isinstance(grid_size, tuple):
|
||||||
|
grid_size_h, grid_size_w = grid_size
|
||||||
|
else:
|
||||||
|
grid_size_h = grid_size_w = grid_size
|
||||||
|
grid_h = np.arange(grid_size_h, dtype=np.float32)
|
||||||
|
grid_w = np.arange(grid_size_w, dtype=np.float32)
|
||||||
|
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||||
|
grid = np.stack(grid, axis=0)
|
||||||
|
|
||||||
|
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
||||||
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||||
|
if cls_token and extra_tokens > 0:
|
||||||
|
pos_embed = np.concatenate(
|
||||||
|
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
|
||||||
|
)
|
||||||
|
return pos_embed
|
||||||
|
|
||||||
|
|
||||||
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||||
|
assert embed_dim % 2 == 0
|
||||||
|
|
||||||
|
# use half of dimensions to encode grid_h
|
||||||
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||||
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||||
|
|
||||||
|
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||||
|
"""
|
||||||
|
embed_dim: output dimension for each position
|
||||||
|
pos: a list of positions to be encoded: size (M,)
|
||||||
|
out: (M, D)
|
||||||
|
"""
|
||||||
|
assert embed_dim % 2 == 0
|
||||||
|
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||||
|
omega /= embed_dim / 2.0
|
||||||
|
omega = 1.0 / 10000 ** omega # (D/2,)
|
||||||
|
|
||||||
|
pos = pos.reshape(-1) # (M,)
|
||||||
|
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||||
|
|
||||||
|
emb_sin = np.sin(out) # (M, D/2)
|
||||||
|
emb_cos = np.cos(out) # (M, D/2)
|
||||||
|
|
||||||
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
def get_2d_embedding(xy, C, cat_coords=True):
|
||||||
|
B, N, D = xy.shape
|
||||||
|
assert D == 2
|
||||||
|
|
||||||
|
x = xy[:, :, 0:1]
|
||||||
|
y = xy[:, :, 1:2]
|
||||||
|
div_term = (
|
||||||
|
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
|
||||||
|
).reshape(1, 1, int(C / 2))
|
||||||
|
|
||||||
|
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
||||||
|
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
||||||
|
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
||||||
|
|
||||||
|
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
||||||
|
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
||||||
|
|
||||||
|
pe = torch.cat([pe_x, pe_y], dim=2) # B, N, C*3
|
||||||
|
if cat_coords:
|
||||||
|
pe = torch.cat([xy, pe], dim=2) # B, N, C*3+3
|
||||||
|
return pe
|
||||||
|
|
||||||
|
|
||||||
|
def get_3d_embedding(xyz, C, cat_coords=True):
|
||||||
|
B, N, D = xyz.shape
|
||||||
|
assert D == 3
|
||||||
|
|
||||||
|
x = xyz[:, :, 0:1]
|
||||||
|
y = xyz[:, :, 1:2]
|
||||||
|
z = xyz[:, :, 2:3]
|
||||||
|
div_term = (
|
||||||
|
torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C)
|
||||||
|
).reshape(1, 1, int(C / 2))
|
||||||
|
|
||||||
|
pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
||||||
|
pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
||||||
|
pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
||||||
|
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
||||||
|
|
||||||
|
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
||||||
|
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
||||||
|
|
||||||
|
pe_z[:, :, 0::2] = torch.sin(z * div_term)
|
||||||
|
pe_z[:, :, 1::2] = torch.cos(z * div_term)
|
||||||
|
|
||||||
|
pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3
|
||||||
|
if cat_coords:
|
||||||
|
pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3
|
||||||
|
return pe
|
||||||
|
|
||||||
|
|
||||||
|
def get_4d_embedding(xyzw, C, cat_coords=True):
|
||||||
|
B, N, D = xyzw.shape
|
||||||
|
assert D == 4
|
||||||
|
|
||||||
|
x = xyzw[:, :, 0:1]
|
||||||
|
y = xyzw[:, :, 1:2]
|
||||||
|
z = xyzw[:, :, 2:3]
|
||||||
|
w = xyzw[:, :, 3:4]
|
||||||
|
div_term = (
|
||||||
|
torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C)
|
||||||
|
).reshape(1, 1, int(C / 2))
|
||||||
|
|
||||||
|
pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
||||||
|
pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
||||||
|
pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
||||||
|
pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
||||||
|
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
||||||
|
|
||||||
|
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
||||||
|
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
||||||
|
|
||||||
|
pe_z[:, :, 0::2] = torch.sin(z * div_term)
|
||||||
|
pe_z[:, :, 1::2] = torch.cos(z * div_term)
|
||||||
|
|
||||||
|
pe_w[:, :, 0::2] = torch.sin(w * div_term)
|
||||||
|
pe_w[:, :, 1::2] = torch.cos(w * div_term)
|
||||||
|
|
||||||
|
pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2) # B, N, C*3
|
||||||
|
if cat_coords:
|
||||||
|
pe = torch.cat([pe, xyzw], dim=2) # B, N, C*3+3
|
||||||
|
return pe
|
169
cotracker/models/core/model_utils.py
Normal file
169
cotracker/models/core/model_utils.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
EPS = 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def smart_cat(tensor1, tensor2, dim):
|
||||||
|
if tensor1 is None:
|
||||||
|
return tensor2
|
||||||
|
return torch.cat([tensor1, tensor2], dim=dim)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_single(d):
|
||||||
|
# d is a whatever shape torch tensor
|
||||||
|
dmin = torch.min(d)
|
||||||
|
dmax = torch.max(d)
|
||||||
|
d = (d - dmin) / (EPS + (dmax - dmin))
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(d):
|
||||||
|
# d is B x whatever. normalize within each element of the batch
|
||||||
|
out = torch.zeros(d.size())
|
||||||
|
if d.is_cuda:
|
||||||
|
out = out.cuda()
|
||||||
|
B = list(d.size())[0]
|
||||||
|
for b in list(range(B)):
|
||||||
|
out[b] = normalize_single(d[b])
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"):
|
||||||
|
# returns a meshgrid sized B x Y x X
|
||||||
|
|
||||||
|
grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))
|
||||||
|
grid_y = torch.reshape(grid_y, [1, Y, 1])
|
||||||
|
grid_y = grid_y.repeat(B, 1, X)
|
||||||
|
|
||||||
|
grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device))
|
||||||
|
grid_x = torch.reshape(grid_x, [1, 1, X])
|
||||||
|
grid_x = grid_x.repeat(B, Y, 1)
|
||||||
|
|
||||||
|
if stack:
|
||||||
|
# note we stack in xy order
|
||||||
|
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
|
||||||
|
grid = torch.stack([grid_x, grid_y], dim=-1)
|
||||||
|
return grid
|
||||||
|
else:
|
||||||
|
return grid_y, grid_x
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_masked_mean(x, mask, dim=None, keepdim=False):
|
||||||
|
# x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
|
||||||
|
# returns shape-1
|
||||||
|
# axis can be a list of axes
|
||||||
|
for (a, b) in zip(x.size(), mask.size()):
|
||||||
|
assert a == b # some shape mismatch!
|
||||||
|
prod = x * mask
|
||||||
|
if dim is None:
|
||||||
|
numer = torch.sum(prod)
|
||||||
|
denom = EPS + torch.sum(mask)
|
||||||
|
else:
|
||||||
|
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
||||||
|
denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim)
|
||||||
|
|
||||||
|
mean = numer / denom
|
||||||
|
return mean
|
||||||
|
|
||||||
|
|
||||||
|
def bilinear_sample2d(im, x, y, return_inbounds=False):
|
||||||
|
# x and y are each B, N
|
||||||
|
# output is B, C, N
|
||||||
|
if len(im.shape) == 5:
|
||||||
|
B, N, C, H, W = list(im.shape)
|
||||||
|
else:
|
||||||
|
B, C, H, W = list(im.shape)
|
||||||
|
N = list(x.shape)[1]
|
||||||
|
|
||||||
|
x = x.float()
|
||||||
|
y = y.float()
|
||||||
|
H_f = torch.tensor(H, dtype=torch.float32)
|
||||||
|
W_f = torch.tensor(W, dtype=torch.float32)
|
||||||
|
|
||||||
|
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
|
||||||
|
|
||||||
|
max_y = (H_f - 1).int()
|
||||||
|
max_x = (W_f - 1).int()
|
||||||
|
|
||||||
|
x0 = torch.floor(x).int()
|
||||||
|
x1 = x0 + 1
|
||||||
|
y0 = torch.floor(y).int()
|
||||||
|
y1 = y0 + 1
|
||||||
|
|
||||||
|
x0_clip = torch.clamp(x0, 0, max_x)
|
||||||
|
x1_clip = torch.clamp(x1, 0, max_x)
|
||||||
|
y0_clip = torch.clamp(y0, 0, max_y)
|
||||||
|
y1_clip = torch.clamp(y1, 0, max_y)
|
||||||
|
dim2 = W
|
||||||
|
dim1 = W * H
|
||||||
|
|
||||||
|
base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1
|
||||||
|
base = torch.reshape(base, [B, 1]).repeat([1, N])
|
||||||
|
|
||||||
|
base_y0 = base + y0_clip * dim2
|
||||||
|
base_y1 = base + y1_clip * dim2
|
||||||
|
|
||||||
|
idx_y0_x0 = base_y0 + x0_clip
|
||||||
|
idx_y0_x1 = base_y0 + x1_clip
|
||||||
|
idx_y1_x0 = base_y1 + x0_clip
|
||||||
|
idx_y1_x1 = base_y1 + x1_clip
|
||||||
|
|
||||||
|
# use the indices to lookup pixels in the flat image
|
||||||
|
# im is B x C x H x W
|
||||||
|
# move C out to last dim
|
||||||
|
if len(im.shape) == 5:
|
||||||
|
im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C)
|
||||||
|
i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute(
|
||||||
|
0, 2, 1
|
||||||
|
)
|
||||||
|
i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute(
|
||||||
|
0, 2, 1
|
||||||
|
)
|
||||||
|
i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute(
|
||||||
|
0, 2, 1
|
||||||
|
)
|
||||||
|
i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute(
|
||||||
|
0, 2, 1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C)
|
||||||
|
i_y0_x0 = im_flat[idx_y0_x0.long()]
|
||||||
|
i_y0_x1 = im_flat[idx_y0_x1.long()]
|
||||||
|
i_y1_x0 = im_flat[idx_y1_x0.long()]
|
||||||
|
i_y1_x1 = im_flat[idx_y1_x1.long()]
|
||||||
|
|
||||||
|
# Finally calculate interpolated values.
|
||||||
|
x0_f = x0.float()
|
||||||
|
x1_f = x1.float()
|
||||||
|
y0_f = y0.float()
|
||||||
|
y1_f = y1.float()
|
||||||
|
|
||||||
|
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
|
||||||
|
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
|
||||||
|
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
|
||||||
|
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
|
||||||
|
|
||||||
|
output = (
|
||||||
|
w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
|
||||||
|
)
|
||||||
|
# output is B*N x C
|
||||||
|
output = output.view(B, -1, C)
|
||||||
|
output = output.permute(0, 2, 1)
|
||||||
|
# output is B x C x N
|
||||||
|
|
||||||
|
if return_inbounds:
|
||||||
|
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
|
||||||
|
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
|
||||||
|
inbounds = (x_valid & y_valid).float()
|
||||||
|
inbounds = inbounds.reshape(
|
||||||
|
B, N
|
||||||
|
) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
|
||||||
|
return output, inbounds
|
||||||
|
|
||||||
|
return output # B, C, N
|
103
cotracker/models/evaluation_predictor.py
Normal file
103
cotracker/models/evaluation_predictor.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from cotracker.models.core.cotracker.cotracker import CoTracker, get_points_on_a_grid
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationPredictor(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cotracker_model: CoTracker,
|
||||||
|
interp_shape: Tuple[int, int] = (384, 512),
|
||||||
|
grid_size: int = 6,
|
||||||
|
local_grid_size: int = 6,
|
||||||
|
single_point: bool = True,
|
||||||
|
n_iters: int = 6,
|
||||||
|
) -> None:
|
||||||
|
super(EvaluationPredictor, self).__init__()
|
||||||
|
self.grid_size = grid_size
|
||||||
|
self.local_grid_size = local_grid_size
|
||||||
|
self.single_point = single_point
|
||||||
|
self.interp_shape = interp_shape
|
||||||
|
self.n_iters = n_iters
|
||||||
|
|
||||||
|
self.model = cotracker_model
|
||||||
|
self.model.to("cuda")
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def forward(self, video, queries):
|
||||||
|
queries = queries.clone().cuda()
|
||||||
|
B, T, C, H, W = video.shape
|
||||||
|
B, N, D = queries.shape
|
||||||
|
|
||||||
|
assert D == 3
|
||||||
|
assert B == 1
|
||||||
|
|
||||||
|
rgbs = video.reshape(B * T, C, H, W)
|
||||||
|
rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear")
|
||||||
|
rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]).cuda()
|
||||||
|
|
||||||
|
queries[:, :, 1] *= self.interp_shape[1] / W
|
||||||
|
queries[:, :, 2] *= self.interp_shape[0] / H
|
||||||
|
|
||||||
|
if self.single_point:
|
||||||
|
traj_e = torch.zeros((B, T, N, 2)).cuda()
|
||||||
|
vis_e = torch.zeros((B, T, N)).cuda()
|
||||||
|
for pind in range((N)):
|
||||||
|
query = queries[:, pind : pind + 1]
|
||||||
|
|
||||||
|
t = query[0, 0, 0].long()
|
||||||
|
|
||||||
|
traj_e_pind, vis_e_pind = self._process_one_point(rgbs, query)
|
||||||
|
traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
|
||||||
|
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
|
||||||
|
else:
|
||||||
|
if self.grid_size > 0:
|
||||||
|
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
|
||||||
|
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() #
|
||||||
|
queries = torch.cat([queries, xy], dim=1) #
|
||||||
|
|
||||||
|
traj_e, __, vis_e, __ = self.model(
|
||||||
|
rgbs=rgbs,
|
||||||
|
queries=queries,
|
||||||
|
iters=self.n_iters,
|
||||||
|
)
|
||||||
|
|
||||||
|
traj_e[:, :, :, 0] *= W / float(self.interp_shape[1])
|
||||||
|
traj_e[:, :, :, 1] *= H / float(self.interp_shape[0])
|
||||||
|
return traj_e, vis_e
|
||||||
|
|
||||||
|
def _process_one_point(self, rgbs, query):
|
||||||
|
t = query[0, 0, 0].long()
|
||||||
|
|
||||||
|
device = rgbs.device
|
||||||
|
if self.local_grid_size > 0:
|
||||||
|
xy_target = get_points_on_a_grid(
|
||||||
|
self.local_grid_size,
|
||||||
|
(50, 50),
|
||||||
|
[query[0, 0, 2], query[0, 0, 1]],
|
||||||
|
)
|
||||||
|
|
||||||
|
xy_target = torch.cat(
|
||||||
|
[torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2
|
||||||
|
) #
|
||||||
|
query = torch.cat([query, xy_target], dim=1).to(device) #
|
||||||
|
|
||||||
|
if self.grid_size > 0:
|
||||||
|
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
|
||||||
|
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() #
|
||||||
|
query = torch.cat([query, xy], dim=1).to(device) #
|
||||||
|
# crop the video to start from the queried frame
|
||||||
|
query[0, 0, 0] = 0
|
||||||
|
traj_e_pind, __, vis_e_pind, __ = self.model(
|
||||||
|
rgbs=rgbs[:, t:], queries=query, iters=self.n_iters
|
||||||
|
)
|
||||||
|
|
||||||
|
return traj_e_pind, vis_e_pind
|
178
cotracker/predictor.py
Normal file
178
cotracker/predictor.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
from cotracker.models.core.cotracker.cotracker import get_points_on_a_grid
|
||||||
|
from cotracker.models.core.model_utils import smart_cat
|
||||||
|
from cotracker.models.build_cotracker import (
|
||||||
|
build_cotracker,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CoTrackerPredictor(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth"
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.interp_shape = (384, 512)
|
||||||
|
self.support_grid_size = 6
|
||||||
|
model = build_cotracker(checkpoint)
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.model.to("cuda")
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
video, # (1, T, 3, H, W)
|
||||||
|
# input prompt types:
|
||||||
|
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
|
||||||
|
# *backward_tracking=True* will compute tracks in both directions.
|
||||||
|
# - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates.
|
||||||
|
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
|
||||||
|
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
|
||||||
|
queries: torch.Tensor = None,
|
||||||
|
segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
|
||||||
|
grid_size: int = 0,
|
||||||
|
grid_query_frame: int = 0, # only for dense and regular grid tracks
|
||||||
|
backward_tracking: bool = False,
|
||||||
|
):
|
||||||
|
|
||||||
|
if queries is None and grid_size == 0:
|
||||||
|
tracks, visibilities = self._compute_dense_tracks(
|
||||||
|
video,
|
||||||
|
grid_query_frame=grid_query_frame,
|
||||||
|
backward_tracking=backward_tracking,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tracks, visibilities = self._compute_sparse_tracks(
|
||||||
|
video,
|
||||||
|
queries,
|
||||||
|
segm_mask,
|
||||||
|
grid_size,
|
||||||
|
add_support_grid=(grid_size == 0 or segm_mask is not None),
|
||||||
|
grid_query_frame=grid_query_frame,
|
||||||
|
backward_tracking=backward_tracking,
|
||||||
|
)
|
||||||
|
|
||||||
|
return tracks, visibilities
|
||||||
|
|
||||||
|
def _compute_dense_tracks(
|
||||||
|
self, video, grid_query_frame, grid_size=50, backward_tracking=False
|
||||||
|
):
|
||||||
|
*_, H, W = video.shape
|
||||||
|
grid_step = W // grid_size
|
||||||
|
grid_width = W // grid_step
|
||||||
|
grid_height = H // grid_step
|
||||||
|
tracks = visibilities = None
|
||||||
|
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to("cuda")
|
||||||
|
grid_pts[0, :, 0] = grid_query_frame
|
||||||
|
for offset in tqdm(range(grid_step * grid_step)):
|
||||||
|
ox = offset % grid_step
|
||||||
|
oy = offset // grid_step
|
||||||
|
grid_pts[0, :, 1] = (
|
||||||
|
torch.arange(grid_width).repeat(grid_height) * grid_step + ox
|
||||||
|
)
|
||||||
|
grid_pts[0, :, 2] = (
|
||||||
|
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
|
||||||
|
)
|
||||||
|
tracks_step, visibilities_step = self._compute_sparse_tracks(
|
||||||
|
video=video,
|
||||||
|
queries=grid_pts,
|
||||||
|
backward_tracking=backward_tracking,
|
||||||
|
)
|
||||||
|
tracks = smart_cat(tracks, tracks_step, dim=2)
|
||||||
|
visibilities = smart_cat(visibilities, visibilities_step, dim=2)
|
||||||
|
|
||||||
|
return tracks, visibilities
|
||||||
|
|
||||||
|
def _compute_sparse_tracks(
|
||||||
|
self,
|
||||||
|
video,
|
||||||
|
queries,
|
||||||
|
segm_mask=None,
|
||||||
|
grid_size=0,
|
||||||
|
add_support_grid=False,
|
||||||
|
grid_query_frame=0,
|
||||||
|
backward_tracking=False,
|
||||||
|
):
|
||||||
|
B, T, C, H, W = video.shape
|
||||||
|
assert B == 1
|
||||||
|
|
||||||
|
video = video.reshape(B * T, C, H, W)
|
||||||
|
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").cuda()
|
||||||
|
video = video.reshape(
|
||||||
|
B, T, 3, self.interp_shape[0], self.interp_shape[1]
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
if queries is not None:
|
||||||
|
queries = queries.clone()
|
||||||
|
B, N, D = queries.shape
|
||||||
|
assert D == 3
|
||||||
|
queries[:, :, 1] *= self.interp_shape[1] / W
|
||||||
|
queries[:, :, 2] *= self.interp_shape[0] / H
|
||||||
|
elif grid_size > 0:
|
||||||
|
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape)
|
||||||
|
if segm_mask is not None:
|
||||||
|
segm_mask = F.interpolate(
|
||||||
|
segm_mask, tuple(self.interp_shape), mode="nearest"
|
||||||
|
)
|
||||||
|
point_mask = segm_mask[0, 0][
|
||||||
|
(grid_pts[0, :, 1]).round().long().cpu(),
|
||||||
|
(grid_pts[0, :, 0]).round().long().cpu(),
|
||||||
|
].bool()
|
||||||
|
grid_pts = grid_pts[:, point_mask]
|
||||||
|
|
||||||
|
queries = torch.cat(
|
||||||
|
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
if add_support_grid:
|
||||||
|
grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape)
|
||||||
|
grid_pts = torch.cat(
|
||||||
|
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
|
||||||
|
)
|
||||||
|
queries = torch.cat([queries, grid_pts], dim=1)
|
||||||
|
|
||||||
|
tracks, __, visibilities, __ = self.model(rgbs=video, queries=queries, iters=6)
|
||||||
|
|
||||||
|
if backward_tracking:
|
||||||
|
tracks, visibilities = self._compute_backward_tracks(
|
||||||
|
video, queries, tracks, visibilities
|
||||||
|
)
|
||||||
|
if add_support_grid:
|
||||||
|
queries[:, -self.support_grid_size ** 2 :, 0] = T - 1
|
||||||
|
if add_support_grid:
|
||||||
|
tracks = tracks[:, :, : -self.support_grid_size ** 2]
|
||||||
|
visibilities = visibilities[:, :, : -self.support_grid_size ** 2]
|
||||||
|
thr = 0.9
|
||||||
|
visibilities = visibilities > thr
|
||||||
|
tracks[:, :, :, 0] *= W / float(self.interp_shape[1])
|
||||||
|
tracks[:, :, :, 1] *= H / float(self.interp_shape[0])
|
||||||
|
return tracks, visibilities
|
||||||
|
|
||||||
|
def _compute_backward_tracks(self, video, queries, tracks, visibilities):
|
||||||
|
inv_video = video.flip(1).clone()
|
||||||
|
inv_queries = queries.clone()
|
||||||
|
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
||||||
|
|
||||||
|
inv_tracks, __, inv_visibilities, __ = self.model(
|
||||||
|
rgbs=inv_video, queries=inv_queries, iters=6
|
||||||
|
)
|
||||||
|
|
||||||
|
inv_tracks = inv_tracks.flip(1)
|
||||||
|
inv_visibilities = inv_visibilities.flip(1)
|
||||||
|
|
||||||
|
mask = tracks == 0
|
||||||
|
|
||||||
|
tracks[mask] = inv_tracks[mask]
|
||||||
|
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
|
||||||
|
return tracks, visibilities
|
5
cotracker/utils/__init__.py
Normal file
5
cotracker/utils/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
291
cotracker/utils/visualizer.py
Normal file
291
cotracker/utils/visualizer.py
Normal file
@ -0,0 +1,291 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
import flow_vis
|
||||||
|
|
||||||
|
from matplotlib import cm
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from moviepy.editor import ImageSequenceClip
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
class Visualizer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
save_dir: str = "./results",
|
||||||
|
grayscale: bool = False,
|
||||||
|
pad_value: int = 0,
|
||||||
|
fps: int = 10,
|
||||||
|
mode: str = "rainbow", # 'cool', 'optical_flow'
|
||||||
|
linewidth: int = 2,
|
||||||
|
show_first_frame: int = 10,
|
||||||
|
tracks_leave_trace: int = 0, # -1 for infinite
|
||||||
|
):
|
||||||
|
self.mode = mode
|
||||||
|
self.save_dir = save_dir
|
||||||
|
if mode == "rainbow":
|
||||||
|
self.color_map = cm.get_cmap("gist_rainbow")
|
||||||
|
elif mode == "cool":
|
||||||
|
self.color_map = cm.get_cmap(mode)
|
||||||
|
self.show_first_frame = show_first_frame
|
||||||
|
self.grayscale = grayscale
|
||||||
|
self.tracks_leave_trace = tracks_leave_trace
|
||||||
|
self.pad_value = pad_value
|
||||||
|
self.linewidth = linewidth
|
||||||
|
self.fps = fps
|
||||||
|
|
||||||
|
def visualize(
|
||||||
|
self,
|
||||||
|
video: torch.Tensor, # (B,T,C,H,W)
|
||||||
|
tracks: torch.Tensor, # (B,T,N,2)
|
||||||
|
gt_tracks: torch.Tensor = None, # (B,T,N,2)
|
||||||
|
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
||||||
|
filename: str = "video",
|
||||||
|
writer: SummaryWriter = None,
|
||||||
|
step: int = 0,
|
||||||
|
query_frame: int = 0,
|
||||||
|
save_video: bool = True,
|
||||||
|
compensate_for_camera_motion: bool = False,
|
||||||
|
):
|
||||||
|
if compensate_for_camera_motion:
|
||||||
|
assert segm_mask is not None
|
||||||
|
if segm_mask is not None:
|
||||||
|
coords = tracks[0, query_frame].round().long()
|
||||||
|
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
|
||||||
|
|
||||||
|
video = F.pad(
|
||||||
|
video,
|
||||||
|
(self.pad_value, self.pad_value, self.pad_value, self.pad_value),
|
||||||
|
"constant",
|
||||||
|
255,
|
||||||
|
)
|
||||||
|
tracks = tracks + self.pad_value
|
||||||
|
|
||||||
|
if self.grayscale:
|
||||||
|
transform = transforms.Grayscale()
|
||||||
|
video = transform(video)
|
||||||
|
video = video.repeat(1, 1, 3, 1, 1)
|
||||||
|
|
||||||
|
res_video = self.draw_tracks_on_video(
|
||||||
|
video=video,
|
||||||
|
tracks=tracks,
|
||||||
|
segm_mask=segm_mask,
|
||||||
|
gt_tracks=gt_tracks,
|
||||||
|
query_frame=query_frame,
|
||||||
|
compensate_for_camera_motion=compensate_for_camera_motion,
|
||||||
|
)
|
||||||
|
if save_video:
|
||||||
|
self.save_video(res_video, filename=filename, writer=writer, step=step)
|
||||||
|
return res_video
|
||||||
|
|
||||||
|
def save_video(self, video, filename, writer=None, step=0):
|
||||||
|
if writer is not None:
|
||||||
|
writer.add_video(
|
||||||
|
f"{filename}_pred_track",
|
||||||
|
video.to(torch.uint8),
|
||||||
|
global_step=step,
|
||||||
|
fps=self.fps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
os.makedirs(self.save_dir, exist_ok=True)
|
||||||
|
wide_list = list(video.unbind(1))
|
||||||
|
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
|
||||||
|
clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps)
|
||||||
|
|
||||||
|
# Write the video file
|
||||||
|
save_path = os.path.join(self.save_dir, f"{filename}_pred_track.mp4")
|
||||||
|
clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None)
|
||||||
|
|
||||||
|
print(f"Video saved to {save_path}")
|
||||||
|
|
||||||
|
def draw_tracks_on_video(
|
||||||
|
self,
|
||||||
|
video: torch.Tensor,
|
||||||
|
tracks: torch.Tensor,
|
||||||
|
segm_mask: torch.Tensor = None,
|
||||||
|
gt_tracks=None,
|
||||||
|
query_frame: int = 0,
|
||||||
|
compensate_for_camera_motion=False,
|
||||||
|
):
|
||||||
|
B, T, C, H, W = video.shape
|
||||||
|
_, _, N, D = tracks.shape
|
||||||
|
|
||||||
|
assert D == 2
|
||||||
|
assert C == 3
|
||||||
|
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
|
||||||
|
tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
|
||||||
|
if gt_tracks is not None:
|
||||||
|
gt_tracks = gt_tracks[0].detach().cpu().numpy()
|
||||||
|
|
||||||
|
res_video = []
|
||||||
|
|
||||||
|
# process input video
|
||||||
|
for rgb in video:
|
||||||
|
res_video.append(rgb.copy())
|
||||||
|
|
||||||
|
vector_colors = np.zeros((T, N, 3))
|
||||||
|
if self.mode == "optical_flow":
|
||||||
|
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
|
||||||
|
elif segm_mask is None:
|
||||||
|
if self.mode == "rainbow":
|
||||||
|
y_min, y_max = (
|
||||||
|
tracks[query_frame, :, 1].min(),
|
||||||
|
tracks[query_frame, :, 1].max(),
|
||||||
|
)
|
||||||
|
norm = plt.Normalize(y_min, y_max)
|
||||||
|
for n in range(N):
|
||||||
|
color = self.color_map(norm(tracks[query_frame, n, 1]))
|
||||||
|
color = np.array(color[:3])[None] * 255
|
||||||
|
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
||||||
|
else:
|
||||||
|
# color changes with time
|
||||||
|
for t in range(T):
|
||||||
|
color = np.array(self.color_map(t / T)[:3])[None] * 255
|
||||||
|
vector_colors[t] = np.repeat(color, N, axis=0)
|
||||||
|
else:
|
||||||
|
if self.mode == "rainbow":
|
||||||
|
vector_colors[:, segm_mask <= 0, :] = 255
|
||||||
|
|
||||||
|
y_min, y_max = (
|
||||||
|
tracks[0, segm_mask > 0, 1].min(),
|
||||||
|
tracks[0, segm_mask > 0, 1].max(),
|
||||||
|
)
|
||||||
|
norm = plt.Normalize(y_min, y_max)
|
||||||
|
for n in range(N):
|
||||||
|
if segm_mask[n] > 0:
|
||||||
|
color = self.color_map(norm(tracks[0, n, 1]))
|
||||||
|
color = np.array(color[:3])[None] * 255
|
||||||
|
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# color changes with segm class
|
||||||
|
segm_mask = segm_mask.cpu()
|
||||||
|
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
|
||||||
|
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
|
||||||
|
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
|
||||||
|
vector_colors = np.repeat(color[None], T, axis=0)
|
||||||
|
|
||||||
|
# draw tracks
|
||||||
|
if self.tracks_leave_trace != 0:
|
||||||
|
for t in range(1, T):
|
||||||
|
first_ind = (
|
||||||
|
max(0, t - self.tracks_leave_trace)
|
||||||
|
if self.tracks_leave_trace >= 0
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
curr_tracks = tracks[first_ind : t + 1]
|
||||||
|
curr_colors = vector_colors[first_ind : t + 1]
|
||||||
|
if compensate_for_camera_motion:
|
||||||
|
diff = (
|
||||||
|
tracks[first_ind : t + 1, segm_mask <= 0]
|
||||||
|
- tracks[t : t + 1, segm_mask <= 0]
|
||||||
|
).mean(1)[:, None]
|
||||||
|
|
||||||
|
curr_tracks = curr_tracks - diff
|
||||||
|
curr_tracks = curr_tracks[:, segm_mask > 0]
|
||||||
|
curr_colors = curr_colors[:, segm_mask > 0]
|
||||||
|
|
||||||
|
res_video[t] = self._draw_pred_tracks(
|
||||||
|
res_video[t],
|
||||||
|
curr_tracks,
|
||||||
|
curr_colors,
|
||||||
|
)
|
||||||
|
if gt_tracks is not None:
|
||||||
|
res_video[t] = self._draw_gt_tracks(
|
||||||
|
res_video[t], gt_tracks[first_ind : t + 1]
|
||||||
|
)
|
||||||
|
|
||||||
|
# draw points
|
||||||
|
for t in range(T):
|
||||||
|
for i in range(N):
|
||||||
|
coord = (tracks[t, i, 0], tracks[t, i, 1])
|
||||||
|
if coord[0] != 0 and coord[1] != 0:
|
||||||
|
if not compensate_for_camera_motion or (
|
||||||
|
compensate_for_camera_motion and segm_mask[i] > 0
|
||||||
|
):
|
||||||
|
cv2.circle(
|
||||||
|
res_video[t],
|
||||||
|
coord,
|
||||||
|
int(self.linewidth * 2),
|
||||||
|
vector_colors[t, i].tolist(),
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# construct the final rgb sequence
|
||||||
|
if self.show_first_frame > 0:
|
||||||
|
res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
|
||||||
|
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
|
||||||
|
|
||||||
|
def _draw_pred_tracks(
|
||||||
|
self,
|
||||||
|
rgb: np.ndarray, # H x W x 3
|
||||||
|
tracks: np.ndarray, # T x 2
|
||||||
|
vector_colors: np.ndarray,
|
||||||
|
alpha: float = 0.5,
|
||||||
|
):
|
||||||
|
T, N, _ = tracks.shape
|
||||||
|
|
||||||
|
for s in range(T - 1):
|
||||||
|
vector_color = vector_colors[s]
|
||||||
|
original = rgb.copy()
|
||||||
|
alpha = (s / T) ** 2
|
||||||
|
for i in range(N):
|
||||||
|
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
|
||||||
|
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
|
||||||
|
if coord_y[0] != 0 and coord_y[1] != 0:
|
||||||
|
cv2.line(
|
||||||
|
rgb,
|
||||||
|
coord_y,
|
||||||
|
coord_x,
|
||||||
|
vector_color[i].tolist(),
|
||||||
|
self.linewidth,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
if self.tracks_leave_trace > 0:
|
||||||
|
rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0)
|
||||||
|
return rgb
|
||||||
|
|
||||||
|
def _draw_gt_tracks(
|
||||||
|
self,
|
||||||
|
rgb: np.ndarray, # H x W x 3,
|
||||||
|
gt_tracks: np.ndarray, # T x 2
|
||||||
|
):
|
||||||
|
T, N, _ = gt_tracks.shape
|
||||||
|
color = np.array((211.0, 0.0, 0.0))
|
||||||
|
|
||||||
|
for t in range(T):
|
||||||
|
for i in range(N):
|
||||||
|
gt_tracks = gt_tracks[t][i]
|
||||||
|
# draw a red cross
|
||||||
|
if gt_tracks[0] > 0 and gt_tracks[1] > 0:
|
||||||
|
length = self.linewidth * 3
|
||||||
|
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
|
||||||
|
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
|
||||||
|
cv2.line(
|
||||||
|
rgb,
|
||||||
|
coord_y,
|
||||||
|
coord_x,
|
||||||
|
color,
|
||||||
|
self.linewidth,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
|
||||||
|
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
|
||||||
|
cv2.line(
|
||||||
|
rgb,
|
||||||
|
coord_y,
|
||||||
|
coord_x,
|
||||||
|
color,
|
||||||
|
self.linewidth,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
return rgb
|
71
demo.py
Normal file
71
demo.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from torchvision.io import read_video
|
||||||
|
from PIL import Image
|
||||||
|
from cotracker.utils.visualizer import Visualizer
|
||||||
|
from cotracker.predictor import CoTrackerPredictor
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--video_path",
|
||||||
|
default="./assets/apple.mp4",
|
||||||
|
help="path to a video",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mask_path",
|
||||||
|
default="./assets/apple_mask.png",
|
||||||
|
help="path to a segmentation mask",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint",
|
||||||
|
default="./checkpoints/cotracker_stride_4_wind_8.pth",
|
||||||
|
help="cotracker model",
|
||||||
|
)
|
||||||
|
parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size")
|
||||||
|
parser.add_argument(
|
||||||
|
"--grid_query_frame",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Compute dense and grid tracks starting from this frame ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--backward_tracking",
|
||||||
|
action="store_true",
|
||||||
|
help="Compute tracks in both directions, not only forward",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# load the input video frame by frame
|
||||||
|
video = read_video(args.video_path)
|
||||||
|
video = video[0].permute(0, 3, 1, 2)[None].float()
|
||||||
|
segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
|
||||||
|
segm_mask = torch.from_numpy(segm_mask)[None, None]
|
||||||
|
|
||||||
|
model = CoTrackerPredictor(checkpoint=args.checkpoint)
|
||||||
|
|
||||||
|
pred_tracks, pred_visibility = model(
|
||||||
|
video,
|
||||||
|
grid_size=args.grid_size,
|
||||||
|
grid_query_frame=args.grid_query_frame,
|
||||||
|
backward_tracking=args.backward_tracking,
|
||||||
|
# segm_mask=segm_mask
|
||||||
|
)
|
||||||
|
print("computed")
|
||||||
|
|
||||||
|
# save a video with predicted tracks
|
||||||
|
seq_name = args.video_path.split("/")[-1]
|
||||||
|
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
|
||||||
|
vis.visualize(video, pred_tracks, query_frame=args.grid_query_frame)
|
924
notebooks/demo.ipynb
Normal file
924
notebooks/demo.ipynb
Normal file
File diff suppressed because one or more lines are too long
18
setup.py
Normal file
18
setup.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="cotracker",
|
||||||
|
version="1.0",
|
||||||
|
install_requires=[],
|
||||||
|
packages=find_packages(exclude="notebooks"),
|
||||||
|
extras_require={
|
||||||
|
"all": ["matplotlib", "opencv-python"],
|
||||||
|
"dev": ["flake8", "black"],
|
||||||
|
},
|
||||||
|
)
|
665
train.py
Normal file
665
train.py
Normal file
@ -0,0 +1,665 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
import signal
|
||||||
|
import socket
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.cuda.amp import GradScaler
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from pytorch_lightning.lite import LightningLite
|
||||||
|
|
||||||
|
from cotracker.models.evaluation_predictor import EvaluationPredictor
|
||||||
|
from cotracker.models.core.cotracker.cotracker import CoTracker
|
||||||
|
from cotracker.utils.visualizer import Visualizer
|
||||||
|
from cotracker.datasets.tap_vid_datasets import TapVidDataset
|
||||||
|
from cotracker.datasets.badja_dataset import BadjaDataset
|
||||||
|
from cotracker.datasets.fast_capture_dataset import FastCaptureDataset
|
||||||
|
from cotracker.evaluation.core.evaluator import Evaluator
|
||||||
|
from cotracker.datasets import kubric_movif_dataset
|
||||||
|
from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_cuda_
|
||||||
|
from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss
|
||||||
|
|
||||||
|
|
||||||
|
# define the handler function
|
||||||
|
# for training on a slurm cluster
|
||||||
|
def sig_handler(signum, frame):
|
||||||
|
print("caught signal", signum)
|
||||||
|
print(socket.gethostname(), "USR1 signal caught.")
|
||||||
|
# do other stuff to cleanup here
|
||||||
|
print("requeuing job " + os.environ["SLURM_JOB_ID"])
|
||||||
|
os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def term_handler(signum, frame):
|
||||||
|
print("bypassing sigterm", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_optimizer(args, model):
|
||||||
|
"""Create the optimizer and learning rate scheduler"""
|
||||||
|
optimizer = optim.AdamW(
|
||||||
|
model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8
|
||||||
|
)
|
||||||
|
scheduler = optim.lr_scheduler.OneCycleLR(
|
||||||
|
optimizer,
|
||||||
|
args.lr,
|
||||||
|
args.num_steps + 100,
|
||||||
|
pct_start=0.05,
|
||||||
|
cycle_momentum=False,
|
||||||
|
anneal_strategy="linear",
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimizer, scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def forward_batch(batch, model, args, loss_fn=None, writer=None, step=0):
|
||||||
|
rgbs = batch.video
|
||||||
|
trajs_g = batch.trajectory
|
||||||
|
vis_g = batch.visibility
|
||||||
|
valids = batch.valid
|
||||||
|
B, T, C, H, W = rgbs.shape
|
||||||
|
assert C == 3
|
||||||
|
B, T, N, D = trajs_g.shape
|
||||||
|
device = rgbs.device
|
||||||
|
|
||||||
|
__, first_positive_inds = torch.max(vis_g, dim=1)
|
||||||
|
# We want to make sure that during training the model sees visible points
|
||||||
|
# that it does not need to track just yet: they are visible but queried from a later frame
|
||||||
|
N_rand = N // 4
|
||||||
|
# inds of visible points in the 1st frame
|
||||||
|
nonzero_inds = [torch.nonzero(vis_g[0, :, i]) for i in range(N)]
|
||||||
|
rand_vis_inds = torch.cat(
|
||||||
|
[
|
||||||
|
nonzero_row[torch.randint(len(nonzero_row), size=(1,))]
|
||||||
|
for nonzero_row in nonzero_inds
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
first_positive_inds = torch.cat(
|
||||||
|
[rand_vis_inds[:, :N_rand], first_positive_inds[:, N_rand:]], dim=1
|
||||||
|
)
|
||||||
|
ind_array_ = torch.arange(T, device=device)
|
||||||
|
ind_array_ = ind_array_[None, :, None].repeat(B, 1, N)
|
||||||
|
assert torch.allclose(
|
||||||
|
vis_g[ind_array_ == first_positive_inds[:, None, :]],
|
||||||
|
torch.ones_like(vis_g),
|
||||||
|
)
|
||||||
|
assert torch.allclose(
|
||||||
|
vis_g[ind_array_ == rand_vis_inds[:, None, :]], torch.ones_like(vis_g)
|
||||||
|
)
|
||||||
|
|
||||||
|
gather = torch.gather(
|
||||||
|
trajs_g, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 2)
|
||||||
|
)
|
||||||
|
xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1)
|
||||||
|
|
||||||
|
queries = torch.cat([first_positive_inds[:, :, None], xys], dim=2)
|
||||||
|
|
||||||
|
predictions, __, visibility, train_data = model(
|
||||||
|
rgbs=rgbs, queries=queries, iters=args.train_iters, is_train=True
|
||||||
|
)
|
||||||
|
vis_predictions, coord_predictions, wind_inds, sort_inds = train_data
|
||||||
|
|
||||||
|
trajs_g = trajs_g[:, :, sort_inds]
|
||||||
|
vis_g = vis_g[:, :, sort_inds]
|
||||||
|
valids = valids[:, :, sort_inds]
|
||||||
|
|
||||||
|
vis_gts = []
|
||||||
|
traj_gts = []
|
||||||
|
valids_gts = []
|
||||||
|
|
||||||
|
for i, wind_idx in enumerate(wind_inds):
|
||||||
|
ind = i * (args.sliding_window_len // 2)
|
||||||
|
|
||||||
|
vis_gts.append(vis_g[:, ind : ind + args.sliding_window_len, :wind_idx])
|
||||||
|
traj_gts.append(trajs_g[:, ind : ind + args.sliding_window_len, :wind_idx])
|
||||||
|
valids_gts.append(valids[:, ind : ind + args.sliding_window_len, :wind_idx])
|
||||||
|
|
||||||
|
seq_loss = sequence_loss(coord_predictions, traj_gts, vis_gts, valids_gts, 0.8)
|
||||||
|
vis_loss = balanced_ce_loss(vis_predictions, vis_gts, valids_gts)
|
||||||
|
|
||||||
|
output = {"flow": {"predictions": predictions[0].detach()}}
|
||||||
|
output["flow"]["loss"] = seq_loss.mean()
|
||||||
|
output["visibility"] = {
|
||||||
|
"loss": vis_loss.mean() * 10.0,
|
||||||
|
"predictions": visibility[0].detach(),
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def run_test_eval(evaluator, model, dataloaders, writer, step):
|
||||||
|
model.eval()
|
||||||
|
for ds_name, dataloader in dataloaders:
|
||||||
|
predictor = EvaluationPredictor(
|
||||||
|
model.module.module,
|
||||||
|
grid_size=6,
|
||||||
|
local_grid_size=0,
|
||||||
|
single_point=False,
|
||||||
|
n_iters=6,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = evaluator.evaluate_sequence(
|
||||||
|
model=predictor,
|
||||||
|
test_dataloader=dataloader,
|
||||||
|
dataset_name=ds_name,
|
||||||
|
train_mode=True,
|
||||||
|
writer=writer,
|
||||||
|
step=step,
|
||||||
|
)
|
||||||
|
|
||||||
|
if ds_name == "badja" or ds_name == "fastcapture" or ("kubric" in ds_name):
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
**{
|
||||||
|
f"{ds_name}_avg": np.mean(
|
||||||
|
[v for k, v in metrics.items() if "accuracy" not in k]
|
||||||
|
)
|
||||||
|
},
|
||||||
|
**{
|
||||||
|
f"{ds_name}_avg_accuracy": np.mean(
|
||||||
|
[v for k, v in metrics.items() if "accuracy" in k]
|
||||||
|
)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
print("avg", np.mean([v for v in metrics.values()]))
|
||||||
|
|
||||||
|
if "tapvid" in ds_name:
|
||||||
|
metrics = {
|
||||||
|
f"{ds_name}_avg_OA": metrics["avg"]["occlusion_accuracy"] * 100,
|
||||||
|
f"{ds_name}_avg_delta": metrics["avg"]["average_pts_within_thresh"]
|
||||||
|
* 100,
|
||||||
|
f"{ds_name}_avg_Jaccard": metrics["avg"]["average_jaccard"] * 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
writer.add_scalars(f"Eval", metrics, step)
|
||||||
|
|
||||||
|
|
||||||
|
class Logger:
|
||||||
|
|
||||||
|
SUM_FREQ = 100
|
||||||
|
|
||||||
|
def __init__(self, model, scheduler):
|
||||||
|
self.model = model
|
||||||
|
self.scheduler = scheduler
|
||||||
|
self.total_steps = 0
|
||||||
|
self.running_loss = {}
|
||||||
|
self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs"))
|
||||||
|
|
||||||
|
def _print_training_status(self):
|
||||||
|
metrics_data = [
|
||||||
|
self.running_loss[k] / Logger.SUM_FREQ
|
||||||
|
for k in sorted(self.running_loss.keys())
|
||||||
|
]
|
||||||
|
training_str = "[{:6d}] ".format(self.total_steps + 1)
|
||||||
|
metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data)
|
||||||
|
|
||||||
|
# print the training status
|
||||||
|
logging.info(
|
||||||
|
f"Training Metrics ({self.total_steps}): {training_str + metrics_str}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.writer is None:
|
||||||
|
self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs"))
|
||||||
|
|
||||||
|
for k in self.running_loss:
|
||||||
|
self.writer.add_scalar(
|
||||||
|
k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps
|
||||||
|
)
|
||||||
|
self.running_loss[k] = 0.0
|
||||||
|
|
||||||
|
def push(self, metrics, task):
|
||||||
|
self.total_steps += 1
|
||||||
|
|
||||||
|
for key in metrics:
|
||||||
|
task_key = str(key) + "_" + task
|
||||||
|
if task_key not in self.running_loss:
|
||||||
|
self.running_loss[task_key] = 0.0
|
||||||
|
|
||||||
|
self.running_loss[task_key] += metrics[key]
|
||||||
|
|
||||||
|
if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ - 1:
|
||||||
|
self._print_training_status()
|
||||||
|
self.running_loss = {}
|
||||||
|
|
||||||
|
def write_dict(self, results):
|
||||||
|
if self.writer is None:
|
||||||
|
self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs"))
|
||||||
|
|
||||||
|
for key in results:
|
||||||
|
self.writer.add_scalar(key, results[key], self.total_steps)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.writer.close()
|
||||||
|
|
||||||
|
|
||||||
|
class Lite(LightningLite):
|
||||||
|
def run(self, args):
|
||||||
|
def seed_everything(seed: int):
|
||||||
|
random.seed(seed)
|
||||||
|
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
seed_everything(0)
|
||||||
|
|
||||||
|
def seed_worker(worker_id):
|
||||||
|
worker_seed = torch.initial_seed() % 2 ** 32
|
||||||
|
np.random.seed(worker_seed)
|
||||||
|
random.seed(worker_seed)
|
||||||
|
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(0)
|
||||||
|
|
||||||
|
eval_dataloaders = []
|
||||||
|
if "badja" in args.eval_datasets:
|
||||||
|
eval_dataset = BadjaDataset(
|
||||||
|
data_root=os.path.join(args.dataset_root, "BADJA"),
|
||||||
|
max_seq_len=args.eval_max_seq_len,
|
||||||
|
dataset_resolution=args.crop_size,
|
||||||
|
)
|
||||||
|
eval_dataloader_badja = torch.utils.data.DataLoader(
|
||||||
|
eval_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=8,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
)
|
||||||
|
eval_dataloaders.append(("badja", eval_dataloader_badja))
|
||||||
|
|
||||||
|
if "fastcapture" in args.eval_datasets:
|
||||||
|
eval_dataset = FastCaptureDataset(
|
||||||
|
data_root=os.path.join(args.dataset_root, "fastcapture"),
|
||||||
|
max_seq_len=min(100, args.eval_max_seq_len),
|
||||||
|
max_num_points=40,
|
||||||
|
dataset_resolution=args.crop_size,
|
||||||
|
)
|
||||||
|
eval_dataloader_fastcapture = torch.utils.data.DataLoader(
|
||||||
|
eval_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=1,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
)
|
||||||
|
eval_dataloaders.append(("fastcapture", eval_dataloader_fastcapture))
|
||||||
|
|
||||||
|
if "tapvid_davis_first" in args.eval_datasets:
|
||||||
|
data_root = os.path.join(
|
||||||
|
args.dataset_root, "/tapvid_davis/tapvid_davis.pkl"
|
||||||
|
)
|
||||||
|
eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root)
|
||||||
|
eval_dataloader_tapvid_davis = torch.utils.data.DataLoader(
|
||||||
|
eval_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=1,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
)
|
||||||
|
eval_dataloaders.append(("tapvid_davis", eval_dataloader_tapvid_davis))
|
||||||
|
|
||||||
|
evaluator = Evaluator(args.ckpt_path)
|
||||||
|
|
||||||
|
visualizer = Visualizer(
|
||||||
|
save_dir=args.ckpt_path,
|
||||||
|
pad_value=80,
|
||||||
|
fps=1,
|
||||||
|
show_first_frame=0,
|
||||||
|
tracks_leave_trace=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_fn = None
|
||||||
|
|
||||||
|
if args.model_name == "cotracker":
|
||||||
|
|
||||||
|
model = CoTracker(
|
||||||
|
stride=args.model_stride,
|
||||||
|
S=args.sliding_window_len,
|
||||||
|
add_space_attn=not args.remove_space_attn,
|
||||||
|
num_heads=args.updateformer_num_heads,
|
||||||
|
hidden_size=args.updateformer_hidden_size,
|
||||||
|
space_depth=args.updateformer_space_depth,
|
||||||
|
time_depth=args.updateformer_time_depth,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Model {args.model_name} doesn't exist")
|
||||||
|
|
||||||
|
with open(args.ckpt_path + "/meta.json", "w") as file:
|
||||||
|
json.dump(vars(args), file, sort_keys=True, indent=4)
|
||||||
|
|
||||||
|
model.cuda()
|
||||||
|
|
||||||
|
train_dataset = kubric_movif_dataset.KubricMovifDataset(
|
||||||
|
data_root=os.path.join(args.dataset_root, "kubric_movi_f"),
|
||||||
|
crop_size=args.crop_size,
|
||||||
|
seq_len=args.sequence_len,
|
||||||
|
traj_per_sample=args.traj_per_sample,
|
||||||
|
sample_vis_1st_frame=args.sample_vis_1st_frame,
|
||||||
|
use_augs=not args.dont_use_augs,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
worker_init_fn=seed_worker,
|
||||||
|
generator=g,
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=collate_fn_train,
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_loader = self.setup_dataloaders(train_loader, move_to_device=False)
|
||||||
|
print("LEN TRAIN LOADER", len(train_loader))
|
||||||
|
optimizer, scheduler = fetch_optimizer(args, model)
|
||||||
|
|
||||||
|
total_steps = 0
|
||||||
|
logger = Logger(model, scheduler)
|
||||||
|
|
||||||
|
folder_ckpts = [
|
||||||
|
f
|
||||||
|
for f in os.listdir(args.ckpt_path)
|
||||||
|
if not os.path.isdir(f) and f.endswith(".pth") and not "final" in f
|
||||||
|
]
|
||||||
|
if len(folder_ckpts) > 0:
|
||||||
|
ckpt_path = sorted(folder_ckpts)[-1]
|
||||||
|
ckpt = self.load(os.path.join(args.ckpt_path, ckpt_path))
|
||||||
|
logging.info(f"Loading checkpoint {ckpt_path}")
|
||||||
|
if "model" in ckpt:
|
||||||
|
model.load_state_dict(ckpt["model"])
|
||||||
|
else:
|
||||||
|
model.load_state_dict(ckpt)
|
||||||
|
if "optimizer" in ckpt:
|
||||||
|
logging.info("Load optimizer")
|
||||||
|
optimizer.load_state_dict(ckpt["optimizer"])
|
||||||
|
if "scheduler" in ckpt:
|
||||||
|
logging.info("Load scheduler")
|
||||||
|
scheduler.load_state_dict(ckpt["scheduler"])
|
||||||
|
if "total_steps" in ckpt:
|
||||||
|
total_steps = ckpt["total_steps"]
|
||||||
|
logging.info(f"Load total_steps {total_steps}")
|
||||||
|
|
||||||
|
elif args.restore_ckpt is not None:
|
||||||
|
assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith(
|
||||||
|
".pt"
|
||||||
|
)
|
||||||
|
logging.info("Loading checkpoint...")
|
||||||
|
|
||||||
|
strict = True
|
||||||
|
state_dict = self.load(args.restore_ckpt)
|
||||||
|
if "model" in state_dict:
|
||||||
|
state_dict = state_dict["model"]
|
||||||
|
|
||||||
|
if list(state_dict.keys())[0].startswith("module."):
|
||||||
|
state_dict = {
|
||||||
|
k.replace("module.", ""): v for k, v in state_dict.items()
|
||||||
|
}
|
||||||
|
model.load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
|
logging.info(f"Done loading checkpoint")
|
||||||
|
model, optimizer = self.setup(model, optimizer, move_to_device=False)
|
||||||
|
# model.cuda()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
save_freq = args.save_freq
|
||||||
|
scaler = GradScaler(enabled=args.mixed_precision)
|
||||||
|
|
||||||
|
should_keep_training = True
|
||||||
|
global_batch_num = 0
|
||||||
|
epoch = -1
|
||||||
|
|
||||||
|
while should_keep_training:
|
||||||
|
epoch += 1
|
||||||
|
for i_batch, batch in enumerate(tqdm(train_loader)):
|
||||||
|
batch, gotit = batch
|
||||||
|
if not all(gotit):
|
||||||
|
print("batch is None")
|
||||||
|
continue
|
||||||
|
dataclass_to_cuda_(batch)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
assert model.training
|
||||||
|
|
||||||
|
output = forward_batch(
|
||||||
|
batch,
|
||||||
|
model,
|
||||||
|
args,
|
||||||
|
loss_fn=loss_fn,
|
||||||
|
writer=logger.writer,
|
||||||
|
step=total_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = 0
|
||||||
|
for k, v in output.items():
|
||||||
|
if "loss" in v:
|
||||||
|
loss += v["loss"]
|
||||||
|
logger.writer.add_scalar(
|
||||||
|
f"live_{k}_loss", v["loss"].item(), total_steps
|
||||||
|
)
|
||||||
|
if "metrics" in v:
|
||||||
|
logger.push(v["metrics"], k)
|
||||||
|
|
||||||
|
if self.global_rank == 0:
|
||||||
|
if total_steps % save_freq == save_freq - 1:
|
||||||
|
if args.model_name == "motion_diffuser":
|
||||||
|
pred_coords = model.module.module.forward_batch_test(
|
||||||
|
batch, interp_shape=args.crop_size
|
||||||
|
)
|
||||||
|
|
||||||
|
output["flow"] = {"predictions": pred_coords[0].detach()}
|
||||||
|
visualizer.visualize(
|
||||||
|
video=batch.video.clone(),
|
||||||
|
tracks=batch.trajectory.clone(),
|
||||||
|
filename="train_gt_traj",
|
||||||
|
writer=logger.writer,
|
||||||
|
step=total_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
visualizer.visualize(
|
||||||
|
video=batch.video.clone(),
|
||||||
|
tracks=output["flow"]["predictions"][None],
|
||||||
|
filename="train_pred_traj",
|
||||||
|
writer=logger.writer,
|
||||||
|
step=total_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(output) > 1:
|
||||||
|
logger.writer.add_scalar(
|
||||||
|
f"live_total_loss", loss.item(), total_steps
|
||||||
|
)
|
||||||
|
logger.writer.add_scalar(
|
||||||
|
f"learning_rate", optimizer.param_groups[0]["lr"], total_steps
|
||||||
|
)
|
||||||
|
global_batch_num += 1
|
||||||
|
|
||||||
|
self.barrier()
|
||||||
|
|
||||||
|
self.backward(scaler.scale(loss))
|
||||||
|
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0)
|
||||||
|
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scheduler.step()
|
||||||
|
scaler.update()
|
||||||
|
total_steps += 1
|
||||||
|
if self.global_rank == 0:
|
||||||
|
if (i_batch >= len(train_loader) - 1) or (
|
||||||
|
total_steps == 1 and args.validate_at_start
|
||||||
|
):
|
||||||
|
if (epoch + 1) % args.save_every_n_epoch == 0:
|
||||||
|
ckpt_iter = "0" * (6 - len(str(total_steps))) + str(
|
||||||
|
total_steps
|
||||||
|
)
|
||||||
|
save_path = Path(
|
||||||
|
f"{args.ckpt_path}/model_{args.model_name}_{ckpt_iter}.pth"
|
||||||
|
)
|
||||||
|
|
||||||
|
save_dict = {
|
||||||
|
"model": model.module.module.state_dict(),
|
||||||
|
"optimizer": optimizer.state_dict(),
|
||||||
|
"scheduler": scheduler.state_dict(),
|
||||||
|
"total_steps": total_steps,
|
||||||
|
}
|
||||||
|
|
||||||
|
logging.info(f"Saving file {save_path}")
|
||||||
|
self.save(save_dict, save_path)
|
||||||
|
|
||||||
|
if (epoch + 1) % args.evaluate_every_n_epoch == 0 or (
|
||||||
|
args.validate_at_start and epoch == 0
|
||||||
|
):
|
||||||
|
run_test_eval(
|
||||||
|
evaluator,
|
||||||
|
model,
|
||||||
|
eval_dataloaders,
|
||||||
|
logger.writer,
|
||||||
|
total_steps,
|
||||||
|
)
|
||||||
|
model.train()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
self.barrier()
|
||||||
|
if total_steps > args.num_steps:
|
||||||
|
should_keep_training = False
|
||||||
|
break
|
||||||
|
|
||||||
|
print("FINISHED TRAINING")
|
||||||
|
|
||||||
|
PATH = f"{args.ckpt_path}/{args.model_name}_final.pth"
|
||||||
|
torch.save(model.module.module.state_dict(), PATH)
|
||||||
|
run_test_eval(evaluator, model, eval_dataloaders, logger.writer, total_steps)
|
||||||
|
logger.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
signal.signal(signal.SIGUSR1, sig_handler)
|
||||||
|
signal.signal(signal.SIGTERM, term_handler)
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model_name", default="cotracker", help="model name")
|
||||||
|
parser.add_argument("--restore_ckpt", help="restore checkpoint")
|
||||||
|
parser.add_argument("--ckpt_path", help="restore checkpoint")
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=4, help="batch size used during training."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_workers", type=int, default=6, help="left right consistency loss"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--mixed_precision", action="store_true", help="use mixed precision"
|
||||||
|
)
|
||||||
|
parser.add_argument("--lr", type=float, default=0.0005, help="max learning rate.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--wdecay", type=float, default=0.00001, help="Weight decay in optimizer."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_steps", type=int, default=200000, help="length of training schedule."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--evaluate_every_n_epoch",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="number of flow-field updates during validation forward pass",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_every_n_epoch",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="number of flow-field updates during validation forward pass",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--validate_at_start", action="store_true", help="use mixed precision"
|
||||||
|
)
|
||||||
|
parser.add_argument("--save_freq", type=int, default=100, help="save_freq")
|
||||||
|
parser.add_argument("--traj_per_sample", type=int, default=768, help="save_freq")
|
||||||
|
parser.add_argument("--dataset_root", type=str, help="path lo all the datasets")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--train_iters",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="number of updates to the disparity field in each forward pass.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sequence_len", type=int, default=8, help="train sequence length"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--eval_datasets",
|
||||||
|
nargs="+",
|
||||||
|
default=["things", "badja", "fastcapture"],
|
||||||
|
help="eval datasets.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--remove_space_attn", action="store_true", help="use mixed precision"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dont_use_augs", action="store_true", help="use mixed precision"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample_vis_1st_frame", action="store_true", help="use mixed precision"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sliding_window_len", type=int, default=8, help="use mixed precision"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--updateformer_hidden_size", type=int, default=384, help="use mixed precision"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--updateformer_num_heads", type=int, default=8, help="use mixed precision"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--updateformer_space_depth", type=int, default=12, help="use mixed precision"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--updateformer_time_depth", type=int, default=12, help="use mixed precision"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_stride", type=int, default=8, help="use mixed precision"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--crop_size",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=[384, 512],
|
||||||
|
help="use mixed precision",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--eval_max_seq_len", type=int, default=1000, help="use mixed precision"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
Path(args.ckpt_path).mkdir(exist_ok=True, parents=True)
|
||||||
|
from pytorch_lightning.strategies import DDPStrategy
|
||||||
|
|
||||||
|
Lite(
|
||||||
|
strategy=DDPStrategy(find_unused_parameters=True),
|
||||||
|
devices="auto",
|
||||||
|
accelerator="gpu",
|
||||||
|
precision=32,
|
||||||
|
num_nodes=4,
|
||||||
|
).run(args)
|
Loading…
Reference in New Issue
Block a user