Initial commit
This commit is contained in:
		
							
								
								
									
										80
									
								
								CODE_OF_CONDUCT.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								CODE_OF_CONDUCT.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,80 @@ | ||||
| # Code of Conduct | ||||
|  | ||||
| ## Our Pledge | ||||
|  | ||||
| In the interest of fostering an open and welcoming environment, we as | ||||
| contributors and maintainers pledge to make participation in our project and | ||||
| our community a harassment-free experience for everyone, regardless of age, body | ||||
| size, disability, ethnicity, sex characteristics, gender identity and expression, | ||||
| level of experience, education, socio-economic status, nationality, personal | ||||
| appearance, race, religion, or sexual identity and orientation. | ||||
|  | ||||
| ## Our Standards | ||||
|  | ||||
| Examples of behavior that contributes to creating a positive environment | ||||
| include: | ||||
|  | ||||
| * Using welcoming and inclusive language | ||||
| * Being respectful of differing viewpoints and experiences | ||||
| * Gracefully accepting constructive criticism | ||||
| * Focusing on what is best for the community | ||||
| * Showing empathy towards other community members | ||||
|  | ||||
| Examples of unacceptable behavior by participants include: | ||||
|  | ||||
| * The use of sexualized language or imagery and unwelcome sexual attention or | ||||
| advances | ||||
| * Trolling, insulting/derogatory comments, and personal or political attacks | ||||
| * Public or private harassment | ||||
| * Publishing others' private information, such as a physical or electronic | ||||
| address, without explicit permission | ||||
| * Other conduct which could reasonably be considered inappropriate in a | ||||
| professional setting | ||||
|  | ||||
| ## Our Responsibilities | ||||
|  | ||||
| Project maintainers are responsible for clarifying the standards of acceptable | ||||
| behavior and are expected to take appropriate and fair corrective action in | ||||
| response to any instances of unacceptable behavior. | ||||
|  | ||||
| Project maintainers have the right and responsibility to remove, edit, or | ||||
| reject comments, commits, code, wiki edits, issues, and other contributions | ||||
| that are not aligned to this Code of Conduct, or to ban temporarily or | ||||
| permanently any contributor for other behaviors that they deem inappropriate, | ||||
| threatening, offensive, or harmful. | ||||
|  | ||||
| ## Scope | ||||
|  | ||||
| This Code of Conduct applies within all project spaces, and it also applies when | ||||
| an individual is representing the project or its community in public spaces. | ||||
| Examples of representing a project or community include using an official | ||||
| project e-mail address, posting via an official social media account, or acting | ||||
| as an appointed representative at an online or offline event. Representation of | ||||
| a project may be further defined and clarified by project maintainers. | ||||
|  | ||||
| This Code of Conduct also applies outside the project spaces when there is a | ||||
| reasonable belief that an individual's behavior may have a negative impact on | ||||
| the project or its community. | ||||
|  | ||||
| ## Enforcement | ||||
|  | ||||
| Instances of abusive, harassing, or otherwise unacceptable behavior may be | ||||
| reported by contacting the project team at <opensource-conduct@fb.com>. All | ||||
| complaints will be reviewed and investigated and will result in a response that | ||||
| is deemed necessary and appropriate to the circumstances. The project team is | ||||
| obligated to maintain confidentiality with regard to the reporter of an incident. | ||||
| Further details of specific enforcement policies may be posted separately. | ||||
|  | ||||
| Project maintainers who do not follow or enforce the Code of Conduct in good | ||||
| faith may face temporary or permanent repercussions as determined by other | ||||
| members of the project's leadership. | ||||
|  | ||||
| ## Attribution | ||||
|  | ||||
| This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, | ||||
| available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html | ||||
|  | ||||
| [homepage]: https://www.contributor-covenant.org | ||||
|  | ||||
| For answers to common questions about this code of conduct, see | ||||
| https://www.contributor-covenant.org/faq | ||||
							
								
								
									
										28
									
								
								CONTRIBUTING.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								CONTRIBUTING.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,28 @@ | ||||
| # CoTracker | ||||
| We want to make contributing to this project as easy and transparent as possible. | ||||
|  | ||||
| ## Pull Requests | ||||
| We actively welcome your pull requests. | ||||
|  | ||||
| 1. Fork the repo and create your branch from `main`. | ||||
| 2. If you've changed APIs, update the documentation. | ||||
| 3. Make sure your code lints. | ||||
| 4. If you haven't already, complete the Contributor License Agreement ("CLA"). | ||||
|  | ||||
| ## Contributor License Agreement ("CLA") | ||||
| In order to accept your pull request, we need you to submit a CLA. You only need | ||||
| to do this once to work on any of Meta's open source projects. | ||||
|  | ||||
| Complete your CLA here: <https://code.facebook.com/cla> | ||||
|  | ||||
| ## Issues | ||||
| We use GitHub issues to track public bugs. Please ensure your description is | ||||
| clear and has sufficient instructions to be able to reproduce the issue. | ||||
|  | ||||
| Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe | ||||
| disclosure of security bugs. In those cases, please go through the process | ||||
| outlined on that page and do not file a public issue. | ||||
|  | ||||
| ## License | ||||
| By contributing to CoTracker, you agree that your contributions will be licensed | ||||
| under the LICENSE file in the root directory of this source tree. | ||||
							
								
								
									
										399
									
								
								LICENSE.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										399
									
								
								LICENSE.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,399 @@ | ||||
| Attribution-NonCommercial 4.0 International | ||||
|  | ||||
| ======================================================================= | ||||
|  | ||||
| Creative Commons Corporation ("Creative Commons") is not a law firm and | ||||
| does not provide legal services or legal advice. Distribution of | ||||
| Creative Commons public licenses does not create a lawyer-client or | ||||
| other relationship. Creative Commons makes its licenses and related | ||||
| information available on an "as-is" basis. Creative Commons gives no | ||||
| warranties regarding its licenses, any material licensed under their | ||||
| terms and conditions, or any related information. Creative Commons | ||||
| disclaims all liability for damages resulting from their use to the | ||||
| fullest extent possible. | ||||
|  | ||||
| Using Creative Commons Public Licenses | ||||
|  | ||||
| Creative Commons public licenses provide a standard set of terms and | ||||
| conditions that creators and other rights holders may use to share | ||||
| original works of authorship and other material subject to copyright | ||||
| and certain other rights specified in the public license below. The | ||||
| following considerations are for informational purposes only, are not | ||||
| exhaustive, and do not form part of our licenses. | ||||
|  | ||||
|      Considerations for licensors: Our public licenses are | ||||
|      intended for use by those authorized to give the public | ||||
|      permission to use material in ways otherwise restricted by | ||||
|      copyright and certain other rights. Our licenses are | ||||
|      irrevocable. Licensors should read and understand the terms | ||||
|      and conditions of the license they choose before applying it. | ||||
|      Licensors should also secure all rights necessary before | ||||
|      applying our licenses so that the public can reuse the | ||||
|      material as expected. Licensors should clearly mark any | ||||
|      material not subject to the license. This includes other CC- | ||||
|      licensed material, or material used under an exception or | ||||
|      limitation to copyright. More considerations for licensors: | ||||
| 	wiki.creativecommons.org/Considerations_for_licensors | ||||
|  | ||||
|      Considerations for the public: By using one of our public | ||||
|      licenses, a licensor grants the public permission to use the | ||||
|      licensed material under specified terms and conditions. If | ||||
|      the licensor's permission is not necessary for any reason--for | ||||
|      example, because of any applicable exception or limitation to | ||||
|      copyright--then that use is not regulated by the license. Our | ||||
|      licenses grant only permissions under copyright and certain | ||||
|      other rights that a licensor has authority to grant. Use of | ||||
|      the licensed material may still be restricted for other | ||||
|      reasons, including because others have copyright or other | ||||
|      rights in the material. A licensor may make special requests, | ||||
|      such as asking that all changes be marked or described. | ||||
|      Although not required by our licenses, you are encouraged to | ||||
|      respect those requests where reasonable. More_considerations | ||||
|      for the public:  | ||||
| 	wiki.creativecommons.org/Considerations_for_licensees | ||||
|  | ||||
| ======================================================================= | ||||
|  | ||||
| Creative Commons Attribution-NonCommercial 4.0 International Public | ||||
| License | ||||
|  | ||||
| By exercising the Licensed Rights (defined below), You accept and agree | ||||
| to be bound by the terms and conditions of this Creative Commons | ||||
| Attribution-NonCommercial 4.0 International Public License ("Public | ||||
| License"). To the extent this Public License may be interpreted as a | ||||
| contract, You are granted the Licensed Rights in consideration of Your | ||||
| acceptance of these terms and conditions, and the Licensor grants You | ||||
| such rights in consideration of benefits the Licensor receives from | ||||
| making the Licensed Material available under these terms and | ||||
| conditions. | ||||
|  | ||||
| Section 1 -- Definitions. | ||||
|  | ||||
|   a. Adapted Material means material subject to Copyright and Similar | ||||
|      Rights that is derived from or based upon the Licensed Material | ||||
|      and in which the Licensed Material is translated, altered, | ||||
|      arranged, transformed, or otherwise modified in a manner requiring | ||||
|      permission under the Copyright and Similar Rights held by the | ||||
|      Licensor. For purposes of this Public License, where the Licensed | ||||
|      Material is a musical work, performance, or sound recording, | ||||
|      Adapted Material is always produced where the Licensed Material is | ||||
|      synched in timed relation with a moving image. | ||||
|  | ||||
|   b. Adapter's License means the license You apply to Your Copyright | ||||
|      and Similar Rights in Your contributions to Adapted Material in | ||||
|      accordance with the terms and conditions of this Public License. | ||||
|  | ||||
|   c. Copyright and Similar Rights means copyright and/or similar rights | ||||
|      closely related to copyright including, without limitation, | ||||
|      performance, broadcast, sound recording, and Sui Generis Database | ||||
|      Rights, without regard to how the rights are labeled or | ||||
|      categorized. For purposes of this Public License, the rights | ||||
|      specified in Section 2(b)(1)-(2) are not Copyright and Similar | ||||
|      Rights. | ||||
|   d. Effective Technological Measures means those measures that, in the | ||||
|      absence of proper authority, may not be circumvented under laws | ||||
|      fulfilling obligations under Article 11 of the WIPO Copyright | ||||
|      Treaty adopted on December 20, 1996, and/or similar international | ||||
|      agreements. | ||||
|  | ||||
|   e. Exceptions and Limitations means fair use, fair dealing, and/or | ||||
|      any other exception or limitation to Copyright and Similar Rights | ||||
|      that applies to Your use of the Licensed Material. | ||||
|  | ||||
|   f. Licensed Material means the artistic or literary work, database, | ||||
|      or other material to which the Licensor applied this Public | ||||
|      License. | ||||
|  | ||||
|   g. Licensed Rights means the rights granted to You subject to the | ||||
|      terms and conditions of this Public License, which are limited to | ||||
|      all Copyright and Similar Rights that apply to Your use of the | ||||
|      Licensed Material and that the Licensor has authority to license. | ||||
|  | ||||
|   h. Licensor means the individual(s) or entity(ies) granting rights | ||||
|      under this Public License. | ||||
|  | ||||
|   i. NonCommercial means not primarily intended for or directed towards | ||||
|      commercial advantage or monetary compensation. For purposes of | ||||
|      this Public License, the exchange of the Licensed Material for | ||||
|      other material subject to Copyright and Similar Rights by digital | ||||
|      file-sharing or similar means is NonCommercial provided there is | ||||
|      no payment of monetary compensation in connection with the | ||||
|      exchange. | ||||
|  | ||||
|   j. Share means to provide material to the public by any means or | ||||
|      process that requires permission under the Licensed Rights, such | ||||
|      as reproduction, public display, public performance, distribution, | ||||
|      dissemination, communication, or importation, and to make material | ||||
|      available to the public including in ways that members of the | ||||
|      public may access the material from a place and at a time | ||||
|      individually chosen by them. | ||||
|  | ||||
|   k. Sui Generis Database Rights means rights other than copyright | ||||
|      resulting from Directive 96/9/EC of the European Parliament and of | ||||
|      the Council of 11 March 1996 on the legal protection of databases, | ||||
|      as amended and/or succeeded, as well as other essentially | ||||
|      equivalent rights anywhere in the world. | ||||
|  | ||||
|   l. You means the individual or entity exercising the Licensed Rights | ||||
|      under this Public License. Your has a corresponding meaning. | ||||
|  | ||||
| Section 2 -- Scope. | ||||
|  | ||||
|   a. License grant. | ||||
|  | ||||
|        1. Subject to the terms and conditions of this Public License, | ||||
|           the Licensor hereby grants You a worldwide, royalty-free, | ||||
|           non-sublicensable, non-exclusive, irrevocable license to | ||||
|           exercise the Licensed Rights in the Licensed Material to: | ||||
|  | ||||
|             a. reproduce and Share the Licensed Material, in whole or | ||||
|                in part, for NonCommercial purposes only; and | ||||
|  | ||||
|             b. produce, reproduce, and Share Adapted Material for | ||||
|                NonCommercial purposes only. | ||||
|  | ||||
|        2. Exceptions and Limitations. For the avoidance of doubt, where | ||||
|           Exceptions and Limitations apply to Your use, this Public | ||||
|           License does not apply, and You do not need to comply with | ||||
|           its terms and conditions. | ||||
|  | ||||
|        3. Term. The term of this Public License is specified in Section | ||||
|           6(a). | ||||
|  | ||||
|        4. Media and formats; technical modifications allowed. The | ||||
|           Licensor authorizes You to exercise the Licensed Rights in | ||||
|           all media and formats whether now known or hereafter created, | ||||
|           and to make technical modifications necessary to do so. The | ||||
|           Licensor waives and/or agrees not to assert any right or | ||||
|           authority to forbid You from making technical modifications | ||||
|           necessary to exercise the Licensed Rights, including | ||||
|           technical modifications necessary to circumvent Effective | ||||
|           Technological Measures. For purposes of this Public License, | ||||
|           simply making modifications authorized by this Section 2(a) | ||||
|           (4) never produces Adapted Material. | ||||
|  | ||||
|        5. Downstream recipients. | ||||
|  | ||||
|             a. Offer from the Licensor -- Licensed Material. Every | ||||
|                recipient of the Licensed Material automatically | ||||
|                receives an offer from the Licensor to exercise the | ||||
|                Licensed Rights under the terms and conditions of this | ||||
|                Public License. | ||||
|  | ||||
|             b. No downstream restrictions. You may not offer or impose | ||||
|                any additional or different terms or conditions on, or | ||||
|                apply any Effective Technological Measures to, the | ||||
|                Licensed Material if doing so restricts exercise of the | ||||
|                Licensed Rights by any recipient of the Licensed | ||||
|                Material. | ||||
|  | ||||
|        6. No endorsement. Nothing in this Public License constitutes or | ||||
|           may be construed as permission to assert or imply that You | ||||
|           are, or that Your use of the Licensed Material is, connected | ||||
|           with, or sponsored, endorsed, or granted official status by, | ||||
|           the Licensor or others designated to receive attribution as | ||||
|           provided in Section 3(a)(1)(A)(i). | ||||
|  | ||||
|   b. Other rights. | ||||
|  | ||||
|        1. Moral rights, such as the right of integrity, are not | ||||
|           licensed under this Public License, nor are publicity, | ||||
|           privacy, and/or other similar personality rights; however, to | ||||
|           the extent possible, the Licensor waives and/or agrees not to | ||||
|           assert any such rights held by the Licensor to the limited | ||||
|           extent necessary to allow You to exercise the Licensed | ||||
|           Rights, but not otherwise. | ||||
|  | ||||
|        2. Patent and trademark rights are not licensed under this | ||||
|           Public License. | ||||
|  | ||||
|        3. To the extent possible, the Licensor waives any right to | ||||
|           collect royalties from You for the exercise of the Licensed | ||||
|           Rights, whether directly or through a collecting society | ||||
|           under any voluntary or waivable statutory or compulsory | ||||
|           licensing scheme. In all other cases the Licensor expressly | ||||
|           reserves any right to collect such royalties, including when | ||||
|           the Licensed Material is used other than for NonCommercial | ||||
|           purposes. | ||||
|  | ||||
| Section 3 -- License Conditions. | ||||
|  | ||||
| Your exercise of the Licensed Rights is expressly made subject to the | ||||
| following conditions. | ||||
|  | ||||
|   a. Attribution. | ||||
|  | ||||
|        1. If You Share the Licensed Material (including in modified | ||||
|           form), You must: | ||||
|  | ||||
|             a. retain the following if it is supplied by the Licensor | ||||
|                with the Licensed Material: | ||||
|  | ||||
|                  i. identification of the creator(s) of the Licensed | ||||
|                     Material and any others designated to receive | ||||
|                     attribution, in any reasonable manner requested by | ||||
|                     the Licensor (including by pseudonym if | ||||
|                     designated); | ||||
|  | ||||
|                 ii. a copyright notice; | ||||
|  | ||||
|                iii. a notice that refers to this Public License; | ||||
|  | ||||
|                 iv. a notice that refers to the disclaimer of | ||||
|                     warranties; | ||||
|  | ||||
|                  v. a URI or hyperlink to the Licensed Material to the | ||||
|                     extent reasonably practicable; | ||||
|  | ||||
|             b. indicate if You modified the Licensed Material and | ||||
|                retain an indication of any previous modifications; and | ||||
|  | ||||
|             c. indicate the Licensed Material is licensed under this | ||||
|                Public License, and include the text of, or the URI or | ||||
|                hyperlink to, this Public License. | ||||
|  | ||||
|        2. You may satisfy the conditions in Section 3(a)(1) in any | ||||
|           reasonable manner based on the medium, means, and context in | ||||
|           which You Share the Licensed Material. For example, it may be | ||||
|           reasonable to satisfy the conditions by providing a URI or | ||||
|           hyperlink to a resource that includes the required | ||||
|           information. | ||||
|  | ||||
|        3. If requested by the Licensor, You must remove any of the | ||||
|           information required by Section 3(a)(1)(A) to the extent | ||||
|           reasonably practicable. | ||||
|  | ||||
|        4. If You Share Adapted Material You produce, the Adapter's | ||||
|           License You apply must not prevent recipients of the Adapted | ||||
|           Material from complying with this Public License. | ||||
|  | ||||
| Section 4 -- Sui Generis Database Rights. | ||||
|  | ||||
| Where the Licensed Rights include Sui Generis Database Rights that | ||||
| apply to Your use of the Licensed Material: | ||||
|  | ||||
|   a. for the avoidance of doubt, Section 2(a)(1) grants You the right | ||||
|      to extract, reuse, reproduce, and Share all or a substantial | ||||
|      portion of the contents of the database for NonCommercial purposes | ||||
|      only; | ||||
|  | ||||
|   b. if You include all or a substantial portion of the database | ||||
|      contents in a database in which You have Sui Generis Database | ||||
|      Rights, then the database in which You have Sui Generis Database | ||||
|      Rights (but not its individual contents) is Adapted Material; and | ||||
|  | ||||
|   c. You must comply with the conditions in Section 3(a) if You Share | ||||
|      all or a substantial portion of the contents of the database. | ||||
|  | ||||
| For the avoidance of doubt, this Section 4 supplements and does not | ||||
| replace Your obligations under this Public License where the Licensed | ||||
| Rights include other Copyright and Similar Rights. | ||||
|  | ||||
| Section 5 -- Disclaimer of Warranties and Limitation of Liability. | ||||
|  | ||||
|   a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE | ||||
|      EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS | ||||
|      AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF | ||||
|      ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, | ||||
|      IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, | ||||
|      WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR | ||||
|      PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, | ||||
|      ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT | ||||
|      KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT | ||||
|      ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. | ||||
|  | ||||
|   b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE | ||||
|      TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, | ||||
|      NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, | ||||
|      INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, | ||||
|      COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR | ||||
|      USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN | ||||
|      ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR | ||||
|      DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR | ||||
|      IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. | ||||
|  | ||||
|   c. The disclaimer of warranties and limitation of liability provided | ||||
|      above shall be interpreted in a manner that, to the extent | ||||
|      possible, most closely approximates an absolute disclaimer and | ||||
|      waiver of all liability. | ||||
|  | ||||
| Section 6 -- Term and Termination. | ||||
|  | ||||
|   a. This Public License applies for the term of the Copyright and | ||||
|      Similar Rights licensed here. However, if You fail to comply with | ||||
|      this Public License, then Your rights under this Public License | ||||
|      terminate automatically. | ||||
|  | ||||
|   b. Where Your right to use the Licensed Material has terminated under | ||||
|      Section 6(a), it reinstates: | ||||
|  | ||||
|        1. automatically as of the date the violation is cured, provided | ||||
|           it is cured within 30 days of Your discovery of the | ||||
|           violation; or | ||||
|  | ||||
|        2. upon express reinstatement by the Licensor. | ||||
|  | ||||
|      For the avoidance of doubt, this Section 6(b) does not affect any | ||||
|      right the Licensor may have to seek remedies for Your violations | ||||
|      of this Public License. | ||||
|  | ||||
|   c. For the avoidance of doubt, the Licensor may also offer the | ||||
|      Licensed Material under separate terms or conditions or stop | ||||
|      distributing the Licensed Material at any time; however, doing so | ||||
|      will not terminate this Public License. | ||||
|  | ||||
|   d. Sections 1, 5, 6, 7, and 8 survive termination of this Public | ||||
|      License. | ||||
|  | ||||
| Section 7 -- Other Terms and Conditions. | ||||
|  | ||||
|   a. The Licensor shall not be bound by any additional or different | ||||
|      terms or conditions communicated by You unless expressly agreed. | ||||
|  | ||||
|   b. Any arrangements, understandings, or agreements regarding the | ||||
|      Licensed Material not stated herein are separate from and | ||||
|      independent of the terms and conditions of this Public License. | ||||
|  | ||||
| Section 8 -- Interpretation. | ||||
|  | ||||
|   a. For the avoidance of doubt, this Public License does not, and | ||||
|      shall not be interpreted to, reduce, limit, restrict, or impose | ||||
|      conditions on any use of the Licensed Material that could lawfully | ||||
|      be made without permission under this Public License. | ||||
|  | ||||
|   b. To the extent possible, if any provision of this Public License is | ||||
|      deemed unenforceable, it shall be automatically reformed to the | ||||
|      minimum extent necessary to make it enforceable. If the provision | ||||
|      cannot be reformed, it shall be severed from this Public License | ||||
|      without affecting the enforceability of the remaining terms and | ||||
|      conditions. | ||||
|  | ||||
|   c. No term or condition of this Public License will be waived and no | ||||
|      failure to comply consented to unless expressly agreed to by the | ||||
|      Licensor. | ||||
|  | ||||
|   d. Nothing in this Public License constitutes or may be interpreted | ||||
|      as a limitation upon, or waiver of, any privileges and immunities | ||||
|      that apply to the Licensor or You, including from the legal | ||||
|      processes of any jurisdiction or authority. | ||||
|  | ||||
| ======================================================================= | ||||
|  | ||||
| Creative Commons is not a party to its public | ||||
| licenses. Notwithstanding, Creative Commons may elect to apply one of | ||||
| its public licenses to material it publishes and in those instances | ||||
| will be considered the “Licensor.” The text of the Creative Commons | ||||
| public licenses is dedicated to the public domain under the CC0 Public | ||||
| Domain Dedication. Except for the limited purpose of indicating that | ||||
| material is shared under a Creative Commons public license or as | ||||
| otherwise permitted by the Creative Commons policies published at | ||||
| creativecommons.org/policies, Creative Commons does not authorize the | ||||
| use of the trademark "Creative Commons" or any other trademark or logo | ||||
| of Creative Commons without its prior written consent including, | ||||
| without limitation, in connection with any unauthorized modifications | ||||
| to any of its public licenses or any other arrangements, | ||||
| understandings, or agreements concerning use of licensed material. For | ||||
| the avoidance of doubt, this paragraph does not form part of the | ||||
| public licenses. | ||||
|  | ||||
| Creative Commons may be contacted at creativecommons.org. | ||||
							
								
								
									
										94
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | ||||
| # CoTracker: It is Better to Track Together | ||||
|  | ||||
| **[Meta AI Research, FAIR](https://ai.facebook.com/research/)**; **[University of Oxford, VGG](https://www.robots.ox.ac.uk/~vgg/)** | ||||
|  | ||||
| [Nikita Karaev](https://nikitakaraevv.github.io/), [Ignacio Rocco](https://www.irocco.info/), [Benjamin Graham](https://ai.facebook.com/people/benjamin-graham/), [Natalia Neverova](https://nneverova.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/) | ||||
|  | ||||
| [[`Paper`]()] [[`Project`](https://co-tracker.github.io/)] [[`BibTeX`](#citing-cotracker)] | ||||
|  | ||||
|  | ||||
|  | ||||
| **CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow. | ||||
|   | ||||
| CoTracker can track: | ||||
| - **Every pixel** within a video | ||||
| - Points sampled on a regular grid on any video frame  | ||||
| - Manually selected points | ||||
|  | ||||
| Try these tracking modes for yourself with our [Colab demo](https://github.com/facebookresearch/co-tracker/notebooks/demo.ipynb). | ||||
|  | ||||
|  | ||||
|  | ||||
| ## Installation Instructions | ||||
| Ensure you have both PyTorch and TorchVision installed on your system. Follow the instructions [here](https://pytorch.org/get-started/locally/) for the installation. We strongly recommend installing both PyTorch and TorchVision with CUDA support. | ||||
|  | ||||
| ## Steps to Install CoTracker and its dependencies: | ||||
| ``` | ||||
| git clone https://github.com/facebookresearch/co-tracker | ||||
| cd co-tracker | ||||
| pip install -e . | ||||
| pip install opencv-python einops timm matplotlib moviepy flow_vis | ||||
| ``` | ||||
|  | ||||
|  | ||||
| ## Model Weights Download: | ||||
| ``` | ||||
| mkdir checkpoints | ||||
| cd checkpoints | ||||
| wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth | ||||
| wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_12.pth | ||||
| wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_8_wind_16.pth | ||||
| cd .. | ||||
| ``` | ||||
|  | ||||
|  | ||||
| ## Running the Demo: | ||||
| Try our [Colab demo](https://github.com/facebookresearch/co-tracker/notebooks/demo.ipynb) or run a local demo with 10*10 points sampled on a grid on the first frame of a video: | ||||
| ``` | ||||
| python demo.py --grid_size 10 | ||||
| ``` | ||||
|  | ||||
| ## Evaluation | ||||
| To reproduce the results presented in the paper, download the following datasets: | ||||
| - [TAP-Vid](https://github.com/deepmind/tapnet) | ||||
| - [BADJA](https://github.com/benjiebob/BADJA) | ||||
| - [ZJU-Mocap (FastCapture)](https://arxiv.org/abs/2303.11898) | ||||
|  | ||||
| And install the necessary dependencies: | ||||
| ``` | ||||
| pip install hydra-core==1.1.0 mediapy tensorboard  | ||||
| ``` | ||||
| Then, execute the following command to evaluate on BADJA: | ||||
| ``` | ||||
| python ./cotracker/evaluation/evaluate.py --config-name eval_badja exp_dir=./eval_outputs dataset_root=your/badja/path | ||||
| ``` | ||||
|  | ||||
| ## Training | ||||
| To train the CoTracker as described in our paper, you first need to generate annotations for [Google Kubric](https://github.com/google-research/kubric) MOVI-f dataset.  Instructions for annotation generation can be found [here](https://github.com/deepmind/tapnet). | ||||
|  | ||||
| Once you have the annotated dataset, you need to make sure you followed the steps for evaluation setup and install the training dependencies: | ||||
| ``` | ||||
| pip install pytorch_lightning==1.6.0 | ||||
| ``` | ||||
|  launch training on Kubric. Our model was trained using 32 GPUs, and you can adjust the parameters to best suit your hardware setup. | ||||
| ``` | ||||
| python train.py --batch_size 1 --num_workers 28 \ | ||||
| --num_steps 50000 --ckpt_path ./ --model_name cotracker \ | ||||
| --save_freq 200 --sequence_len 24 --eval_datasets tapvid_davis_first badja \ | ||||
| --traj_per_sample 256 --sliding_window_len 8 --updateformer_space_depth 6 --updateformer_time_depth 6 \ | ||||
| --save_every_n_epoch 10 --evaluate_every_n_epoch 10 --model_stride 4 | ||||
| ``` | ||||
|  | ||||
| ## License | ||||
| The majority of CoTracker is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Particle Video Revisited is licensed under the MIT license, TAP-Vid is licensed under the Apache 2.0 license. | ||||
|  | ||||
| ## Citing CoTracker | ||||
| If you find our repository useful, please consider giving it a star ⭐ and citing our paper in your work: | ||||
| ``` | ||||
| @article{karaev2023cotracker, | ||||
|   title={CoTracker: It is Better to Track Together}, | ||||
|   author={Nikita Karaev and Ignacio Rocco and Benjamin Graham and Natalia Neverova and Andrea Vedaldi and Christian Rupprecht}, | ||||
|   journal={arxiv}, | ||||
|   year={2023} | ||||
| } | ||||
| ``` | ||||
							
								
								
									
										
											BIN
										
									
								
								assets/apple.mp4
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								assets/apple.mp4
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								assets/apple_mask.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								assets/apple_mask.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 14 KiB | 
							
								
								
									
										
											BIN
										
									
								
								assets/bmx-bumps.gif
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								assets/bmx-bumps.gif
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 7.3 MiB | 
							
								
								
									
										5
									
								
								cotracker/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
							
								
								
									
										5
									
								
								cotracker/datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
							
								
								
									
										390
									
								
								cotracker/datasets/badja_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										390
									
								
								cotracker/datasets/badja_dataset.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,390 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
| import numpy as np | ||||
| import os | ||||
|  | ||||
| import json | ||||
| import imageio | ||||
| import cv2 | ||||
|  | ||||
| from enum import Enum | ||||
|  | ||||
| from cotracker.datasets.utils import CoTrackerData, resize_sample | ||||
|  | ||||
| IGNORE_ANIMALS = [ | ||||
|     # "bear.json", | ||||
|     # "camel.json", | ||||
|     "cat_jump.json" | ||||
|     # "cows.json", | ||||
|     # "dog.json", | ||||
|     # "dog-agility.json", | ||||
|     # "horsejump-high.json", | ||||
|     # "horsejump-low.json", | ||||
|     # "impala0.json", | ||||
|     # "rs_dog.json" | ||||
|     "tiger.json" | ||||
| ] | ||||
|  | ||||
|  | ||||
| class SMALJointCatalog(Enum): | ||||
|     # body_0 = 0 | ||||
|     # body_1 = 1 | ||||
|     # body_2 = 2 | ||||
|     # body_3 = 3 | ||||
|     # body_4 = 4 | ||||
|     # body_5 = 5 | ||||
|     # body_6 = 6 | ||||
|     # upper_right_0 = 7 | ||||
|     upper_right_1 = 8 | ||||
|     upper_right_2 = 9 | ||||
|     upper_right_3 = 10 | ||||
|     # upper_left_0 = 11 | ||||
|     upper_left_1 = 12 | ||||
|     upper_left_2 = 13 | ||||
|     upper_left_3 = 14 | ||||
|     neck_lower = 15 | ||||
|     # neck_upper = 16 | ||||
|     # lower_right_0 = 17 | ||||
|     lower_right_1 = 18 | ||||
|     lower_right_2 = 19 | ||||
|     lower_right_3 = 20 | ||||
|     # lower_left_0 = 21 | ||||
|     lower_left_1 = 22 | ||||
|     lower_left_2 = 23 | ||||
|     lower_left_3 = 24 | ||||
|     tail_0 = 25 | ||||
|     # tail_1 = 26 | ||||
|     # tail_2 = 27 | ||||
|     tail_3 = 28 | ||||
|     # tail_4 = 29 | ||||
|     # tail_5 = 30 | ||||
|     tail_6 = 31 | ||||
|     jaw = 32 | ||||
|     nose = 33  # ADDED JOINT FOR VERTEX 1863 | ||||
|     # chin = 34 # ADDED JOINT FOR VERTEX 26 | ||||
|     right_ear = 35  # ADDED JOINT FOR VERTEX 149 | ||||
|     left_ear = 36  # ADDED JOINT FOR VERTEX 2124 | ||||
|  | ||||
|  | ||||
| class SMALJointInfo: | ||||
|     def __init__(self): | ||||
|         # These are the | ||||
|         self.annotated_classes = np.array( | ||||
|             [ | ||||
|                 8, | ||||
|                 9, | ||||
|                 10,  # upper_right | ||||
|                 12, | ||||
|                 13, | ||||
|                 14,  # upper_left | ||||
|                 15,  # neck | ||||
|                 18, | ||||
|                 19, | ||||
|                 20,  # lower_right | ||||
|                 22, | ||||
|                 23, | ||||
|                 24,  # lower_left | ||||
|                 25, | ||||
|                 28, | ||||
|                 31,  # tail | ||||
|                 32, | ||||
|                 33,  # head | ||||
|                 35,  # right_ear | ||||
|                 36, | ||||
|             ] | ||||
|         )  # left_ear | ||||
|  | ||||
|         self.annotated_markers = np.array( | ||||
|             [ | ||||
|                 cv2.MARKER_CROSS, | ||||
|                 cv2.MARKER_STAR, | ||||
|                 cv2.MARKER_TRIANGLE_DOWN, | ||||
|                 cv2.MARKER_CROSS, | ||||
|                 cv2.MARKER_STAR, | ||||
|                 cv2.MARKER_TRIANGLE_DOWN, | ||||
|                 cv2.MARKER_CROSS, | ||||
|                 cv2.MARKER_CROSS, | ||||
|                 cv2.MARKER_STAR, | ||||
|                 cv2.MARKER_TRIANGLE_DOWN, | ||||
|                 cv2.MARKER_CROSS, | ||||
|                 cv2.MARKER_STAR, | ||||
|                 cv2.MARKER_TRIANGLE_DOWN, | ||||
|                 cv2.MARKER_CROSS, | ||||
|                 cv2.MARKER_STAR, | ||||
|                 cv2.MARKER_TRIANGLE_DOWN, | ||||
|                 cv2.MARKER_CROSS, | ||||
|                 cv2.MARKER_STAR, | ||||
|                 cv2.MARKER_CROSS, | ||||
|                 cv2.MARKER_CROSS, | ||||
|             ] | ||||
|         ) | ||||
|  | ||||
|         self.joint_regions = np.array( | ||||
|             [ | ||||
|                 0, | ||||
|                 0, | ||||
|                 0, | ||||
|                 0, | ||||
|                 0, | ||||
|                 0, | ||||
|                 0, | ||||
|                 1, | ||||
|                 1, | ||||
|                 1, | ||||
|                 1, | ||||
|                 2, | ||||
|                 2, | ||||
|                 2, | ||||
|                 2, | ||||
|                 3, | ||||
|                 3, | ||||
|                 4, | ||||
|                 4, | ||||
|                 4, | ||||
|                 4, | ||||
|                 5, | ||||
|                 5, | ||||
|                 5, | ||||
|                 5, | ||||
|                 6, | ||||
|                 6, | ||||
|                 6, | ||||
|                 6, | ||||
|                 6, | ||||
|                 6, | ||||
|                 6, | ||||
|                 7, | ||||
|                 7, | ||||
|                 7, | ||||
|                 8, | ||||
|                 9, | ||||
|             ] | ||||
|         ) | ||||
|  | ||||
|         self.annotated_joint_region = self.joint_regions[self.annotated_classes] | ||||
|         self.region_colors = np.array( | ||||
|             [ | ||||
|                 [250, 190, 190],  # body, light pink | ||||
|                 [60, 180, 75],  # upper_right, green | ||||
|                 [230, 25, 75],  # upper_left, red | ||||
|                 [128, 0, 0],  # neck, maroon | ||||
|                 [0, 130, 200],  # lower_right, blue | ||||
|                 [255, 255, 25],  # lower_left, yellow | ||||
|                 [240, 50, 230],  # tail, majenta | ||||
|                 [245, 130, 48],  # jaw / nose / chin, orange | ||||
|                 [29, 98, 115],  # right_ear, turquoise | ||||
|                 [255, 153, 204], | ||||
|             ] | ||||
|         )  # left_ear, pink | ||||
|  | ||||
|         self.joint_colors = np.array(self.region_colors)[self.annotated_joint_region] | ||||
|  | ||||
|  | ||||
| class BADJAData: | ||||
|     def __init__(self, data_root, complete=False): | ||||
|         annotations_path = os.path.join(data_root, "joint_annotations") | ||||
|  | ||||
|         self.animal_dict = {} | ||||
|         self.animal_count = 0 | ||||
|         self.smal_joint_info = SMALJointInfo() | ||||
|         for __, animal_json in enumerate(sorted(os.listdir(annotations_path))): | ||||
|             if animal_json not in IGNORE_ANIMALS: | ||||
|                 json_path = os.path.join(annotations_path, animal_json) | ||||
|                 with open(json_path) as json_data: | ||||
|                     animal_joint_data = json.load(json_data) | ||||
|  | ||||
|                 filenames = [] | ||||
|                 segnames = [] | ||||
|                 joints = [] | ||||
|                 visible = [] | ||||
|  | ||||
|                 first_path = animal_joint_data[0]["segmentation_path"] | ||||
|                 last_path = animal_joint_data[-1]["segmentation_path"] | ||||
|                 first_frame = first_path.split("/")[-1] | ||||
|                 last_frame = last_path.split("/")[-1] | ||||
|  | ||||
|                 if not "extra_videos" in first_path: | ||||
|                     animal = first_path.split("/")[-2] | ||||
|  | ||||
|                     first_frame_int = int(first_frame.split(".")[0]) | ||||
|                     last_frame_int = int(last_frame.split(".")[0]) | ||||
|  | ||||
|                     for fr in range(first_frame_int, last_frame_int + 1): | ||||
|                         ref_file_name = os.path.join( | ||||
|                             data_root, | ||||
|                             "DAVIS/JPEGImages/Full-Resolution/%s/%05d.jpg" | ||||
|                             % (animal, fr), | ||||
|                         ) | ||||
|                         ref_seg_name = os.path.join( | ||||
|                             data_root, | ||||
|                             "DAVIS/Annotations/Full-Resolution/%s/%05d.png" | ||||
|                             % (animal, fr), | ||||
|                         ) | ||||
|  | ||||
|                         foundit = False | ||||
|                         for ind, image_annotation in enumerate(animal_joint_data): | ||||
|                             file_name = os.path.join( | ||||
|                                 data_root, image_annotation["image_path"] | ||||
|                             ) | ||||
|                             seg_name = os.path.join( | ||||
|                                 data_root, image_annotation["segmentation_path"] | ||||
|                             ) | ||||
|  | ||||
|                             if file_name == ref_file_name: | ||||
|                                 foundit = True | ||||
|                                 label_ind = ind | ||||
|  | ||||
|                         if foundit: | ||||
|                             image_annotation = animal_joint_data[label_ind] | ||||
|                             file_name = os.path.join( | ||||
|                                 data_root, image_annotation["image_path"] | ||||
|                             ) | ||||
|                             seg_name = os.path.join( | ||||
|                                 data_root, image_annotation["segmentation_path"] | ||||
|                             ) | ||||
|                             joint = np.array(image_annotation["joints"]) | ||||
|                             vis = np.array(image_annotation["visibility"]) | ||||
|                         else: | ||||
|                             file_name = ref_file_name | ||||
|                             seg_name = ref_seg_name | ||||
|                             joint = None | ||||
|                             vis = None | ||||
|  | ||||
|                         filenames.append(file_name) | ||||
|                         segnames.append(seg_name) | ||||
|                         joints.append(joint) | ||||
|                         visible.append(vis) | ||||
|  | ||||
|                 if len(filenames): | ||||
|                     self.animal_dict[self.animal_count] = ( | ||||
|                         filenames, | ||||
|                         segnames, | ||||
|                         joints, | ||||
|                         visible, | ||||
|                     ) | ||||
|                     self.animal_count += 1 | ||||
|         print("Loaded BADJA dataset") | ||||
|  | ||||
|     def get_loader(self): | ||||
|         for __ in range(int(1e6)): | ||||
|             animal_id = np.random.choice(len(self.animal_dict.keys())) | ||||
|             filenames, segnames, joints, visible = self.animal_dict[animal_id] | ||||
|  | ||||
|             image_id = np.random.randint(0, len(filenames)) | ||||
|  | ||||
|             seg_file = segnames[image_id] | ||||
|             image_file = filenames[image_id] | ||||
|  | ||||
|             joints = joints[image_id].copy() | ||||
|             joints = joints[self.smal_joint_info.annotated_classes] | ||||
|             visible = visible[image_id][self.smal_joint_info.annotated_classes] | ||||
|  | ||||
|             rgb_img = imageio.imread(image_file)  # , mode='RGB') | ||||
|             sil_img = imageio.imread(seg_file)  # , mode='RGB') | ||||
|  | ||||
|             rgb_h, rgb_w, _ = rgb_img.shape | ||||
|             sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST) | ||||
|  | ||||
|             yield rgb_img, sil_img, joints, visible, image_file | ||||
|  | ||||
|     def get_video(self, animal_id): | ||||
|         filenames, segnames, joint, visible = self.animal_dict[animal_id] | ||||
|  | ||||
|         rgbs = [] | ||||
|         segs = [] | ||||
|         joints = [] | ||||
|         visibles = [] | ||||
|  | ||||
|         for s in range(len(filenames)): | ||||
|             image_file = filenames[s] | ||||
|             rgb_img = imageio.imread(image_file)  # , mode='RGB') | ||||
|             rgb_h, rgb_w, _ = rgb_img.shape | ||||
|  | ||||
|             seg_file = segnames[s] | ||||
|             sil_img = imageio.imread(seg_file)  # , mode='RGB') | ||||
|             sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST) | ||||
|  | ||||
|             jo = joint[s] | ||||
|  | ||||
|             if jo is not None: | ||||
|                 joi = joint[s].copy() | ||||
|                 joi = joi[self.smal_joint_info.annotated_classes] | ||||
|                 vis = visible[s][self.smal_joint_info.annotated_classes] | ||||
|             else: | ||||
|                 joi = None | ||||
|                 vis = None | ||||
|  | ||||
|             rgbs.append(rgb_img) | ||||
|             segs.append(sil_img) | ||||
|             joints.append(joi) | ||||
|             visibles.append(vis) | ||||
|  | ||||
|         return rgbs, segs, joints, visibles, filenames[0] | ||||
|  | ||||
|  | ||||
| class BadjaDataset(torch.utils.data.Dataset): | ||||
|     def __init__( | ||||
|         self, data_root, max_seq_len=1000, dataset_resolution=(384, 512) | ||||
|     ): | ||||
|  | ||||
|         self.data_root = data_root | ||||
|         self.badja_data = BADJAData(data_root) | ||||
|         self.max_seq_len = max_seq_len | ||||
|         self.dataset_resolution = dataset_resolution | ||||
|         print( | ||||
|             "found %d unique videos in %s" | ||||
|             % (self.badja_data.animal_count, self.data_root) | ||||
|         ) | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|  | ||||
|         rgbs, segs, joints, visibles, filename = self.badja_data.get_video(index) | ||||
|         S = len(rgbs) | ||||
|         H, W, __ = rgbs[0].shape | ||||
|         H, W, __ = segs[0].shape | ||||
|  | ||||
|         N, __ = joints[0].shape | ||||
|  | ||||
|         # let's eliminate the Nones | ||||
|         # note the first one is guaranteed present | ||||
|         for s in range(1, S): | ||||
|             if joints[s] is None: | ||||
|                 joints[s] = np.zeros_like(joints[0]) | ||||
|                 visibles[s] = np.zeros_like(visibles[0]) | ||||
|  | ||||
|         # eliminate the mystery dim | ||||
|         segs = [seg[:, :, 0] for seg in segs] | ||||
|  | ||||
|         rgbs = np.stack(rgbs, 0) | ||||
|         segs = np.stack(segs, 0) | ||||
|         trajs = np.stack(joints, 0) | ||||
|         visibles = np.stack(visibles, 0) | ||||
|  | ||||
|         rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0, 3, 1, 2).float() | ||||
|         segs = torch.from_numpy(segs).reshape(S, 1, H, W).float() | ||||
|         trajs = torch.from_numpy(trajs).reshape(S, N, 2).float() | ||||
|         visibles = torch.from_numpy(visibles).reshape(S, N) | ||||
|  | ||||
|         rgbs = rgbs[: self.max_seq_len] | ||||
|         segs = segs[: self.max_seq_len] | ||||
|         trajs = trajs[: self.max_seq_len] | ||||
|         visibles = visibles[: self.max_seq_len] | ||||
|         # apparently the coords are in yx order | ||||
|         trajs = torch.flip(trajs, [2]) | ||||
|  | ||||
|         if "extra_videos" in filename: | ||||
|             seq_name = filename.split("/")[-3] | ||||
|         else: | ||||
|             seq_name = filename.split("/")[-2] | ||||
|  | ||||
|         rgbs, trajs, segs = resize_sample(rgbs, trajs, segs, self.dataset_resolution) | ||||
|  | ||||
|         return CoTrackerData(rgbs, segs, trajs, visibles, seq_name=seq_name) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return self.badja_data.animal_count | ||||
							
								
								
									
										72
									
								
								cotracker/datasets/fast_capture_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										72
									
								
								cotracker/datasets/fast_capture_dataset.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,72 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import os | ||||
| import torch | ||||
|  | ||||
| # from PIL import Image | ||||
| import imageio | ||||
| import numpy as np | ||||
| from cotracker.datasets.utils import CoTrackerData, resize_sample | ||||
|  | ||||
|  | ||||
| class FastCaptureDataset(torch.utils.data.Dataset): | ||||
|     def __init__( | ||||
|         self, | ||||
|         data_root, | ||||
|         max_seq_len=50, | ||||
|         max_num_points=20, | ||||
|         dataset_resolution=(384, 512), | ||||
|     ): | ||||
|  | ||||
|         self.data_root = data_root | ||||
|         self.seq_names = os.listdir(os.path.join(data_root, "renders_local_rm")) | ||||
|         self.pth_dir = os.path.join(data_root, "zju_tracking") | ||||
|         self.max_seq_len = max_seq_len | ||||
|         self.max_num_points = max_num_points | ||||
|         self.dataset_resolution = dataset_resolution | ||||
|         print("found %d unique videos in %s" % (len(self.seq_names), self.data_root)) | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         seq_name = self.seq_names[index] | ||||
|         spath = os.path.join(self.data_root, "renders_local_rm", seq_name) | ||||
|         pthpath = os.path.join(self.pth_dir, seq_name + ".pth") | ||||
|  | ||||
|         rgbs = [] | ||||
|         img_paths = sorted(os.listdir(spath)) | ||||
|         for i, img_path in enumerate(img_paths): | ||||
|             if i < self.max_seq_len: | ||||
|                 rgbs.append(imageio.imread(os.path.join(spath, img_path))) | ||||
|  | ||||
|         annot_dict = torch.load(pthpath) | ||||
|         traj_2d = annot_dict["traj_2d"][:, :, : self.max_seq_len] | ||||
|         visibility = annot_dict["visibility"][:, : self.max_seq_len] | ||||
|  | ||||
|         S = len(rgbs) | ||||
|         H, W, __ = rgbs[0].shape | ||||
|         *_, S = traj_2d.shape | ||||
|         visibile_pts_first_frame_inds = (visibility[:, 0] > 0).nonzero(as_tuple=False)[ | ||||
|             :, 0 | ||||
|         ] | ||||
|         torch.manual_seed(0) | ||||
|         point_inds = torch.randperm(len(visibile_pts_first_frame_inds))[ | ||||
|             : self.max_num_points | ||||
|         ] | ||||
|         visible_inds_sampled = visibile_pts_first_frame_inds[point_inds] | ||||
|  | ||||
|         rgbs = np.stack(rgbs, 0) | ||||
|         rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0, 3, 1, 2).float() | ||||
|  | ||||
|         segs = torch.ones(S, 1, H, W).float() | ||||
|         trajs = traj_2d[visible_inds_sampled].permute(2, 0, 1).float() | ||||
|         visibles = visibility[visible_inds_sampled].permute(1, 0) | ||||
|  | ||||
|         rgbs, trajs, segs = resize_sample(rgbs, trajs, segs, self.dataset_resolution) | ||||
|  | ||||
|         return CoTrackerData(rgbs, segs, trajs, visibles, seq_name=seq_name) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.seq_names) | ||||
							
								
								
									
										494
									
								
								cotracker/datasets/kubric_movif_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										494
									
								
								cotracker/datasets/kubric_movif_dataset.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,494 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import os | ||||
| import torch | ||||
|  | ||||
| import imageio | ||||
| import numpy as np | ||||
|  | ||||
| from cotracker.datasets.utils import CoTrackerData | ||||
| from torchvision.transforms import ColorJitter, GaussianBlur | ||||
| from PIL import Image | ||||
| import cv2 | ||||
|  | ||||
|  | ||||
| class CoTrackerDataset(torch.utils.data.Dataset): | ||||
|     def __init__( | ||||
|         self, | ||||
|         data_root, | ||||
|         crop_size=(384, 512), | ||||
|         seq_len=24, | ||||
|         traj_per_sample=768, | ||||
|         sample_vis_1st_frame=False, | ||||
|         use_augs=False, | ||||
|     ): | ||||
|         super(CoTrackerDataset, self).__init__() | ||||
|         np.random.seed(0) | ||||
|         torch.manual_seed(0) | ||||
|         self.data_root = data_root | ||||
|         self.seq_len = seq_len | ||||
|         self.traj_per_sample = traj_per_sample | ||||
|         self.sample_vis_1st_frame = sample_vis_1st_frame | ||||
|         self.use_augs = use_augs | ||||
|         self.crop_size = crop_size | ||||
|  | ||||
|         # photometric augmentation | ||||
|         self.photo_aug = ColorJitter( | ||||
|             brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14 | ||||
|         ) | ||||
|         self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0)) | ||||
|  | ||||
|         self.blur_aug_prob = 0.25 | ||||
|         self.color_aug_prob = 0.25 | ||||
|  | ||||
|         # occlusion augmentation | ||||
|         self.eraser_aug_prob = 0.5 | ||||
|         self.eraser_bounds = [2, 100] | ||||
|         self.eraser_max = 10 | ||||
|  | ||||
|         # occlusion augmentation | ||||
|         self.replace_aug_prob = 0.5 | ||||
|         self.replace_bounds = [2, 100] | ||||
|         self.replace_max = 10 | ||||
|  | ||||
|         # spatial augmentations | ||||
|         self.pad_bounds = [0, 100] | ||||
|         self.crop_size = crop_size | ||||
|         self.resize_lim = [0.25, 2.0]  # sample resizes from here | ||||
|         self.resize_delta = 0.2 | ||||
|         self.max_crop_offset = 50 | ||||
|  | ||||
|         self.do_flip = True | ||||
|         self.h_flip_prob = 0.5 | ||||
|         self.v_flip_prob = 0.5 | ||||
|  | ||||
|     def getitem_helper(self, index): | ||||
|         return NotImplementedError | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         gotit = False | ||||
|  | ||||
|         sample, gotit = self.getitem_helper(index) | ||||
|         if not gotit: | ||||
|             print("warning: sampling failed") | ||||
|             # fake sample, so we can still collate | ||||
|             sample = CoTrackerData( | ||||
|                 video=torch.zeros( | ||||
|                     (self.seq_len, 3, self.crop_size[0], self.crop_size[1]) | ||||
|                 ), | ||||
|                 segmentation=torch.zeros( | ||||
|                     (self.seq_len, 1, self.crop_size[0], self.crop_size[1]) | ||||
|                 ), | ||||
|                 trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)), | ||||
|                 visibility=torch.zeros((self.seq_len, self.traj_per_sample)), | ||||
|                 valid=torch.zeros((self.seq_len, self.traj_per_sample)), | ||||
|             ) | ||||
|  | ||||
|         return sample, gotit | ||||
|  | ||||
|     def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True): | ||||
|         T, N, _ = trajs.shape | ||||
|  | ||||
|         S = len(rgbs) | ||||
|         H, W = rgbs[0].shape[:2] | ||||
|         assert S == T | ||||
|  | ||||
|         if eraser: | ||||
|             ############ eraser transform (per image after the first) ############ | ||||
|             rgbs = [rgb.astype(np.float32) for rgb in rgbs] | ||||
|             for i in range(1, S): | ||||
|                 if np.random.rand() < self.eraser_aug_prob: | ||||
|                     for _ in range( | ||||
|                         np.random.randint(1, self.eraser_max + 1) | ||||
|                     ):  # number of times to occlude | ||||
|  | ||||
|                         xc = np.random.randint(0, W) | ||||
|                         yc = np.random.randint(0, H) | ||||
|                         dx = np.random.randint( | ||||
|                             self.eraser_bounds[0], self.eraser_bounds[1] | ||||
|                         ) | ||||
|                         dy = np.random.randint( | ||||
|                             self.eraser_bounds[0], self.eraser_bounds[1] | ||||
|                         ) | ||||
|                         x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32) | ||||
|                         x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32) | ||||
|                         y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32) | ||||
|                         y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32) | ||||
|  | ||||
|                         mean_color = np.mean( | ||||
|                             rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0 | ||||
|                         ) | ||||
|                         rgbs[i][y0:y1, x0:x1, :] = mean_color | ||||
|  | ||||
|                         occ_inds = np.logical_and( | ||||
|                             np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1), | ||||
|                             np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1), | ||||
|                         ) | ||||
|                         visibles[i, occ_inds] = 0 | ||||
|             rgbs = [rgb.astype(np.uint8) for rgb in rgbs] | ||||
|  | ||||
|         if replace: | ||||
|  | ||||
|             rgbs_alt = [ | ||||
|                 np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) | ||||
|                 for rgb in rgbs | ||||
|             ] | ||||
|             rgbs_alt = [ | ||||
|                 np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) | ||||
|                 for rgb in rgbs_alt | ||||
|             ] | ||||
|  | ||||
|             ############ replace transform (per image after the first) ############ | ||||
|             rgbs = [rgb.astype(np.float32) for rgb in rgbs] | ||||
|             rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt] | ||||
|             for i in range(1, S): | ||||
|                 if np.random.rand() < self.replace_aug_prob: | ||||
|                     for _ in range( | ||||
|                         np.random.randint(1, self.replace_max + 1) | ||||
|                     ):  # number of times to occlude | ||||
|                         xc = np.random.randint(0, W) | ||||
|                         yc = np.random.randint(0, H) | ||||
|                         dx = np.random.randint( | ||||
|                             self.replace_bounds[0], self.replace_bounds[1] | ||||
|                         ) | ||||
|                         dy = np.random.randint( | ||||
|                             self.replace_bounds[0], self.replace_bounds[1] | ||||
|                         ) | ||||
|                         x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32) | ||||
|                         x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32) | ||||
|                         y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32) | ||||
|                         y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32) | ||||
|  | ||||
|                         wid = x1 - x0 | ||||
|                         hei = y1 - y0 | ||||
|                         y00 = np.random.randint(0, H - hei) | ||||
|                         x00 = np.random.randint(0, W - wid) | ||||
|                         fr = np.random.randint(0, S) | ||||
|                         rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :] | ||||
|                         rgbs[i][y0:y1, x0:x1, :] = rep | ||||
|  | ||||
|                         occ_inds = np.logical_and( | ||||
|                             np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1), | ||||
|                             np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1), | ||||
|                         ) | ||||
|                         visibles[i, occ_inds] = 0 | ||||
|             rgbs = [rgb.astype(np.uint8) for rgb in rgbs] | ||||
|  | ||||
|         ############ photometric augmentation ############ | ||||
|         if np.random.rand() < self.color_aug_prob: | ||||
|             # random per-frame amount of aug | ||||
|             rgbs = [ | ||||
|                 np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) | ||||
|                 for rgb in rgbs | ||||
|             ] | ||||
|  | ||||
|         if np.random.rand() < self.blur_aug_prob: | ||||
|             # random per-frame amount of blur | ||||
|             rgbs = [ | ||||
|                 np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) | ||||
|                 for rgb in rgbs | ||||
|             ] | ||||
|  | ||||
|         return rgbs, trajs, visibles | ||||
|  | ||||
|     def add_spatial_augs(self, rgbs, trajs, visibles): | ||||
|         T, N, __ = trajs.shape | ||||
|  | ||||
|         S = len(rgbs) | ||||
|         H, W = rgbs[0].shape[:2] | ||||
|         assert S == T | ||||
|  | ||||
|         rgbs = [rgb.astype(np.float32) for rgb in rgbs] | ||||
|  | ||||
|         ############ spatial transform ############ | ||||
|  | ||||
|         # padding | ||||
|         pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) | ||||
|         pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) | ||||
|         pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) | ||||
|         pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) | ||||
|  | ||||
|         rgbs = [ | ||||
|             np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs | ||||
|         ] | ||||
|         trajs[:, :, 0] += pad_x0 | ||||
|         trajs[:, :, 1] += pad_y0 | ||||
|         H, W = rgbs[0].shape[:2] | ||||
|  | ||||
|         # scaling + stretching | ||||
|         scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1]) | ||||
|         scale_x = scale | ||||
|         scale_y = scale | ||||
|         H_new = H | ||||
|         W_new = W | ||||
|  | ||||
|         scale_delta_x = 0.0 | ||||
|         scale_delta_y = 0.0 | ||||
|  | ||||
|         rgbs_scaled = [] | ||||
|         for s in range(S): | ||||
|             if s == 1: | ||||
|                 scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta) | ||||
|                 scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta) | ||||
|             elif s > 1: | ||||
|                 scale_delta_x = ( | ||||
|                     scale_delta_x * 0.8 | ||||
|                     + np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2 | ||||
|                 ) | ||||
|                 scale_delta_y = ( | ||||
|                     scale_delta_y * 0.8 | ||||
|                     + np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2 | ||||
|                 ) | ||||
|             scale_x = scale_x + scale_delta_x | ||||
|             scale_y = scale_y + scale_delta_y | ||||
|  | ||||
|             # bring h/w closer | ||||
|             scale_xy = (scale_x + scale_y) * 0.5 | ||||
|             scale_x = scale_x * 0.5 + scale_xy * 0.5 | ||||
|             scale_y = scale_y * 0.5 + scale_xy * 0.5 | ||||
|  | ||||
|             # don't get too crazy | ||||
|             scale_x = np.clip(scale_x, 0.2, 2.0) | ||||
|             scale_y = np.clip(scale_y, 0.2, 2.0) | ||||
|  | ||||
|             H_new = int(H * scale_y) | ||||
|             W_new = int(W * scale_x) | ||||
|  | ||||
|             # make it at least slightly bigger than the crop area, | ||||
|             # so that the random cropping can add diversity | ||||
|             H_new = np.clip(H_new, self.crop_size[0] + 10, None) | ||||
|             W_new = np.clip(W_new, self.crop_size[1] + 10, None) | ||||
|             # recompute scale in case we clipped | ||||
|             scale_x = W_new / float(W) | ||||
|             scale_y = H_new / float(H) | ||||
|  | ||||
|             rgbs_scaled.append( | ||||
|                 cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR) | ||||
|             ) | ||||
|             trajs[s, :, 0] *= scale_x | ||||
|             trajs[s, :, 1] *= scale_y | ||||
|         rgbs = rgbs_scaled | ||||
|  | ||||
|         ok_inds = visibles[0, :] > 0 | ||||
|         vis_trajs = trajs[:, ok_inds]  # S,?,2 | ||||
|  | ||||
|         if vis_trajs.shape[1] > 0: | ||||
|             mid_x = np.mean(vis_trajs[0, :, 0]) | ||||
|             mid_y = np.mean(vis_trajs[0, :, 1]) | ||||
|         else: | ||||
|             mid_y = self.crop_size[0] | ||||
|             mid_x = self.crop_size[1] | ||||
|  | ||||
|         x0 = int(mid_x - self.crop_size[1] // 2) | ||||
|         y0 = int(mid_y - self.crop_size[0] // 2) | ||||
|  | ||||
|         offset_x = 0 | ||||
|         offset_y = 0 | ||||
|  | ||||
|         for s in range(S): | ||||
|             # on each frame, shift a bit more | ||||
|             if s == 1: | ||||
|                 offset_x = np.random.randint( | ||||
|                     -self.max_crop_offset, self.max_crop_offset | ||||
|                 ) | ||||
|                 offset_y = np.random.randint( | ||||
|                     -self.max_crop_offset, self.max_crop_offset | ||||
|                 ) | ||||
|             elif s > 1: | ||||
|                 offset_x = int( | ||||
|                     offset_x * 0.8 | ||||
|                     + np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) | ||||
|                     * 0.2 | ||||
|                 ) | ||||
|                 offset_y = int( | ||||
|                     offset_y * 0.8 | ||||
|                     + np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) | ||||
|                     * 0.2 | ||||
|                 ) | ||||
|             x0 = x0 + offset_x | ||||
|             y0 = y0 + offset_y | ||||
|  | ||||
|             H_new, W_new = rgbs[s].shape[:2] | ||||
|             if H_new == self.crop_size[0]: | ||||
|                 y0 = 0 | ||||
|             else: | ||||
|                 y0 = min(max(0, y0), H_new - self.crop_size[0] - 1) | ||||
|  | ||||
|             if W_new == self.crop_size[1]: | ||||
|                 x0 = 0 | ||||
|             else: | ||||
|                 x0 = min(max(0, x0), W_new - self.crop_size[1] - 1) | ||||
|  | ||||
|             rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] | ||||
|             trajs[s, :, 0] -= x0 | ||||
|             trajs[s, :, 1] -= y0 | ||||
|  | ||||
|         H_new = self.crop_size[0] | ||||
|         W_new = self.crop_size[1] | ||||
|  | ||||
|         # flip | ||||
|         h_flipped = False | ||||
|         v_flipped = False | ||||
|         if self.do_flip: | ||||
|             # h flip | ||||
|             if np.random.rand() < self.h_flip_prob: | ||||
|                 h_flipped = True | ||||
|                 rgbs = [rgb[:, ::-1] for rgb in rgbs] | ||||
|             # v flip | ||||
|             if np.random.rand() < self.v_flip_prob: | ||||
|                 v_flipped = True | ||||
|                 rgbs = [rgb[::-1] for rgb in rgbs] | ||||
|         if h_flipped: | ||||
|             trajs[:, :, 0] = W_new - trajs[:, :, 0] | ||||
|         if v_flipped: | ||||
|             trajs[:, :, 1] = H_new - trajs[:, :, 1] | ||||
|  | ||||
|         return rgbs, trajs | ||||
|  | ||||
|     def crop(self, rgbs, trajs): | ||||
|         T, N, _ = trajs.shape | ||||
|  | ||||
|         S = len(rgbs) | ||||
|         H, W = rgbs[0].shape[:2] | ||||
|         assert S == T | ||||
|  | ||||
|         ############ spatial transform ############ | ||||
|  | ||||
|         H_new = H | ||||
|         W_new = W | ||||
|  | ||||
|         # simple random crop | ||||
|         y0 = ( | ||||
|             0 | ||||
|             if self.crop_size[0] >= H_new | ||||
|             else np.random.randint(0, H_new - self.crop_size[0]) | ||||
|         ) | ||||
|         x0 = ( | ||||
|             0 | ||||
|             if self.crop_size[1] >= W_new | ||||
|             else np.random.randint(0, W_new - self.crop_size[1]) | ||||
|         ) | ||||
|         rgbs = [ | ||||
|             rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] | ||||
|             for rgb in rgbs | ||||
|         ] | ||||
|  | ||||
|         trajs[:, :, 0] -= x0 | ||||
|         trajs[:, :, 1] -= y0 | ||||
|  | ||||
|         return rgbs, trajs | ||||
|  | ||||
|  | ||||
| class KubricMovifDataset(CoTrackerDataset): | ||||
|     def __init__( | ||||
|         self, | ||||
|         data_root, | ||||
|         crop_size=(384, 512), | ||||
|         seq_len=24, | ||||
|         traj_per_sample=768, | ||||
|         sample_vis_1st_frame=False, | ||||
|         use_augs=False, | ||||
|     ): | ||||
|         super(KubricMovifDataset, self).__init__( | ||||
|             data_root=data_root, | ||||
|             crop_size=crop_size, | ||||
|             seq_len=seq_len, | ||||
|             traj_per_sample=traj_per_sample, | ||||
|             sample_vis_1st_frame=sample_vis_1st_frame, | ||||
|             use_augs=use_augs, | ||||
|         ) | ||||
|  | ||||
|         self.pad_bounds = [0, 25] | ||||
|         self.resize_lim = [0.75, 1.25]  # sample resizes from here | ||||
|         self.resize_delta = 0.05 | ||||
|         self.max_crop_offset = 15 | ||||
|         self.seq_names = [ | ||||
|             fname | ||||
|             for fname in os.listdir(data_root) | ||||
|             if os.path.isdir(os.path.join(data_root, fname)) | ||||
|         ] | ||||
|         print("found %d unique videos in %s" % (len(self.seq_names), self.data_root)) | ||||
|  | ||||
|     def getitem_helper(self, index): | ||||
|         gotit = True | ||||
|         seq_name = self.seq_names[index] | ||||
|  | ||||
|         npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy") | ||||
|         rgb_path = os.path.join(self.data_root, seq_name, "frames") | ||||
|  | ||||
|         img_paths = sorted(os.listdir(rgb_path)) | ||||
|         rgbs = [] | ||||
|         for i, img_path in enumerate(img_paths): | ||||
|             rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path))) | ||||
|  | ||||
|         rgbs = np.stack(rgbs) | ||||
|         annot_dict = np.load(npy_path, allow_pickle=True).item() | ||||
|         traj_2d = annot_dict["coords"] | ||||
|         visibility = annot_dict["visibility"] | ||||
|  | ||||
|         # random crop | ||||
|         assert self.seq_len <= len(rgbs) | ||||
|         if self.seq_len < len(rgbs): | ||||
|             start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0] | ||||
|  | ||||
|             rgbs = rgbs[start_ind : start_ind + self.seq_len] | ||||
|             traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len] | ||||
|             visibility = visibility[:, start_ind : start_ind + self.seq_len] | ||||
|  | ||||
|         traj_2d = np.transpose(traj_2d, (1, 0, 2)) | ||||
|         visibility = np.transpose(np.logical_not(visibility), (1, 0)) | ||||
|         if self.use_augs: | ||||
|             rgbs, traj_2d, visibility = self.add_photometric_augs( | ||||
|                 rgbs, traj_2d, visibility | ||||
|             ) | ||||
|             rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility) | ||||
|         else: | ||||
|             rgbs, traj_2d = self.crop(rgbs, traj_2d) | ||||
|  | ||||
|         visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False | ||||
|         visibility[traj_2d[:, :, 0] < 0] = False | ||||
|         visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False | ||||
|         visibility[traj_2d[:, :, 1] < 0] = False | ||||
|  | ||||
|         visibility = torch.from_numpy(visibility) | ||||
|         traj_2d = torch.from_numpy(traj_2d) | ||||
|  | ||||
|         visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0] | ||||
|  | ||||
|         if self.sample_vis_1st_frame: | ||||
|             visibile_pts_inds = visibile_pts_first_frame_inds | ||||
|         else: | ||||
|             visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero( | ||||
|                 as_tuple=False | ||||
|             )[:, 0] | ||||
|             visibile_pts_inds = torch.cat( | ||||
|                 (visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0 | ||||
|             ) | ||||
|         point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample] | ||||
|         if len(point_inds) < self.traj_per_sample: | ||||
|             gotit = False | ||||
|  | ||||
|         visible_inds_sampled = visibile_pts_inds[point_inds] | ||||
|  | ||||
|         trajs = traj_2d[:, visible_inds_sampled].float() | ||||
|         visibles = visibility[:, visible_inds_sampled] | ||||
|         valids = torch.ones((self.seq_len, self.traj_per_sample)) | ||||
|  | ||||
|         rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float() | ||||
|         segs = torch.ones((self.seq_len, 1, self.crop_size[0], self.crop_size[1])) | ||||
|         sample = CoTrackerData( | ||||
|             video=rgbs, | ||||
|             segmentation=segs, | ||||
|             trajectory=trajs, | ||||
|             visibility=visibles, | ||||
|             valid=valids, | ||||
|             seq_name=seq_name, | ||||
|         ) | ||||
|         return sample, gotit | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.seq_names) | ||||
							
								
								
									
										218
									
								
								cotracker/datasets/tap_vid_datasets.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										218
									
								
								cotracker/datasets/tap_vid_datasets.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,218 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import os | ||||
| import io | ||||
| import glob | ||||
| import torch | ||||
| import pickle | ||||
| import numpy as np | ||||
| import mediapy as media | ||||
|  | ||||
| from PIL import Image | ||||
| from typing import Mapping, Tuple, Union | ||||
|  | ||||
| from cotracker.datasets.utils import CoTrackerData | ||||
|  | ||||
| DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]] | ||||
|  | ||||
|  | ||||
| def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: | ||||
|     """Resize a video to output_size.""" | ||||
|     # If you have a GPU, consider replacing this with a GPU-enabled resize op, | ||||
|     # such as a jitted jax.image.resize.  It will make things faster. | ||||
|     return media.resize_video(video, output_size) | ||||
|  | ||||
|  | ||||
| def sample_queries_first( | ||||
|     target_occluded: np.ndarray, | ||||
|     target_points: np.ndarray, | ||||
|     frames: np.ndarray, | ||||
| ) -> Mapping[str, np.ndarray]: | ||||
|     """Package a set of frames and tracks for use in TAPNet evaluations. | ||||
|     Given a set of frames and tracks with no query points, use the first | ||||
|     visible point in each track as the query. | ||||
|     Args: | ||||
|       target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames], | ||||
|         where True indicates occluded. | ||||
|       target_points: Position, of shape [n_tracks, n_frames, 2], where each point | ||||
|         is [x,y] scaled between 0 and 1. | ||||
|       frames: Video tensor, of shape [n_frames, height, width, 3].  Scaled between | ||||
|         -1 and 1. | ||||
|     Returns: | ||||
|       A dict with the keys: | ||||
|         video: Video tensor of shape [1, n_frames, height, width, 3] | ||||
|         query_points: Query points of shape [1, n_queries, 3] where | ||||
|           each point is [t, y, x] scaled to the range [-1, 1] | ||||
|         target_points: Target points of shape [1, n_queries, n_frames, 2] where | ||||
|           each point is [x, y] scaled to the range [-1, 1] | ||||
|     """ | ||||
|     valid = np.sum(~target_occluded, axis=1) > 0 | ||||
|     target_points = target_points[valid, :] | ||||
|     target_occluded = target_occluded[valid, :] | ||||
|  | ||||
|     query_points = [] | ||||
|     for i in range(target_points.shape[0]): | ||||
|         index = np.where(target_occluded[i] == 0)[0][0] | ||||
|         x, y = target_points[i, index, 0], target_points[i, index, 1] | ||||
|         query_points.append(np.array([index, y, x]))  # [t, y, x] | ||||
|     query_points = np.stack(query_points, axis=0) | ||||
|  | ||||
|     return { | ||||
|         "video": frames[np.newaxis, ...], | ||||
|         "query_points": query_points[np.newaxis, ...], | ||||
|         "target_points": target_points[np.newaxis, ...], | ||||
|         "occluded": target_occluded[np.newaxis, ...], | ||||
|     } | ||||
|  | ||||
|  | ||||
| def sample_queries_strided( | ||||
|     target_occluded: np.ndarray, | ||||
|     target_points: np.ndarray, | ||||
|     frames: np.ndarray, | ||||
|     query_stride: int = 5, | ||||
| ) -> Mapping[str, np.ndarray]: | ||||
|     """Package a set of frames and tracks for use in TAPNet evaluations. | ||||
|  | ||||
|     Given a set of frames and tracks with no query points, sample queries | ||||
|     strided every query_stride frames, ignoring points that are not visible | ||||
|     at the selected frames. | ||||
|  | ||||
|     Args: | ||||
|       target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames], | ||||
|         where True indicates occluded. | ||||
|       target_points: Position, of shape [n_tracks, n_frames, 2], where each point | ||||
|         is [x,y] scaled between 0 and 1. | ||||
|       frames: Video tensor, of shape [n_frames, height, width, 3].  Scaled between | ||||
|         -1 and 1. | ||||
|       query_stride: When sampling query points, search for un-occluded points | ||||
|         every query_stride frames and convert each one into a query. | ||||
|  | ||||
|     Returns: | ||||
|       A dict with the keys: | ||||
|         video: Video tensor of shape [1, n_frames, height, width, 3].  The video | ||||
|           has floats scaled to the range [-1, 1]. | ||||
|         query_points: Query points of shape [1, n_queries, 3] where | ||||
|           each point is [t, y, x] scaled to the range [-1, 1]. | ||||
|         target_points: Target points of shape [1, n_queries, n_frames, 2] where | ||||
|           each point is [x, y] scaled to the range [-1, 1]. | ||||
|         trackgroup: Index of the original track that each query point was | ||||
|           sampled from.  This is useful for visualization. | ||||
|     """ | ||||
|     tracks = [] | ||||
|     occs = [] | ||||
|     queries = [] | ||||
|     trackgroups = [] | ||||
|     total = 0 | ||||
|     trackgroup = np.arange(target_occluded.shape[0]) | ||||
|     for i in range(0, target_occluded.shape[1], query_stride): | ||||
|         mask = target_occluded[:, i] == 0 | ||||
|         query = np.stack( | ||||
|             [ | ||||
|                 i * np.ones(target_occluded.shape[0:1]), | ||||
|                 target_points[:, i, 1], | ||||
|                 target_points[:, i, 0], | ||||
|             ], | ||||
|             axis=-1, | ||||
|         ) | ||||
|         queries.append(query[mask]) | ||||
|         tracks.append(target_points[mask]) | ||||
|         occs.append(target_occluded[mask]) | ||||
|         trackgroups.append(trackgroup[mask]) | ||||
|         total += np.array(np.sum(target_occluded[:, i] == 0)) | ||||
|  | ||||
|     return { | ||||
|         "video": frames[np.newaxis, ...], | ||||
|         "query_points": np.concatenate(queries, axis=0)[np.newaxis, ...], | ||||
|         "target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...], | ||||
|         "occluded": np.concatenate(occs, axis=0)[np.newaxis, ...], | ||||
|         "trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...], | ||||
|     } | ||||
|  | ||||
|  | ||||
| class TapVidDataset(torch.utils.data.Dataset): | ||||
|     def __init__( | ||||
|         self, | ||||
|         data_root, | ||||
|         dataset_type="davis", | ||||
|         resize_to_256=True, | ||||
|         queried_first=True, | ||||
|     ): | ||||
|         self.dataset_type = dataset_type | ||||
|         self.resize_to_256 = resize_to_256 | ||||
|         self.queried_first = queried_first | ||||
|         if self.dataset_type == "kinetics": | ||||
|             all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl")) | ||||
|             points_dataset = [] | ||||
|             for pickle_path in all_paths: | ||||
|                 with open(pickle_path, "rb") as f: | ||||
|                     data = pickle.load(f) | ||||
|                     points_dataset = points_dataset + data | ||||
|             self.points_dataset = points_dataset | ||||
|         else: | ||||
|             with open(data_root, "rb") as f: | ||||
|                 self.points_dataset = pickle.load(f) | ||||
|             if self.dataset_type == "davis": | ||||
|                 self.video_names = list(self.points_dataset.keys()) | ||||
|         print("found %d unique videos in %s" % (len(self.points_dataset), data_root)) | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         if self.dataset_type == "davis": | ||||
|             video_name = self.video_names[index] | ||||
|         else: | ||||
|             video_name = index | ||||
|         video = self.points_dataset[video_name] | ||||
|         frames = video["video"] | ||||
|  | ||||
|         if isinstance(frames[0], bytes): | ||||
|             # TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s. | ||||
|             def decode(frame): | ||||
|                 byteio = io.BytesIO(frame) | ||||
|                 img = Image.open(byteio) | ||||
|                 return np.array(img) | ||||
|  | ||||
|             frames = np.array([decode(frame) for frame in frames]) | ||||
|  | ||||
|         target_points = self.points_dataset[video_name]["points"] | ||||
|         if self.resize_to_256: | ||||
|             frames = resize_video(frames, [256, 256]) | ||||
|             target_points *= np.array([256, 256]) | ||||
|         else: | ||||
|             target_points *= np.array([frames.shape[2], frames.shape[1]]) | ||||
|  | ||||
|         T, H, W, C = frames.shape | ||||
|         N, T, D = target_points.shape | ||||
|  | ||||
|         target_occ = self.points_dataset[video_name]["occluded"] | ||||
|         if self.queried_first: | ||||
|             converted = sample_queries_first(target_occ, target_points, frames) | ||||
|         else: | ||||
|             converted = sample_queries_strided(target_occ, target_points, frames) | ||||
|         assert converted["target_points"].shape[1] == converted["query_points"].shape[1] | ||||
|  | ||||
|         trajs = ( | ||||
|             torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() | ||||
|         )  # T, N, D | ||||
|  | ||||
|         rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float() | ||||
|         segs = torch.ones(T, 1, H, W).float() | ||||
|         visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[ | ||||
|             0 | ||||
|         ].permute( | ||||
|             1, 0 | ||||
|         )  # T, N | ||||
|         query_points = torch.from_numpy(converted["query_points"])[0]  # T, N | ||||
|         return CoTrackerData( | ||||
|             rgbs, | ||||
|             segs, | ||||
|             trajs, | ||||
|             visibles, | ||||
|             seq_name=str(video_name), | ||||
|             query_points=query_points, | ||||
|         ) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.points_dataset) | ||||
							
								
								
									
										114
									
								
								cotracker/datasets/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								cotracker/datasets/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,114 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
|  | ||||
| import torch | ||||
| import dataclasses | ||||
| import torch.nn.functional as F | ||||
| from dataclasses import dataclass | ||||
| from typing import Any, Optional | ||||
|  | ||||
|  | ||||
| @dataclass(eq=False) | ||||
| class CoTrackerData: | ||||
|     """ | ||||
|     Dataclass for storing video tracks data. | ||||
|     """ | ||||
|  | ||||
|     video: torch.Tensor  # B, S, C, H, W | ||||
|     segmentation: torch.Tensor  # B, S, 1, H, W | ||||
|     trajectory: torch.Tensor  # B, S, N, 2 | ||||
|     visibility: torch.Tensor  # B, S, N | ||||
|     # optional data | ||||
|     valid: Optional[torch.Tensor] = None  # B, S, N | ||||
|     seq_name: Optional[str] = None | ||||
|     query_points: Optional[torch.Tensor] = None  # TapVID evaluation format | ||||
|  | ||||
|  | ||||
| def collate_fn(batch): | ||||
|     """ | ||||
|     Collate function for video tracks data. | ||||
|     """ | ||||
|     video = torch.stack([b.video for b in batch], dim=0) | ||||
|     segmentation = torch.stack([b.segmentation for b in batch], dim=0) | ||||
|     trajectory = torch.stack([b.trajectory for b in batch], dim=0) | ||||
|     visibility = torch.stack([b.visibility for b in batch], dim=0) | ||||
|     query_points = None | ||||
|     if batch[0].query_points is not None: | ||||
|         query_points = torch.stack([b.query_points for b in batch], dim=0) | ||||
|     seq_name = [b.seq_name for b in batch] | ||||
|  | ||||
|     return CoTrackerData( | ||||
|         video, | ||||
|         segmentation, | ||||
|         trajectory, | ||||
|         visibility, | ||||
|         seq_name=seq_name, | ||||
|         query_points=query_points, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def collate_fn_train(batch): | ||||
|     """ | ||||
|     Collate function for video tracks data during training. | ||||
|     """ | ||||
|     gotit = [gotit for _, gotit in batch] | ||||
|     video = torch.stack([b.video for b, _ in batch], dim=0) | ||||
|     segmentation = torch.stack([b.segmentation for b, _ in batch], dim=0) | ||||
|     trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0) | ||||
|     visibility = torch.stack([b.visibility for b, _ in batch], dim=0) | ||||
|     valid = torch.stack([b.valid for b, _ in batch], dim=0) | ||||
|     seq_name = [b.seq_name for b, _ in batch] | ||||
|     return ( | ||||
|         CoTrackerData(video, segmentation, trajectory, visibility, valid, seq_name), | ||||
|         gotit, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def try_to_cuda(t: Any) -> Any: | ||||
|     """ | ||||
|     Try to move the input variable `t` to a cuda device. | ||||
|  | ||||
|     Args: | ||||
|         t: Input. | ||||
|  | ||||
|     Returns: | ||||
|         t_cuda: `t` moved to a cuda device, if supported. | ||||
|     """ | ||||
|     try: | ||||
|         t = t.float().cuda() | ||||
|     except AttributeError: | ||||
|         pass | ||||
|     return t | ||||
|  | ||||
|  | ||||
| def dataclass_to_cuda_(obj): | ||||
|     """ | ||||
|     Move all contents of a dataclass to cuda inplace if supported. | ||||
|  | ||||
|     Args: | ||||
|         batch: Input dataclass. | ||||
|  | ||||
|     Returns: | ||||
|         batch_cuda: `batch` moved to a cuda device, if supported. | ||||
|     """ | ||||
|     for f in dataclasses.fields(obj): | ||||
|         setattr(obj, f.name, try_to_cuda(getattr(obj, f.name))) | ||||
|     return obj | ||||
|  | ||||
|  | ||||
| def resize_sample(rgbs, trajs_g, segs, interp_shape): | ||||
|     S, C, H, W = rgbs.shape | ||||
|     S, N, D = trajs_g.shape | ||||
|  | ||||
|     assert D == 2 | ||||
|  | ||||
|     rgbs = F.interpolate(rgbs, interp_shape, mode="bilinear") | ||||
|     segs = F.interpolate(segs, interp_shape, mode="nearest") | ||||
|  | ||||
|     trajs_g[:, :, 0] *= interp_shape[1] / W | ||||
|     trajs_g[:, :, 1] *= interp_shape[0] / H | ||||
|     return rgbs, trajs_g, segs | ||||
							
								
								
									
										5
									
								
								cotracker/evaluation/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/evaluation/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
							
								
								
									
										6
									
								
								cotracker/evaluation/configs/eval_badja.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								cotracker/evaluation/configs/eval_badja.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | ||||
| defaults: | ||||
|   - default_config_eval | ||||
| exp_dir: ./outputs/cotracker | ||||
| dataset_name: badja | ||||
|  | ||||
|     | ||||
							
								
								
									
										6
									
								
								cotracker/evaluation/configs/eval_fastcapture.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								cotracker/evaluation/configs/eval_fastcapture.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | ||||
| defaults: | ||||
|   - default_config_eval | ||||
| exp_dir: ./outputs/cotracker | ||||
| dataset_name: fastcapture | ||||
|  | ||||
|     | ||||
| @@ -0,0 +1,6 @@ | ||||
| defaults: | ||||
|   - default_config_eval | ||||
| exp_dir: ./outputs/cotracker | ||||
| dataset_name: tapvid_davis_first | ||||
|  | ||||
|     | ||||
| @@ -0,0 +1,6 @@ | ||||
| defaults: | ||||
|   - default_config_eval | ||||
| exp_dir: ./outputs/cotracker | ||||
| dataset_name: tapvid_davis_strided | ||||
|  | ||||
|     | ||||
| @@ -0,0 +1,6 @@ | ||||
| defaults: | ||||
|   - default_config_eval | ||||
| exp_dir: ./outputs/cotracker | ||||
| dataset_name: tapvid_kinetics_first | ||||
|  | ||||
|     | ||||
							
								
								
									
										5
									
								
								cotracker/evaluation/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/evaluation/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
							
								
								
									
										144
									
								
								cotracker/evaluation/core/eval_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								cotracker/evaluation/core/eval_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,144 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| from typing import Iterable, Mapping, Tuple, Union | ||||
|  | ||||
|  | ||||
| def compute_tapvid_metrics( | ||||
|     query_points: np.ndarray, | ||||
|     gt_occluded: np.ndarray, | ||||
|     gt_tracks: np.ndarray, | ||||
|     pred_occluded: np.ndarray, | ||||
|     pred_tracks: np.ndarray, | ||||
|     query_mode: str, | ||||
| ) -> Mapping[str, np.ndarray]: | ||||
|     """Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.) | ||||
|     See the TAP-Vid paper for details on the metric computation.  All inputs are | ||||
|     given in raster coordinates.  The first three arguments should be the direct | ||||
|     outputs of the reader: the 'query_points', 'occluded', and 'target_points'. | ||||
|     The paper metrics assume these are scaled relative to 256x256 images. | ||||
|     pred_occluded and pred_tracks are your algorithm's predictions. | ||||
|     This function takes a batch of inputs, and computes metrics separately for | ||||
|     each video.  The metrics for the full benchmark are a simple mean of the | ||||
|     metrics across the full set of videos.  These numbers are between 0 and 1, | ||||
|     but the paper multiplies them by 100 to ease reading. | ||||
|     Args: | ||||
|        query_points: The query points, an in the format [t, y, x].  Its size is | ||||
|          [b, n, 3], where b is the batch size and n is the number of queries | ||||
|        gt_occluded: A boolean array of shape [b, n, t], where t is the number | ||||
|          of frames.  True indicates that the point is occluded. | ||||
|        gt_tracks: The target points, of shape [b, n, t, 2].  Each point is | ||||
|          in the format [x, y] | ||||
|        pred_occluded: A boolean array of predicted occlusions, in the same | ||||
|          format as gt_occluded. | ||||
|        pred_tracks: An array of track predictions from your algorithm, in the | ||||
|          same format as gt_tracks. | ||||
|        query_mode: Either 'first' or 'strided', depending on how queries are | ||||
|          sampled.  If 'first', we assume the prior knowledge that all points | ||||
|          before the query point are occluded, and these are removed from the | ||||
|          evaluation. | ||||
|     Returns: | ||||
|         A dict with the following keys: | ||||
|         occlusion_accuracy: Accuracy at predicting occlusion. | ||||
|         pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points | ||||
|           predicted to be within the given pixel threshold, ignoring occlusion | ||||
|           prediction. | ||||
|         jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given | ||||
|           threshold | ||||
|         average_pts_within_thresh: average across pts_within_{x} | ||||
|         average_jaccard: average across jaccard_{x} | ||||
|     """ | ||||
|  | ||||
|     metrics = {} | ||||
|  | ||||
|     # Don't evaluate the query point.  Numpy doesn't have one_hot, so we | ||||
|     # replicate it by indexing into an identity matrix. | ||||
|     one_hot_eye = np.eye(gt_tracks.shape[2]) | ||||
|     query_frame = query_points[..., 0] | ||||
|     query_frame = np.round(query_frame).astype(np.int32) | ||||
|     evaluation_points = one_hot_eye[query_frame] == 0 | ||||
|  | ||||
|     # If we're using the first point on the track as a query, don't evaluate the | ||||
|     # other points. | ||||
|     if query_mode == "first": | ||||
|         for i in range(gt_occluded.shape[0]): | ||||
|             index = np.where(gt_occluded[i] == 0)[0][0] | ||||
|             evaluation_points[i, :index] = False | ||||
|     elif query_mode != "strided": | ||||
|         raise ValueError("Unknown query mode " + query_mode) | ||||
|  | ||||
|     # Occlusion accuracy is simply how often the predicted occlusion equals the | ||||
|     # ground truth. | ||||
|     occ_acc = ( | ||||
|         np.sum( | ||||
|             np.equal(pred_occluded, gt_occluded) & evaluation_points, | ||||
|             axis=(1, 2), | ||||
|         ) | ||||
|         / np.sum(evaluation_points) | ||||
|     ) | ||||
|     metrics["occlusion_accuracy"] = occ_acc | ||||
|  | ||||
|     # Next, convert the predictions and ground truth positions into pixel | ||||
|     # coordinates. | ||||
|     visible = np.logical_not(gt_occluded) | ||||
|     pred_visible = np.logical_not(pred_occluded) | ||||
|     all_frac_within = [] | ||||
|     all_jaccard = [] | ||||
|     for thresh in [1, 2, 4, 8, 16]: | ||||
|         # True positives are points that are within the threshold and where both | ||||
|         # the prediction and the ground truth are listed as visible. | ||||
|         within_dist = ( | ||||
|             np.sum( | ||||
|                 np.square(pred_tracks - gt_tracks), | ||||
|                 axis=-1, | ||||
|             ) | ||||
|             < np.square(thresh) | ||||
|         ) | ||||
|         is_correct = np.logical_and(within_dist, visible) | ||||
|  | ||||
|         # Compute the frac_within_threshold, which is the fraction of points | ||||
|         # within the threshold among points that are visible in the ground truth, | ||||
|         # ignoring whether they're predicted to be visible. | ||||
|         count_correct = np.sum( | ||||
|             is_correct & evaluation_points, | ||||
|             axis=(1, 2), | ||||
|         ) | ||||
|         count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2)) | ||||
|         frac_correct = count_correct / count_visible_points | ||||
|         metrics["pts_within_" + str(thresh)] = frac_correct | ||||
|         all_frac_within.append(frac_correct) | ||||
|  | ||||
|         true_positives = np.sum( | ||||
|             is_correct & pred_visible & evaluation_points, axis=(1, 2) | ||||
|         ) | ||||
|  | ||||
|         # The denominator of the jaccard metric is the true positives plus | ||||
|         # false positives plus false negatives.  However, note that true positives | ||||
|         # plus false negatives is simply the number of points in the ground truth | ||||
|         # which is easier to compute than trying to compute all three quantities. | ||||
|         # Thus we just add the number of points in the ground truth to the number | ||||
|         # of false positives. | ||||
|         # | ||||
|         # False positives are simply points that are predicted to be visible, | ||||
|         # but the ground truth is not visible or too far from the prediction. | ||||
|         gt_positives = np.sum(visible & evaluation_points, axis=(1, 2)) | ||||
|         false_positives = (~visible) & pred_visible | ||||
|         false_positives = false_positives | ((~within_dist) & pred_visible) | ||||
|         false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2)) | ||||
|         jaccard = true_positives / (gt_positives + false_positives) | ||||
|         metrics["jaccard_" + str(thresh)] = jaccard | ||||
|         all_jaccard.append(jaccard) | ||||
|     metrics["average_jaccard"] = np.mean( | ||||
|         np.stack(all_jaccard, axis=1), | ||||
|         axis=1, | ||||
|     ) | ||||
|     metrics["average_pts_within_thresh"] = np.mean( | ||||
|         np.stack(all_frac_within, axis=1), | ||||
|         axis=1, | ||||
|     ) | ||||
|     return metrics | ||||
							
								
								
									
										252
									
								
								cotracker/evaluation/core/evaluator.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										252
									
								
								cotracker/evaluation/core/evaluator.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,252 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| from collections import defaultdict | ||||
| import os | ||||
| from typing import Optional | ||||
| import torch | ||||
| from tqdm import tqdm | ||||
| import numpy as np | ||||
|  | ||||
| from torch.utils.tensorboard import SummaryWriter | ||||
| from cotracker.datasets.utils import dataclass_to_cuda_ | ||||
| from cotracker.utils.visualizer import Visualizer | ||||
| from cotracker.models.core.model_utils import reduce_masked_mean | ||||
| from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics | ||||
|  | ||||
| import logging | ||||
|  | ||||
|  | ||||
| class Evaluator: | ||||
|     """ | ||||
|     A class defining the CoTracker evaluator. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, exp_dir) -> None: | ||||
|         # Visualization | ||||
|         self.exp_dir = exp_dir | ||||
|         os.makedirs(exp_dir, exist_ok=True) | ||||
|         self.visualization_filepaths = defaultdict(lambda: defaultdict(list)) | ||||
|         self.visualize_dir = os.path.join(exp_dir, "visualisations") | ||||
|  | ||||
|     def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name): | ||||
|         if isinstance(pred_trajectory, tuple): | ||||
|             pred_trajectory, pred_visibility = pred_trajectory | ||||
|         else: | ||||
|             pred_visibility = None | ||||
|         if dataset_name == "badja": | ||||
|             sample.segmentation = (sample.segmentation > 0).float() | ||||
|             *_, N, _ = sample.trajectory.shape | ||||
|             accs = [] | ||||
|             accs_3px = [] | ||||
|             for s1 in range(1, sample.video.shape[1]):  # target frame | ||||
|                 for n in range(N): | ||||
|                     vis = sample.visibility[0, s1, n] | ||||
|                     if vis > 0: | ||||
|                         coord_e = pred_trajectory[0, s1, n]  # 2 | ||||
|                         coord_g = sample.trajectory[0, s1, n]  # 2 | ||||
|                         dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0)) | ||||
|                         area = torch.sum(sample.segmentation[0, s1]) | ||||
|                         # print_('0.2*sqrt(area)', 0.2*torch.sqrt(area)) | ||||
|                         thr = 0.2 * torch.sqrt(area) | ||||
|                         # correct = | ||||
|                         accs.append((dist < thr).float()) | ||||
|                         # print('thr',thr) | ||||
|                         accs_3px.append((dist < 3.0).float()) | ||||
|  | ||||
|             res = torch.mean(torch.stack(accs)) * 100.0 | ||||
|             res_3px = torch.mean(torch.stack(accs_3px)) * 100.0 | ||||
|             metrics[sample.seq_name[0]] = res.item() | ||||
|             metrics[sample.seq_name[0] + "_accuracy"] = res_3px.item() | ||||
|             print(metrics) | ||||
|             print( | ||||
|                 "avg", np.mean([v for k, v in metrics.items() if "accuracy" not in k]) | ||||
|             ) | ||||
|             print( | ||||
|                 "avg acc 3px", | ||||
|                 np.mean([v for k, v in metrics.items() if "accuracy" in k]), | ||||
|             ) | ||||
|         elif dataset_name == "fastcapture" or ("kubric" in dataset_name): | ||||
|             *_, N, _ = sample.trajectory.shape | ||||
|             accs = [] | ||||
|             for s1 in range(1, sample.video.shape[1]):  # target frame | ||||
|                 for n in range(N): | ||||
|                     vis = sample.visibility[0, s1, n] | ||||
|                     if vis > 0: | ||||
|                         coord_e = pred_trajectory[0, s1, n]  # 2 | ||||
|                         coord_g = sample.trajectory[0, s1, n]  # 2 | ||||
|                         dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0)) | ||||
|                         thr = 3 | ||||
|                         correct = (dist < thr).float() | ||||
|                         accs.append(correct) | ||||
|  | ||||
|             res = torch.mean(torch.stack(accs)) * 100.0 | ||||
|             metrics[sample.seq_name[0] + "_accuracy"] = res.item() | ||||
|             print(metrics) | ||||
|             print("avg", np.mean([v for v in metrics.values()])) | ||||
|         elif "tapvid" in dataset_name: | ||||
|             B, T, N, D = sample.trajectory.shape | ||||
|             traj = sample.trajectory.clone() | ||||
|             thr = 0.9 | ||||
|  | ||||
|             if pred_visibility is None: | ||||
|                 logging.warning("visibility is NONE") | ||||
|                 pred_visibility = torch.zeros_like(sample.visibility) | ||||
|  | ||||
|             if not pred_visibility.dtype == torch.bool: | ||||
|                 pred_visibility = pred_visibility > thr | ||||
|  | ||||
|             # pred_trajectory | ||||
|             query_points = sample.query_points.clone().cpu().numpy() | ||||
|  | ||||
|             pred_visibility = pred_visibility[:, :, :N] | ||||
|             pred_trajectory = pred_trajectory[:, :, :N] | ||||
|  | ||||
|             gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy() | ||||
|             gt_occluded = ( | ||||
|                 torch.logical_not(sample.visibility.clone().permute(0, 2, 1)) | ||||
|                 .cpu() | ||||
|                 .numpy() | ||||
|             ) | ||||
|  | ||||
|             pred_occluded = ( | ||||
|                 torch.logical_not(pred_visibility.clone().permute(0, 2, 1)) | ||||
|                 .cpu() | ||||
|                 .numpy() | ||||
|             ) | ||||
|             pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy() | ||||
|  | ||||
|             out_metrics = compute_tapvid_metrics( | ||||
|                 query_points, | ||||
|                 gt_occluded, | ||||
|                 gt_tracks, | ||||
|                 pred_occluded, | ||||
|                 pred_tracks, | ||||
|                 query_mode="strided" if "strided" in dataset_name else "first", | ||||
|             ) | ||||
|  | ||||
|             metrics[sample.seq_name[0]] = out_metrics | ||||
|             for metric_name in out_metrics.keys(): | ||||
|                 if "avg" not in metrics: | ||||
|                     metrics["avg"] = {} | ||||
|                 metrics["avg"][metric_name] = np.mean( | ||||
|                     [v[metric_name] for k, v in metrics.items() if k != "avg"] | ||||
|                 ) | ||||
|  | ||||
|             logging.info(f"Metrics: {out_metrics}") | ||||
|             logging.info(f"avg: {metrics['avg']}") | ||||
|             print("metrics", out_metrics) | ||||
|             print("avg", metrics["avg"]) | ||||
|         else: | ||||
|             rgbs = sample.video | ||||
|             trajs_g = sample.trajectory | ||||
|             valids = sample.valid | ||||
|             vis_g = sample.visibility | ||||
|  | ||||
|             B, S, C, H, W = rgbs.shape | ||||
|             assert C == 3 | ||||
|             B, S, N, D = trajs_g.shape | ||||
|  | ||||
|             assert torch.sum(valids) == B * S * N | ||||
|  | ||||
|             vis_g = (torch.sum(vis_g, dim=1, keepdim=True) >= 4).float().repeat(1, S, 1) | ||||
|  | ||||
|             ate = torch.norm(pred_trajectory - trajs_g, dim=-1)  # B, S, N | ||||
|  | ||||
|             metrics["things_all"] = reduce_masked_mean(ate, valids).item() | ||||
|             metrics["things_vis"] = reduce_masked_mean(ate, valids * vis_g).item() | ||||
|             metrics["things_occ"] = reduce_masked_mean( | ||||
|                 ate, valids * (1.0 - vis_g) | ||||
|             ).item() | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def evaluate_sequence( | ||||
|         self, | ||||
|         model, | ||||
|         test_dataloader: torch.utils.data.DataLoader, | ||||
|         dataset_name: str, | ||||
|         train_mode=False, | ||||
|         writer: Optional[SummaryWriter] = None, | ||||
|         step: Optional[int] = 0, | ||||
|     ): | ||||
|         metrics = {} | ||||
|  | ||||
|         vis = Visualizer( | ||||
|             save_dir=self.exp_dir, | ||||
|             fps=7, | ||||
|         ) | ||||
|  | ||||
|         for ind, sample in enumerate(tqdm(test_dataloader)): | ||||
|             if isinstance(sample, tuple): | ||||
|                 sample, gotit = sample | ||||
|                 if not all(gotit): | ||||
|                     print("batch is None") | ||||
|                     continue | ||||
|             dataclass_to_cuda_(sample) | ||||
|  | ||||
|             if ( | ||||
|                 not train_mode | ||||
|                 and hasattr(model, "sequence_len") | ||||
|                 and (sample.visibility[:, : model.sequence_len].sum() == 0) | ||||
|             ): | ||||
|                 print(f"skipping batch {ind}") | ||||
|                 continue | ||||
|  | ||||
|             if "tapvid" in dataset_name: | ||||
|                 queries = sample.query_points.clone().float() | ||||
|  | ||||
|                 queries = torch.stack( | ||||
|                     [ | ||||
|                         queries[:, :, 0], | ||||
|                         queries[:, :, 2], | ||||
|                         queries[:, :, 1], | ||||
|                     ], | ||||
|                     dim=2, | ||||
|                 ) | ||||
|             else: | ||||
|                 queries = torch.cat( | ||||
|                     [ | ||||
|                         torch.zeros_like(sample.trajectory[:, 0, :, :1]), | ||||
|                         sample.trajectory[:, 0], | ||||
|                     ], | ||||
|                     dim=2, | ||||
|                 ) | ||||
|  | ||||
|             pred_tracks = model(sample.video, queries) | ||||
|             if "strided" in dataset_name: | ||||
|  | ||||
|                 inv_video = sample.video.flip(1).clone() | ||||
|                 inv_queries = queries.clone() | ||||
|                 inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1 | ||||
|  | ||||
|                 pred_trj, pred_vsb = pred_tracks | ||||
|                 inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries) | ||||
|  | ||||
|                 inv_pred_trj = inv_pred_trj.flip(1) | ||||
|                 inv_pred_vsb = inv_pred_vsb.flip(1) | ||||
|  | ||||
|                 mask = pred_trj == 0 | ||||
|  | ||||
|                 pred_trj[mask] = inv_pred_trj[mask] | ||||
|                 pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]] | ||||
|  | ||||
|                 pred_tracks = pred_trj, pred_vsb | ||||
|  | ||||
|             if dataset_name == "badja" or dataset_name == "fastcapture": | ||||
|                 seq_name = sample.seq_name[0] | ||||
|             else: | ||||
|                 seq_name = str(ind) | ||||
|  | ||||
|             vis.visualize( | ||||
|                 sample.video, | ||||
|                 pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks, | ||||
|                 filename=dataset_name + "_" + seq_name, | ||||
|                 writer=writer, | ||||
|                 step=step, | ||||
|             ) | ||||
|  | ||||
|             self.compute_metrics(metrics, sample, pred_tracks, dataset_name) | ||||
|         return metrics | ||||
							
								
								
									
										179
									
								
								cotracker/evaluation/evaluate.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										179
									
								
								cotracker/evaluation/evaluate.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,179 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import json | ||||
| import os | ||||
| from dataclasses import dataclass, field | ||||
|  | ||||
| import hydra | ||||
| import numpy as np | ||||
|  | ||||
| import torch | ||||
| from omegaconf import OmegaConf | ||||
|  | ||||
| from cotracker.datasets.badja_dataset import BadjaDataset | ||||
| from cotracker.datasets.fast_capture_dataset import FastCaptureDataset | ||||
| from cotracker.datasets.tap_vid_datasets import TapVidDataset | ||||
| from cotracker.datasets.utils import collate_fn | ||||
|  | ||||
| from cotracker.models.evaluation_predictor import EvaluationPredictor | ||||
|  | ||||
| from cotracker.evaluation.core.evaluator import Evaluator | ||||
| from cotracker.models.build_cotracker import ( | ||||
|     build_cotracker, | ||||
| ) | ||||
|  | ||||
|  | ||||
| @dataclass(eq=False) | ||||
| class DefaultConfig: | ||||
|     # Directory where all outputs of the experiment will be saved. | ||||
|     exp_dir: str = "./outputs" | ||||
|  | ||||
|     # Name of the dataset to be used for the evaluation. | ||||
|     dataset_name: str = "badja" | ||||
|     # The root directory of the dataset. | ||||
|     dataset_root: str = "./" | ||||
|  | ||||
|     # Path to the pre-trained model checkpoint to be used for the evaluation. | ||||
|     # The default value is the path to a specific CoTracker model checkpoint. | ||||
|     # Other available options are commented. | ||||
|     checkpoint: str = "./checkpoints/cotracker_stride_4_wind_8.pth" | ||||
|     # cotracker_stride_4_wind_12 | ||||
|     # cotracker_stride_8_wind_16 | ||||
|  | ||||
|     # EvaluationPredictor parameters | ||||
|     # The size (N) of the support grid used in the predictor. | ||||
|     # The total number of points is (N*N). | ||||
|     grid_size: int = 6 | ||||
|     # The size (N) of the local support grid. | ||||
|     local_grid_size: int = 6 | ||||
|     # A flag indicating whether to evaluate one ground truth point at a time. | ||||
|     single_point: bool = True | ||||
|     # The number of iterative updates for each sliding window. | ||||
|     n_iters: int = 6 | ||||
|  | ||||
|     seed: int = 0 | ||||
|     gpu_idx: int = 0 | ||||
|  | ||||
|     # Override hydra's working directory to current working dir, | ||||
|     # also disable storing the .hydra logs: | ||||
|     hydra: dict = field( | ||||
|         default_factory=lambda: { | ||||
|             "run": {"dir": "."}, | ||||
|             "output_subdir": None, | ||||
|         } | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def run_eval(cfg: DefaultConfig): | ||||
|     """ | ||||
|     The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration. | ||||
|  | ||||
|     Args: | ||||
|         cfg (DefaultConfig): An instance of DefaultConfig class which includes: | ||||
|             - exp_dir (str): The directory path for the experiment. | ||||
|             - dataset_name (str): The name of the dataset to be used. | ||||
|             - dataset_root (str): The root directory of the dataset. | ||||
|             - checkpoint (str): The path to the CoTracker model's checkpoint. | ||||
|             - single_point (bool): A flag indicating whether to evaluate one ground truth point at a time. | ||||
|             - n_iters (int): The number of iterative updates for each sliding window. | ||||
|             - seed (int): The seed for setting the random state for reproducibility. | ||||
|             - gpu_idx (int): The index of the GPU to be used. | ||||
|     """ | ||||
|     # Creating the experiment directory if it doesn't exist | ||||
|     os.makedirs(cfg.exp_dir, exist_ok=True) | ||||
|  | ||||
|     # Saving the experiment configuration to a .yaml file in the experiment directory | ||||
|     cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml") | ||||
|     with open(cfg_file, "w") as f: | ||||
|         OmegaConf.save(config=cfg, f=f) | ||||
|  | ||||
|     evaluator = Evaluator(cfg.exp_dir) | ||||
|     cotracker_model = build_cotracker(cfg.checkpoint) | ||||
|  | ||||
|     # Creating the EvaluationPredictor object | ||||
|     predictor = EvaluationPredictor( | ||||
|         cotracker_model, | ||||
|         grid_size=cfg.grid_size, | ||||
|         local_grid_size=cfg.local_grid_size, | ||||
|         single_point=cfg.single_point, | ||||
|         n_iters=cfg.n_iters, | ||||
|     ) | ||||
|  | ||||
|     # Setting the random seeds | ||||
|     torch.manual_seed(cfg.seed) | ||||
|     np.random.seed(cfg.seed) | ||||
|  | ||||
|     # Constructing the specified dataset | ||||
|     curr_collate_fn = collate_fn | ||||
|     if cfg.dataset_name == "badja": | ||||
|         test_dataset = BadjaDataset(data_root=os.path.join(cfg.dataset_root, "BADJA")) | ||||
|     elif cfg.dataset_name == "fastcapture": | ||||
|         test_dataset = FastCaptureDataset( | ||||
|             data_root=os.path.join(cfg.dataset_root, "fastcapture"), | ||||
|             max_seq_len=100, | ||||
|             max_num_points=20, | ||||
|         ) | ||||
|     elif "tapvid" in cfg.dataset_name: | ||||
|         dataset_type = cfg.dataset_name.split("_")[1] | ||||
|         if dataset_type == "davis": | ||||
|             data_root = os.path.join(cfg.dataset_root, "/tapvid_davis/tapvid_davis.pkl") | ||||
|         elif dataset_type == "kinetics": | ||||
|             data_root = os.path.join( | ||||
|                 cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics" | ||||
|             ) | ||||
|         test_dataset = TapVidDataset( | ||||
|             dataset_type=dataset_type, | ||||
|             data_root=data_root, | ||||
|             queried_first=not "strided" in cfg.dataset_name, | ||||
|         ) | ||||
|  | ||||
|     # Creating the DataLoader object | ||||
|     test_dataloader = torch.utils.data.DataLoader( | ||||
|         test_dataset, | ||||
|         batch_size=1, | ||||
|         shuffle=False, | ||||
|         num_workers=14, | ||||
|         collate_fn=curr_collate_fn, | ||||
|     ) | ||||
|  | ||||
|     # Timing and conducting the evaluation | ||||
|     import time | ||||
|  | ||||
|     start = time.time() | ||||
|     evaluate_result = evaluator.evaluate_sequence( | ||||
|         predictor, | ||||
|         test_dataloader, | ||||
|         dataset_name=cfg.dataset_name, | ||||
|     ) | ||||
|     end = time.time() | ||||
|     print(end - start) | ||||
|  | ||||
|     # Saving the evaluation results to a .json file | ||||
|     if not "tapvid" in cfg.dataset_name: | ||||
|         print("evaluate_result", evaluate_result) | ||||
|     else: | ||||
|         evaluate_result = evaluate_result["avg"] | ||||
|     result_file = os.path.join(cfg.exp_dir, f"result_eval_.json") | ||||
|     evaluate_result["time"] = end - start | ||||
|     print(f"Dumping eval results to {result_file}.") | ||||
|     with open(result_file, "w") as f: | ||||
|         json.dump(evaluate_result, f) | ||||
|  | ||||
|  | ||||
| cs = hydra.core.config_store.ConfigStore.instance() | ||||
| cs.store(name="default_config_eval", node=DefaultConfig) | ||||
|  | ||||
|  | ||||
| @hydra.main(config_path="./configs/", config_name="default_config_eval") | ||||
| def evaluate(cfg: DefaultConfig) -> None: | ||||
|     os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | ||||
|     os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx) | ||||
|     run_eval(cfg) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     evaluate() | ||||
							
								
								
									
										5
									
								
								cotracker/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
							
								
								
									
										70
									
								
								cotracker/models/build_cotracker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								cotracker/models/build_cotracker.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,70 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
|  | ||||
| from cotracker.models.core.cotracker.cotracker import CoTracker | ||||
|  | ||||
|  | ||||
| def build_cotracker( | ||||
|     checkpoint: str, | ||||
| ): | ||||
|     model_name = checkpoint.split("/")[-1].split(".")[0] | ||||
|     if model_name == "cotracker_stride_4_wind_8": | ||||
|         return build_cotracker_stride_4_wind_8(checkpoint=checkpoint) | ||||
|     elif model_name == "cotracker_stride_4_wind_12": | ||||
|         return build_cotracker_stride_4_wind_12(checkpoint=checkpoint) | ||||
|     elif model_name == "cotracker_stride_8_wind_16": | ||||
|         return build_cotracker_stride_8_wind_16(checkpoint=checkpoint) | ||||
|     else: | ||||
|         raise ValueError(f"Unknown model name {model_name}") | ||||
|  | ||||
|  | ||||
| # model used to produce the results in the paper | ||||
| def build_cotracker_stride_4_wind_8(checkpoint=None): | ||||
|     return _build_cotracker( | ||||
|         stride=4, | ||||
|         sequence_len=8, | ||||
|         checkpoint=checkpoint, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def build_cotracker_stride_4_wind_12(checkpoint=None): | ||||
|     return _build_cotracker( | ||||
|         stride=4, | ||||
|         sequence_len=12, | ||||
|         checkpoint=checkpoint, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| # the fastest model | ||||
| def build_cotracker_stride_8_wind_16(checkpoint=None): | ||||
|     return _build_cotracker( | ||||
|         stride=8, | ||||
|         sequence_len=16, | ||||
|         checkpoint=checkpoint, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def _build_cotracker( | ||||
|     stride, | ||||
|     sequence_len, | ||||
|     checkpoint=None, | ||||
| ): | ||||
|     cotracker = CoTracker( | ||||
|         stride=stride, | ||||
|         S=sequence_len, | ||||
|         add_space_attn=True, | ||||
|         space_depth=6, | ||||
|         time_depth=6, | ||||
|     ) | ||||
|     if checkpoint is not None: | ||||
|         with open(checkpoint, "rb") as f: | ||||
|             state_dict = torch.load(f, map_location="cpu") | ||||
|             if "model" in state_dict: | ||||
|                 state_dict = state_dict["model"] | ||||
|         cotracker.load_state_dict(state_dict) | ||||
|     return cotracker | ||||
							
								
								
									
										5
									
								
								cotracker/models/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/models/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
							
								
								
									
										5
									
								
								cotracker/models/core/cotracker/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/models/core/cotracker/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
							
								
								
									
										400
									
								
								cotracker/models/core/cotracker/blocks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										400
									
								
								cotracker/models/core/cotracker/blocks.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,400 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from einops import rearrange | ||||
| from timm.models.vision_transformer import Attention, Mlp | ||||
|  | ||||
|  | ||||
| class ResidualBlock(nn.Module): | ||||
|     def __init__(self, in_planes, planes, norm_fn="group", stride=1): | ||||
|         super(ResidualBlock, self).__init__() | ||||
|  | ||||
|         self.conv1 = nn.Conv2d( | ||||
|             in_planes, | ||||
|             planes, | ||||
|             kernel_size=3, | ||||
|             padding=1, | ||||
|             stride=stride, | ||||
|             padding_mode="zeros", | ||||
|         ) | ||||
|         self.conv2 = nn.Conv2d( | ||||
|             planes, planes, kernel_size=3, padding=1, padding_mode="zeros" | ||||
|         ) | ||||
|         self.relu = nn.ReLU(inplace=True) | ||||
|  | ||||
|         num_groups = planes // 8 | ||||
|  | ||||
|         if norm_fn == "group": | ||||
|             self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) | ||||
|             self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) | ||||
|             if not stride == 1: | ||||
|                 self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) | ||||
|  | ||||
|         elif norm_fn == "batch": | ||||
|             self.norm1 = nn.BatchNorm2d(planes) | ||||
|             self.norm2 = nn.BatchNorm2d(planes) | ||||
|             if not stride == 1: | ||||
|                 self.norm3 = nn.BatchNorm2d(planes) | ||||
|  | ||||
|         elif norm_fn == "instance": | ||||
|             self.norm1 = nn.InstanceNorm2d(planes) | ||||
|             self.norm2 = nn.InstanceNorm2d(planes) | ||||
|             if not stride == 1: | ||||
|                 self.norm3 = nn.InstanceNorm2d(planes) | ||||
|  | ||||
|         elif norm_fn == "none": | ||||
|             self.norm1 = nn.Sequential() | ||||
|             self.norm2 = nn.Sequential() | ||||
|             if not stride == 1: | ||||
|                 self.norm3 = nn.Sequential() | ||||
|  | ||||
|         if stride == 1: | ||||
|             self.downsample = None | ||||
|  | ||||
|         else: | ||||
|             self.downsample = nn.Sequential( | ||||
|                 nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 | ||||
|             ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         y = x | ||||
|         y = self.relu(self.norm1(self.conv1(y))) | ||||
|         y = self.relu(self.norm2(self.conv2(y))) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             x = self.downsample(x) | ||||
|  | ||||
|         return self.relu(x + y) | ||||
|  | ||||
|  | ||||
| class BasicEncoder(nn.Module): | ||||
|     def __init__( | ||||
|         self, input_dim=3, output_dim=128, stride=8, norm_fn="batch", dropout=0.0 | ||||
|     ): | ||||
|         super(BasicEncoder, self).__init__() | ||||
|         self.stride = stride | ||||
|         self.norm_fn = norm_fn | ||||
|         self.in_planes = 64 | ||||
|  | ||||
|         if self.norm_fn == "group": | ||||
|             self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes) | ||||
|             self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2) | ||||
|  | ||||
|         elif self.norm_fn == "batch": | ||||
|             self.norm1 = nn.BatchNorm2d(self.in_planes) | ||||
|             self.norm2 = nn.BatchNorm2d(output_dim * 2) | ||||
|  | ||||
|         elif self.norm_fn == "instance": | ||||
|             self.norm1 = nn.InstanceNorm2d(self.in_planes) | ||||
|             self.norm2 = nn.InstanceNorm2d(output_dim * 2) | ||||
|  | ||||
|         elif self.norm_fn == "none": | ||||
|             self.norm1 = nn.Sequential() | ||||
|  | ||||
|         self.conv1 = nn.Conv2d( | ||||
|             input_dim, | ||||
|             self.in_planes, | ||||
|             kernel_size=7, | ||||
|             stride=2, | ||||
|             padding=3, | ||||
|             padding_mode="zeros", | ||||
|         ) | ||||
|         self.relu1 = nn.ReLU(inplace=True) | ||||
|  | ||||
|         self.shallow = False | ||||
|         if self.shallow: | ||||
|             self.layer1 = self._make_layer(64, stride=1) | ||||
|             self.layer2 = self._make_layer(96, stride=2) | ||||
|             self.layer3 = self._make_layer(128, stride=2) | ||||
|             self.conv2 = nn.Conv2d(128 + 96 + 64, output_dim, kernel_size=1) | ||||
|         else: | ||||
|             self.layer1 = self._make_layer(64, stride=1) | ||||
|             self.layer2 = self._make_layer(96, stride=2) | ||||
|             self.layer3 = self._make_layer(128, stride=2) | ||||
|             self.layer4 = self._make_layer(128, stride=2) | ||||
|  | ||||
|             self.conv2 = nn.Conv2d( | ||||
|                 128 + 128 + 96 + 64, | ||||
|                 output_dim * 2, | ||||
|                 kernel_size=3, | ||||
|                 padding=1, | ||||
|                 padding_mode="zeros", | ||||
|             ) | ||||
|             self.relu2 = nn.ReLU(inplace=True) | ||||
|             self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) | ||||
|  | ||||
|         self.dropout = None | ||||
|         if dropout > 0: | ||||
|             self.dropout = nn.Dropout2d(p=dropout) | ||||
|  | ||||
|         for m in self.modules(): | ||||
|             if isinstance(m, nn.Conv2d): | ||||
|                 nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | ||||
|             elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): | ||||
|                 if m.weight is not None: | ||||
|                     nn.init.constant_(m.weight, 1) | ||||
|                 if m.bias is not None: | ||||
|                     nn.init.constant_(m.bias, 0) | ||||
|  | ||||
|     def _make_layer(self, dim, stride=1): | ||||
|         layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) | ||||
|         layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) | ||||
|         layers = (layer1, layer2) | ||||
|  | ||||
|         self.in_planes = dim | ||||
|         return nn.Sequential(*layers) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         _, _, H, W = x.shape | ||||
|  | ||||
|         x = self.conv1(x) | ||||
|         x = self.norm1(x) | ||||
|         x = self.relu1(x) | ||||
|  | ||||
|         if self.shallow: | ||||
|             a = self.layer1(x) | ||||
|             b = self.layer2(a) | ||||
|             c = self.layer3(b) | ||||
|             a = F.interpolate( | ||||
|                 a, | ||||
|                 (H // self.stride, W // self.stride), | ||||
|                 mode="bilinear", | ||||
|                 align_corners=True, | ||||
|             ) | ||||
|             b = F.interpolate( | ||||
|                 b, | ||||
|                 (H // self.stride, W // self.stride), | ||||
|                 mode="bilinear", | ||||
|                 align_corners=True, | ||||
|             ) | ||||
|             c = F.interpolate( | ||||
|                 c, | ||||
|                 (H // self.stride, W // self.stride), | ||||
|                 mode="bilinear", | ||||
|                 align_corners=True, | ||||
|             ) | ||||
|             x = self.conv2(torch.cat([a, b, c], dim=1)) | ||||
|         else: | ||||
|             a = self.layer1(x) | ||||
|             b = self.layer2(a) | ||||
|             c = self.layer3(b) | ||||
|             d = self.layer4(c) | ||||
|             a = F.interpolate( | ||||
|                 a, | ||||
|                 (H // self.stride, W // self.stride), | ||||
|                 mode="bilinear", | ||||
|                 align_corners=True, | ||||
|             ) | ||||
|             b = F.interpolate( | ||||
|                 b, | ||||
|                 (H // self.stride, W // self.stride), | ||||
|                 mode="bilinear", | ||||
|                 align_corners=True, | ||||
|             ) | ||||
|             c = F.interpolate( | ||||
|                 c, | ||||
|                 (H // self.stride, W // self.stride), | ||||
|                 mode="bilinear", | ||||
|                 align_corners=True, | ||||
|             ) | ||||
|             d = F.interpolate( | ||||
|                 d, | ||||
|                 (H // self.stride, W // self.stride), | ||||
|                 mode="bilinear", | ||||
|                 align_corners=True, | ||||
|             ) | ||||
|             x = self.conv2(torch.cat([a, b, c, d], dim=1)) | ||||
|             x = self.norm2(x) | ||||
|             x = self.relu2(x) | ||||
|             x = self.conv3(x) | ||||
|  | ||||
|         if self.training and self.dropout is not None: | ||||
|             x = self.dropout(x) | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class AttnBlock(nn.Module): | ||||
|     """ | ||||
|     A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): | ||||
|         super().__init__() | ||||
|         self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | ||||
|         self.attn = Attention( | ||||
|             hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs | ||||
|         ) | ||||
|  | ||||
|         self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | ||||
|         mlp_hidden_dim = int(hidden_size * mlp_ratio) | ||||
|         approx_gelu = lambda: nn.GELU(approximate="tanh") | ||||
|         self.mlp = Mlp( | ||||
|             in_features=hidden_size, | ||||
|             hidden_features=mlp_hidden_dim, | ||||
|             act_layer=approx_gelu, | ||||
|             drop=0, | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = x + self.attn(self.norm1(x)) | ||||
|         x = x + self.mlp(self.norm2(x)) | ||||
|         return x | ||||
|  | ||||
|  | ||||
| def bilinear_sampler(img, coords, mode="bilinear", mask=False): | ||||
|     """Wrapper for grid_sample, uses pixel coordinates""" | ||||
|     H, W = img.shape[-2:] | ||||
|     xgrid, ygrid = coords.split([1, 1], dim=-1) | ||||
|     # go to 0,1 then 0,2 then -1,1 | ||||
|     xgrid = 2 * xgrid / (W - 1) - 1 | ||||
|     ygrid = 2 * ygrid / (H - 1) - 1 | ||||
|  | ||||
|     grid = torch.cat([xgrid, ygrid], dim=-1) | ||||
|     img = F.grid_sample(img, grid, align_corners=True) | ||||
|  | ||||
|     if mask: | ||||
|         mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) | ||||
|         return img, mask.float() | ||||
|  | ||||
|     return img | ||||
|  | ||||
|  | ||||
| class CorrBlock: | ||||
|     def __init__(self, fmaps, num_levels=4, radius=4): | ||||
|         B, S, C, H, W = fmaps.shape | ||||
|         self.S, self.C, self.H, self.W = S, C, H, W | ||||
|  | ||||
|         self.num_levels = num_levels | ||||
|         self.radius = radius | ||||
|         self.fmaps_pyramid = [] | ||||
|  | ||||
|         self.fmaps_pyramid.append(fmaps) | ||||
|         for i in range(self.num_levels - 1): | ||||
|             fmaps_ = fmaps.reshape(B * S, C, H, W) | ||||
|             fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2) | ||||
|             _, _, H, W = fmaps_.shape | ||||
|             fmaps = fmaps_.reshape(B, S, C, H, W) | ||||
|             self.fmaps_pyramid.append(fmaps) | ||||
|  | ||||
|     def sample(self, coords): | ||||
|         r = self.radius | ||||
|         B, S, N, D = coords.shape | ||||
|         assert D == 2 | ||||
|  | ||||
|         H, W = self.H, self.W | ||||
|         out_pyramid = [] | ||||
|         for i in range(self.num_levels): | ||||
|             corrs = self.corrs_pyramid[i]  # B, S, N, H, W | ||||
|             _, _, _, H, W = corrs.shape | ||||
|  | ||||
|             dx = torch.linspace(-r, r, 2 * r + 1) | ||||
|             dy = torch.linspace(-r, r, 2 * r + 1) | ||||
|             delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to( | ||||
|                 coords.device | ||||
|             ) | ||||
|  | ||||
|             centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i | ||||
|             delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) | ||||
|             coords_lvl = centroid_lvl + delta_lvl | ||||
|  | ||||
|             corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl) | ||||
|             corrs = corrs.view(B, S, N, -1) | ||||
|             out_pyramid.append(corrs) | ||||
|  | ||||
|         out = torch.cat(out_pyramid, dim=-1)  # B, S, N, LRR*2 | ||||
|         return out.contiguous().float() | ||||
|  | ||||
|     def corr(self, targets): | ||||
|         B, S, N, C = targets.shape | ||||
|         assert C == self.C | ||||
|         assert S == self.S | ||||
|  | ||||
|         fmap1 = targets | ||||
|  | ||||
|         self.corrs_pyramid = [] | ||||
|         for fmaps in self.fmaps_pyramid: | ||||
|             _, _, _, H, W = fmaps.shape | ||||
|             fmap2s = fmaps.view(B, S, C, H * W) | ||||
|             corrs = torch.matmul(fmap1, fmap2s) | ||||
|             corrs = corrs.view(B, S, N, H, W) | ||||
|             corrs = corrs / torch.sqrt(torch.tensor(C).float()) | ||||
|             self.corrs_pyramid.append(corrs) | ||||
|  | ||||
|  | ||||
| class UpdateFormer(nn.Module): | ||||
|     """ | ||||
|     Transformer model that updates track estimates. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         space_depth=12, | ||||
|         time_depth=12, | ||||
|         input_dim=320, | ||||
|         hidden_size=384, | ||||
|         num_heads=8, | ||||
|         output_dim=130, | ||||
|         mlp_ratio=4.0, | ||||
|         add_space_attn=True, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.out_channels = 2 | ||||
|         self.num_heads = num_heads | ||||
|         self.hidden_size = hidden_size | ||||
|         self.add_space_attn = add_space_attn | ||||
|         self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) | ||||
|         self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) | ||||
|  | ||||
|         self.time_blocks = nn.ModuleList( | ||||
|             [ | ||||
|                 AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) | ||||
|                 for _ in range(time_depth) | ||||
|             ] | ||||
|         ) | ||||
|  | ||||
|         if add_space_attn: | ||||
|             self.space_blocks = nn.ModuleList( | ||||
|                 [ | ||||
|                     AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) | ||||
|                     for _ in range(space_depth) | ||||
|                 ] | ||||
|             ) | ||||
|             assert len(self.time_blocks) >= len(self.space_blocks) | ||||
|         self.initialize_weights() | ||||
|  | ||||
|     def initialize_weights(self): | ||||
|         def _basic_init(module): | ||||
|             if isinstance(module, nn.Linear): | ||||
|                 torch.nn.init.xavier_uniform_(module.weight) | ||||
|                 if module.bias is not None: | ||||
|                     nn.init.constant_(module.bias, 0) | ||||
|  | ||||
|         self.apply(_basic_init) | ||||
|  | ||||
|     def forward(self, input_tensor): | ||||
|         x = self.input_transform(input_tensor) | ||||
|  | ||||
|         j = 0 | ||||
|         for i in range(len(self.time_blocks)): | ||||
|             B, N, T, _ = x.shape | ||||
|             x_time = rearrange(x, "b n t c -> (b n) t c", b=B, t=T, n=N) | ||||
|             x_time = self.time_blocks[i](x_time) | ||||
|  | ||||
|             x = rearrange(x_time, "(b n) t c -> b n t c ", b=B, t=T, n=N) | ||||
|             if self.add_space_attn and ( | ||||
|                 i % (len(self.time_blocks) // len(self.space_blocks)) == 0 | ||||
|             ): | ||||
|                 x_space = rearrange(x, "b n t c -> (b t) n c ", b=B, t=T, n=N) | ||||
|                 x_space = self.space_blocks[j](x_space) | ||||
|                 x = rearrange(x_space, "(b t) n c -> b n t c  ", b=B, t=T, n=N) | ||||
|                 j += 1 | ||||
|  | ||||
|         flow = self.flow_head(x) | ||||
|         return flow | ||||
							
								
								
									
										351
									
								
								cotracker/models/core/cotracker/cotracker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										351
									
								
								cotracker/models/core/cotracker/cotracker.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,351 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from einops import rearrange | ||||
|  | ||||
| from cotracker.models.core.cotracker.blocks import ( | ||||
|     BasicEncoder, | ||||
|     CorrBlock, | ||||
|     UpdateFormer, | ||||
| ) | ||||
|  | ||||
| from cotracker.models.core.model_utils import meshgrid2d, bilinear_sample2d, smart_cat | ||||
| from cotracker.models.core.embeddings import ( | ||||
|     get_2d_embedding, | ||||
|     get_1d_sincos_pos_embed_from_grid, | ||||
|     get_2d_sincos_pos_embed, | ||||
| ) | ||||
|  | ||||
|  | ||||
| torch.manual_seed(0) | ||||
|  | ||||
|  | ||||
| def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)): | ||||
|     if grid_size == 1: | ||||
|         return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[ | ||||
|             None, None | ||||
|         ].cuda() | ||||
|  | ||||
|     grid_y, grid_x = meshgrid2d( | ||||
|         1, grid_size, grid_size, stack=False, norm=False, device="cuda" | ||||
|     ) | ||||
|     step = interp_shape[1] // 64 | ||||
|     if grid_center[0] != 0 or grid_center[1] != 0: | ||||
|         grid_y = grid_y - grid_size / 2.0 | ||||
|         grid_x = grid_x - grid_size / 2.0 | ||||
|     grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * ( | ||||
|         interp_shape[0] - step * 2 | ||||
|     ) | ||||
|     grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * ( | ||||
|         interp_shape[1] - step * 2 | ||||
|     ) | ||||
|  | ||||
|     grid_y = grid_y + grid_center[0] | ||||
|     grid_x = grid_x + grid_center[1] | ||||
|     xy = torch.stack([grid_x, grid_y], dim=-1).cuda() | ||||
|     return xy | ||||
|  | ||||
|  | ||||
| def sample_pos_embed(grid_size, embed_dim, coords): | ||||
|     pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim, grid_size=grid_size) | ||||
|     pos_embed = ( | ||||
|         torch.from_numpy(pos_embed) | ||||
|         .reshape(grid_size[0], grid_size[1], embed_dim) | ||||
|         .float() | ||||
|         .unsqueeze(0) | ||||
|         .to(coords.device) | ||||
|     ) | ||||
|     sampled_pos_embed = bilinear_sample2d( | ||||
|         pos_embed.permute(0, 3, 1, 2), coords[:, 0, :, 0], coords[:, 0, :, 1] | ||||
|     ) | ||||
|     return sampled_pos_embed | ||||
|  | ||||
|  | ||||
| class CoTracker(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         S=8, | ||||
|         stride=8, | ||||
|         add_space_attn=True, | ||||
|         num_heads=8, | ||||
|         hidden_size=384, | ||||
|         space_depth=12, | ||||
|         time_depth=12, | ||||
|     ): | ||||
|         super(CoTracker, self).__init__() | ||||
|         self.S = S | ||||
|         self.stride = stride | ||||
|         self.hidden_dim = 256 | ||||
|         self.latent_dim = latent_dim = 128 | ||||
|         self.corr_levels = 4 | ||||
|         self.corr_radius = 3 | ||||
|         self.add_space_attn = add_space_attn | ||||
|         self.fnet = BasicEncoder( | ||||
|             output_dim=self.latent_dim, norm_fn="instance", dropout=0, stride=stride | ||||
|         ) | ||||
|  | ||||
|         self.updateformer = UpdateFormer( | ||||
|             space_depth=space_depth, | ||||
|             time_depth=time_depth, | ||||
|             input_dim=456, | ||||
|             hidden_size=hidden_size, | ||||
|             num_heads=num_heads, | ||||
|             output_dim=latent_dim + 2, | ||||
|             mlp_ratio=4.0, | ||||
|             add_space_attn=add_space_attn, | ||||
|         ) | ||||
|  | ||||
|         self.norm = nn.GroupNorm(1, self.latent_dim) | ||||
|         self.ffeat_updater = nn.Sequential( | ||||
|             nn.Linear(self.latent_dim, self.latent_dim), | ||||
|             nn.GELU(), | ||||
|         ) | ||||
|         self.vis_predictor = nn.Sequential( | ||||
|             nn.Linear(self.latent_dim, 1), | ||||
|         ) | ||||
|  | ||||
|     def forward_iteration( | ||||
|         self, | ||||
|         fmaps, | ||||
|         coords_init, | ||||
|         feat_init=None, | ||||
|         vis_init=None, | ||||
|         track_mask=None, | ||||
|         iters=4, | ||||
|     ): | ||||
|         B, S_init, N, D = coords_init.shape | ||||
|         assert D == 2 | ||||
|         assert B == 1 | ||||
|  | ||||
|         B, S, __, H8, W8 = fmaps.shape | ||||
|  | ||||
|         device = fmaps.device | ||||
|  | ||||
|         if S_init < S: | ||||
|             coords = torch.cat( | ||||
|                 [coords_init, coords_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1 | ||||
|             ) | ||||
|             vis_init = torch.cat( | ||||
|                 [vis_init, vis_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1 | ||||
|             ) | ||||
|         else: | ||||
|             coords = coords_init.clone() | ||||
|  | ||||
|         fcorr_fn = CorrBlock( | ||||
|             fmaps, num_levels=self.corr_levels, radius=self.corr_radius | ||||
|         ) | ||||
|  | ||||
|         ffeats = feat_init.clone() | ||||
|  | ||||
|         times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1) | ||||
|  | ||||
|         pos_embed = sample_pos_embed( | ||||
|             grid_size=(H8, W8), | ||||
|             embed_dim=456, | ||||
|             coords=coords, | ||||
|         ) | ||||
|         pos_embed = rearrange(pos_embed, "b e n -> (b n) e").unsqueeze(1) | ||||
|         times_embed = ( | ||||
|             torch.from_numpy(get_1d_sincos_pos_embed_from_grid(456, times_[0]))[None] | ||||
|             .repeat(B, 1, 1) | ||||
|             .float() | ||||
|             .to(device) | ||||
|         ) | ||||
|         coord_predictions = [] | ||||
|  | ||||
|         for __ in range(iters): | ||||
|             coords = coords.detach() | ||||
|             fcorr_fn.corr(ffeats) | ||||
|  | ||||
|             fcorrs = fcorr_fn.sample(coords)  # B, S, N, LRR | ||||
|             LRR = fcorrs.shape[3] | ||||
|  | ||||
|             fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR) | ||||
|             flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) | ||||
|  | ||||
|             flows_cat = get_2d_embedding(flows_, 64, cat_coords=True) | ||||
|             ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) | ||||
|  | ||||
|             if track_mask.shape[1] < vis_init.shape[1]: | ||||
|                 track_mask = torch.cat( | ||||
|                     [ | ||||
|                         track_mask, | ||||
|                         torch.zeros_like(track_mask[:, 0]).repeat( | ||||
|                             1, vis_init.shape[1] - track_mask.shape[1], 1, 1 | ||||
|                         ), | ||||
|                     ], | ||||
|                     dim=1, | ||||
|                 ) | ||||
|             concat = ( | ||||
|                 torch.cat([track_mask, vis_init], dim=2) | ||||
|                 .permute(0, 2, 1, 3) | ||||
|                 .reshape(B * N, S, 2) | ||||
|             ) | ||||
|  | ||||
|             transformer_input = torch.cat([flows_cat, fcorrs_, ffeats_, concat], dim=2) | ||||
|             x = transformer_input + pos_embed + times_embed | ||||
|  | ||||
|             x = rearrange(x, "(b n) t d -> b n t d", b=B) | ||||
|  | ||||
|             delta = self.updateformer(x) | ||||
|  | ||||
|             delta = rearrange(delta, " b n t d -> (b n) t d") | ||||
|  | ||||
|             delta_coords_ = delta[:, :, :2] | ||||
|             delta_feats_ = delta[:, :, 2:] | ||||
|  | ||||
|             delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) | ||||
|             ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N * S, self.latent_dim) | ||||
|  | ||||
|             ffeats_ = self.ffeat_updater(self.norm(delta_feats_)) + ffeats_ | ||||
|  | ||||
|             ffeats = ffeats_.reshape(B, N, S, self.latent_dim).permute( | ||||
|                 0, 2, 1, 3 | ||||
|             )  # B,S,N,C | ||||
|  | ||||
|             coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) | ||||
|             coord_predictions.append(coords * self.stride) | ||||
|  | ||||
|         vis_e = self.vis_predictor(ffeats.reshape(B * S * N, self.latent_dim)).reshape( | ||||
|             B, S, N | ||||
|         ) | ||||
|         return coord_predictions, vis_e, feat_init | ||||
|  | ||||
|     def forward(self, rgbs, queries, iters=4, feat_init=None, is_train=False): | ||||
|         B, T, C, H, W = rgbs.shape | ||||
|         B, N, __ = queries.shape | ||||
|  | ||||
|         device = rgbs.device | ||||
|         assert B == 1 | ||||
|         # INIT for the first sequence | ||||
|         # We want to sort points by the first frame they are visible to add them to the tensor of tracked points consequtively | ||||
|         first_positive_inds = queries[:, :, 0].long() | ||||
|  | ||||
|         __, sort_inds = torch.sort(first_positive_inds[0], dim=0, descending=False) | ||||
|         inv_sort_inds = torch.argsort(sort_inds, dim=0) | ||||
|         first_positive_sorted_inds = first_positive_inds[0][sort_inds] | ||||
|  | ||||
|         assert torch.allclose( | ||||
|             first_positive_inds[0], first_positive_inds[0][sort_inds][inv_sort_inds] | ||||
|         ) | ||||
|  | ||||
|         coords_init = queries[:, :, 1:].reshape(B, 1, N, 2).repeat( | ||||
|             1, self.S, 1, 1 | ||||
|         ) / float(self.stride) | ||||
|  | ||||
|         rgbs = 2 * (rgbs / 255.0) - 1.0 | ||||
|  | ||||
|         traj_e = torch.zeros((B, T, N, 2), device=device) | ||||
|         vis_e = torch.zeros((B, T, N), device=device) | ||||
|  | ||||
|         ind_array = torch.arange(T, device=device) | ||||
|         ind_array = ind_array[None, :, None].repeat(B, 1, N) | ||||
|  | ||||
|         track_mask = (ind_array >= first_positive_inds[:, None, :]).unsqueeze(-1) | ||||
|         # these are logits, so we initialize visibility with something that would give a value close to 1 after softmax | ||||
|         vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10 | ||||
|  | ||||
|         ind = 0 | ||||
|  | ||||
|         track_mask_ = track_mask[:, :, sort_inds].clone() | ||||
|         coords_init_ = coords_init[:, :, sort_inds].clone() | ||||
|         vis_init_ = vis_init[:, :, sort_inds].clone() | ||||
|  | ||||
|         prev_wind_idx = 0 | ||||
|         fmaps_ = None | ||||
|         vis_predictions = [] | ||||
|         coord_predictions = [] | ||||
|         wind_inds = [] | ||||
|         while ind < T - self.S // 2: | ||||
|             rgbs_seq = rgbs[:, ind : ind + self.S] | ||||
|  | ||||
|             S = S_local = rgbs_seq.shape[1] | ||||
|             if S < self.S: | ||||
|                 rgbs_seq = torch.cat( | ||||
|                     [rgbs_seq, rgbs_seq[:, -1, None].repeat(1, self.S - S, 1, 1, 1)], | ||||
|                     dim=1, | ||||
|                 ) | ||||
|                 S = rgbs_seq.shape[1] | ||||
|             rgbs_ = rgbs_seq.reshape(B * S, C, H, W) | ||||
|  | ||||
|             if fmaps_ is None: | ||||
|                 fmaps_ = self.fnet(rgbs_) | ||||
|             else: | ||||
|                 fmaps_ = torch.cat( | ||||
|                     [fmaps_[self.S // 2 :], self.fnet(rgbs_[self.S // 2 :])], dim=0 | ||||
|                 ) | ||||
|             fmaps = fmaps_.reshape( | ||||
|                 B, S, self.latent_dim, H // self.stride, W // self.stride | ||||
|             ) | ||||
|  | ||||
|             curr_wind_points = torch.nonzero(first_positive_sorted_inds < ind + self.S) | ||||
|             if curr_wind_points.shape[0] == 0: | ||||
|                 ind = ind + self.S // 2 | ||||
|                 continue | ||||
|             wind_idx = curr_wind_points[-1] + 1 | ||||
|  | ||||
|             if wind_idx - prev_wind_idx > 0: | ||||
|                 fmaps_sample = fmaps[ | ||||
|                     :, first_positive_sorted_inds[prev_wind_idx:wind_idx] - ind | ||||
|                 ] | ||||
|  | ||||
|                 feat_init_ = bilinear_sample2d( | ||||
|                     fmaps_sample, | ||||
|                     coords_init_[:, 0, prev_wind_idx:wind_idx, 0], | ||||
|                     coords_init_[:, 0, prev_wind_idx:wind_idx, 1], | ||||
|                 ).permute(0, 2, 1) | ||||
|  | ||||
|                 feat_init_ = feat_init_.unsqueeze(1).repeat(1, self.S, 1, 1) | ||||
|                 feat_init = smart_cat(feat_init, feat_init_, dim=2) | ||||
|  | ||||
|             if prev_wind_idx > 0: | ||||
|                 new_coords = coords[-1][:, self.S // 2 :] / float(self.stride) | ||||
|  | ||||
|                 coords_init_[:, : self.S // 2, :prev_wind_idx] = new_coords | ||||
|                 coords_init_[:, self.S // 2 :, :prev_wind_idx] = new_coords[ | ||||
|                     :, -1 | ||||
|                 ].repeat(1, self.S // 2, 1, 1) | ||||
|  | ||||
|                 new_vis = vis[:, self.S // 2 :].unsqueeze(-1) | ||||
|                 vis_init_[:, : self.S // 2, :prev_wind_idx] = new_vis | ||||
|                 vis_init_[:, self.S // 2 :, :prev_wind_idx] = new_vis[:, -1].repeat( | ||||
|                     1, self.S // 2, 1, 1 | ||||
|                 ) | ||||
|  | ||||
|             coords, vis, __ = self.forward_iteration( | ||||
|                 fmaps=fmaps, | ||||
|                 coords_init=coords_init_[:, :, :wind_idx], | ||||
|                 feat_init=feat_init[:, :, :wind_idx], | ||||
|                 vis_init=vis_init_[:, :, :wind_idx], | ||||
|                 track_mask=track_mask_[:, ind : ind + self.S, :wind_idx], | ||||
|                 iters=iters, | ||||
|             ) | ||||
|             if is_train: | ||||
|                 vis_predictions.append(torch.sigmoid(vis[:, :S_local])) | ||||
|                 coord_predictions.append([coord[:, :S_local] for coord in coords]) | ||||
|                 wind_inds.append(wind_idx) | ||||
|  | ||||
|             traj_e[:, ind : ind + self.S, :wind_idx] = coords[-1][:, :S_local] | ||||
|             vis_e[:, ind : ind + self.S, :wind_idx] = vis[:, :S_local] | ||||
|  | ||||
|             track_mask_[:, : ind + self.S, :wind_idx] = 0.0 | ||||
|             ind = ind + self.S // 2 | ||||
|  | ||||
|             prev_wind_idx = wind_idx | ||||
|  | ||||
|         traj_e = traj_e[:, :, inv_sort_inds] | ||||
|         vis_e = vis_e[:, :, inv_sort_inds] | ||||
|  | ||||
|         vis_e = torch.sigmoid(vis_e) | ||||
|  | ||||
|         train_data = ( | ||||
|             (vis_predictions, coord_predictions, wind_inds, sort_inds) | ||||
|             if is_train | ||||
|             else None | ||||
|         ) | ||||
|         return traj_e, feat_init, vis_e, train_data | ||||
							
								
								
									
										61
									
								
								cotracker/models/core/cotracker/losses.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								cotracker/models/core/cotracker/losses.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from cotracker.models.core.model_utils import reduce_masked_mean | ||||
|  | ||||
| EPS = 1e-6 | ||||
|  | ||||
|  | ||||
| def balanced_ce_loss(pred, gt, valid=None): | ||||
|     total_balanced_loss = 0.0 | ||||
|     for j in range(len(gt)): | ||||
|         B, S, N = gt[j].shape | ||||
|         # pred and gt are the same shape | ||||
|         for (a, b) in zip(pred[j].size(), gt[j].size()): | ||||
|             assert a == b  # some shape mismatch! | ||||
|         # if valid is not None: | ||||
|         for (a, b) in zip(pred[j].size(), valid[j].size()): | ||||
|             assert a == b  # some shape mismatch! | ||||
|  | ||||
|         pos = (gt[j] > 0.95).float() | ||||
|         neg = (gt[j] < 0.05).float() | ||||
|  | ||||
|         label = pos * 2.0 - 1.0 | ||||
|         a = -label * pred[j] | ||||
|         b = F.relu(a) | ||||
|         loss = b + torch.log(torch.exp(-b) + torch.exp(a - b)) | ||||
|  | ||||
|         pos_loss = reduce_masked_mean(loss, pos * valid[j]) | ||||
|         neg_loss = reduce_masked_mean(loss, neg * valid[j]) | ||||
|  | ||||
|         balanced_loss = pos_loss + neg_loss | ||||
|         total_balanced_loss += balanced_loss / float(N) | ||||
|     return total_balanced_loss | ||||
|  | ||||
|  | ||||
| def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8): | ||||
|     """Loss function defined over sequence of flow predictions""" | ||||
|     total_flow_loss = 0.0 | ||||
|     for j in range(len(flow_gt)): | ||||
|         B, S, N, D = flow_gt[j].shape | ||||
|         assert D == 2 | ||||
|         B, S1, N = vis[j].shape | ||||
|         B, S2, N = valids[j].shape | ||||
|         assert S == S1 | ||||
|         assert S == S2 | ||||
|         n_predictions = len(flow_preds[j]) | ||||
|         flow_loss = 0.0 | ||||
|         for i in range(n_predictions): | ||||
|             i_weight = gamma ** (n_predictions - i - 1) | ||||
|             flow_pred = flow_preds[j][i] | ||||
|             i_loss = (flow_pred - flow_gt[j]).abs()  # B, S, N, 2 | ||||
|             i_loss = torch.mean(i_loss, dim=3)  # B, S, N | ||||
|             flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j]) | ||||
|         flow_loss = flow_loss / n_predictions | ||||
|         total_flow_loss += flow_loss / float(N) | ||||
|     return total_flow_loss | ||||
							
								
								
									
										154
									
								
								cotracker/models/core/embeddings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								cotracker/models/core/embeddings.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,154 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): | ||||
|     """ | ||||
|     grid_size: int of the grid height and width | ||||
|     return: | ||||
|     pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | ||||
|     """ | ||||
|     if isinstance(grid_size, tuple): | ||||
|         grid_size_h, grid_size_w = grid_size | ||||
|     else: | ||||
|         grid_size_h = grid_size_w = grid_size | ||||
|     grid_h = np.arange(grid_size_h, dtype=np.float32) | ||||
|     grid_w = np.arange(grid_size_w, dtype=np.float32) | ||||
|     grid = np.meshgrid(grid_w, grid_h)  # here w goes first | ||||
|     grid = np.stack(grid, axis=0) | ||||
|  | ||||
|     grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) | ||||
|     pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | ||||
|     if cls_token and extra_tokens > 0: | ||||
|         pos_embed = np.concatenate( | ||||
|             [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 | ||||
|         ) | ||||
|     return pos_embed | ||||
|  | ||||
|  | ||||
| def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | ||||
|     assert embed_dim % 2 == 0 | ||||
|  | ||||
|     # use half of dimensions to encode grid_h | ||||
|     emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2) | ||||
|     emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2) | ||||
|  | ||||
|     emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D) | ||||
|     return emb | ||||
|  | ||||
|  | ||||
| def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | ||||
|     """ | ||||
|     embed_dim: output dimension for each position | ||||
|     pos: a list of positions to be encoded: size (M,) | ||||
|     out: (M, D) | ||||
|     """ | ||||
|     assert embed_dim % 2 == 0 | ||||
|     omega = np.arange(embed_dim // 2, dtype=np.float64) | ||||
|     omega /= embed_dim / 2.0 | ||||
|     omega = 1.0 / 10000 ** omega  # (D/2,) | ||||
|  | ||||
|     pos = pos.reshape(-1)  # (M,) | ||||
|     out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product | ||||
|  | ||||
|     emb_sin = np.sin(out)  # (M, D/2) | ||||
|     emb_cos = np.cos(out)  # (M, D/2) | ||||
|  | ||||
|     emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D) | ||||
|     return emb | ||||
|  | ||||
|  | ||||
| def get_2d_embedding(xy, C, cat_coords=True): | ||||
|     B, N, D = xy.shape | ||||
|     assert D == 2 | ||||
|  | ||||
|     x = xy[:, :, 0:1] | ||||
|     y = xy[:, :, 1:2] | ||||
|     div_term = ( | ||||
|         torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C) | ||||
|     ).reshape(1, 1, int(C / 2)) | ||||
|  | ||||
|     pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) | ||||
|     pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) | ||||
|  | ||||
|     pe_x[:, :, 0::2] = torch.sin(x * div_term) | ||||
|     pe_x[:, :, 1::2] = torch.cos(x * div_term) | ||||
|  | ||||
|     pe_y[:, :, 0::2] = torch.sin(y * div_term) | ||||
|     pe_y[:, :, 1::2] = torch.cos(y * div_term) | ||||
|  | ||||
|     pe = torch.cat([pe_x, pe_y], dim=2)  # B, N, C*3 | ||||
|     if cat_coords: | ||||
|         pe = torch.cat([xy, pe], dim=2)  # B, N, C*3+3 | ||||
|     return pe | ||||
|  | ||||
|  | ||||
| def get_3d_embedding(xyz, C, cat_coords=True): | ||||
|     B, N, D = xyz.shape | ||||
|     assert D == 3 | ||||
|  | ||||
|     x = xyz[:, :, 0:1] | ||||
|     y = xyz[:, :, 1:2] | ||||
|     z = xyz[:, :, 2:3] | ||||
|     div_term = ( | ||||
|         torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C) | ||||
|     ).reshape(1, 1, int(C / 2)) | ||||
|  | ||||
|     pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) | ||||
|     pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) | ||||
|     pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) | ||||
|  | ||||
|     pe_x[:, :, 0::2] = torch.sin(x * div_term) | ||||
|     pe_x[:, :, 1::2] = torch.cos(x * div_term) | ||||
|  | ||||
|     pe_y[:, :, 0::2] = torch.sin(y * div_term) | ||||
|     pe_y[:, :, 1::2] = torch.cos(y * div_term) | ||||
|  | ||||
|     pe_z[:, :, 0::2] = torch.sin(z * div_term) | ||||
|     pe_z[:, :, 1::2] = torch.cos(z * div_term) | ||||
|  | ||||
|     pe = torch.cat([pe_x, pe_y, pe_z], dim=2)  # B, N, C*3 | ||||
|     if cat_coords: | ||||
|         pe = torch.cat([pe, xyz], dim=2)  # B, N, C*3+3 | ||||
|     return pe | ||||
|  | ||||
|  | ||||
| def get_4d_embedding(xyzw, C, cat_coords=True): | ||||
|     B, N, D = xyzw.shape | ||||
|     assert D == 4 | ||||
|  | ||||
|     x = xyzw[:, :, 0:1] | ||||
|     y = xyzw[:, :, 1:2] | ||||
|     z = xyzw[:, :, 2:3] | ||||
|     w = xyzw[:, :, 3:4] | ||||
|     div_term = ( | ||||
|         torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C) | ||||
|     ).reshape(1, 1, int(C / 2)) | ||||
|  | ||||
|     pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) | ||||
|     pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) | ||||
|     pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) | ||||
|     pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) | ||||
|  | ||||
|     pe_x[:, :, 0::2] = torch.sin(x * div_term) | ||||
|     pe_x[:, :, 1::2] = torch.cos(x * div_term) | ||||
|  | ||||
|     pe_y[:, :, 0::2] = torch.sin(y * div_term) | ||||
|     pe_y[:, :, 1::2] = torch.cos(y * div_term) | ||||
|  | ||||
|     pe_z[:, :, 0::2] = torch.sin(z * div_term) | ||||
|     pe_z[:, :, 1::2] = torch.cos(z * div_term) | ||||
|  | ||||
|     pe_w[:, :, 0::2] = torch.sin(w * div_term) | ||||
|     pe_w[:, :, 1::2] = torch.cos(w * div_term) | ||||
|  | ||||
|     pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2)  # B, N, C*3 | ||||
|     if cat_coords: | ||||
|         pe = torch.cat([pe, xyzw], dim=2)  # B, N, C*3+3 | ||||
|     return pe | ||||
							
								
								
									
										169
									
								
								cotracker/models/core/model_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										169
									
								
								cotracker/models/core/model_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,169 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
|  | ||||
| EPS = 1e-6 | ||||
|  | ||||
|  | ||||
| def smart_cat(tensor1, tensor2, dim): | ||||
|     if tensor1 is None: | ||||
|         return tensor2 | ||||
|     return torch.cat([tensor1, tensor2], dim=dim) | ||||
|  | ||||
|  | ||||
| def normalize_single(d): | ||||
|     # d is a whatever shape torch tensor | ||||
|     dmin = torch.min(d) | ||||
|     dmax = torch.max(d) | ||||
|     d = (d - dmin) / (EPS + (dmax - dmin)) | ||||
|     return d | ||||
|  | ||||
|  | ||||
| def normalize(d): | ||||
|     # d is B x whatever. normalize within each element of the batch | ||||
|     out = torch.zeros(d.size()) | ||||
|     if d.is_cuda: | ||||
|         out = out.cuda() | ||||
|     B = list(d.size())[0] | ||||
|     for b in list(range(B)): | ||||
|         out[b] = normalize_single(d[b]) | ||||
|     return out | ||||
|  | ||||
|  | ||||
| def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"): | ||||
|     # returns a meshgrid sized B x Y x X | ||||
|  | ||||
|     grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device)) | ||||
|     grid_y = torch.reshape(grid_y, [1, Y, 1]) | ||||
|     grid_y = grid_y.repeat(B, 1, X) | ||||
|  | ||||
|     grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device)) | ||||
|     grid_x = torch.reshape(grid_x, [1, 1, X]) | ||||
|     grid_x = grid_x.repeat(B, Y, 1) | ||||
|  | ||||
|     if stack: | ||||
|         # note we stack in xy order | ||||
|         # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample) | ||||
|         grid = torch.stack([grid_x, grid_y], dim=-1) | ||||
|         return grid | ||||
|     else: | ||||
|         return grid_y, grid_x | ||||
|  | ||||
|  | ||||
| def reduce_masked_mean(x, mask, dim=None, keepdim=False): | ||||
|     # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting | ||||
|     # returns shape-1 | ||||
|     # axis can be a list of axes | ||||
|     for (a, b) in zip(x.size(), mask.size()): | ||||
|         assert a == b  # some shape mismatch! | ||||
|     prod = x * mask | ||||
|     if dim is None: | ||||
|         numer = torch.sum(prod) | ||||
|         denom = EPS + torch.sum(mask) | ||||
|     else: | ||||
|         numer = torch.sum(prod, dim=dim, keepdim=keepdim) | ||||
|         denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim) | ||||
|  | ||||
|     mean = numer / denom | ||||
|     return mean | ||||
|  | ||||
|  | ||||
| def bilinear_sample2d(im, x, y, return_inbounds=False): | ||||
|     # x and y are each B, N | ||||
|     # output is B, C, N | ||||
|     if len(im.shape) == 5: | ||||
|         B, N, C, H, W = list(im.shape) | ||||
|     else: | ||||
|         B, C, H, W = list(im.shape) | ||||
|     N = list(x.shape)[1] | ||||
|  | ||||
|     x = x.float() | ||||
|     y = y.float() | ||||
|     H_f = torch.tensor(H, dtype=torch.float32) | ||||
|     W_f = torch.tensor(W, dtype=torch.float32) | ||||
|  | ||||
|     # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float() | ||||
|  | ||||
|     max_y = (H_f - 1).int() | ||||
|     max_x = (W_f - 1).int() | ||||
|  | ||||
|     x0 = torch.floor(x).int() | ||||
|     x1 = x0 + 1 | ||||
|     y0 = torch.floor(y).int() | ||||
|     y1 = y0 + 1 | ||||
|  | ||||
|     x0_clip = torch.clamp(x0, 0, max_x) | ||||
|     x1_clip = torch.clamp(x1, 0, max_x) | ||||
|     y0_clip = torch.clamp(y0, 0, max_y) | ||||
|     y1_clip = torch.clamp(y1, 0, max_y) | ||||
|     dim2 = W | ||||
|     dim1 = W * H | ||||
|  | ||||
|     base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1 | ||||
|     base = torch.reshape(base, [B, 1]).repeat([1, N]) | ||||
|  | ||||
|     base_y0 = base + y0_clip * dim2 | ||||
|     base_y1 = base + y1_clip * dim2 | ||||
|  | ||||
|     idx_y0_x0 = base_y0 + x0_clip | ||||
|     idx_y0_x1 = base_y0 + x1_clip | ||||
|     idx_y1_x0 = base_y1 + x0_clip | ||||
|     idx_y1_x1 = base_y1 + x1_clip | ||||
|  | ||||
|     # use the indices to lookup pixels in the flat image | ||||
|     # im is B x C x H x W | ||||
|     # move C out to last dim | ||||
|     if len(im.shape) == 5: | ||||
|         im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C) | ||||
|         i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute( | ||||
|             0, 2, 1 | ||||
|         ) | ||||
|         i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute( | ||||
|             0, 2, 1 | ||||
|         ) | ||||
|         i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute( | ||||
|             0, 2, 1 | ||||
|         ) | ||||
|         i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute( | ||||
|             0, 2, 1 | ||||
|         ) | ||||
|     else: | ||||
|         im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C) | ||||
|         i_y0_x0 = im_flat[idx_y0_x0.long()] | ||||
|         i_y0_x1 = im_flat[idx_y0_x1.long()] | ||||
|         i_y1_x0 = im_flat[idx_y1_x0.long()] | ||||
|         i_y1_x1 = im_flat[idx_y1_x1.long()] | ||||
|  | ||||
|     # Finally calculate interpolated values. | ||||
|     x0_f = x0.float() | ||||
|     x1_f = x1.float() | ||||
|     y0_f = y0.float() | ||||
|     y1_f = y1.float() | ||||
|  | ||||
|     w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2) | ||||
|     w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2) | ||||
|     w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2) | ||||
|     w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2) | ||||
|  | ||||
|     output = ( | ||||
|         w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1 | ||||
|     ) | ||||
|     # output is B*N x C | ||||
|     output = output.view(B, -1, C) | ||||
|     output = output.permute(0, 2, 1) | ||||
|     # output is B x C x N | ||||
|  | ||||
|     if return_inbounds: | ||||
|         x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte() | ||||
|         y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte() | ||||
|         inbounds = (x_valid & y_valid).float() | ||||
|         inbounds = inbounds.reshape( | ||||
|             B, N | ||||
|         )  # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1) | ||||
|         return output, inbounds | ||||
|  | ||||
|     return output  # B, C, N | ||||
							
								
								
									
										103
									
								
								cotracker/models/evaluation_predictor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								cotracker/models/evaluation_predictor.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,103 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from typing import Tuple | ||||
|  | ||||
| from cotracker.models.core.cotracker.cotracker import CoTracker, get_points_on_a_grid | ||||
|  | ||||
|  | ||||
| class EvaluationPredictor(torch.nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         cotracker_model: CoTracker, | ||||
|         interp_shape: Tuple[int, int] = (384, 512), | ||||
|         grid_size: int = 6, | ||||
|         local_grid_size: int = 6, | ||||
|         single_point: bool = True, | ||||
|         n_iters: int = 6, | ||||
|     ) -> None: | ||||
|         super(EvaluationPredictor, self).__init__() | ||||
|         self.grid_size = grid_size | ||||
|         self.local_grid_size = local_grid_size | ||||
|         self.single_point = single_point | ||||
|         self.interp_shape = interp_shape | ||||
|         self.n_iters = n_iters | ||||
|  | ||||
|         self.model = cotracker_model | ||||
|         self.model.to("cuda") | ||||
|         self.model.eval() | ||||
|  | ||||
|     def forward(self, video, queries): | ||||
|         queries = queries.clone().cuda() | ||||
|         B, T, C, H, W = video.shape | ||||
|         B, N, D = queries.shape | ||||
|  | ||||
|         assert D == 3 | ||||
|         assert B == 1 | ||||
|  | ||||
|         rgbs = video.reshape(B * T, C, H, W) | ||||
|         rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear") | ||||
|         rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]).cuda() | ||||
|  | ||||
|         queries[:, :, 1] *= self.interp_shape[1] / W | ||||
|         queries[:, :, 2] *= self.interp_shape[0] / H | ||||
|  | ||||
|         if self.single_point: | ||||
|             traj_e = torch.zeros((B, T, N, 2)).cuda() | ||||
|             vis_e = torch.zeros((B, T, N)).cuda() | ||||
|             for pind in range((N)): | ||||
|                 query = queries[:, pind : pind + 1] | ||||
|  | ||||
|                 t = query[0, 0, 0].long() | ||||
|  | ||||
|                 traj_e_pind, vis_e_pind = self._process_one_point(rgbs, query) | ||||
|                 traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1] | ||||
|                 vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1] | ||||
|         else: | ||||
|             if self.grid_size > 0: | ||||
|                 xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:]) | ||||
|                 xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda()  # | ||||
|                 queries = torch.cat([queries, xy], dim=1)  # | ||||
|  | ||||
|             traj_e, __, vis_e, __ = self.model( | ||||
|                 rgbs=rgbs, | ||||
|                 queries=queries, | ||||
|                 iters=self.n_iters, | ||||
|             ) | ||||
|  | ||||
|         traj_e[:, :, :, 0] *= W / float(self.interp_shape[1]) | ||||
|         traj_e[:, :, :, 1] *= H / float(self.interp_shape[0]) | ||||
|         return traj_e, vis_e | ||||
|  | ||||
|     def _process_one_point(self, rgbs, query): | ||||
|         t = query[0, 0, 0].long() | ||||
|  | ||||
|         device = rgbs.device | ||||
|         if self.local_grid_size > 0: | ||||
|             xy_target = get_points_on_a_grid( | ||||
|                 self.local_grid_size, | ||||
|                 (50, 50), | ||||
|                 [query[0, 0, 2], query[0, 0, 1]], | ||||
|             ) | ||||
|  | ||||
|             xy_target = torch.cat( | ||||
|                 [torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2 | ||||
|             )  # | ||||
|             query = torch.cat([query, xy_target], dim=1).to(device)  # | ||||
|  | ||||
|         if self.grid_size > 0: | ||||
|             xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:]) | ||||
|             xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda()  # | ||||
|             query = torch.cat([query, xy], dim=1).to(device)  # | ||||
|         # crop the video to start from the queried frame | ||||
|         query[0, 0, 0] = 0 | ||||
|         traj_e_pind, __, vis_e_pind, __ = self.model( | ||||
|             rgbs=rgbs[:, t:], queries=query, iters=self.n_iters | ||||
|         ) | ||||
|  | ||||
|         return traj_e_pind, vis_e_pind | ||||
							
								
								
									
										178
									
								
								cotracker/predictor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										178
									
								
								cotracker/predictor.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,178 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from tqdm import tqdm | ||||
| from cotracker.models.core.cotracker.cotracker import get_points_on_a_grid | ||||
| from cotracker.models.core.model_utils import smart_cat | ||||
| from cotracker.models.build_cotracker import ( | ||||
|     build_cotracker, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class CoTrackerPredictor(torch.nn.Module): | ||||
|     def __init__( | ||||
|         self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth" | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.interp_shape = (384, 512) | ||||
|         self.support_grid_size = 6 | ||||
|         model = build_cotracker(checkpoint) | ||||
|  | ||||
|         self.model = model | ||||
|         self.model.to("cuda") | ||||
|         self.model.eval() | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def forward( | ||||
|         self, | ||||
|         video,  # (1, T, 3, H, W) | ||||
|         # input prompt types: | ||||
|         # - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame. | ||||
|         # *backward_tracking=True* will compute tracks in both directions. | ||||
|         # - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates. | ||||
|         # - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask. | ||||
|         # You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks. | ||||
|         queries: torch.Tensor = None, | ||||
|         segm_mask: torch.Tensor = None,  # Segmentation mask of shape (B, 1, H, W) | ||||
|         grid_size: int = 0, | ||||
|         grid_query_frame: int = 0,  # only for dense and regular grid tracks | ||||
|         backward_tracking: bool = False, | ||||
|     ): | ||||
|  | ||||
|         if queries is None and grid_size == 0: | ||||
|             tracks, visibilities = self._compute_dense_tracks( | ||||
|                 video, | ||||
|                 grid_query_frame=grid_query_frame, | ||||
|                 backward_tracking=backward_tracking, | ||||
|             ) | ||||
|         else: | ||||
|             tracks, visibilities = self._compute_sparse_tracks( | ||||
|                 video, | ||||
|                 queries, | ||||
|                 segm_mask, | ||||
|                 grid_size, | ||||
|                 add_support_grid=(grid_size == 0 or segm_mask is not None), | ||||
|                 grid_query_frame=grid_query_frame, | ||||
|                 backward_tracking=backward_tracking, | ||||
|             ) | ||||
|  | ||||
|         return tracks, visibilities | ||||
|  | ||||
|     def _compute_dense_tracks( | ||||
|         self, video, grid_query_frame, grid_size=50, backward_tracking=False | ||||
|     ): | ||||
|         *_, H, W = video.shape | ||||
|         grid_step = W // grid_size | ||||
|         grid_width = W // grid_step | ||||
|         grid_height = H // grid_step | ||||
|         tracks = visibilities = None | ||||
|         grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to("cuda") | ||||
|         grid_pts[0, :, 0] = grid_query_frame | ||||
|         for offset in tqdm(range(grid_step * grid_step)): | ||||
|             ox = offset % grid_step | ||||
|             oy = offset // grid_step | ||||
|             grid_pts[0, :, 1] = ( | ||||
|                 torch.arange(grid_width).repeat(grid_height) * grid_step + ox | ||||
|             ) | ||||
|             grid_pts[0, :, 2] = ( | ||||
|                 torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy | ||||
|             ) | ||||
|             tracks_step, visibilities_step = self._compute_sparse_tracks( | ||||
|                 video=video, | ||||
|                 queries=grid_pts, | ||||
|                 backward_tracking=backward_tracking, | ||||
|             ) | ||||
|             tracks = smart_cat(tracks, tracks_step, dim=2) | ||||
|             visibilities = smart_cat(visibilities, visibilities_step, dim=2) | ||||
|  | ||||
|         return tracks, visibilities | ||||
|  | ||||
|     def _compute_sparse_tracks( | ||||
|         self, | ||||
|         video, | ||||
|         queries, | ||||
|         segm_mask=None, | ||||
|         grid_size=0, | ||||
|         add_support_grid=False, | ||||
|         grid_query_frame=0, | ||||
|         backward_tracking=False, | ||||
|     ): | ||||
|         B, T, C, H, W = video.shape | ||||
|         assert B == 1 | ||||
|  | ||||
|         video = video.reshape(B * T, C, H, W) | ||||
|         video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").cuda() | ||||
|         video = video.reshape( | ||||
|             B, T, 3, self.interp_shape[0], self.interp_shape[1] | ||||
|         ).cuda() | ||||
|  | ||||
|         if queries is not None: | ||||
|             queries = queries.clone() | ||||
|             B, N, D = queries.shape | ||||
|             assert D == 3 | ||||
|             queries[:, :, 1] *= self.interp_shape[1] / W | ||||
|             queries[:, :, 2] *= self.interp_shape[0] / H | ||||
|         elif grid_size > 0: | ||||
|             grid_pts = get_points_on_a_grid(grid_size, self.interp_shape) | ||||
|             if segm_mask is not None: | ||||
|                 segm_mask = F.interpolate( | ||||
|                     segm_mask, tuple(self.interp_shape), mode="nearest" | ||||
|                 ) | ||||
|                 point_mask = segm_mask[0, 0][ | ||||
|                     (grid_pts[0, :, 1]).round().long().cpu(), | ||||
|                     (grid_pts[0, :, 0]).round().long().cpu(), | ||||
|                 ].bool() | ||||
|                 grid_pts = grid_pts[:, point_mask] | ||||
|  | ||||
|             queries = torch.cat( | ||||
|                 [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts], | ||||
|                 dim=2, | ||||
|             ) | ||||
|  | ||||
|         if add_support_grid: | ||||
|             grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape) | ||||
|             grid_pts = torch.cat( | ||||
|                 [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 | ||||
|             ) | ||||
|             queries = torch.cat([queries, grid_pts], dim=1) | ||||
|  | ||||
|         tracks, __, visibilities, __ = self.model(rgbs=video, queries=queries, iters=6) | ||||
|  | ||||
|         if backward_tracking: | ||||
|             tracks, visibilities = self._compute_backward_tracks( | ||||
|                 video, queries, tracks, visibilities | ||||
|             ) | ||||
|             if add_support_grid: | ||||
|                 queries[:, -self.support_grid_size ** 2 :, 0] = T - 1 | ||||
|         if add_support_grid: | ||||
|             tracks = tracks[:, :, : -self.support_grid_size ** 2] | ||||
|             visibilities = visibilities[:, :, : -self.support_grid_size ** 2] | ||||
|         thr = 0.9 | ||||
|         visibilities = visibilities > thr | ||||
|         tracks[:, :, :, 0] *= W / float(self.interp_shape[1]) | ||||
|         tracks[:, :, :, 1] *= H / float(self.interp_shape[0]) | ||||
|         return tracks, visibilities | ||||
|  | ||||
|     def _compute_backward_tracks(self, video, queries, tracks, visibilities): | ||||
|         inv_video = video.flip(1).clone() | ||||
|         inv_queries = queries.clone() | ||||
|         inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1 | ||||
|  | ||||
|         inv_tracks, __, inv_visibilities, __ = self.model( | ||||
|             rgbs=inv_video, queries=inv_queries, iters=6 | ||||
|         ) | ||||
|  | ||||
|         inv_tracks = inv_tracks.flip(1) | ||||
|         inv_visibilities = inv_visibilities.flip(1) | ||||
|  | ||||
|         mask = tracks == 0 | ||||
|  | ||||
|         tracks[mask] = inv_tracks[mask] | ||||
|         visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]] | ||||
|         return tracks, visibilities | ||||
							
								
								
									
										5
									
								
								cotracker/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
							
								
								
									
										291
									
								
								cotracker/utils/visualizer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										291
									
								
								cotracker/utils/visualizer.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,291 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import os | ||||
| import numpy as np | ||||
| import cv2 | ||||
| import torch | ||||
| import flow_vis | ||||
|  | ||||
| from matplotlib import cm | ||||
| import torch.nn.functional as F | ||||
| import torchvision.transforms as transforms | ||||
| from moviepy.editor import ImageSequenceClip | ||||
| from torch.utils.tensorboard import SummaryWriter | ||||
| import matplotlib.pyplot as plt | ||||
|  | ||||
|  | ||||
| class Visualizer: | ||||
|     def __init__( | ||||
|         self, | ||||
|         save_dir: str = "./results", | ||||
|         grayscale: bool = False, | ||||
|         pad_value: int = 0, | ||||
|         fps: int = 10, | ||||
|         mode: str = "rainbow",  # 'cool', 'optical_flow' | ||||
|         linewidth: int = 2, | ||||
|         show_first_frame: int = 10, | ||||
|         tracks_leave_trace: int = 0,  # -1 for infinite | ||||
|     ): | ||||
|         self.mode = mode | ||||
|         self.save_dir = save_dir | ||||
|         if mode == "rainbow": | ||||
|             self.color_map = cm.get_cmap("gist_rainbow") | ||||
|         elif mode == "cool": | ||||
|             self.color_map = cm.get_cmap(mode) | ||||
|         self.show_first_frame = show_first_frame | ||||
|         self.grayscale = grayscale | ||||
|         self.tracks_leave_trace = tracks_leave_trace | ||||
|         self.pad_value = pad_value | ||||
|         self.linewidth = linewidth | ||||
|         self.fps = fps | ||||
|  | ||||
|     def visualize( | ||||
|         self, | ||||
|         video: torch.Tensor,  # (B,T,C,H,W) | ||||
|         tracks: torch.Tensor,  # (B,T,N,2) | ||||
|         gt_tracks: torch.Tensor = None,  # (B,T,N,2) | ||||
|         segm_mask: torch.Tensor = None,  # (B,1,H,W) | ||||
|         filename: str = "video", | ||||
|         writer: SummaryWriter = None, | ||||
|         step: int = 0, | ||||
|         query_frame: int = 0, | ||||
|         save_video: bool = True, | ||||
|         compensate_for_camera_motion: bool = False, | ||||
|     ): | ||||
|         if compensate_for_camera_motion: | ||||
|             assert segm_mask is not None | ||||
|         if segm_mask is not None: | ||||
|             coords = tracks[0, query_frame].round().long() | ||||
|             segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long() | ||||
|  | ||||
|         video = F.pad( | ||||
|             video, | ||||
|             (self.pad_value, self.pad_value, self.pad_value, self.pad_value), | ||||
|             "constant", | ||||
|             255, | ||||
|         ) | ||||
|         tracks = tracks + self.pad_value | ||||
|  | ||||
|         if self.grayscale: | ||||
|             transform = transforms.Grayscale() | ||||
|             video = transform(video) | ||||
|             video = video.repeat(1, 1, 3, 1, 1) | ||||
|  | ||||
|         res_video = self.draw_tracks_on_video( | ||||
|             video=video, | ||||
|             tracks=tracks, | ||||
|             segm_mask=segm_mask, | ||||
|             gt_tracks=gt_tracks, | ||||
|             query_frame=query_frame, | ||||
|             compensate_for_camera_motion=compensate_for_camera_motion, | ||||
|         ) | ||||
|         if save_video: | ||||
|             self.save_video(res_video, filename=filename, writer=writer, step=step) | ||||
|         return res_video | ||||
|  | ||||
|     def save_video(self, video, filename, writer=None, step=0): | ||||
|         if writer is not None: | ||||
|             writer.add_video( | ||||
|                 f"{filename}_pred_track", | ||||
|                 video.to(torch.uint8), | ||||
|                 global_step=step, | ||||
|                 fps=self.fps, | ||||
|             ) | ||||
|         else: | ||||
|             os.makedirs(self.save_dir, exist_ok=True) | ||||
|             wide_list = list(video.unbind(1)) | ||||
|             wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list] | ||||
|             clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps) | ||||
|  | ||||
|             # Write the video file | ||||
|             save_path = os.path.join(self.save_dir, f"{filename}_pred_track.mp4") | ||||
|             clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None) | ||||
|  | ||||
|             print(f"Video saved to {save_path}") | ||||
|  | ||||
|     def draw_tracks_on_video( | ||||
|         self, | ||||
|         video: torch.Tensor, | ||||
|         tracks: torch.Tensor, | ||||
|         segm_mask: torch.Tensor = None, | ||||
|         gt_tracks=None, | ||||
|         query_frame: int = 0, | ||||
|         compensate_for_camera_motion=False, | ||||
|     ): | ||||
|         B, T, C, H, W = video.shape | ||||
|         _, _, N, D = tracks.shape | ||||
|  | ||||
|         assert D == 2 | ||||
|         assert C == 3 | ||||
|         video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy()  # S, H, W, C | ||||
|         tracks = tracks[0].long().detach().cpu().numpy()  # S, N, 2 | ||||
|         if gt_tracks is not None: | ||||
|             gt_tracks = gt_tracks[0].detach().cpu().numpy() | ||||
|  | ||||
|         res_video = [] | ||||
|  | ||||
|         # process input video | ||||
|         for rgb in video: | ||||
|             res_video.append(rgb.copy()) | ||||
|  | ||||
|         vector_colors = np.zeros((T, N, 3)) | ||||
|         if self.mode == "optical_flow": | ||||
|             vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None]) | ||||
|         elif segm_mask is None: | ||||
|             if self.mode == "rainbow": | ||||
|                 y_min, y_max = ( | ||||
|                     tracks[query_frame, :, 1].min(), | ||||
|                     tracks[query_frame, :, 1].max(), | ||||
|                 ) | ||||
|                 norm = plt.Normalize(y_min, y_max) | ||||
|                 for n in range(N): | ||||
|                     color = self.color_map(norm(tracks[query_frame, n, 1])) | ||||
|                     color = np.array(color[:3])[None] * 255 | ||||
|                     vector_colors[:, n] = np.repeat(color, T, axis=0) | ||||
|             else: | ||||
|                 # color changes with time | ||||
|                 for t in range(T): | ||||
|                     color = np.array(self.color_map(t / T)[:3])[None] * 255 | ||||
|                     vector_colors[t] = np.repeat(color, N, axis=0) | ||||
|         else: | ||||
|             if self.mode == "rainbow": | ||||
|                 vector_colors[:, segm_mask <= 0, :] = 255 | ||||
|  | ||||
|                 y_min, y_max = ( | ||||
|                     tracks[0, segm_mask > 0, 1].min(), | ||||
|                     tracks[0, segm_mask > 0, 1].max(), | ||||
|                 ) | ||||
|                 norm = plt.Normalize(y_min, y_max) | ||||
|                 for n in range(N): | ||||
|                     if segm_mask[n] > 0: | ||||
|                         color = self.color_map(norm(tracks[0, n, 1])) | ||||
|                         color = np.array(color[:3])[None] * 255 | ||||
|                         vector_colors[:, n] = np.repeat(color, T, axis=0) | ||||
|  | ||||
|             else: | ||||
|                 # color changes with segm class | ||||
|                 segm_mask = segm_mask.cpu() | ||||
|                 color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32) | ||||
|                 color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0 | ||||
|                 color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0 | ||||
|                 vector_colors = np.repeat(color[None], T, axis=0) | ||||
|  | ||||
|         #  draw tracks | ||||
|         if self.tracks_leave_trace != 0: | ||||
|             for t in range(1, T): | ||||
|                 first_ind = ( | ||||
|                     max(0, t - self.tracks_leave_trace) | ||||
|                     if self.tracks_leave_trace >= 0 | ||||
|                     else 0 | ||||
|                 ) | ||||
|                 curr_tracks = tracks[first_ind : t + 1] | ||||
|                 curr_colors = vector_colors[first_ind : t + 1] | ||||
|                 if compensate_for_camera_motion: | ||||
|                     diff = ( | ||||
|                         tracks[first_ind : t + 1, segm_mask <= 0] | ||||
|                         - tracks[t : t + 1, segm_mask <= 0] | ||||
|                     ).mean(1)[:, None] | ||||
|  | ||||
|                     curr_tracks = curr_tracks - diff | ||||
|                     curr_tracks = curr_tracks[:, segm_mask > 0] | ||||
|                     curr_colors = curr_colors[:, segm_mask > 0] | ||||
|  | ||||
|                 res_video[t] = self._draw_pred_tracks( | ||||
|                     res_video[t], | ||||
|                     curr_tracks, | ||||
|                     curr_colors, | ||||
|                 ) | ||||
|                 if gt_tracks is not None: | ||||
|                     res_video[t] = self._draw_gt_tracks( | ||||
|                         res_video[t], gt_tracks[first_ind : t + 1] | ||||
|                     ) | ||||
|  | ||||
|         #  draw points | ||||
|         for t in range(T): | ||||
|             for i in range(N): | ||||
|                 coord = (tracks[t, i, 0], tracks[t, i, 1]) | ||||
|                 if coord[0] != 0 and coord[1] != 0: | ||||
|                     if not compensate_for_camera_motion or ( | ||||
|                         compensate_for_camera_motion and segm_mask[i] > 0 | ||||
|                     ): | ||||
|                         cv2.circle( | ||||
|                             res_video[t], | ||||
|                             coord, | ||||
|                             int(self.linewidth * 2), | ||||
|                             vector_colors[t, i].tolist(), | ||||
|                             -1, | ||||
|                         ) | ||||
|  | ||||
|         #  construct the final rgb sequence | ||||
|         if self.show_first_frame > 0: | ||||
|             res_video = [res_video[0]] * self.show_first_frame + res_video[1:] | ||||
|         return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte() | ||||
|  | ||||
|     def _draw_pred_tracks( | ||||
|         self, | ||||
|         rgb: np.ndarray,  # H x W x 3 | ||||
|         tracks: np.ndarray,  # T x 2 | ||||
|         vector_colors: np.ndarray, | ||||
|         alpha: float = 0.5, | ||||
|     ): | ||||
|         T, N, _ = tracks.shape | ||||
|  | ||||
|         for s in range(T - 1): | ||||
|             vector_color = vector_colors[s] | ||||
|             original = rgb.copy() | ||||
|             alpha = (s / T) ** 2 | ||||
|             for i in range(N): | ||||
|                 coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1])) | ||||
|                 coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1])) | ||||
|                 if coord_y[0] != 0 and coord_y[1] != 0: | ||||
|                     cv2.line( | ||||
|                         rgb, | ||||
|                         coord_y, | ||||
|                         coord_x, | ||||
|                         vector_color[i].tolist(), | ||||
|                         self.linewidth, | ||||
|                         cv2.LINE_AA, | ||||
|                     ) | ||||
|             if self.tracks_leave_trace > 0: | ||||
|                 rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0) | ||||
|         return rgb | ||||
|  | ||||
|     def _draw_gt_tracks( | ||||
|         self, | ||||
|         rgb: np.ndarray,  # H x W x 3, | ||||
|         gt_tracks: np.ndarray,  # T x 2 | ||||
|     ): | ||||
|         T, N, _ = gt_tracks.shape | ||||
|         color = np.array((211.0, 0.0, 0.0)) | ||||
|  | ||||
|         for t in range(T): | ||||
|             for i in range(N): | ||||
|                 gt_tracks = gt_tracks[t][i] | ||||
|                 #  draw a red cross | ||||
|                 if gt_tracks[0] > 0 and gt_tracks[1] > 0: | ||||
|                     length = self.linewidth * 3 | ||||
|                     coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length) | ||||
|                     coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length) | ||||
|                     cv2.line( | ||||
|                         rgb, | ||||
|                         coord_y, | ||||
|                         coord_x, | ||||
|                         color, | ||||
|                         self.linewidth, | ||||
|                         cv2.LINE_AA, | ||||
|                     ) | ||||
|                     coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length) | ||||
|                     coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length) | ||||
|                     cv2.line( | ||||
|                         rgb, | ||||
|                         coord_y, | ||||
|                         coord_x, | ||||
|                         color, | ||||
|                         self.linewidth, | ||||
|                         cv2.LINE_AA, | ||||
|                     ) | ||||
|         return rgb | ||||
							
								
								
									
										71
									
								
								demo.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								demo.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,71 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import os | ||||
| import torch | ||||
| import argparse | ||||
| import numpy as np | ||||
|  | ||||
| from torchvision.io import read_video | ||||
| from PIL import Image | ||||
| from cotracker.utils.visualizer import Visualizer | ||||
| from cotracker.predictor import CoTrackerPredictor | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument( | ||||
|         "--video_path", | ||||
|         default="./assets/apple.mp4", | ||||
|         help="path to a video", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--mask_path", | ||||
|         default="./assets/apple_mask.png", | ||||
|         help="path to a segmentation mask", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--checkpoint", | ||||
|         default="./checkpoints/cotracker_stride_4_wind_8.pth", | ||||
|         help="cotracker model", | ||||
|     ) | ||||
|     parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size") | ||||
|     parser.add_argument( | ||||
|         "--grid_query_frame", | ||||
|         type=int, | ||||
|         default=0, | ||||
|         help="Compute dense and grid tracks starting from this frame ", | ||||
|     ) | ||||
|  | ||||
|     parser.add_argument( | ||||
|         "--backward_tracking", | ||||
|         action="store_true", | ||||
|         help="Compute tracks in both directions, not only forward", | ||||
|     ) | ||||
|  | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     # load the input video frame by frame | ||||
|     video = read_video(args.video_path) | ||||
|     video = video[0].permute(0, 3, 1, 2)[None].float() | ||||
|     segm_mask = np.array(Image.open(os.path.join(args.mask_path))) | ||||
|     segm_mask = torch.from_numpy(segm_mask)[None, None] | ||||
|  | ||||
|     model = CoTrackerPredictor(checkpoint=args.checkpoint) | ||||
|  | ||||
|     pred_tracks, pred_visibility = model( | ||||
|         video, | ||||
|         grid_size=args.grid_size, | ||||
|         grid_query_frame=args.grid_query_frame, | ||||
|         backward_tracking=args.backward_tracking, | ||||
|         # segm_mask=segm_mask | ||||
|     ) | ||||
|     print("computed") | ||||
|  | ||||
|     # save a video with predicted tracks | ||||
|     seq_name = args.video_path.split("/")[-1] | ||||
|     vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) | ||||
|     vis.visualize(video, pred_tracks, query_frame=args.grid_query_frame) | ||||
							
								
								
									
										924
									
								
								notebooks/demo.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										924
									
								
								notebooks/demo.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										18
									
								
								setup.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								setup.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,18 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| from setuptools import find_packages, setup | ||||
|  | ||||
| setup( | ||||
|     name="cotracker", | ||||
|     version="1.0", | ||||
|     install_requires=[], | ||||
|     packages=find_packages(exclude="notebooks"), | ||||
|     extras_require={ | ||||
|         "all": ["matplotlib", "opencv-python"], | ||||
|         "dev": ["flake8", "black"], | ||||
|     }, | ||||
| ) | ||||
							
								
								
									
										665
									
								
								train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										665
									
								
								train.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,665 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import os | ||||
| import random | ||||
| import torch | ||||
| import signal | ||||
| import socket | ||||
| import sys | ||||
| import json | ||||
|  | ||||
| import numpy as np | ||||
| import argparse | ||||
| import logging | ||||
| from pathlib import Path | ||||
| from tqdm import tqdm | ||||
| import torch.optim as optim | ||||
| from torch.utils.data import DataLoader | ||||
| from torch.cuda.amp import GradScaler | ||||
|  | ||||
| from torch.utils.tensorboard import SummaryWriter | ||||
| from pytorch_lightning.lite import LightningLite | ||||
|  | ||||
| from cotracker.models.evaluation_predictor import EvaluationPredictor | ||||
| from cotracker.models.core.cotracker.cotracker import CoTracker | ||||
| from cotracker.utils.visualizer import Visualizer | ||||
| from cotracker.datasets.tap_vid_datasets import TapVidDataset | ||||
| from cotracker.datasets.badja_dataset import BadjaDataset | ||||
| from cotracker.datasets.fast_capture_dataset import FastCaptureDataset | ||||
| from cotracker.evaluation.core.evaluator import Evaluator | ||||
| from cotracker.datasets import kubric_movif_dataset | ||||
| from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_cuda_ | ||||
| from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss | ||||
|  | ||||
|  | ||||
| # define the handler function | ||||
| # for training on a slurm cluster | ||||
| def sig_handler(signum, frame): | ||||
|     print("caught signal", signum) | ||||
|     print(socket.gethostname(), "USR1 signal caught.") | ||||
|     # do other stuff to cleanup here | ||||
|     print("requeuing job " + os.environ["SLURM_JOB_ID"]) | ||||
|     os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"]) | ||||
|     sys.exit(-1) | ||||
|  | ||||
|  | ||||
| def term_handler(signum, frame): | ||||
|     print("bypassing sigterm", flush=True) | ||||
|  | ||||
|  | ||||
| def fetch_optimizer(args, model): | ||||
|     """Create the optimizer and learning rate scheduler""" | ||||
|     optimizer = optim.AdamW( | ||||
|         model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8 | ||||
|     ) | ||||
|     scheduler = optim.lr_scheduler.OneCycleLR( | ||||
|         optimizer, | ||||
|         args.lr, | ||||
|         args.num_steps + 100, | ||||
|         pct_start=0.05, | ||||
|         cycle_momentum=False, | ||||
|         anneal_strategy="linear", | ||||
|     ) | ||||
|  | ||||
|     return optimizer, scheduler | ||||
|  | ||||
|  | ||||
| def forward_batch(batch, model, args, loss_fn=None, writer=None, step=0): | ||||
|     rgbs = batch.video | ||||
|     trajs_g = batch.trajectory | ||||
|     vis_g = batch.visibility | ||||
|     valids = batch.valid | ||||
|     B, T, C, H, W = rgbs.shape | ||||
|     assert C == 3 | ||||
|     B, T, N, D = trajs_g.shape | ||||
|     device = rgbs.device | ||||
|  | ||||
|     __, first_positive_inds = torch.max(vis_g, dim=1) | ||||
|     # We want to make sure that during training the model sees visible points | ||||
|     # that it does not need to track just yet: they are visible but queried from a later frame | ||||
|     N_rand = N // 4 | ||||
|     # inds of visible points in the 1st frame | ||||
|     nonzero_inds = [torch.nonzero(vis_g[0, :, i]) for i in range(N)] | ||||
|     rand_vis_inds = torch.cat( | ||||
|         [ | ||||
|             nonzero_row[torch.randint(len(nonzero_row), size=(1,))] | ||||
|             for nonzero_row in nonzero_inds | ||||
|         ], | ||||
|         dim=1, | ||||
|     ) | ||||
|     first_positive_inds = torch.cat( | ||||
|         [rand_vis_inds[:, :N_rand], first_positive_inds[:, N_rand:]], dim=1 | ||||
|     ) | ||||
|     ind_array_ = torch.arange(T, device=device) | ||||
|     ind_array_ = ind_array_[None, :, None].repeat(B, 1, N) | ||||
|     assert torch.allclose( | ||||
|         vis_g[ind_array_ == first_positive_inds[:, None, :]], | ||||
|         torch.ones_like(vis_g), | ||||
|     ) | ||||
|     assert torch.allclose( | ||||
|         vis_g[ind_array_ == rand_vis_inds[:, None, :]], torch.ones_like(vis_g) | ||||
|     ) | ||||
|  | ||||
|     gather = torch.gather( | ||||
|         trajs_g, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 2) | ||||
|     ) | ||||
|     xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1) | ||||
|  | ||||
|     queries = torch.cat([first_positive_inds[:, :, None], xys], dim=2) | ||||
|  | ||||
|     predictions, __, visibility, train_data = model( | ||||
|         rgbs=rgbs, queries=queries, iters=args.train_iters, is_train=True | ||||
|     ) | ||||
|     vis_predictions, coord_predictions, wind_inds, sort_inds = train_data | ||||
|  | ||||
|     trajs_g = trajs_g[:, :, sort_inds] | ||||
|     vis_g = vis_g[:, :, sort_inds] | ||||
|     valids = valids[:, :, sort_inds] | ||||
|  | ||||
|     vis_gts = [] | ||||
|     traj_gts = [] | ||||
|     valids_gts = [] | ||||
|  | ||||
|     for i, wind_idx in enumerate(wind_inds): | ||||
|         ind = i * (args.sliding_window_len // 2) | ||||
|  | ||||
|         vis_gts.append(vis_g[:, ind : ind + args.sliding_window_len, :wind_idx]) | ||||
|         traj_gts.append(trajs_g[:, ind : ind + args.sliding_window_len, :wind_idx]) | ||||
|         valids_gts.append(valids[:, ind : ind + args.sliding_window_len, :wind_idx]) | ||||
|  | ||||
|     seq_loss = sequence_loss(coord_predictions, traj_gts, vis_gts, valids_gts, 0.8) | ||||
|     vis_loss = balanced_ce_loss(vis_predictions, vis_gts, valids_gts) | ||||
|  | ||||
|     output = {"flow": {"predictions": predictions[0].detach()}} | ||||
|     output["flow"]["loss"] = seq_loss.mean() | ||||
|     output["visibility"] = { | ||||
|         "loss": vis_loss.mean() * 10.0, | ||||
|         "predictions": visibility[0].detach(), | ||||
|     } | ||||
|     return output | ||||
|  | ||||
|  | ||||
| def run_test_eval(evaluator, model, dataloaders, writer, step): | ||||
|     model.eval() | ||||
|     for ds_name, dataloader in dataloaders: | ||||
|         predictor = EvaluationPredictor( | ||||
|             model.module.module, | ||||
|             grid_size=6, | ||||
|             local_grid_size=0, | ||||
|             single_point=False, | ||||
|             n_iters=6, | ||||
|         ) | ||||
|  | ||||
|         metrics = evaluator.evaluate_sequence( | ||||
|             model=predictor, | ||||
|             test_dataloader=dataloader, | ||||
|             dataset_name=ds_name, | ||||
|             train_mode=True, | ||||
|             writer=writer, | ||||
|             step=step, | ||||
|         ) | ||||
|  | ||||
|         if ds_name == "badja" or ds_name == "fastcapture" or ("kubric" in ds_name): | ||||
|  | ||||
|             metrics = { | ||||
|                 **{ | ||||
|                     f"{ds_name}_avg": np.mean( | ||||
|                         [v for k, v in metrics.items() if "accuracy" not in k] | ||||
|                     ) | ||||
|                 }, | ||||
|                 **{ | ||||
|                     f"{ds_name}_avg_accuracy": np.mean( | ||||
|                         [v for k, v in metrics.items() if "accuracy" in k] | ||||
|                     ) | ||||
|                 }, | ||||
|             } | ||||
|             print("avg", np.mean([v for v in metrics.values()])) | ||||
|  | ||||
|         if "tapvid" in ds_name: | ||||
|             metrics = { | ||||
|                 f"{ds_name}_avg_OA": metrics["avg"]["occlusion_accuracy"] * 100, | ||||
|                 f"{ds_name}_avg_delta": metrics["avg"]["average_pts_within_thresh"] | ||||
|                 * 100, | ||||
|                 f"{ds_name}_avg_Jaccard": metrics["avg"]["average_jaccard"] * 100, | ||||
|             } | ||||
|  | ||||
|         writer.add_scalars(f"Eval", metrics, step) | ||||
|  | ||||
|  | ||||
| class Logger: | ||||
|  | ||||
|     SUM_FREQ = 100 | ||||
|  | ||||
|     def __init__(self, model, scheduler): | ||||
|         self.model = model | ||||
|         self.scheduler = scheduler | ||||
|         self.total_steps = 0 | ||||
|         self.running_loss = {} | ||||
|         self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs")) | ||||
|  | ||||
|     def _print_training_status(self): | ||||
|         metrics_data = [ | ||||
|             self.running_loss[k] / Logger.SUM_FREQ | ||||
|             for k in sorted(self.running_loss.keys()) | ||||
|         ] | ||||
|         training_str = "[{:6d}] ".format(self.total_steps + 1) | ||||
|         metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data) | ||||
|  | ||||
|         # print the training status | ||||
|         logging.info( | ||||
|             f"Training Metrics ({self.total_steps}): {training_str + metrics_str}" | ||||
|         ) | ||||
|  | ||||
|         if self.writer is None: | ||||
|             self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs")) | ||||
|  | ||||
|         for k in self.running_loss: | ||||
|             self.writer.add_scalar( | ||||
|                 k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps | ||||
|             ) | ||||
|             self.running_loss[k] = 0.0 | ||||
|  | ||||
|     def push(self, metrics, task): | ||||
|         self.total_steps += 1 | ||||
|  | ||||
|         for key in metrics: | ||||
|             task_key = str(key) + "_" + task | ||||
|             if task_key not in self.running_loss: | ||||
|                 self.running_loss[task_key] = 0.0 | ||||
|  | ||||
|             self.running_loss[task_key] += metrics[key] | ||||
|  | ||||
|         if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ - 1: | ||||
|             self._print_training_status() | ||||
|             self.running_loss = {} | ||||
|  | ||||
|     def write_dict(self, results): | ||||
|         if self.writer is None: | ||||
|             self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs")) | ||||
|  | ||||
|         for key in results: | ||||
|             self.writer.add_scalar(key, results[key], self.total_steps) | ||||
|  | ||||
|     def close(self): | ||||
|         self.writer.close() | ||||
|  | ||||
|  | ||||
| class Lite(LightningLite): | ||||
|     def run(self, args): | ||||
|         def seed_everything(seed: int): | ||||
|             random.seed(seed) | ||||
|             os.environ["PYTHONHASHSEED"] = str(seed) | ||||
|             np.random.seed(seed) | ||||
|             torch.manual_seed(seed) | ||||
|             torch.cuda.manual_seed(seed) | ||||
|             torch.backends.cudnn.deterministic = True | ||||
|             torch.backends.cudnn.benchmark = False | ||||
|  | ||||
|         seed_everything(0) | ||||
|  | ||||
|         def seed_worker(worker_id): | ||||
|             worker_seed = torch.initial_seed() % 2 ** 32 | ||||
|             np.random.seed(worker_seed) | ||||
|             random.seed(worker_seed) | ||||
|  | ||||
|         g = torch.Generator() | ||||
|         g.manual_seed(0) | ||||
|  | ||||
|         eval_dataloaders = [] | ||||
|         if "badja" in args.eval_datasets: | ||||
|             eval_dataset = BadjaDataset( | ||||
|                 data_root=os.path.join(args.dataset_root, "BADJA"), | ||||
|                 max_seq_len=args.eval_max_seq_len, | ||||
|                 dataset_resolution=args.crop_size, | ||||
|             ) | ||||
|             eval_dataloader_badja = torch.utils.data.DataLoader( | ||||
|                 eval_dataset, | ||||
|                 batch_size=1, | ||||
|                 shuffle=False, | ||||
|                 num_workers=8, | ||||
|                 collate_fn=collate_fn, | ||||
|             ) | ||||
|             eval_dataloaders.append(("badja", eval_dataloader_badja)) | ||||
|  | ||||
|         if "fastcapture" in args.eval_datasets: | ||||
|             eval_dataset = FastCaptureDataset( | ||||
|                 data_root=os.path.join(args.dataset_root, "fastcapture"), | ||||
|                 max_seq_len=min(100, args.eval_max_seq_len), | ||||
|                 max_num_points=40, | ||||
|                 dataset_resolution=args.crop_size, | ||||
|             ) | ||||
|             eval_dataloader_fastcapture = torch.utils.data.DataLoader( | ||||
|                 eval_dataset, | ||||
|                 batch_size=1, | ||||
|                 shuffle=False, | ||||
|                 num_workers=1, | ||||
|                 collate_fn=collate_fn, | ||||
|             ) | ||||
|             eval_dataloaders.append(("fastcapture", eval_dataloader_fastcapture)) | ||||
|  | ||||
|         if "tapvid_davis_first" in args.eval_datasets: | ||||
|             data_root = os.path.join( | ||||
|                 args.dataset_root, "/tapvid_davis/tapvid_davis.pkl" | ||||
|             ) | ||||
|             eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root) | ||||
|             eval_dataloader_tapvid_davis = torch.utils.data.DataLoader( | ||||
|                 eval_dataset, | ||||
|                 batch_size=1, | ||||
|                 shuffle=False, | ||||
|                 num_workers=1, | ||||
|                 collate_fn=collate_fn, | ||||
|             ) | ||||
|             eval_dataloaders.append(("tapvid_davis", eval_dataloader_tapvid_davis)) | ||||
|  | ||||
|         evaluator = Evaluator(args.ckpt_path) | ||||
|  | ||||
|         visualizer = Visualizer( | ||||
|             save_dir=args.ckpt_path, | ||||
|             pad_value=80, | ||||
|             fps=1, | ||||
|             show_first_frame=0, | ||||
|             tracks_leave_trace=0, | ||||
|         ) | ||||
|  | ||||
|         loss_fn = None | ||||
|  | ||||
|         if args.model_name == "cotracker": | ||||
|  | ||||
|             model = CoTracker( | ||||
|                 stride=args.model_stride, | ||||
|                 S=args.sliding_window_len, | ||||
|                 add_space_attn=not args.remove_space_attn, | ||||
|                 num_heads=args.updateformer_num_heads, | ||||
|                 hidden_size=args.updateformer_hidden_size, | ||||
|                 space_depth=args.updateformer_space_depth, | ||||
|                 time_depth=args.updateformer_time_depth, | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError(f"Model {args.model_name} doesn't exist") | ||||
|  | ||||
|         with open(args.ckpt_path + "/meta.json", "w") as file: | ||||
|             json.dump(vars(args), file, sort_keys=True, indent=4) | ||||
|  | ||||
|         model.cuda() | ||||
|  | ||||
|         train_dataset = kubric_movif_dataset.KubricMovifDataset( | ||||
|             data_root=os.path.join(args.dataset_root, "kubric_movi_f"), | ||||
|             crop_size=args.crop_size, | ||||
|             seq_len=args.sequence_len, | ||||
|             traj_per_sample=args.traj_per_sample, | ||||
|             sample_vis_1st_frame=args.sample_vis_1st_frame, | ||||
|             use_augs=not args.dont_use_augs, | ||||
|         ) | ||||
|  | ||||
|         train_loader = DataLoader( | ||||
|             train_dataset, | ||||
|             batch_size=args.batch_size, | ||||
|             shuffle=True, | ||||
|             num_workers=args.num_workers, | ||||
|             worker_init_fn=seed_worker, | ||||
|             generator=g, | ||||
|             pin_memory=True, | ||||
|             collate_fn=collate_fn_train, | ||||
|             drop_last=True, | ||||
|         ) | ||||
|  | ||||
|         train_loader = self.setup_dataloaders(train_loader, move_to_device=False) | ||||
|         print("LEN TRAIN LOADER", len(train_loader)) | ||||
|         optimizer, scheduler = fetch_optimizer(args, model) | ||||
|  | ||||
|         total_steps = 0 | ||||
|         logger = Logger(model, scheduler) | ||||
|  | ||||
|         folder_ckpts = [ | ||||
|             f | ||||
|             for f in os.listdir(args.ckpt_path) | ||||
|             if not os.path.isdir(f) and f.endswith(".pth") and not "final" in f | ||||
|         ] | ||||
|         if len(folder_ckpts) > 0: | ||||
|             ckpt_path = sorted(folder_ckpts)[-1] | ||||
|             ckpt = self.load(os.path.join(args.ckpt_path, ckpt_path)) | ||||
|             logging.info(f"Loading checkpoint {ckpt_path}") | ||||
|             if "model" in ckpt: | ||||
|                 model.load_state_dict(ckpt["model"]) | ||||
|             else: | ||||
|                 model.load_state_dict(ckpt) | ||||
|             if "optimizer" in ckpt: | ||||
|                 logging.info("Load optimizer") | ||||
|                 optimizer.load_state_dict(ckpt["optimizer"]) | ||||
|             if "scheduler" in ckpt: | ||||
|                 logging.info("Load scheduler") | ||||
|                 scheduler.load_state_dict(ckpt["scheduler"]) | ||||
|             if "total_steps" in ckpt: | ||||
|                 total_steps = ckpt["total_steps"] | ||||
|                 logging.info(f"Load total_steps {total_steps}") | ||||
|  | ||||
|         elif args.restore_ckpt is not None: | ||||
|             assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith( | ||||
|                 ".pt" | ||||
|             ) | ||||
|             logging.info("Loading checkpoint...") | ||||
|  | ||||
|             strict = True | ||||
|             state_dict = self.load(args.restore_ckpt) | ||||
|             if "model" in state_dict: | ||||
|                 state_dict = state_dict["model"] | ||||
|  | ||||
|             if list(state_dict.keys())[0].startswith("module."): | ||||
|                 state_dict = { | ||||
|                     k.replace("module.", ""): v for k, v in state_dict.items() | ||||
|                 } | ||||
|             model.load_state_dict(state_dict, strict=strict) | ||||
|  | ||||
|             logging.info(f"Done loading checkpoint") | ||||
|         model, optimizer = self.setup(model, optimizer, move_to_device=False) | ||||
|         # model.cuda() | ||||
|         model.train() | ||||
|  | ||||
|         save_freq = args.save_freq | ||||
|         scaler = GradScaler(enabled=args.mixed_precision) | ||||
|  | ||||
|         should_keep_training = True | ||||
|         global_batch_num = 0 | ||||
|         epoch = -1 | ||||
|  | ||||
|         while should_keep_training: | ||||
|             epoch += 1 | ||||
|             for i_batch, batch in enumerate(tqdm(train_loader)): | ||||
|                 batch, gotit = batch | ||||
|                 if not all(gotit): | ||||
|                     print("batch is None") | ||||
|                     continue | ||||
|                 dataclass_to_cuda_(batch) | ||||
|  | ||||
|                 optimizer.zero_grad() | ||||
|  | ||||
|                 assert model.training | ||||
|  | ||||
|                 output = forward_batch( | ||||
|                     batch, | ||||
|                     model, | ||||
|                     args, | ||||
|                     loss_fn=loss_fn, | ||||
|                     writer=logger.writer, | ||||
|                     step=total_steps, | ||||
|                 ) | ||||
|  | ||||
|                 loss = 0 | ||||
|                 for k, v in output.items(): | ||||
|                     if "loss" in v: | ||||
|                         loss += v["loss"] | ||||
|                         logger.writer.add_scalar( | ||||
|                             f"live_{k}_loss", v["loss"].item(), total_steps | ||||
|                         ) | ||||
|                     if "metrics" in v: | ||||
|                         logger.push(v["metrics"], k) | ||||
|  | ||||
|                 if self.global_rank == 0: | ||||
|                     if total_steps % save_freq == save_freq - 1: | ||||
|                         if args.model_name == "motion_diffuser": | ||||
|                             pred_coords = model.module.module.forward_batch_test( | ||||
|                                 batch, interp_shape=args.crop_size | ||||
|                             ) | ||||
|  | ||||
|                             output["flow"] = {"predictions": pred_coords[0].detach()} | ||||
|                         visualizer.visualize( | ||||
|                             video=batch.video.clone(), | ||||
|                             tracks=batch.trajectory.clone(), | ||||
|                             filename="train_gt_traj", | ||||
|                             writer=logger.writer, | ||||
|                             step=total_steps, | ||||
|                         ) | ||||
|  | ||||
|                         visualizer.visualize( | ||||
|                             video=batch.video.clone(), | ||||
|                             tracks=output["flow"]["predictions"][None], | ||||
|                             filename="train_pred_traj", | ||||
|                             writer=logger.writer, | ||||
|                             step=total_steps, | ||||
|                         ) | ||||
|  | ||||
|                     if len(output) > 1: | ||||
|                         logger.writer.add_scalar( | ||||
|                             f"live_total_loss", loss.item(), total_steps | ||||
|                         ) | ||||
|                     logger.writer.add_scalar( | ||||
|                         f"learning_rate", optimizer.param_groups[0]["lr"], total_steps | ||||
|                     ) | ||||
|                     global_batch_num += 1 | ||||
|  | ||||
|                 self.barrier() | ||||
|  | ||||
|                 self.backward(scaler.scale(loss)) | ||||
|  | ||||
|                 scaler.unscale_(optimizer) | ||||
|                 torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0) | ||||
|  | ||||
|                 scaler.step(optimizer) | ||||
|                 scheduler.step() | ||||
|                 scaler.update() | ||||
|                 total_steps += 1 | ||||
|                 if self.global_rank == 0: | ||||
|                     if (i_batch >= len(train_loader) - 1) or ( | ||||
|                         total_steps == 1 and args.validate_at_start | ||||
|                     ): | ||||
|                         if (epoch + 1) % args.save_every_n_epoch == 0: | ||||
|                             ckpt_iter = "0" * (6 - len(str(total_steps))) + str( | ||||
|                                 total_steps | ||||
|                             ) | ||||
|                             save_path = Path( | ||||
|                                 f"{args.ckpt_path}/model_{args.model_name}_{ckpt_iter}.pth" | ||||
|                             ) | ||||
|  | ||||
|                             save_dict = { | ||||
|                                 "model": model.module.module.state_dict(), | ||||
|                                 "optimizer": optimizer.state_dict(), | ||||
|                                 "scheduler": scheduler.state_dict(), | ||||
|                                 "total_steps": total_steps, | ||||
|                             } | ||||
|  | ||||
|                             logging.info(f"Saving file {save_path}") | ||||
|                             self.save(save_dict, save_path) | ||||
|  | ||||
|                         if (epoch + 1) % args.evaluate_every_n_epoch == 0 or ( | ||||
|                             args.validate_at_start and epoch == 0 | ||||
|                         ): | ||||
|                             run_test_eval( | ||||
|                                 evaluator, | ||||
|                                 model, | ||||
|                                 eval_dataloaders, | ||||
|                                 logger.writer, | ||||
|                                 total_steps, | ||||
|                             ) | ||||
|                             model.train() | ||||
|                             torch.cuda.empty_cache() | ||||
|  | ||||
|                 self.barrier() | ||||
|                 if total_steps > args.num_steps: | ||||
|                     should_keep_training = False | ||||
|                     break | ||||
|  | ||||
|         print("FINISHED TRAINING") | ||||
|  | ||||
|         PATH = f"{args.ckpt_path}/{args.model_name}_final.pth" | ||||
|         torch.save(model.module.module.state_dict(), PATH) | ||||
|         run_test_eval(evaluator, model, eval_dataloaders, logger.writer, total_steps) | ||||
|         logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     signal.signal(signal.SIGUSR1, sig_handler) | ||||
|     signal.signal(signal.SIGTERM, term_handler) | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument("--model_name", default="cotracker", help="model name") | ||||
|     parser.add_argument("--restore_ckpt", help="restore checkpoint") | ||||
|     parser.add_argument("--ckpt_path", help="restore checkpoint") | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=4, help="batch size used during training." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--num_workers", type=int, default=6, help="left right consistency loss" | ||||
|     ) | ||||
|  | ||||
|     parser.add_argument( | ||||
|         "--mixed_precision", action="store_true", help="use mixed precision" | ||||
|     ) | ||||
|     parser.add_argument("--lr", type=float, default=0.0005, help="max learning rate.") | ||||
|     parser.add_argument( | ||||
|         "--wdecay", type=float, default=0.00001, help="Weight decay in optimizer." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--num_steps", type=int, default=200000, help="length of training schedule." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--evaluate_every_n_epoch", | ||||
|         type=int, | ||||
|         default=1, | ||||
|         help="number of flow-field updates during validation forward pass", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_every_n_epoch", | ||||
|         type=int, | ||||
|         default=1, | ||||
|         help="number of flow-field updates during validation forward pass", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--validate_at_start", action="store_true", help="use mixed precision" | ||||
|     ) | ||||
|     parser.add_argument("--save_freq", type=int, default=100, help="save_freq") | ||||
|     parser.add_argument("--traj_per_sample", type=int, default=768, help="save_freq") | ||||
|     parser.add_argument("--dataset_root", type=str, help="path lo all the datasets") | ||||
|  | ||||
|     parser.add_argument( | ||||
|         "--train_iters", | ||||
|         type=int, | ||||
|         default=4, | ||||
|         help="number of updates to the disparity field in each forward pass.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--sequence_len", type=int, default=8, help="train sequence length" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--eval_datasets", | ||||
|         nargs="+", | ||||
|         default=["things", "badja", "fastcapture"], | ||||
|         help="eval datasets.", | ||||
|     ) | ||||
|  | ||||
|     parser.add_argument( | ||||
|         "--remove_space_attn", action="store_true", help="use mixed precision" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--dont_use_augs", action="store_true", help="use mixed precision" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--sample_vis_1st_frame", action="store_true", help="use mixed precision" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--sliding_window_len", type=int, default=8, help="use mixed precision" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--updateformer_hidden_size", type=int, default=384, help="use mixed precision" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--updateformer_num_heads", type=int, default=8, help="use mixed precision" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--updateformer_space_depth", type=int, default=12, help="use mixed precision" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--updateformer_time_depth", type=int, default=12, help="use mixed precision" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--model_stride", type=int, default=8, help="use mixed precision" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--crop_size", | ||||
|         type=int, | ||||
|         nargs="+", | ||||
|         default=[384, 512], | ||||
|         help="use mixed precision", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--eval_max_seq_len", type=int, default=1000, help="use mixed precision" | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     logging.basicConfig( | ||||
|         level=logging.INFO, | ||||
|         format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", | ||||
|     ) | ||||
|  | ||||
|     Path(args.ckpt_path).mkdir(exist_ok=True, parents=True) | ||||
|     from pytorch_lightning.strategies import DDPStrategy | ||||
|  | ||||
|     Lite( | ||||
|         strategy=DDPStrategy(find_unused_parameters=True), | ||||
|         devices="auto", | ||||
|         accelerator="gpu", | ||||
|         precision=32, | ||||
|         num_nodes=4, | ||||
|     ).run(args) | ||||
		Reference in New Issue
	
	Block a user