Initial commit
This commit is contained in:
		
							
								
								
									
										80
									
								
								CODE_OF_CONDUCT.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								CODE_OF_CONDUCT.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,80 @@ | |||||||
|  | # Code of Conduct | ||||||
|  |  | ||||||
|  | ## Our Pledge | ||||||
|  |  | ||||||
|  | In the interest of fostering an open and welcoming environment, we as | ||||||
|  | contributors and maintainers pledge to make participation in our project and | ||||||
|  | our community a harassment-free experience for everyone, regardless of age, body | ||||||
|  | size, disability, ethnicity, sex characteristics, gender identity and expression, | ||||||
|  | level of experience, education, socio-economic status, nationality, personal | ||||||
|  | appearance, race, religion, or sexual identity and orientation. | ||||||
|  |  | ||||||
|  | ## Our Standards | ||||||
|  |  | ||||||
|  | Examples of behavior that contributes to creating a positive environment | ||||||
|  | include: | ||||||
|  |  | ||||||
|  | * Using welcoming and inclusive language | ||||||
|  | * Being respectful of differing viewpoints and experiences | ||||||
|  | * Gracefully accepting constructive criticism | ||||||
|  | * Focusing on what is best for the community | ||||||
|  | * Showing empathy towards other community members | ||||||
|  |  | ||||||
|  | Examples of unacceptable behavior by participants include: | ||||||
|  |  | ||||||
|  | * The use of sexualized language or imagery and unwelcome sexual attention or | ||||||
|  | advances | ||||||
|  | * Trolling, insulting/derogatory comments, and personal or political attacks | ||||||
|  | * Public or private harassment | ||||||
|  | * Publishing others' private information, such as a physical or electronic | ||||||
|  | address, without explicit permission | ||||||
|  | * Other conduct which could reasonably be considered inappropriate in a | ||||||
|  | professional setting | ||||||
|  |  | ||||||
|  | ## Our Responsibilities | ||||||
|  |  | ||||||
|  | Project maintainers are responsible for clarifying the standards of acceptable | ||||||
|  | behavior and are expected to take appropriate and fair corrective action in | ||||||
|  | response to any instances of unacceptable behavior. | ||||||
|  |  | ||||||
|  | Project maintainers have the right and responsibility to remove, edit, or | ||||||
|  | reject comments, commits, code, wiki edits, issues, and other contributions | ||||||
|  | that are not aligned to this Code of Conduct, or to ban temporarily or | ||||||
|  | permanently any contributor for other behaviors that they deem inappropriate, | ||||||
|  | threatening, offensive, or harmful. | ||||||
|  |  | ||||||
|  | ## Scope | ||||||
|  |  | ||||||
|  | This Code of Conduct applies within all project spaces, and it also applies when | ||||||
|  | an individual is representing the project or its community in public spaces. | ||||||
|  | Examples of representing a project or community include using an official | ||||||
|  | project e-mail address, posting via an official social media account, or acting | ||||||
|  | as an appointed representative at an online or offline event. Representation of | ||||||
|  | a project may be further defined and clarified by project maintainers. | ||||||
|  |  | ||||||
|  | This Code of Conduct also applies outside the project spaces when there is a | ||||||
|  | reasonable belief that an individual's behavior may have a negative impact on | ||||||
|  | the project or its community. | ||||||
|  |  | ||||||
|  | ## Enforcement | ||||||
|  |  | ||||||
|  | Instances of abusive, harassing, or otherwise unacceptable behavior may be | ||||||
|  | reported by contacting the project team at <opensource-conduct@fb.com>. All | ||||||
|  | complaints will be reviewed and investigated and will result in a response that | ||||||
|  | is deemed necessary and appropriate to the circumstances. The project team is | ||||||
|  | obligated to maintain confidentiality with regard to the reporter of an incident. | ||||||
|  | Further details of specific enforcement policies may be posted separately. | ||||||
|  |  | ||||||
|  | Project maintainers who do not follow or enforce the Code of Conduct in good | ||||||
|  | faith may face temporary or permanent repercussions as determined by other | ||||||
|  | members of the project's leadership. | ||||||
|  |  | ||||||
|  | ## Attribution | ||||||
|  |  | ||||||
|  | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, | ||||||
|  | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html | ||||||
|  |  | ||||||
|  | [homepage]: https://www.contributor-covenant.org | ||||||
|  |  | ||||||
|  | For answers to common questions about this code of conduct, see | ||||||
|  | https://www.contributor-covenant.org/faq | ||||||
							
								
								
									
										28
									
								
								CONTRIBUTING.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								CONTRIBUTING.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,28 @@ | |||||||
|  | # CoTracker | ||||||
|  | We want to make contributing to this project as easy and transparent as possible. | ||||||
|  |  | ||||||
|  | ## Pull Requests | ||||||
|  | We actively welcome your pull requests. | ||||||
|  |  | ||||||
|  | 1. Fork the repo and create your branch from `main`. | ||||||
|  | 2. If you've changed APIs, update the documentation. | ||||||
|  | 3. Make sure your code lints. | ||||||
|  | 4. If you haven't already, complete the Contributor License Agreement ("CLA"). | ||||||
|  |  | ||||||
|  | ## Contributor License Agreement ("CLA") | ||||||
|  | In order to accept your pull request, we need you to submit a CLA. You only need | ||||||
|  | to do this once to work on any of Meta's open source projects. | ||||||
|  |  | ||||||
|  | Complete your CLA here: <https://code.facebook.com/cla> | ||||||
|  |  | ||||||
|  | ## Issues | ||||||
|  | We use GitHub issues to track public bugs. Please ensure your description is | ||||||
|  | clear and has sufficient instructions to be able to reproduce the issue. | ||||||
|  |  | ||||||
|  | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe | ||||||
|  | disclosure of security bugs. In those cases, please go through the process | ||||||
|  | outlined on that page and do not file a public issue. | ||||||
|  |  | ||||||
|  | ## License | ||||||
|  | By contributing to CoTracker, you agree that your contributions will be licensed | ||||||
|  | under the LICENSE file in the root directory of this source tree. | ||||||
							
								
								
									
										399
									
								
								LICENSE.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										399
									
								
								LICENSE.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,399 @@ | |||||||
|  | Attribution-NonCommercial 4.0 International | ||||||
|  |  | ||||||
|  | ======================================================================= | ||||||
|  |  | ||||||
|  | Creative Commons Corporation ("Creative Commons") is not a law firm and | ||||||
|  | does not provide legal services or legal advice. Distribution of | ||||||
|  | Creative Commons public licenses does not create a lawyer-client or | ||||||
|  | other relationship. Creative Commons makes its licenses and related | ||||||
|  | information available on an "as-is" basis. Creative Commons gives no | ||||||
|  | warranties regarding its licenses, any material licensed under their | ||||||
|  | terms and conditions, or any related information. Creative Commons | ||||||
|  | disclaims all liability for damages resulting from their use to the | ||||||
|  | fullest extent possible. | ||||||
|  |  | ||||||
|  | Using Creative Commons Public Licenses | ||||||
|  |  | ||||||
|  | Creative Commons public licenses provide a standard set of terms and | ||||||
|  | conditions that creators and other rights holders may use to share | ||||||
|  | original works of authorship and other material subject to copyright | ||||||
|  | and certain other rights specified in the public license below. The | ||||||
|  | following considerations are for informational purposes only, are not | ||||||
|  | exhaustive, and do not form part of our licenses. | ||||||
|  |  | ||||||
|  |      Considerations for licensors: Our public licenses are | ||||||
|  |      intended for use by those authorized to give the public | ||||||
|  |      permission to use material in ways otherwise restricted by | ||||||
|  |      copyright and certain other rights. Our licenses are | ||||||
|  |      irrevocable. Licensors should read and understand the terms | ||||||
|  |      and conditions of the license they choose before applying it. | ||||||
|  |      Licensors should also secure all rights necessary before | ||||||
|  |      applying our licenses so that the public can reuse the | ||||||
|  |      material as expected. Licensors should clearly mark any | ||||||
|  |      material not subject to the license. This includes other CC- | ||||||
|  |      licensed material, or material used under an exception or | ||||||
|  |      limitation to copyright. More considerations for licensors: | ||||||
|  | 	wiki.creativecommons.org/Considerations_for_licensors | ||||||
|  |  | ||||||
|  |      Considerations for the public: By using one of our public | ||||||
|  |      licenses, a licensor grants the public permission to use the | ||||||
|  |      licensed material under specified terms and conditions. If | ||||||
|  |      the licensor's permission is not necessary for any reason--for | ||||||
|  |      example, because of any applicable exception or limitation to | ||||||
|  |      copyright--then that use is not regulated by the license. Our | ||||||
|  |      licenses grant only permissions under copyright and certain | ||||||
|  |      other rights that a licensor has authority to grant. Use of | ||||||
|  |      the licensed material may still be restricted for other | ||||||
|  |      reasons, including because others have copyright or other | ||||||
|  |      rights in the material. A licensor may make special requests, | ||||||
|  |      such as asking that all changes be marked or described. | ||||||
|  |      Although not required by our licenses, you are encouraged to | ||||||
|  |      respect those requests where reasonable. More_considerations | ||||||
|  |      for the public:  | ||||||
|  | 	wiki.creativecommons.org/Considerations_for_licensees | ||||||
|  |  | ||||||
|  | ======================================================================= | ||||||
|  |  | ||||||
|  | Creative Commons Attribution-NonCommercial 4.0 International Public | ||||||
|  | License | ||||||
|  |  | ||||||
|  | By exercising the Licensed Rights (defined below), You accept and agree | ||||||
|  | to be bound by the terms and conditions of this Creative Commons | ||||||
|  | Attribution-NonCommercial 4.0 International Public License ("Public | ||||||
|  | License"). To the extent this Public License may be interpreted as a | ||||||
|  | contract, You are granted the Licensed Rights in consideration of Your | ||||||
|  | acceptance of these terms and conditions, and the Licensor grants You | ||||||
|  | such rights in consideration of benefits the Licensor receives from | ||||||
|  | making the Licensed Material available under these terms and | ||||||
|  | conditions. | ||||||
|  |  | ||||||
|  | Section 1 -- Definitions. | ||||||
|  |  | ||||||
|  |   a. Adapted Material means material subject to Copyright and Similar | ||||||
|  |      Rights that is derived from or based upon the Licensed Material | ||||||
|  |      and in which the Licensed Material is translated, altered, | ||||||
|  |      arranged, transformed, or otherwise modified in a manner requiring | ||||||
|  |      permission under the Copyright and Similar Rights held by the | ||||||
|  |      Licensor. For purposes of this Public License, where the Licensed | ||||||
|  |      Material is a musical work, performance, or sound recording, | ||||||
|  |      Adapted Material is always produced where the Licensed Material is | ||||||
|  |      synched in timed relation with a moving image. | ||||||
|  |  | ||||||
|  |   b. Adapter's License means the license You apply to Your Copyright | ||||||
|  |      and Similar Rights in Your contributions to Adapted Material in | ||||||
|  |      accordance with the terms and conditions of this Public License. | ||||||
|  |  | ||||||
|  |   c. Copyright and Similar Rights means copyright and/or similar rights | ||||||
|  |      closely related to copyright including, without limitation, | ||||||
|  |      performance, broadcast, sound recording, and Sui Generis Database | ||||||
|  |      Rights, without regard to how the rights are labeled or | ||||||
|  |      categorized. For purposes of this Public License, the rights | ||||||
|  |      specified in Section 2(b)(1)-(2) are not Copyright and Similar | ||||||
|  |      Rights. | ||||||
|  |   d. Effective Technological Measures means those measures that, in the | ||||||
|  |      absence of proper authority, may not be circumvented under laws | ||||||
|  |      fulfilling obligations under Article 11 of the WIPO Copyright | ||||||
|  |      Treaty adopted on December 20, 1996, and/or similar international | ||||||
|  |      agreements. | ||||||
|  |  | ||||||
|  |   e. Exceptions and Limitations means fair use, fair dealing, and/or | ||||||
|  |      any other exception or limitation to Copyright and Similar Rights | ||||||
|  |      that applies to Your use of the Licensed Material. | ||||||
|  |  | ||||||
|  |   f. Licensed Material means the artistic or literary work, database, | ||||||
|  |      or other material to which the Licensor applied this Public | ||||||
|  |      License. | ||||||
|  |  | ||||||
|  |   g. Licensed Rights means the rights granted to You subject to the | ||||||
|  |      terms and conditions of this Public License, which are limited to | ||||||
|  |      all Copyright and Similar Rights that apply to Your use of the | ||||||
|  |      Licensed Material and that the Licensor has authority to license. | ||||||
|  |  | ||||||
|  |   h. Licensor means the individual(s) or entity(ies) granting rights | ||||||
|  |      under this Public License. | ||||||
|  |  | ||||||
|  |   i. NonCommercial means not primarily intended for or directed towards | ||||||
|  |      commercial advantage or monetary compensation. For purposes of | ||||||
|  |      this Public License, the exchange of the Licensed Material for | ||||||
|  |      other material subject to Copyright and Similar Rights by digital | ||||||
|  |      file-sharing or similar means is NonCommercial provided there is | ||||||
|  |      no payment of monetary compensation in connection with the | ||||||
|  |      exchange. | ||||||
|  |  | ||||||
|  |   j. Share means to provide material to the public by any means or | ||||||
|  |      process that requires permission under the Licensed Rights, such | ||||||
|  |      as reproduction, public display, public performance, distribution, | ||||||
|  |      dissemination, communication, or importation, and to make material | ||||||
|  |      available to the public including in ways that members of the | ||||||
|  |      public may access the material from a place and at a time | ||||||
|  |      individually chosen by them. | ||||||
|  |  | ||||||
|  |   k. Sui Generis Database Rights means rights other than copyright | ||||||
|  |      resulting from Directive 96/9/EC of the European Parliament and of | ||||||
|  |      the Council of 11 March 1996 on the legal protection of databases, | ||||||
|  |      as amended and/or succeeded, as well as other essentially | ||||||
|  |      equivalent rights anywhere in the world. | ||||||
|  |  | ||||||
|  |   l. You means the individual or entity exercising the Licensed Rights | ||||||
|  |      under this Public License. Your has a corresponding meaning. | ||||||
|  |  | ||||||
|  | Section 2 -- Scope. | ||||||
|  |  | ||||||
|  |   a. License grant. | ||||||
|  |  | ||||||
|  |        1. Subject to the terms and conditions of this Public License, | ||||||
|  |           the Licensor hereby grants You a worldwide, royalty-free, | ||||||
|  |           non-sublicensable, non-exclusive, irrevocable license to | ||||||
|  |           exercise the Licensed Rights in the Licensed Material to: | ||||||
|  |  | ||||||
|  |             a. reproduce and Share the Licensed Material, in whole or | ||||||
|  |                in part, for NonCommercial purposes only; and | ||||||
|  |  | ||||||
|  |             b. produce, reproduce, and Share Adapted Material for | ||||||
|  |                NonCommercial purposes only. | ||||||
|  |  | ||||||
|  |        2. Exceptions and Limitations. For the avoidance of doubt, where | ||||||
|  |           Exceptions and Limitations apply to Your use, this Public | ||||||
|  |           License does not apply, and You do not need to comply with | ||||||
|  |           its terms and conditions. | ||||||
|  |  | ||||||
|  |        3. Term. The term of this Public License is specified in Section | ||||||
|  |           6(a). | ||||||
|  |  | ||||||
|  |        4. Media and formats; technical modifications allowed. The | ||||||
|  |           Licensor authorizes You to exercise the Licensed Rights in | ||||||
|  |           all media and formats whether now known or hereafter created, | ||||||
|  |           and to make technical modifications necessary to do so. The | ||||||
|  |           Licensor waives and/or agrees not to assert any right or | ||||||
|  |           authority to forbid You from making technical modifications | ||||||
|  |           necessary to exercise the Licensed Rights, including | ||||||
|  |           technical modifications necessary to circumvent Effective | ||||||
|  |           Technological Measures. For purposes of this Public License, | ||||||
|  |           simply making modifications authorized by this Section 2(a) | ||||||
|  |           (4) never produces Adapted Material. | ||||||
|  |  | ||||||
|  |        5. Downstream recipients. | ||||||
|  |  | ||||||
|  |             a. Offer from the Licensor -- Licensed Material. Every | ||||||
|  |                recipient of the Licensed Material automatically | ||||||
|  |                receives an offer from the Licensor to exercise the | ||||||
|  |                Licensed Rights under the terms and conditions of this | ||||||
|  |                Public License. | ||||||
|  |  | ||||||
|  |             b. No downstream restrictions. You may not offer or impose | ||||||
|  |                any additional or different terms or conditions on, or | ||||||
|  |                apply any Effective Technological Measures to, the | ||||||
|  |                Licensed Material if doing so restricts exercise of the | ||||||
|  |                Licensed Rights by any recipient of the Licensed | ||||||
|  |                Material. | ||||||
|  |  | ||||||
|  |        6. No endorsement. Nothing in this Public License constitutes or | ||||||
|  |           may be construed as permission to assert or imply that You | ||||||
|  |           are, or that Your use of the Licensed Material is, connected | ||||||
|  |           with, or sponsored, endorsed, or granted official status by, | ||||||
|  |           the Licensor or others designated to receive attribution as | ||||||
|  |           provided in Section 3(a)(1)(A)(i). | ||||||
|  |  | ||||||
|  |   b. Other rights. | ||||||
|  |  | ||||||
|  |        1. Moral rights, such as the right of integrity, are not | ||||||
|  |           licensed under this Public License, nor are publicity, | ||||||
|  |           privacy, and/or other similar personality rights; however, to | ||||||
|  |           the extent possible, the Licensor waives and/or agrees not to | ||||||
|  |           assert any such rights held by the Licensor to the limited | ||||||
|  |           extent necessary to allow You to exercise the Licensed | ||||||
|  |           Rights, but not otherwise. | ||||||
|  |  | ||||||
|  |        2. Patent and trademark rights are not licensed under this | ||||||
|  |           Public License. | ||||||
|  |  | ||||||
|  |        3. To the extent possible, the Licensor waives any right to | ||||||
|  |           collect royalties from You for the exercise of the Licensed | ||||||
|  |           Rights, whether directly or through a collecting society | ||||||
|  |           under any voluntary or waivable statutory or compulsory | ||||||
|  |           licensing scheme. In all other cases the Licensor expressly | ||||||
|  |           reserves any right to collect such royalties, including when | ||||||
|  |           the Licensed Material is used other than for NonCommercial | ||||||
|  |           purposes. | ||||||
|  |  | ||||||
|  | Section 3 -- License Conditions. | ||||||
|  |  | ||||||
|  | Your exercise of the Licensed Rights is expressly made subject to the | ||||||
|  | following conditions. | ||||||
|  |  | ||||||
|  |   a. Attribution. | ||||||
|  |  | ||||||
|  |        1. If You Share the Licensed Material (including in modified | ||||||
|  |           form), You must: | ||||||
|  |  | ||||||
|  |             a. retain the following if it is supplied by the Licensor | ||||||
|  |                with the Licensed Material: | ||||||
|  |  | ||||||
|  |                  i. identification of the creator(s) of the Licensed | ||||||
|  |                     Material and any others designated to receive | ||||||
|  |                     attribution, in any reasonable manner requested by | ||||||
|  |                     the Licensor (including by pseudonym if | ||||||
|  |                     designated); | ||||||
|  |  | ||||||
|  |                 ii. a copyright notice; | ||||||
|  |  | ||||||
|  |                iii. a notice that refers to this Public License; | ||||||
|  |  | ||||||
|  |                 iv. a notice that refers to the disclaimer of | ||||||
|  |                     warranties; | ||||||
|  |  | ||||||
|  |                  v. a URI or hyperlink to the Licensed Material to the | ||||||
|  |                     extent reasonably practicable; | ||||||
|  |  | ||||||
|  |             b. indicate if You modified the Licensed Material and | ||||||
|  |                retain an indication of any previous modifications; and | ||||||
|  |  | ||||||
|  |             c. indicate the Licensed Material is licensed under this | ||||||
|  |                Public License, and include the text of, or the URI or | ||||||
|  |                hyperlink to, this Public License. | ||||||
|  |  | ||||||
|  |        2. You may satisfy the conditions in Section 3(a)(1) in any | ||||||
|  |           reasonable manner based on the medium, means, and context in | ||||||
|  |           which You Share the Licensed Material. For example, it may be | ||||||
|  |           reasonable to satisfy the conditions by providing a URI or | ||||||
|  |           hyperlink to a resource that includes the required | ||||||
|  |           information. | ||||||
|  |  | ||||||
|  |        3. If requested by the Licensor, You must remove any of the | ||||||
|  |           information required by Section 3(a)(1)(A) to the extent | ||||||
|  |           reasonably practicable. | ||||||
|  |  | ||||||
|  |        4. If You Share Adapted Material You produce, the Adapter's | ||||||
|  |           License You apply must not prevent recipients of the Adapted | ||||||
|  |           Material from complying with this Public License. | ||||||
|  |  | ||||||
|  | Section 4 -- Sui Generis Database Rights. | ||||||
|  |  | ||||||
|  | Where the Licensed Rights include Sui Generis Database Rights that | ||||||
|  | apply to Your use of the Licensed Material: | ||||||
|  |  | ||||||
|  |   a. for the avoidance of doubt, Section 2(a)(1) grants You the right | ||||||
|  |      to extract, reuse, reproduce, and Share all or a substantial | ||||||
|  |      portion of the contents of the database for NonCommercial purposes | ||||||
|  |      only; | ||||||
|  |  | ||||||
|  |   b. if You include all or a substantial portion of the database | ||||||
|  |      contents in a database in which You have Sui Generis Database | ||||||
|  |      Rights, then the database in which You have Sui Generis Database | ||||||
|  |      Rights (but not its individual contents) is Adapted Material; and | ||||||
|  |  | ||||||
|  |   c. You must comply with the conditions in Section 3(a) if You Share | ||||||
|  |      all or a substantial portion of the contents of the database. | ||||||
|  |  | ||||||
|  | For the avoidance of doubt, this Section 4 supplements and does not | ||||||
|  | replace Your obligations under this Public License where the Licensed | ||||||
|  | Rights include other Copyright and Similar Rights. | ||||||
|  |  | ||||||
|  | Section 5 -- Disclaimer of Warranties and Limitation of Liability. | ||||||
|  |  | ||||||
|  |   a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE | ||||||
|  |      EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS | ||||||
|  |      AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF | ||||||
|  |      ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, | ||||||
|  |      IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, | ||||||
|  |      WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR | ||||||
|  |      PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, | ||||||
|  |      ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT | ||||||
|  |      KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT | ||||||
|  |      ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. | ||||||
|  |  | ||||||
|  |   b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE | ||||||
|  |      TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, | ||||||
|  |      NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, | ||||||
|  |      INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, | ||||||
|  |      COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR | ||||||
|  |      USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN | ||||||
|  |      ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR | ||||||
|  |      DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR | ||||||
|  |      IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. | ||||||
|  |  | ||||||
|  |   c. The disclaimer of warranties and limitation of liability provided | ||||||
|  |      above shall be interpreted in a manner that, to the extent | ||||||
|  |      possible, most closely approximates an absolute disclaimer and | ||||||
|  |      waiver of all liability. | ||||||
|  |  | ||||||
|  | Section 6 -- Term and Termination. | ||||||
|  |  | ||||||
|  |   a. This Public License applies for the term of the Copyright and | ||||||
|  |      Similar Rights licensed here. However, if You fail to comply with | ||||||
|  |      this Public License, then Your rights under this Public License | ||||||
|  |      terminate automatically. | ||||||
|  |  | ||||||
|  |   b. Where Your right to use the Licensed Material has terminated under | ||||||
|  |      Section 6(a), it reinstates: | ||||||
|  |  | ||||||
|  |        1. automatically as of the date the violation is cured, provided | ||||||
|  |           it is cured within 30 days of Your discovery of the | ||||||
|  |           violation; or | ||||||
|  |  | ||||||
|  |        2. upon express reinstatement by the Licensor. | ||||||
|  |  | ||||||
|  |      For the avoidance of doubt, this Section 6(b) does not affect any | ||||||
|  |      right the Licensor may have to seek remedies for Your violations | ||||||
|  |      of this Public License. | ||||||
|  |  | ||||||
|  |   c. For the avoidance of doubt, the Licensor may also offer the | ||||||
|  |      Licensed Material under separate terms or conditions or stop | ||||||
|  |      distributing the Licensed Material at any time; however, doing so | ||||||
|  |      will not terminate this Public License. | ||||||
|  |  | ||||||
|  |   d. Sections 1, 5, 6, 7, and 8 survive termination of this Public | ||||||
|  |      License. | ||||||
|  |  | ||||||
|  | Section 7 -- Other Terms and Conditions. | ||||||
|  |  | ||||||
|  |   a. The Licensor shall not be bound by any additional or different | ||||||
|  |      terms or conditions communicated by You unless expressly agreed. | ||||||
|  |  | ||||||
|  |   b. Any arrangements, understandings, or agreements regarding the | ||||||
|  |      Licensed Material not stated herein are separate from and | ||||||
|  |      independent of the terms and conditions of this Public License. | ||||||
|  |  | ||||||
|  | Section 8 -- Interpretation. | ||||||
|  |  | ||||||
|  |   a. For the avoidance of doubt, this Public License does not, and | ||||||
|  |      shall not be interpreted to, reduce, limit, restrict, or impose | ||||||
|  |      conditions on any use of the Licensed Material that could lawfully | ||||||
|  |      be made without permission under this Public License. | ||||||
|  |  | ||||||
|  |   b. To the extent possible, if any provision of this Public License is | ||||||
|  |      deemed unenforceable, it shall be automatically reformed to the | ||||||
|  |      minimum extent necessary to make it enforceable. If the provision | ||||||
|  |      cannot be reformed, it shall be severed from this Public License | ||||||
|  |      without affecting the enforceability of the remaining terms and | ||||||
|  |      conditions. | ||||||
|  |  | ||||||
|  |   c. No term or condition of this Public License will be waived and no | ||||||
|  |      failure to comply consented to unless expressly agreed to by the | ||||||
|  |      Licensor. | ||||||
|  |  | ||||||
|  |   d. Nothing in this Public License constitutes or may be interpreted | ||||||
|  |      as a limitation upon, or waiver of, any privileges and immunities | ||||||
|  |      that apply to the Licensor or You, including from the legal | ||||||
|  |      processes of any jurisdiction or authority. | ||||||
|  |  | ||||||
|  | ======================================================================= | ||||||
|  |  | ||||||
|  | Creative Commons is not a party to its public | ||||||
|  | licenses. Notwithstanding, Creative Commons may elect to apply one of | ||||||
|  | its public licenses to material it publishes and in those instances | ||||||
|  | will be considered the “Licensor.” The text of the Creative Commons | ||||||
|  | public licenses is dedicated to the public domain under the CC0 Public | ||||||
|  | Domain Dedication. Except for the limited purpose of indicating that | ||||||
|  | material is shared under a Creative Commons public license or as | ||||||
|  | otherwise permitted by the Creative Commons policies published at | ||||||
|  | creativecommons.org/policies, Creative Commons does not authorize the | ||||||
|  | use of the trademark "Creative Commons" or any other trademark or logo | ||||||
|  | of Creative Commons without its prior written consent including, | ||||||
|  | without limitation, in connection with any unauthorized modifications | ||||||
|  | to any of its public licenses or any other arrangements, | ||||||
|  | understandings, or agreements concerning use of licensed material. For | ||||||
|  | the avoidance of doubt, this paragraph does not form part of the | ||||||
|  | public licenses. | ||||||
|  |  | ||||||
|  | Creative Commons may be contacted at creativecommons.org. | ||||||
							
								
								
									
										94
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | |||||||
|  | # CoTracker: It is Better to Track Together | ||||||
|  |  | ||||||
|  | **[Meta AI Research, FAIR](https://ai.facebook.com/research/)**; **[University of Oxford, VGG](https://www.robots.ox.ac.uk/~vgg/)** | ||||||
|  |  | ||||||
|  | [Nikita Karaev](https://nikitakaraevv.github.io/), [Ignacio Rocco](https://www.irocco.info/), [Benjamin Graham](https://ai.facebook.com/people/benjamin-graham/), [Natalia Neverova](https://nneverova.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/) | ||||||
|  |  | ||||||
|  | [[`Paper`]()] [[`Project`](https://co-tracker.github.io/)] [[`BibTeX`](#citing-cotracker)] | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | **CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow. | ||||||
|  |   | ||||||
|  | CoTracker can track: | ||||||
|  | - **Every pixel** within a video | ||||||
|  | - Points sampled on a regular grid on any video frame  | ||||||
|  | - Manually selected points | ||||||
|  |  | ||||||
|  | Try these tracking modes for yourself with our [Colab demo](https://github.com/facebookresearch/co-tracker/notebooks/demo.ipynb). | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ## Installation Instructions | ||||||
|  | Ensure you have both PyTorch and TorchVision installed on your system. Follow the instructions [here](https://pytorch.org/get-started/locally/) for the installation. We strongly recommend installing both PyTorch and TorchVision with CUDA support. | ||||||
|  |  | ||||||
|  | ## Steps to Install CoTracker and its dependencies: | ||||||
|  | ``` | ||||||
|  | git clone https://github.com/facebookresearch/co-tracker | ||||||
|  | cd co-tracker | ||||||
|  | pip install -e . | ||||||
|  | pip install opencv-python einops timm matplotlib moviepy flow_vis | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ## Model Weights Download: | ||||||
|  | ``` | ||||||
|  | mkdir checkpoints | ||||||
|  | cd checkpoints | ||||||
|  | wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth | ||||||
|  | wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_12.pth | ||||||
|  | wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_8_wind_16.pth | ||||||
|  | cd .. | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ## Running the Demo: | ||||||
|  | Try our [Colab demo](https://github.com/facebookresearch/co-tracker/notebooks/demo.ipynb) or run a local demo with 10*10 points sampled on a grid on the first frame of a video: | ||||||
|  | ``` | ||||||
|  | python demo.py --grid_size 10 | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## Evaluation | ||||||
|  | To reproduce the results presented in the paper, download the following datasets: | ||||||
|  | - [TAP-Vid](https://github.com/deepmind/tapnet) | ||||||
|  | - [BADJA](https://github.com/benjiebob/BADJA) | ||||||
|  | - [ZJU-Mocap (FastCapture)](https://arxiv.org/abs/2303.11898) | ||||||
|  |  | ||||||
|  | And install the necessary dependencies: | ||||||
|  | ``` | ||||||
|  | pip install hydra-core==1.1.0 mediapy tensorboard  | ||||||
|  | ``` | ||||||
|  | Then, execute the following command to evaluate on BADJA: | ||||||
|  | ``` | ||||||
|  | python ./cotracker/evaluation/evaluate.py --config-name eval_badja exp_dir=./eval_outputs dataset_root=your/badja/path | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## Training | ||||||
|  | To train the CoTracker as described in our paper, you first need to generate annotations for [Google Kubric](https://github.com/google-research/kubric) MOVI-f dataset.  Instructions for annotation generation can be found [here](https://github.com/deepmind/tapnet). | ||||||
|  |  | ||||||
|  | Once you have the annotated dataset, you need to make sure you followed the steps for evaluation setup and install the training dependencies: | ||||||
|  | ``` | ||||||
|  | pip install pytorch_lightning==1.6.0 | ||||||
|  | ``` | ||||||
|  |  launch training on Kubric. Our model was trained using 32 GPUs, and you can adjust the parameters to best suit your hardware setup. | ||||||
|  | ``` | ||||||
|  | python train.py --batch_size 1 --num_workers 28 \ | ||||||
|  | --num_steps 50000 --ckpt_path ./ --model_name cotracker \ | ||||||
|  | --save_freq 200 --sequence_len 24 --eval_datasets tapvid_davis_first badja \ | ||||||
|  | --traj_per_sample 256 --sliding_window_len 8 --updateformer_space_depth 6 --updateformer_time_depth 6 \ | ||||||
|  | --save_every_n_epoch 10 --evaluate_every_n_epoch 10 --model_stride 4 | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## License | ||||||
|  | The majority of CoTracker is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Particle Video Revisited is licensed under the MIT license, TAP-Vid is licensed under the Apache 2.0 license. | ||||||
|  |  | ||||||
|  | ## Citing CoTracker | ||||||
|  | If you find our repository useful, please consider giving it a star ⭐ and citing our paper in your work: | ||||||
|  | ``` | ||||||
|  | @article{karaev2023cotracker, | ||||||
|  |   title={CoTracker: It is Better to Track Together}, | ||||||
|  |   author={Nikita Karaev and Ignacio Rocco and Benjamin Graham and Natalia Neverova and Andrea Vedaldi and Christian Rupprecht}, | ||||||
|  |   journal={arxiv}, | ||||||
|  |   year={2023} | ||||||
|  | } | ||||||
|  | ``` | ||||||
							
								
								
									
										
											BIN
										
									
								
								assets/apple.mp4
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								assets/apple.mp4
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								assets/apple_mask.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								assets/apple_mask.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 14 KiB | 
							
								
								
									
										
											BIN
										
									
								
								assets/bmx-bumps.gif
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								assets/bmx-bumps.gif
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 7.3 MiB | 
							
								
								
									
										5
									
								
								cotracker/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
							
								
								
									
										5
									
								
								cotracker/datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
							
								
								
									
										390
									
								
								cotracker/datasets/badja_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										390
									
								
								cotracker/datasets/badja_dataset.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,390 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import numpy as np | ||||||
|  | import os | ||||||
|  |  | ||||||
|  | import json | ||||||
|  | import imageio | ||||||
|  | import cv2 | ||||||
|  |  | ||||||
|  | from enum import Enum | ||||||
|  |  | ||||||
|  | from cotracker.datasets.utils import CoTrackerData, resize_sample | ||||||
|  |  | ||||||
|  | IGNORE_ANIMALS = [ | ||||||
|  |     # "bear.json", | ||||||
|  |     # "camel.json", | ||||||
|  |     "cat_jump.json" | ||||||
|  |     # "cows.json", | ||||||
|  |     # "dog.json", | ||||||
|  |     # "dog-agility.json", | ||||||
|  |     # "horsejump-high.json", | ||||||
|  |     # "horsejump-low.json", | ||||||
|  |     # "impala0.json", | ||||||
|  |     # "rs_dog.json" | ||||||
|  |     "tiger.json" | ||||||
|  | ] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SMALJointCatalog(Enum): | ||||||
|  |     # body_0 = 0 | ||||||
|  |     # body_1 = 1 | ||||||
|  |     # body_2 = 2 | ||||||
|  |     # body_3 = 3 | ||||||
|  |     # body_4 = 4 | ||||||
|  |     # body_5 = 5 | ||||||
|  |     # body_6 = 6 | ||||||
|  |     # upper_right_0 = 7 | ||||||
|  |     upper_right_1 = 8 | ||||||
|  |     upper_right_2 = 9 | ||||||
|  |     upper_right_3 = 10 | ||||||
|  |     # upper_left_0 = 11 | ||||||
|  |     upper_left_1 = 12 | ||||||
|  |     upper_left_2 = 13 | ||||||
|  |     upper_left_3 = 14 | ||||||
|  |     neck_lower = 15 | ||||||
|  |     # neck_upper = 16 | ||||||
|  |     # lower_right_0 = 17 | ||||||
|  |     lower_right_1 = 18 | ||||||
|  |     lower_right_2 = 19 | ||||||
|  |     lower_right_3 = 20 | ||||||
|  |     # lower_left_0 = 21 | ||||||
|  |     lower_left_1 = 22 | ||||||
|  |     lower_left_2 = 23 | ||||||
|  |     lower_left_3 = 24 | ||||||
|  |     tail_0 = 25 | ||||||
|  |     # tail_1 = 26 | ||||||
|  |     # tail_2 = 27 | ||||||
|  |     tail_3 = 28 | ||||||
|  |     # tail_4 = 29 | ||||||
|  |     # tail_5 = 30 | ||||||
|  |     tail_6 = 31 | ||||||
|  |     jaw = 32 | ||||||
|  |     nose = 33  # ADDED JOINT FOR VERTEX 1863 | ||||||
|  |     # chin = 34 # ADDED JOINT FOR VERTEX 26 | ||||||
|  |     right_ear = 35  # ADDED JOINT FOR VERTEX 149 | ||||||
|  |     left_ear = 36  # ADDED JOINT FOR VERTEX 2124 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SMALJointInfo: | ||||||
|  |     def __init__(self): | ||||||
|  |         # These are the | ||||||
|  |         self.annotated_classes = np.array( | ||||||
|  |             [ | ||||||
|  |                 8, | ||||||
|  |                 9, | ||||||
|  |                 10,  # upper_right | ||||||
|  |                 12, | ||||||
|  |                 13, | ||||||
|  |                 14,  # upper_left | ||||||
|  |                 15,  # neck | ||||||
|  |                 18, | ||||||
|  |                 19, | ||||||
|  |                 20,  # lower_right | ||||||
|  |                 22, | ||||||
|  |                 23, | ||||||
|  |                 24,  # lower_left | ||||||
|  |                 25, | ||||||
|  |                 28, | ||||||
|  |                 31,  # tail | ||||||
|  |                 32, | ||||||
|  |                 33,  # head | ||||||
|  |                 35,  # right_ear | ||||||
|  |                 36, | ||||||
|  |             ] | ||||||
|  |         )  # left_ear | ||||||
|  |  | ||||||
|  |         self.annotated_markers = np.array( | ||||||
|  |             [ | ||||||
|  |                 cv2.MARKER_CROSS, | ||||||
|  |                 cv2.MARKER_STAR, | ||||||
|  |                 cv2.MARKER_TRIANGLE_DOWN, | ||||||
|  |                 cv2.MARKER_CROSS, | ||||||
|  |                 cv2.MARKER_STAR, | ||||||
|  |                 cv2.MARKER_TRIANGLE_DOWN, | ||||||
|  |                 cv2.MARKER_CROSS, | ||||||
|  |                 cv2.MARKER_CROSS, | ||||||
|  |                 cv2.MARKER_STAR, | ||||||
|  |                 cv2.MARKER_TRIANGLE_DOWN, | ||||||
|  |                 cv2.MARKER_CROSS, | ||||||
|  |                 cv2.MARKER_STAR, | ||||||
|  |                 cv2.MARKER_TRIANGLE_DOWN, | ||||||
|  |                 cv2.MARKER_CROSS, | ||||||
|  |                 cv2.MARKER_STAR, | ||||||
|  |                 cv2.MARKER_TRIANGLE_DOWN, | ||||||
|  |                 cv2.MARKER_CROSS, | ||||||
|  |                 cv2.MARKER_STAR, | ||||||
|  |                 cv2.MARKER_CROSS, | ||||||
|  |                 cv2.MARKER_CROSS, | ||||||
|  |             ] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         self.joint_regions = np.array( | ||||||
|  |             [ | ||||||
|  |                 0, | ||||||
|  |                 0, | ||||||
|  |                 0, | ||||||
|  |                 0, | ||||||
|  |                 0, | ||||||
|  |                 0, | ||||||
|  |                 0, | ||||||
|  |                 1, | ||||||
|  |                 1, | ||||||
|  |                 1, | ||||||
|  |                 1, | ||||||
|  |                 2, | ||||||
|  |                 2, | ||||||
|  |                 2, | ||||||
|  |                 2, | ||||||
|  |                 3, | ||||||
|  |                 3, | ||||||
|  |                 4, | ||||||
|  |                 4, | ||||||
|  |                 4, | ||||||
|  |                 4, | ||||||
|  |                 5, | ||||||
|  |                 5, | ||||||
|  |                 5, | ||||||
|  |                 5, | ||||||
|  |                 6, | ||||||
|  |                 6, | ||||||
|  |                 6, | ||||||
|  |                 6, | ||||||
|  |                 6, | ||||||
|  |                 6, | ||||||
|  |                 6, | ||||||
|  |                 7, | ||||||
|  |                 7, | ||||||
|  |                 7, | ||||||
|  |                 8, | ||||||
|  |                 9, | ||||||
|  |             ] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         self.annotated_joint_region = self.joint_regions[self.annotated_classes] | ||||||
|  |         self.region_colors = np.array( | ||||||
|  |             [ | ||||||
|  |                 [250, 190, 190],  # body, light pink | ||||||
|  |                 [60, 180, 75],  # upper_right, green | ||||||
|  |                 [230, 25, 75],  # upper_left, red | ||||||
|  |                 [128, 0, 0],  # neck, maroon | ||||||
|  |                 [0, 130, 200],  # lower_right, blue | ||||||
|  |                 [255, 255, 25],  # lower_left, yellow | ||||||
|  |                 [240, 50, 230],  # tail, majenta | ||||||
|  |                 [245, 130, 48],  # jaw / nose / chin, orange | ||||||
|  |                 [29, 98, 115],  # right_ear, turquoise | ||||||
|  |                 [255, 153, 204], | ||||||
|  |             ] | ||||||
|  |         )  # left_ear, pink | ||||||
|  |  | ||||||
|  |         self.joint_colors = np.array(self.region_colors)[self.annotated_joint_region] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class BADJAData: | ||||||
|  |     def __init__(self, data_root, complete=False): | ||||||
|  |         annotations_path = os.path.join(data_root, "joint_annotations") | ||||||
|  |  | ||||||
|  |         self.animal_dict = {} | ||||||
|  |         self.animal_count = 0 | ||||||
|  |         self.smal_joint_info = SMALJointInfo() | ||||||
|  |         for __, animal_json in enumerate(sorted(os.listdir(annotations_path))): | ||||||
|  |             if animal_json not in IGNORE_ANIMALS: | ||||||
|  |                 json_path = os.path.join(annotations_path, animal_json) | ||||||
|  |                 with open(json_path) as json_data: | ||||||
|  |                     animal_joint_data = json.load(json_data) | ||||||
|  |  | ||||||
|  |                 filenames = [] | ||||||
|  |                 segnames = [] | ||||||
|  |                 joints = [] | ||||||
|  |                 visible = [] | ||||||
|  |  | ||||||
|  |                 first_path = animal_joint_data[0]["segmentation_path"] | ||||||
|  |                 last_path = animal_joint_data[-1]["segmentation_path"] | ||||||
|  |                 first_frame = first_path.split("/")[-1] | ||||||
|  |                 last_frame = last_path.split("/")[-1] | ||||||
|  |  | ||||||
|  |                 if not "extra_videos" in first_path: | ||||||
|  |                     animal = first_path.split("/")[-2] | ||||||
|  |  | ||||||
|  |                     first_frame_int = int(first_frame.split(".")[0]) | ||||||
|  |                     last_frame_int = int(last_frame.split(".")[0]) | ||||||
|  |  | ||||||
|  |                     for fr in range(first_frame_int, last_frame_int + 1): | ||||||
|  |                         ref_file_name = os.path.join( | ||||||
|  |                             data_root, | ||||||
|  |                             "DAVIS/JPEGImages/Full-Resolution/%s/%05d.jpg" | ||||||
|  |                             % (animal, fr), | ||||||
|  |                         ) | ||||||
|  |                         ref_seg_name = os.path.join( | ||||||
|  |                             data_root, | ||||||
|  |                             "DAVIS/Annotations/Full-Resolution/%s/%05d.png" | ||||||
|  |                             % (animal, fr), | ||||||
|  |                         ) | ||||||
|  |  | ||||||
|  |                         foundit = False | ||||||
|  |                         for ind, image_annotation in enumerate(animal_joint_data): | ||||||
|  |                             file_name = os.path.join( | ||||||
|  |                                 data_root, image_annotation["image_path"] | ||||||
|  |                             ) | ||||||
|  |                             seg_name = os.path.join( | ||||||
|  |                                 data_root, image_annotation["segmentation_path"] | ||||||
|  |                             ) | ||||||
|  |  | ||||||
|  |                             if file_name == ref_file_name: | ||||||
|  |                                 foundit = True | ||||||
|  |                                 label_ind = ind | ||||||
|  |  | ||||||
|  |                         if foundit: | ||||||
|  |                             image_annotation = animal_joint_data[label_ind] | ||||||
|  |                             file_name = os.path.join( | ||||||
|  |                                 data_root, image_annotation["image_path"] | ||||||
|  |                             ) | ||||||
|  |                             seg_name = os.path.join( | ||||||
|  |                                 data_root, image_annotation["segmentation_path"] | ||||||
|  |                             ) | ||||||
|  |                             joint = np.array(image_annotation["joints"]) | ||||||
|  |                             vis = np.array(image_annotation["visibility"]) | ||||||
|  |                         else: | ||||||
|  |                             file_name = ref_file_name | ||||||
|  |                             seg_name = ref_seg_name | ||||||
|  |                             joint = None | ||||||
|  |                             vis = None | ||||||
|  |  | ||||||
|  |                         filenames.append(file_name) | ||||||
|  |                         segnames.append(seg_name) | ||||||
|  |                         joints.append(joint) | ||||||
|  |                         visible.append(vis) | ||||||
|  |  | ||||||
|  |                 if len(filenames): | ||||||
|  |                     self.animal_dict[self.animal_count] = ( | ||||||
|  |                         filenames, | ||||||
|  |                         segnames, | ||||||
|  |                         joints, | ||||||
|  |                         visible, | ||||||
|  |                     ) | ||||||
|  |                     self.animal_count += 1 | ||||||
|  |         print("Loaded BADJA dataset") | ||||||
|  |  | ||||||
|  |     def get_loader(self): | ||||||
|  |         for __ in range(int(1e6)): | ||||||
|  |             animal_id = np.random.choice(len(self.animal_dict.keys())) | ||||||
|  |             filenames, segnames, joints, visible = self.animal_dict[animal_id] | ||||||
|  |  | ||||||
|  |             image_id = np.random.randint(0, len(filenames)) | ||||||
|  |  | ||||||
|  |             seg_file = segnames[image_id] | ||||||
|  |             image_file = filenames[image_id] | ||||||
|  |  | ||||||
|  |             joints = joints[image_id].copy() | ||||||
|  |             joints = joints[self.smal_joint_info.annotated_classes] | ||||||
|  |             visible = visible[image_id][self.smal_joint_info.annotated_classes] | ||||||
|  |  | ||||||
|  |             rgb_img = imageio.imread(image_file)  # , mode='RGB') | ||||||
|  |             sil_img = imageio.imread(seg_file)  # , mode='RGB') | ||||||
|  |  | ||||||
|  |             rgb_h, rgb_w, _ = rgb_img.shape | ||||||
|  |             sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST) | ||||||
|  |  | ||||||
|  |             yield rgb_img, sil_img, joints, visible, image_file | ||||||
|  |  | ||||||
|  |     def get_video(self, animal_id): | ||||||
|  |         filenames, segnames, joint, visible = self.animal_dict[animal_id] | ||||||
|  |  | ||||||
|  |         rgbs = [] | ||||||
|  |         segs = [] | ||||||
|  |         joints = [] | ||||||
|  |         visibles = [] | ||||||
|  |  | ||||||
|  |         for s in range(len(filenames)): | ||||||
|  |             image_file = filenames[s] | ||||||
|  |             rgb_img = imageio.imread(image_file)  # , mode='RGB') | ||||||
|  |             rgb_h, rgb_w, _ = rgb_img.shape | ||||||
|  |  | ||||||
|  |             seg_file = segnames[s] | ||||||
|  |             sil_img = imageio.imread(seg_file)  # , mode='RGB') | ||||||
|  |             sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST) | ||||||
|  |  | ||||||
|  |             jo = joint[s] | ||||||
|  |  | ||||||
|  |             if jo is not None: | ||||||
|  |                 joi = joint[s].copy() | ||||||
|  |                 joi = joi[self.smal_joint_info.annotated_classes] | ||||||
|  |                 vis = visible[s][self.smal_joint_info.annotated_classes] | ||||||
|  |             else: | ||||||
|  |                 joi = None | ||||||
|  |                 vis = None | ||||||
|  |  | ||||||
|  |             rgbs.append(rgb_img) | ||||||
|  |             segs.append(sil_img) | ||||||
|  |             joints.append(joi) | ||||||
|  |             visibles.append(vis) | ||||||
|  |  | ||||||
|  |         return rgbs, segs, joints, visibles, filenames[0] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class BadjaDataset(torch.utils.data.Dataset): | ||||||
|  |     def __init__( | ||||||
|  |         self, data_root, max_seq_len=1000, dataset_resolution=(384, 512) | ||||||
|  |     ): | ||||||
|  |  | ||||||
|  |         self.data_root = data_root | ||||||
|  |         self.badja_data = BADJAData(data_root) | ||||||
|  |         self.max_seq_len = max_seq_len | ||||||
|  |         self.dataset_resolution = dataset_resolution | ||||||
|  |         print( | ||||||
|  |             "found %d unique videos in %s" | ||||||
|  |             % (self.badja_data.animal_count, self.data_root) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def __getitem__(self, index): | ||||||
|  |  | ||||||
|  |         rgbs, segs, joints, visibles, filename = self.badja_data.get_video(index) | ||||||
|  |         S = len(rgbs) | ||||||
|  |         H, W, __ = rgbs[0].shape | ||||||
|  |         H, W, __ = segs[0].shape | ||||||
|  |  | ||||||
|  |         N, __ = joints[0].shape | ||||||
|  |  | ||||||
|  |         # let's eliminate the Nones | ||||||
|  |         # note the first one is guaranteed present | ||||||
|  |         for s in range(1, S): | ||||||
|  |             if joints[s] is None: | ||||||
|  |                 joints[s] = np.zeros_like(joints[0]) | ||||||
|  |                 visibles[s] = np.zeros_like(visibles[0]) | ||||||
|  |  | ||||||
|  |         # eliminate the mystery dim | ||||||
|  |         segs = [seg[:, :, 0] for seg in segs] | ||||||
|  |  | ||||||
|  |         rgbs = np.stack(rgbs, 0) | ||||||
|  |         segs = np.stack(segs, 0) | ||||||
|  |         trajs = np.stack(joints, 0) | ||||||
|  |         visibles = np.stack(visibles, 0) | ||||||
|  |  | ||||||
|  |         rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0, 3, 1, 2).float() | ||||||
|  |         segs = torch.from_numpy(segs).reshape(S, 1, H, W).float() | ||||||
|  |         trajs = torch.from_numpy(trajs).reshape(S, N, 2).float() | ||||||
|  |         visibles = torch.from_numpy(visibles).reshape(S, N) | ||||||
|  |  | ||||||
|  |         rgbs = rgbs[: self.max_seq_len] | ||||||
|  |         segs = segs[: self.max_seq_len] | ||||||
|  |         trajs = trajs[: self.max_seq_len] | ||||||
|  |         visibles = visibles[: self.max_seq_len] | ||||||
|  |         # apparently the coords are in yx order | ||||||
|  |         trajs = torch.flip(trajs, [2]) | ||||||
|  |  | ||||||
|  |         if "extra_videos" in filename: | ||||||
|  |             seq_name = filename.split("/")[-3] | ||||||
|  |         else: | ||||||
|  |             seq_name = filename.split("/")[-2] | ||||||
|  |  | ||||||
|  |         rgbs, trajs, segs = resize_sample(rgbs, trajs, segs, self.dataset_resolution) | ||||||
|  |  | ||||||
|  |         return CoTrackerData(rgbs, segs, trajs, visibles, seq_name=seq_name) | ||||||
|  |  | ||||||
|  |     def __len__(self): | ||||||
|  |         return self.badja_data.animal_count | ||||||
							
								
								
									
										72
									
								
								cotracker/datasets/fast_capture_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										72
									
								
								cotracker/datasets/fast_capture_dataset.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,72 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
|  | import os | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  | # from PIL import Image | ||||||
|  | import imageio | ||||||
|  | import numpy as np | ||||||
|  | from cotracker.datasets.utils import CoTrackerData, resize_sample | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FastCaptureDataset(torch.utils.data.Dataset): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         data_root, | ||||||
|  |         max_seq_len=50, | ||||||
|  |         max_num_points=20, | ||||||
|  |         dataset_resolution=(384, 512), | ||||||
|  |     ): | ||||||
|  |  | ||||||
|  |         self.data_root = data_root | ||||||
|  |         self.seq_names = os.listdir(os.path.join(data_root, "renders_local_rm")) | ||||||
|  |         self.pth_dir = os.path.join(data_root, "zju_tracking") | ||||||
|  |         self.max_seq_len = max_seq_len | ||||||
|  |         self.max_num_points = max_num_points | ||||||
|  |         self.dataset_resolution = dataset_resolution | ||||||
|  |         print("found %d unique videos in %s" % (len(self.seq_names), self.data_root)) | ||||||
|  |  | ||||||
|  |     def __getitem__(self, index): | ||||||
|  |         seq_name = self.seq_names[index] | ||||||
|  |         spath = os.path.join(self.data_root, "renders_local_rm", seq_name) | ||||||
|  |         pthpath = os.path.join(self.pth_dir, seq_name + ".pth") | ||||||
|  |  | ||||||
|  |         rgbs = [] | ||||||
|  |         img_paths = sorted(os.listdir(spath)) | ||||||
|  |         for i, img_path in enumerate(img_paths): | ||||||
|  |             if i < self.max_seq_len: | ||||||
|  |                 rgbs.append(imageio.imread(os.path.join(spath, img_path))) | ||||||
|  |  | ||||||
|  |         annot_dict = torch.load(pthpath) | ||||||
|  |         traj_2d = annot_dict["traj_2d"][:, :, : self.max_seq_len] | ||||||
|  |         visibility = annot_dict["visibility"][:, : self.max_seq_len] | ||||||
|  |  | ||||||
|  |         S = len(rgbs) | ||||||
|  |         H, W, __ = rgbs[0].shape | ||||||
|  |         *_, S = traj_2d.shape | ||||||
|  |         visibile_pts_first_frame_inds = (visibility[:, 0] > 0).nonzero(as_tuple=False)[ | ||||||
|  |             :, 0 | ||||||
|  |         ] | ||||||
|  |         torch.manual_seed(0) | ||||||
|  |         point_inds = torch.randperm(len(visibile_pts_first_frame_inds))[ | ||||||
|  |             : self.max_num_points | ||||||
|  |         ] | ||||||
|  |         visible_inds_sampled = visibile_pts_first_frame_inds[point_inds] | ||||||
|  |  | ||||||
|  |         rgbs = np.stack(rgbs, 0) | ||||||
|  |         rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0, 3, 1, 2).float() | ||||||
|  |  | ||||||
|  |         segs = torch.ones(S, 1, H, W).float() | ||||||
|  |         trajs = traj_2d[visible_inds_sampled].permute(2, 0, 1).float() | ||||||
|  |         visibles = visibility[visible_inds_sampled].permute(1, 0) | ||||||
|  |  | ||||||
|  |         rgbs, trajs, segs = resize_sample(rgbs, trajs, segs, self.dataset_resolution) | ||||||
|  |  | ||||||
|  |         return CoTrackerData(rgbs, segs, trajs, visibles, seq_name=seq_name) | ||||||
|  |  | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self.seq_names) | ||||||
							
								
								
									
										494
									
								
								cotracker/datasets/kubric_movif_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										494
									
								
								cotracker/datasets/kubric_movif_dataset.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,494 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
|  | import os | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  | import imageio | ||||||
|  | import numpy as np | ||||||
|  |  | ||||||
|  | from cotracker.datasets.utils import CoTrackerData | ||||||
|  | from torchvision.transforms import ColorJitter, GaussianBlur | ||||||
|  | from PIL import Image | ||||||
|  | import cv2 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CoTrackerDataset(torch.utils.data.Dataset): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         data_root, | ||||||
|  |         crop_size=(384, 512), | ||||||
|  |         seq_len=24, | ||||||
|  |         traj_per_sample=768, | ||||||
|  |         sample_vis_1st_frame=False, | ||||||
|  |         use_augs=False, | ||||||
|  |     ): | ||||||
|  |         super(CoTrackerDataset, self).__init__() | ||||||
|  |         np.random.seed(0) | ||||||
|  |         torch.manual_seed(0) | ||||||
|  |         self.data_root = data_root | ||||||
|  |         self.seq_len = seq_len | ||||||
|  |         self.traj_per_sample = traj_per_sample | ||||||
|  |         self.sample_vis_1st_frame = sample_vis_1st_frame | ||||||
|  |         self.use_augs = use_augs | ||||||
|  |         self.crop_size = crop_size | ||||||
|  |  | ||||||
|  |         # photometric augmentation | ||||||
|  |         self.photo_aug = ColorJitter( | ||||||
|  |             brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14 | ||||||
|  |         ) | ||||||
|  |         self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0)) | ||||||
|  |  | ||||||
|  |         self.blur_aug_prob = 0.25 | ||||||
|  |         self.color_aug_prob = 0.25 | ||||||
|  |  | ||||||
|  |         # occlusion augmentation | ||||||
|  |         self.eraser_aug_prob = 0.5 | ||||||
|  |         self.eraser_bounds = [2, 100] | ||||||
|  |         self.eraser_max = 10 | ||||||
|  |  | ||||||
|  |         # occlusion augmentation | ||||||
|  |         self.replace_aug_prob = 0.5 | ||||||
|  |         self.replace_bounds = [2, 100] | ||||||
|  |         self.replace_max = 10 | ||||||
|  |  | ||||||
|  |         # spatial augmentations | ||||||
|  |         self.pad_bounds = [0, 100] | ||||||
|  |         self.crop_size = crop_size | ||||||
|  |         self.resize_lim = [0.25, 2.0]  # sample resizes from here | ||||||
|  |         self.resize_delta = 0.2 | ||||||
|  |         self.max_crop_offset = 50 | ||||||
|  |  | ||||||
|  |         self.do_flip = True | ||||||
|  |         self.h_flip_prob = 0.5 | ||||||
|  |         self.v_flip_prob = 0.5 | ||||||
|  |  | ||||||
|  |     def getitem_helper(self, index): | ||||||
|  |         return NotImplementedError | ||||||
|  |  | ||||||
|  |     def __getitem__(self, index): | ||||||
|  |         gotit = False | ||||||
|  |  | ||||||
|  |         sample, gotit = self.getitem_helper(index) | ||||||
|  |         if not gotit: | ||||||
|  |             print("warning: sampling failed") | ||||||
|  |             # fake sample, so we can still collate | ||||||
|  |             sample = CoTrackerData( | ||||||
|  |                 video=torch.zeros( | ||||||
|  |                     (self.seq_len, 3, self.crop_size[0], self.crop_size[1]) | ||||||
|  |                 ), | ||||||
|  |                 segmentation=torch.zeros( | ||||||
|  |                     (self.seq_len, 1, self.crop_size[0], self.crop_size[1]) | ||||||
|  |                 ), | ||||||
|  |                 trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)), | ||||||
|  |                 visibility=torch.zeros((self.seq_len, self.traj_per_sample)), | ||||||
|  |                 valid=torch.zeros((self.seq_len, self.traj_per_sample)), | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         return sample, gotit | ||||||
|  |  | ||||||
|  |     def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True): | ||||||
|  |         T, N, _ = trajs.shape | ||||||
|  |  | ||||||
|  |         S = len(rgbs) | ||||||
|  |         H, W = rgbs[0].shape[:2] | ||||||
|  |         assert S == T | ||||||
|  |  | ||||||
|  |         if eraser: | ||||||
|  |             ############ eraser transform (per image after the first) ############ | ||||||
|  |             rgbs = [rgb.astype(np.float32) for rgb in rgbs] | ||||||
|  |             for i in range(1, S): | ||||||
|  |                 if np.random.rand() < self.eraser_aug_prob: | ||||||
|  |                     for _ in range( | ||||||
|  |                         np.random.randint(1, self.eraser_max + 1) | ||||||
|  |                     ):  # number of times to occlude | ||||||
|  |  | ||||||
|  |                         xc = np.random.randint(0, W) | ||||||
|  |                         yc = np.random.randint(0, H) | ||||||
|  |                         dx = np.random.randint( | ||||||
|  |                             self.eraser_bounds[0], self.eraser_bounds[1] | ||||||
|  |                         ) | ||||||
|  |                         dy = np.random.randint( | ||||||
|  |                             self.eraser_bounds[0], self.eraser_bounds[1] | ||||||
|  |                         ) | ||||||
|  |                         x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32) | ||||||
|  |                         x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32) | ||||||
|  |                         y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32) | ||||||
|  |                         y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32) | ||||||
|  |  | ||||||
|  |                         mean_color = np.mean( | ||||||
|  |                             rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0 | ||||||
|  |                         ) | ||||||
|  |                         rgbs[i][y0:y1, x0:x1, :] = mean_color | ||||||
|  |  | ||||||
|  |                         occ_inds = np.logical_and( | ||||||
|  |                             np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1), | ||||||
|  |                             np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1), | ||||||
|  |                         ) | ||||||
|  |                         visibles[i, occ_inds] = 0 | ||||||
|  |             rgbs = [rgb.astype(np.uint8) for rgb in rgbs] | ||||||
|  |  | ||||||
|  |         if replace: | ||||||
|  |  | ||||||
|  |             rgbs_alt = [ | ||||||
|  |                 np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) | ||||||
|  |                 for rgb in rgbs | ||||||
|  |             ] | ||||||
|  |             rgbs_alt = [ | ||||||
|  |                 np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) | ||||||
|  |                 for rgb in rgbs_alt | ||||||
|  |             ] | ||||||
|  |  | ||||||
|  |             ############ replace transform (per image after the first) ############ | ||||||
|  |             rgbs = [rgb.astype(np.float32) for rgb in rgbs] | ||||||
|  |             rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt] | ||||||
|  |             for i in range(1, S): | ||||||
|  |                 if np.random.rand() < self.replace_aug_prob: | ||||||
|  |                     for _ in range( | ||||||
|  |                         np.random.randint(1, self.replace_max + 1) | ||||||
|  |                     ):  # number of times to occlude | ||||||
|  |                         xc = np.random.randint(0, W) | ||||||
|  |                         yc = np.random.randint(0, H) | ||||||
|  |                         dx = np.random.randint( | ||||||
|  |                             self.replace_bounds[0], self.replace_bounds[1] | ||||||
|  |                         ) | ||||||
|  |                         dy = np.random.randint( | ||||||
|  |                             self.replace_bounds[0], self.replace_bounds[1] | ||||||
|  |                         ) | ||||||
|  |                         x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32) | ||||||
|  |                         x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32) | ||||||
|  |                         y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32) | ||||||
|  |                         y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32) | ||||||
|  |  | ||||||
|  |                         wid = x1 - x0 | ||||||
|  |                         hei = y1 - y0 | ||||||
|  |                         y00 = np.random.randint(0, H - hei) | ||||||
|  |                         x00 = np.random.randint(0, W - wid) | ||||||
|  |                         fr = np.random.randint(0, S) | ||||||
|  |                         rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :] | ||||||
|  |                         rgbs[i][y0:y1, x0:x1, :] = rep | ||||||
|  |  | ||||||
|  |                         occ_inds = np.logical_and( | ||||||
|  |                             np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1), | ||||||
|  |                             np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1), | ||||||
|  |                         ) | ||||||
|  |                         visibles[i, occ_inds] = 0 | ||||||
|  |             rgbs = [rgb.astype(np.uint8) for rgb in rgbs] | ||||||
|  |  | ||||||
|  |         ############ photometric augmentation ############ | ||||||
|  |         if np.random.rand() < self.color_aug_prob: | ||||||
|  |             # random per-frame amount of aug | ||||||
|  |             rgbs = [ | ||||||
|  |                 np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) | ||||||
|  |                 for rgb in rgbs | ||||||
|  |             ] | ||||||
|  |  | ||||||
|  |         if np.random.rand() < self.blur_aug_prob: | ||||||
|  |             # random per-frame amount of blur | ||||||
|  |             rgbs = [ | ||||||
|  |                 np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) | ||||||
|  |                 for rgb in rgbs | ||||||
|  |             ] | ||||||
|  |  | ||||||
|  |         return rgbs, trajs, visibles | ||||||
|  |  | ||||||
|  |     def add_spatial_augs(self, rgbs, trajs, visibles): | ||||||
|  |         T, N, __ = trajs.shape | ||||||
|  |  | ||||||
|  |         S = len(rgbs) | ||||||
|  |         H, W = rgbs[0].shape[:2] | ||||||
|  |         assert S == T | ||||||
|  |  | ||||||
|  |         rgbs = [rgb.astype(np.float32) for rgb in rgbs] | ||||||
|  |  | ||||||
|  |         ############ spatial transform ############ | ||||||
|  |  | ||||||
|  |         # padding | ||||||
|  |         pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) | ||||||
|  |         pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) | ||||||
|  |         pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) | ||||||
|  |         pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) | ||||||
|  |  | ||||||
|  |         rgbs = [ | ||||||
|  |             np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs | ||||||
|  |         ] | ||||||
|  |         trajs[:, :, 0] += pad_x0 | ||||||
|  |         trajs[:, :, 1] += pad_y0 | ||||||
|  |         H, W = rgbs[0].shape[:2] | ||||||
|  |  | ||||||
|  |         # scaling + stretching | ||||||
|  |         scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1]) | ||||||
|  |         scale_x = scale | ||||||
|  |         scale_y = scale | ||||||
|  |         H_new = H | ||||||
|  |         W_new = W | ||||||
|  |  | ||||||
|  |         scale_delta_x = 0.0 | ||||||
|  |         scale_delta_y = 0.0 | ||||||
|  |  | ||||||
|  |         rgbs_scaled = [] | ||||||
|  |         for s in range(S): | ||||||
|  |             if s == 1: | ||||||
|  |                 scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta) | ||||||
|  |                 scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta) | ||||||
|  |             elif s > 1: | ||||||
|  |                 scale_delta_x = ( | ||||||
|  |                     scale_delta_x * 0.8 | ||||||
|  |                     + np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2 | ||||||
|  |                 ) | ||||||
|  |                 scale_delta_y = ( | ||||||
|  |                     scale_delta_y * 0.8 | ||||||
|  |                     + np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2 | ||||||
|  |                 ) | ||||||
|  |             scale_x = scale_x + scale_delta_x | ||||||
|  |             scale_y = scale_y + scale_delta_y | ||||||
|  |  | ||||||
|  |             # bring h/w closer | ||||||
|  |             scale_xy = (scale_x + scale_y) * 0.5 | ||||||
|  |             scale_x = scale_x * 0.5 + scale_xy * 0.5 | ||||||
|  |             scale_y = scale_y * 0.5 + scale_xy * 0.5 | ||||||
|  |  | ||||||
|  |             # don't get too crazy | ||||||
|  |             scale_x = np.clip(scale_x, 0.2, 2.0) | ||||||
|  |             scale_y = np.clip(scale_y, 0.2, 2.0) | ||||||
|  |  | ||||||
|  |             H_new = int(H * scale_y) | ||||||
|  |             W_new = int(W * scale_x) | ||||||
|  |  | ||||||
|  |             # make it at least slightly bigger than the crop area, | ||||||
|  |             # so that the random cropping can add diversity | ||||||
|  |             H_new = np.clip(H_new, self.crop_size[0] + 10, None) | ||||||
|  |             W_new = np.clip(W_new, self.crop_size[1] + 10, None) | ||||||
|  |             # recompute scale in case we clipped | ||||||
|  |             scale_x = W_new / float(W) | ||||||
|  |             scale_y = H_new / float(H) | ||||||
|  |  | ||||||
|  |             rgbs_scaled.append( | ||||||
|  |                 cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR) | ||||||
|  |             ) | ||||||
|  |             trajs[s, :, 0] *= scale_x | ||||||
|  |             trajs[s, :, 1] *= scale_y | ||||||
|  |         rgbs = rgbs_scaled | ||||||
|  |  | ||||||
|  |         ok_inds = visibles[0, :] > 0 | ||||||
|  |         vis_trajs = trajs[:, ok_inds]  # S,?,2 | ||||||
|  |  | ||||||
|  |         if vis_trajs.shape[1] > 0: | ||||||
|  |             mid_x = np.mean(vis_trajs[0, :, 0]) | ||||||
|  |             mid_y = np.mean(vis_trajs[0, :, 1]) | ||||||
|  |         else: | ||||||
|  |             mid_y = self.crop_size[0] | ||||||
|  |             mid_x = self.crop_size[1] | ||||||
|  |  | ||||||
|  |         x0 = int(mid_x - self.crop_size[1] // 2) | ||||||
|  |         y0 = int(mid_y - self.crop_size[0] // 2) | ||||||
|  |  | ||||||
|  |         offset_x = 0 | ||||||
|  |         offset_y = 0 | ||||||
|  |  | ||||||
|  |         for s in range(S): | ||||||
|  |             # on each frame, shift a bit more | ||||||
|  |             if s == 1: | ||||||
|  |                 offset_x = np.random.randint( | ||||||
|  |                     -self.max_crop_offset, self.max_crop_offset | ||||||
|  |                 ) | ||||||
|  |                 offset_y = np.random.randint( | ||||||
|  |                     -self.max_crop_offset, self.max_crop_offset | ||||||
|  |                 ) | ||||||
|  |             elif s > 1: | ||||||
|  |                 offset_x = int( | ||||||
|  |                     offset_x * 0.8 | ||||||
|  |                     + np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) | ||||||
|  |                     * 0.2 | ||||||
|  |                 ) | ||||||
|  |                 offset_y = int( | ||||||
|  |                     offset_y * 0.8 | ||||||
|  |                     + np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) | ||||||
|  |                     * 0.2 | ||||||
|  |                 ) | ||||||
|  |             x0 = x0 + offset_x | ||||||
|  |             y0 = y0 + offset_y | ||||||
|  |  | ||||||
|  |             H_new, W_new = rgbs[s].shape[:2] | ||||||
|  |             if H_new == self.crop_size[0]: | ||||||
|  |                 y0 = 0 | ||||||
|  |             else: | ||||||
|  |                 y0 = min(max(0, y0), H_new - self.crop_size[0] - 1) | ||||||
|  |  | ||||||
|  |             if W_new == self.crop_size[1]: | ||||||
|  |                 x0 = 0 | ||||||
|  |             else: | ||||||
|  |                 x0 = min(max(0, x0), W_new - self.crop_size[1] - 1) | ||||||
|  |  | ||||||
|  |             rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] | ||||||
|  |             trajs[s, :, 0] -= x0 | ||||||
|  |             trajs[s, :, 1] -= y0 | ||||||
|  |  | ||||||
|  |         H_new = self.crop_size[0] | ||||||
|  |         W_new = self.crop_size[1] | ||||||
|  |  | ||||||
|  |         # flip | ||||||
|  |         h_flipped = False | ||||||
|  |         v_flipped = False | ||||||
|  |         if self.do_flip: | ||||||
|  |             # h flip | ||||||
|  |             if np.random.rand() < self.h_flip_prob: | ||||||
|  |                 h_flipped = True | ||||||
|  |                 rgbs = [rgb[:, ::-1] for rgb in rgbs] | ||||||
|  |             # v flip | ||||||
|  |             if np.random.rand() < self.v_flip_prob: | ||||||
|  |                 v_flipped = True | ||||||
|  |                 rgbs = [rgb[::-1] for rgb in rgbs] | ||||||
|  |         if h_flipped: | ||||||
|  |             trajs[:, :, 0] = W_new - trajs[:, :, 0] | ||||||
|  |         if v_flipped: | ||||||
|  |             trajs[:, :, 1] = H_new - trajs[:, :, 1] | ||||||
|  |  | ||||||
|  |         return rgbs, trajs | ||||||
|  |  | ||||||
|  |     def crop(self, rgbs, trajs): | ||||||
|  |         T, N, _ = trajs.shape | ||||||
|  |  | ||||||
|  |         S = len(rgbs) | ||||||
|  |         H, W = rgbs[0].shape[:2] | ||||||
|  |         assert S == T | ||||||
|  |  | ||||||
|  |         ############ spatial transform ############ | ||||||
|  |  | ||||||
|  |         H_new = H | ||||||
|  |         W_new = W | ||||||
|  |  | ||||||
|  |         # simple random crop | ||||||
|  |         y0 = ( | ||||||
|  |             0 | ||||||
|  |             if self.crop_size[0] >= H_new | ||||||
|  |             else np.random.randint(0, H_new - self.crop_size[0]) | ||||||
|  |         ) | ||||||
|  |         x0 = ( | ||||||
|  |             0 | ||||||
|  |             if self.crop_size[1] >= W_new | ||||||
|  |             else np.random.randint(0, W_new - self.crop_size[1]) | ||||||
|  |         ) | ||||||
|  |         rgbs = [ | ||||||
|  |             rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] | ||||||
|  |             for rgb in rgbs | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |         trajs[:, :, 0] -= x0 | ||||||
|  |         trajs[:, :, 1] -= y0 | ||||||
|  |  | ||||||
|  |         return rgbs, trajs | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class KubricMovifDataset(CoTrackerDataset): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         data_root, | ||||||
|  |         crop_size=(384, 512), | ||||||
|  |         seq_len=24, | ||||||
|  |         traj_per_sample=768, | ||||||
|  |         sample_vis_1st_frame=False, | ||||||
|  |         use_augs=False, | ||||||
|  |     ): | ||||||
|  |         super(KubricMovifDataset, self).__init__( | ||||||
|  |             data_root=data_root, | ||||||
|  |             crop_size=crop_size, | ||||||
|  |             seq_len=seq_len, | ||||||
|  |             traj_per_sample=traj_per_sample, | ||||||
|  |             sample_vis_1st_frame=sample_vis_1st_frame, | ||||||
|  |             use_augs=use_augs, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         self.pad_bounds = [0, 25] | ||||||
|  |         self.resize_lim = [0.75, 1.25]  # sample resizes from here | ||||||
|  |         self.resize_delta = 0.05 | ||||||
|  |         self.max_crop_offset = 15 | ||||||
|  |         self.seq_names = [ | ||||||
|  |             fname | ||||||
|  |             for fname in os.listdir(data_root) | ||||||
|  |             if os.path.isdir(os.path.join(data_root, fname)) | ||||||
|  |         ] | ||||||
|  |         print("found %d unique videos in %s" % (len(self.seq_names), self.data_root)) | ||||||
|  |  | ||||||
|  |     def getitem_helper(self, index): | ||||||
|  |         gotit = True | ||||||
|  |         seq_name = self.seq_names[index] | ||||||
|  |  | ||||||
|  |         npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy") | ||||||
|  |         rgb_path = os.path.join(self.data_root, seq_name, "frames") | ||||||
|  |  | ||||||
|  |         img_paths = sorted(os.listdir(rgb_path)) | ||||||
|  |         rgbs = [] | ||||||
|  |         for i, img_path in enumerate(img_paths): | ||||||
|  |             rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path))) | ||||||
|  |  | ||||||
|  |         rgbs = np.stack(rgbs) | ||||||
|  |         annot_dict = np.load(npy_path, allow_pickle=True).item() | ||||||
|  |         traj_2d = annot_dict["coords"] | ||||||
|  |         visibility = annot_dict["visibility"] | ||||||
|  |  | ||||||
|  |         # random crop | ||||||
|  |         assert self.seq_len <= len(rgbs) | ||||||
|  |         if self.seq_len < len(rgbs): | ||||||
|  |             start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0] | ||||||
|  |  | ||||||
|  |             rgbs = rgbs[start_ind : start_ind + self.seq_len] | ||||||
|  |             traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len] | ||||||
|  |             visibility = visibility[:, start_ind : start_ind + self.seq_len] | ||||||
|  |  | ||||||
|  |         traj_2d = np.transpose(traj_2d, (1, 0, 2)) | ||||||
|  |         visibility = np.transpose(np.logical_not(visibility), (1, 0)) | ||||||
|  |         if self.use_augs: | ||||||
|  |             rgbs, traj_2d, visibility = self.add_photometric_augs( | ||||||
|  |                 rgbs, traj_2d, visibility | ||||||
|  |             ) | ||||||
|  |             rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility) | ||||||
|  |         else: | ||||||
|  |             rgbs, traj_2d = self.crop(rgbs, traj_2d) | ||||||
|  |  | ||||||
|  |         visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False | ||||||
|  |         visibility[traj_2d[:, :, 0] < 0] = False | ||||||
|  |         visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False | ||||||
|  |         visibility[traj_2d[:, :, 1] < 0] = False | ||||||
|  |  | ||||||
|  |         visibility = torch.from_numpy(visibility) | ||||||
|  |         traj_2d = torch.from_numpy(traj_2d) | ||||||
|  |  | ||||||
|  |         visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0] | ||||||
|  |  | ||||||
|  |         if self.sample_vis_1st_frame: | ||||||
|  |             visibile_pts_inds = visibile_pts_first_frame_inds | ||||||
|  |         else: | ||||||
|  |             visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero( | ||||||
|  |                 as_tuple=False | ||||||
|  |             )[:, 0] | ||||||
|  |             visibile_pts_inds = torch.cat( | ||||||
|  |                 (visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0 | ||||||
|  |             ) | ||||||
|  |         point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample] | ||||||
|  |         if len(point_inds) < self.traj_per_sample: | ||||||
|  |             gotit = False | ||||||
|  |  | ||||||
|  |         visible_inds_sampled = visibile_pts_inds[point_inds] | ||||||
|  |  | ||||||
|  |         trajs = traj_2d[:, visible_inds_sampled].float() | ||||||
|  |         visibles = visibility[:, visible_inds_sampled] | ||||||
|  |         valids = torch.ones((self.seq_len, self.traj_per_sample)) | ||||||
|  |  | ||||||
|  |         rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float() | ||||||
|  |         segs = torch.ones((self.seq_len, 1, self.crop_size[0], self.crop_size[1])) | ||||||
|  |         sample = CoTrackerData( | ||||||
|  |             video=rgbs, | ||||||
|  |             segmentation=segs, | ||||||
|  |             trajectory=trajs, | ||||||
|  |             visibility=visibles, | ||||||
|  |             valid=valids, | ||||||
|  |             seq_name=seq_name, | ||||||
|  |         ) | ||||||
|  |         return sample, gotit | ||||||
|  |  | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self.seq_names) | ||||||
							
								
								
									
										218
									
								
								cotracker/datasets/tap_vid_datasets.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										218
									
								
								cotracker/datasets/tap_vid_datasets.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,218 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
|  | import os | ||||||
|  | import io | ||||||
|  | import glob | ||||||
|  | import torch | ||||||
|  | import pickle | ||||||
|  | import numpy as np | ||||||
|  | import mediapy as media | ||||||
|  |  | ||||||
|  | from PIL import Image | ||||||
|  | from typing import Mapping, Tuple, Union | ||||||
|  |  | ||||||
|  | from cotracker.datasets.utils import CoTrackerData | ||||||
|  |  | ||||||
|  | DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: | ||||||
|  |     """Resize a video to output_size.""" | ||||||
|  |     # If you have a GPU, consider replacing this with a GPU-enabled resize op, | ||||||
|  |     # such as a jitted jax.image.resize.  It will make things faster. | ||||||
|  |     return media.resize_video(video, output_size) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def sample_queries_first( | ||||||
|  |     target_occluded: np.ndarray, | ||||||
|  |     target_points: np.ndarray, | ||||||
|  |     frames: np.ndarray, | ||||||
|  | ) -> Mapping[str, np.ndarray]: | ||||||
|  |     """Package a set of frames and tracks for use in TAPNet evaluations. | ||||||
|  |     Given a set of frames and tracks with no query points, use the first | ||||||
|  |     visible point in each track as the query. | ||||||
|  |     Args: | ||||||
|  |       target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames], | ||||||
|  |         where True indicates occluded. | ||||||
|  |       target_points: Position, of shape [n_tracks, n_frames, 2], where each point | ||||||
|  |         is [x,y] scaled between 0 and 1. | ||||||
|  |       frames: Video tensor, of shape [n_frames, height, width, 3].  Scaled between | ||||||
|  |         -1 and 1. | ||||||
|  |     Returns: | ||||||
|  |       A dict with the keys: | ||||||
|  |         video: Video tensor of shape [1, n_frames, height, width, 3] | ||||||
|  |         query_points: Query points of shape [1, n_queries, 3] where | ||||||
|  |           each point is [t, y, x] scaled to the range [-1, 1] | ||||||
|  |         target_points: Target points of shape [1, n_queries, n_frames, 2] where | ||||||
|  |           each point is [x, y] scaled to the range [-1, 1] | ||||||
|  |     """ | ||||||
|  |     valid = np.sum(~target_occluded, axis=1) > 0 | ||||||
|  |     target_points = target_points[valid, :] | ||||||
|  |     target_occluded = target_occluded[valid, :] | ||||||
|  |  | ||||||
|  |     query_points = [] | ||||||
|  |     for i in range(target_points.shape[0]): | ||||||
|  |         index = np.where(target_occluded[i] == 0)[0][0] | ||||||
|  |         x, y = target_points[i, index, 0], target_points[i, index, 1] | ||||||
|  |         query_points.append(np.array([index, y, x]))  # [t, y, x] | ||||||
|  |     query_points = np.stack(query_points, axis=0) | ||||||
|  |  | ||||||
|  |     return { | ||||||
|  |         "video": frames[np.newaxis, ...], | ||||||
|  |         "query_points": query_points[np.newaxis, ...], | ||||||
|  |         "target_points": target_points[np.newaxis, ...], | ||||||
|  |         "occluded": target_occluded[np.newaxis, ...], | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def sample_queries_strided( | ||||||
|  |     target_occluded: np.ndarray, | ||||||
|  |     target_points: np.ndarray, | ||||||
|  |     frames: np.ndarray, | ||||||
|  |     query_stride: int = 5, | ||||||
|  | ) -> Mapping[str, np.ndarray]: | ||||||
|  |     """Package a set of frames and tracks for use in TAPNet evaluations. | ||||||
|  |  | ||||||
|  |     Given a set of frames and tracks with no query points, sample queries | ||||||
|  |     strided every query_stride frames, ignoring points that are not visible | ||||||
|  |     at the selected frames. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |       target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames], | ||||||
|  |         where True indicates occluded. | ||||||
|  |       target_points: Position, of shape [n_tracks, n_frames, 2], where each point | ||||||
|  |         is [x,y] scaled between 0 and 1. | ||||||
|  |       frames: Video tensor, of shape [n_frames, height, width, 3].  Scaled between | ||||||
|  |         -1 and 1. | ||||||
|  |       query_stride: When sampling query points, search for un-occluded points | ||||||
|  |         every query_stride frames and convert each one into a query. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |       A dict with the keys: | ||||||
|  |         video: Video tensor of shape [1, n_frames, height, width, 3].  The video | ||||||
|  |           has floats scaled to the range [-1, 1]. | ||||||
|  |         query_points: Query points of shape [1, n_queries, 3] where | ||||||
|  |           each point is [t, y, x] scaled to the range [-1, 1]. | ||||||
|  |         target_points: Target points of shape [1, n_queries, n_frames, 2] where | ||||||
|  |           each point is [x, y] scaled to the range [-1, 1]. | ||||||
|  |         trackgroup: Index of the original track that each query point was | ||||||
|  |           sampled from.  This is useful for visualization. | ||||||
|  |     """ | ||||||
|  |     tracks = [] | ||||||
|  |     occs = [] | ||||||
|  |     queries = [] | ||||||
|  |     trackgroups = [] | ||||||
|  |     total = 0 | ||||||
|  |     trackgroup = np.arange(target_occluded.shape[0]) | ||||||
|  |     for i in range(0, target_occluded.shape[1], query_stride): | ||||||
|  |         mask = target_occluded[:, i] == 0 | ||||||
|  |         query = np.stack( | ||||||
|  |             [ | ||||||
|  |                 i * np.ones(target_occluded.shape[0:1]), | ||||||
|  |                 target_points[:, i, 1], | ||||||
|  |                 target_points[:, i, 0], | ||||||
|  |             ], | ||||||
|  |             axis=-1, | ||||||
|  |         ) | ||||||
|  |         queries.append(query[mask]) | ||||||
|  |         tracks.append(target_points[mask]) | ||||||
|  |         occs.append(target_occluded[mask]) | ||||||
|  |         trackgroups.append(trackgroup[mask]) | ||||||
|  |         total += np.array(np.sum(target_occluded[:, i] == 0)) | ||||||
|  |  | ||||||
|  |     return { | ||||||
|  |         "video": frames[np.newaxis, ...], | ||||||
|  |         "query_points": np.concatenate(queries, axis=0)[np.newaxis, ...], | ||||||
|  |         "target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...], | ||||||
|  |         "occluded": np.concatenate(occs, axis=0)[np.newaxis, ...], | ||||||
|  |         "trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...], | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TapVidDataset(torch.utils.data.Dataset): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         data_root, | ||||||
|  |         dataset_type="davis", | ||||||
|  |         resize_to_256=True, | ||||||
|  |         queried_first=True, | ||||||
|  |     ): | ||||||
|  |         self.dataset_type = dataset_type | ||||||
|  |         self.resize_to_256 = resize_to_256 | ||||||
|  |         self.queried_first = queried_first | ||||||
|  |         if self.dataset_type == "kinetics": | ||||||
|  |             all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl")) | ||||||
|  |             points_dataset = [] | ||||||
|  |             for pickle_path in all_paths: | ||||||
|  |                 with open(pickle_path, "rb") as f: | ||||||
|  |                     data = pickle.load(f) | ||||||
|  |                     points_dataset = points_dataset + data | ||||||
|  |             self.points_dataset = points_dataset | ||||||
|  |         else: | ||||||
|  |             with open(data_root, "rb") as f: | ||||||
|  |                 self.points_dataset = pickle.load(f) | ||||||
|  |             if self.dataset_type == "davis": | ||||||
|  |                 self.video_names = list(self.points_dataset.keys()) | ||||||
|  |         print("found %d unique videos in %s" % (len(self.points_dataset), data_root)) | ||||||
|  |  | ||||||
|  |     def __getitem__(self, index): | ||||||
|  |         if self.dataset_type == "davis": | ||||||
|  |             video_name = self.video_names[index] | ||||||
|  |         else: | ||||||
|  |             video_name = index | ||||||
|  |         video = self.points_dataset[video_name] | ||||||
|  |         frames = video["video"] | ||||||
|  |  | ||||||
|  |         if isinstance(frames[0], bytes): | ||||||
|  |             # TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s. | ||||||
|  |             def decode(frame): | ||||||
|  |                 byteio = io.BytesIO(frame) | ||||||
|  |                 img = Image.open(byteio) | ||||||
|  |                 return np.array(img) | ||||||
|  |  | ||||||
|  |             frames = np.array([decode(frame) for frame in frames]) | ||||||
|  |  | ||||||
|  |         target_points = self.points_dataset[video_name]["points"] | ||||||
|  |         if self.resize_to_256: | ||||||
|  |             frames = resize_video(frames, [256, 256]) | ||||||
|  |             target_points *= np.array([256, 256]) | ||||||
|  |         else: | ||||||
|  |             target_points *= np.array([frames.shape[2], frames.shape[1]]) | ||||||
|  |  | ||||||
|  |         T, H, W, C = frames.shape | ||||||
|  |         N, T, D = target_points.shape | ||||||
|  |  | ||||||
|  |         target_occ = self.points_dataset[video_name]["occluded"] | ||||||
|  |         if self.queried_first: | ||||||
|  |             converted = sample_queries_first(target_occ, target_points, frames) | ||||||
|  |         else: | ||||||
|  |             converted = sample_queries_strided(target_occ, target_points, frames) | ||||||
|  |         assert converted["target_points"].shape[1] == converted["query_points"].shape[1] | ||||||
|  |  | ||||||
|  |         trajs = ( | ||||||
|  |             torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() | ||||||
|  |         )  # T, N, D | ||||||
|  |  | ||||||
|  |         rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float() | ||||||
|  |         segs = torch.ones(T, 1, H, W).float() | ||||||
|  |         visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[ | ||||||
|  |             0 | ||||||
|  |         ].permute( | ||||||
|  |             1, 0 | ||||||
|  |         )  # T, N | ||||||
|  |         query_points = torch.from_numpy(converted["query_points"])[0]  # T, N | ||||||
|  |         return CoTrackerData( | ||||||
|  |             rgbs, | ||||||
|  |             segs, | ||||||
|  |             trajs, | ||||||
|  |             visibles, | ||||||
|  |             seq_name=str(video_name), | ||||||
|  |             query_points=query_points, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self.points_dataset) | ||||||
							
								
								
									
										114
									
								
								cotracker/datasets/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								cotracker/datasets/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,114 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import dataclasses | ||||||
|  | import torch.nn.functional as F | ||||||
|  | from dataclasses import dataclass | ||||||
|  | from typing import Any, Optional | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @dataclass(eq=False) | ||||||
|  | class CoTrackerData: | ||||||
|  |     """ | ||||||
|  |     Dataclass for storing video tracks data. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     video: torch.Tensor  # B, S, C, H, W | ||||||
|  |     segmentation: torch.Tensor  # B, S, 1, H, W | ||||||
|  |     trajectory: torch.Tensor  # B, S, N, 2 | ||||||
|  |     visibility: torch.Tensor  # B, S, N | ||||||
|  |     # optional data | ||||||
|  |     valid: Optional[torch.Tensor] = None  # B, S, N | ||||||
|  |     seq_name: Optional[str] = None | ||||||
|  |     query_points: Optional[torch.Tensor] = None  # TapVID evaluation format | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def collate_fn(batch): | ||||||
|  |     """ | ||||||
|  |     Collate function for video tracks data. | ||||||
|  |     """ | ||||||
|  |     video = torch.stack([b.video for b in batch], dim=0) | ||||||
|  |     segmentation = torch.stack([b.segmentation for b in batch], dim=0) | ||||||
|  |     trajectory = torch.stack([b.trajectory for b in batch], dim=0) | ||||||
|  |     visibility = torch.stack([b.visibility for b in batch], dim=0) | ||||||
|  |     query_points = None | ||||||
|  |     if batch[0].query_points is not None: | ||||||
|  |         query_points = torch.stack([b.query_points for b in batch], dim=0) | ||||||
|  |     seq_name = [b.seq_name for b in batch] | ||||||
|  |  | ||||||
|  |     return CoTrackerData( | ||||||
|  |         video, | ||||||
|  |         segmentation, | ||||||
|  |         trajectory, | ||||||
|  |         visibility, | ||||||
|  |         seq_name=seq_name, | ||||||
|  |         query_points=query_points, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def collate_fn_train(batch): | ||||||
|  |     """ | ||||||
|  |     Collate function for video tracks data during training. | ||||||
|  |     """ | ||||||
|  |     gotit = [gotit for _, gotit in batch] | ||||||
|  |     video = torch.stack([b.video for b, _ in batch], dim=0) | ||||||
|  |     segmentation = torch.stack([b.segmentation for b, _ in batch], dim=0) | ||||||
|  |     trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0) | ||||||
|  |     visibility = torch.stack([b.visibility for b, _ in batch], dim=0) | ||||||
|  |     valid = torch.stack([b.valid for b, _ in batch], dim=0) | ||||||
|  |     seq_name = [b.seq_name for b, _ in batch] | ||||||
|  |     return ( | ||||||
|  |         CoTrackerData(video, segmentation, trajectory, visibility, valid, seq_name), | ||||||
|  |         gotit, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def try_to_cuda(t: Any) -> Any: | ||||||
|  |     """ | ||||||
|  |     Try to move the input variable `t` to a cuda device. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         t: Input. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         t_cuda: `t` moved to a cuda device, if supported. | ||||||
|  |     """ | ||||||
|  |     try: | ||||||
|  |         t = t.float().cuda() | ||||||
|  |     except AttributeError: | ||||||
|  |         pass | ||||||
|  |     return t | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def dataclass_to_cuda_(obj): | ||||||
|  |     """ | ||||||
|  |     Move all contents of a dataclass to cuda inplace if supported. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         batch: Input dataclass. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         batch_cuda: `batch` moved to a cuda device, if supported. | ||||||
|  |     """ | ||||||
|  |     for f in dataclasses.fields(obj): | ||||||
|  |         setattr(obj, f.name, try_to_cuda(getattr(obj, f.name))) | ||||||
|  |     return obj | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def resize_sample(rgbs, trajs_g, segs, interp_shape): | ||||||
|  |     S, C, H, W = rgbs.shape | ||||||
|  |     S, N, D = trajs_g.shape | ||||||
|  |  | ||||||
|  |     assert D == 2 | ||||||
|  |  | ||||||
|  |     rgbs = F.interpolate(rgbs, interp_shape, mode="bilinear") | ||||||
|  |     segs = F.interpolate(segs, interp_shape, mode="nearest") | ||||||
|  |  | ||||||
|  |     trajs_g[:, :, 0] *= interp_shape[1] / W | ||||||
|  |     trajs_g[:, :, 1] *= interp_shape[0] / H | ||||||
|  |     return rgbs, trajs_g, segs | ||||||
							
								
								
									
										5
									
								
								cotracker/evaluation/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/evaluation/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
							
								
								
									
										6
									
								
								cotracker/evaluation/configs/eval_badja.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								cotracker/evaluation/configs/eval_badja.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | |||||||
|  | defaults: | ||||||
|  |   - default_config_eval | ||||||
|  | exp_dir: ./outputs/cotracker | ||||||
|  | dataset_name: badja | ||||||
|  |  | ||||||
|  |     | ||||||
							
								
								
									
										6
									
								
								cotracker/evaluation/configs/eval_fastcapture.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								cotracker/evaluation/configs/eval_fastcapture.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | |||||||
|  | defaults: | ||||||
|  |   - default_config_eval | ||||||
|  | exp_dir: ./outputs/cotracker | ||||||
|  | dataset_name: fastcapture | ||||||
|  |  | ||||||
|  |     | ||||||
| @@ -0,0 +1,6 @@ | |||||||
|  | defaults: | ||||||
|  |   - default_config_eval | ||||||
|  | exp_dir: ./outputs/cotracker | ||||||
|  | dataset_name: tapvid_davis_first | ||||||
|  |  | ||||||
|  |     | ||||||
| @@ -0,0 +1,6 @@ | |||||||
|  | defaults: | ||||||
|  |   - default_config_eval | ||||||
|  | exp_dir: ./outputs/cotracker | ||||||
|  | dataset_name: tapvid_davis_strided | ||||||
|  |  | ||||||
|  |     | ||||||
| @@ -0,0 +1,6 @@ | |||||||
|  | defaults: | ||||||
|  |   - default_config_eval | ||||||
|  | exp_dir: ./outputs/cotracker | ||||||
|  | dataset_name: tapvid_kinetics_first | ||||||
|  |  | ||||||
|  |     | ||||||
							
								
								
									
										5
									
								
								cotracker/evaluation/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/evaluation/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
							
								
								
									
										144
									
								
								cotracker/evaluation/core/eval_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								cotracker/evaluation/core/eval_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,144 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
|  | import numpy as np | ||||||
|  |  | ||||||
|  | from typing import Iterable, Mapping, Tuple, Union | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def compute_tapvid_metrics( | ||||||
|  |     query_points: np.ndarray, | ||||||
|  |     gt_occluded: np.ndarray, | ||||||
|  |     gt_tracks: np.ndarray, | ||||||
|  |     pred_occluded: np.ndarray, | ||||||
|  |     pred_tracks: np.ndarray, | ||||||
|  |     query_mode: str, | ||||||
|  | ) -> Mapping[str, np.ndarray]: | ||||||
|  |     """Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.) | ||||||
|  |     See the TAP-Vid paper for details on the metric computation.  All inputs are | ||||||
|  |     given in raster coordinates.  The first three arguments should be the direct | ||||||
|  |     outputs of the reader: the 'query_points', 'occluded', and 'target_points'. | ||||||
|  |     The paper metrics assume these are scaled relative to 256x256 images. | ||||||
|  |     pred_occluded and pred_tracks are your algorithm's predictions. | ||||||
|  |     This function takes a batch of inputs, and computes metrics separately for | ||||||
|  |     each video.  The metrics for the full benchmark are a simple mean of the | ||||||
|  |     metrics across the full set of videos.  These numbers are between 0 and 1, | ||||||
|  |     but the paper multiplies them by 100 to ease reading. | ||||||
|  |     Args: | ||||||
|  |        query_points: The query points, an in the format [t, y, x].  Its size is | ||||||
|  |          [b, n, 3], where b is the batch size and n is the number of queries | ||||||
|  |        gt_occluded: A boolean array of shape [b, n, t], where t is the number | ||||||
|  |          of frames.  True indicates that the point is occluded. | ||||||
|  |        gt_tracks: The target points, of shape [b, n, t, 2].  Each point is | ||||||
|  |          in the format [x, y] | ||||||
|  |        pred_occluded: A boolean array of predicted occlusions, in the same | ||||||
|  |          format as gt_occluded. | ||||||
|  |        pred_tracks: An array of track predictions from your algorithm, in the | ||||||
|  |          same format as gt_tracks. | ||||||
|  |        query_mode: Either 'first' or 'strided', depending on how queries are | ||||||
|  |          sampled.  If 'first', we assume the prior knowledge that all points | ||||||
|  |          before the query point are occluded, and these are removed from the | ||||||
|  |          evaluation. | ||||||
|  |     Returns: | ||||||
|  |         A dict with the following keys: | ||||||
|  |         occlusion_accuracy: Accuracy at predicting occlusion. | ||||||
|  |         pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points | ||||||
|  |           predicted to be within the given pixel threshold, ignoring occlusion | ||||||
|  |           prediction. | ||||||
|  |         jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given | ||||||
|  |           threshold | ||||||
|  |         average_pts_within_thresh: average across pts_within_{x} | ||||||
|  |         average_jaccard: average across jaccard_{x} | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     metrics = {} | ||||||
|  |  | ||||||
|  |     # Don't evaluate the query point.  Numpy doesn't have one_hot, so we | ||||||
|  |     # replicate it by indexing into an identity matrix. | ||||||
|  |     one_hot_eye = np.eye(gt_tracks.shape[2]) | ||||||
|  |     query_frame = query_points[..., 0] | ||||||
|  |     query_frame = np.round(query_frame).astype(np.int32) | ||||||
|  |     evaluation_points = one_hot_eye[query_frame] == 0 | ||||||
|  |  | ||||||
|  |     # If we're using the first point on the track as a query, don't evaluate the | ||||||
|  |     # other points. | ||||||
|  |     if query_mode == "first": | ||||||
|  |         for i in range(gt_occluded.shape[0]): | ||||||
|  |             index = np.where(gt_occluded[i] == 0)[0][0] | ||||||
|  |             evaluation_points[i, :index] = False | ||||||
|  |     elif query_mode != "strided": | ||||||
|  |         raise ValueError("Unknown query mode " + query_mode) | ||||||
|  |  | ||||||
|  |     # Occlusion accuracy is simply how often the predicted occlusion equals the | ||||||
|  |     # ground truth. | ||||||
|  |     occ_acc = ( | ||||||
|  |         np.sum( | ||||||
|  |             np.equal(pred_occluded, gt_occluded) & evaluation_points, | ||||||
|  |             axis=(1, 2), | ||||||
|  |         ) | ||||||
|  |         / np.sum(evaluation_points) | ||||||
|  |     ) | ||||||
|  |     metrics["occlusion_accuracy"] = occ_acc | ||||||
|  |  | ||||||
|  |     # Next, convert the predictions and ground truth positions into pixel | ||||||
|  |     # coordinates. | ||||||
|  |     visible = np.logical_not(gt_occluded) | ||||||
|  |     pred_visible = np.logical_not(pred_occluded) | ||||||
|  |     all_frac_within = [] | ||||||
|  |     all_jaccard = [] | ||||||
|  |     for thresh in [1, 2, 4, 8, 16]: | ||||||
|  |         # True positives are points that are within the threshold and where both | ||||||
|  |         # the prediction and the ground truth are listed as visible. | ||||||
|  |         within_dist = ( | ||||||
|  |             np.sum( | ||||||
|  |                 np.square(pred_tracks - gt_tracks), | ||||||
|  |                 axis=-1, | ||||||
|  |             ) | ||||||
|  |             < np.square(thresh) | ||||||
|  |         ) | ||||||
|  |         is_correct = np.logical_and(within_dist, visible) | ||||||
|  |  | ||||||
|  |         # Compute the frac_within_threshold, which is the fraction of points | ||||||
|  |         # within the threshold among points that are visible in the ground truth, | ||||||
|  |         # ignoring whether they're predicted to be visible. | ||||||
|  |         count_correct = np.sum( | ||||||
|  |             is_correct & evaluation_points, | ||||||
|  |             axis=(1, 2), | ||||||
|  |         ) | ||||||
|  |         count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2)) | ||||||
|  |         frac_correct = count_correct / count_visible_points | ||||||
|  |         metrics["pts_within_" + str(thresh)] = frac_correct | ||||||
|  |         all_frac_within.append(frac_correct) | ||||||
|  |  | ||||||
|  |         true_positives = np.sum( | ||||||
|  |             is_correct & pred_visible & evaluation_points, axis=(1, 2) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # The denominator of the jaccard metric is the true positives plus | ||||||
|  |         # false positives plus false negatives.  However, note that true positives | ||||||
|  |         # plus false negatives is simply the number of points in the ground truth | ||||||
|  |         # which is easier to compute than trying to compute all three quantities. | ||||||
|  |         # Thus we just add the number of points in the ground truth to the number | ||||||
|  |         # of false positives. | ||||||
|  |         # | ||||||
|  |         # False positives are simply points that are predicted to be visible, | ||||||
|  |         # but the ground truth is not visible or too far from the prediction. | ||||||
|  |         gt_positives = np.sum(visible & evaluation_points, axis=(1, 2)) | ||||||
|  |         false_positives = (~visible) & pred_visible | ||||||
|  |         false_positives = false_positives | ((~within_dist) & pred_visible) | ||||||
|  |         false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2)) | ||||||
|  |         jaccard = true_positives / (gt_positives + false_positives) | ||||||
|  |         metrics["jaccard_" + str(thresh)] = jaccard | ||||||
|  |         all_jaccard.append(jaccard) | ||||||
|  |     metrics["average_jaccard"] = np.mean( | ||||||
|  |         np.stack(all_jaccard, axis=1), | ||||||
|  |         axis=1, | ||||||
|  |     ) | ||||||
|  |     metrics["average_pts_within_thresh"] = np.mean( | ||||||
|  |         np.stack(all_frac_within, axis=1), | ||||||
|  |         axis=1, | ||||||
|  |     ) | ||||||
|  |     return metrics | ||||||
							
								
								
									
										252
									
								
								cotracker/evaluation/core/evaluator.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										252
									
								
								cotracker/evaluation/core/evaluator.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,252 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
|  | from collections import defaultdict | ||||||
|  | import os | ||||||
|  | from typing import Optional | ||||||
|  | import torch | ||||||
|  | from tqdm import tqdm | ||||||
|  | import numpy as np | ||||||
|  |  | ||||||
|  | from torch.utils.tensorboard import SummaryWriter | ||||||
|  | from cotracker.datasets.utils import dataclass_to_cuda_ | ||||||
|  | from cotracker.utils.visualizer import Visualizer | ||||||
|  | from cotracker.models.core.model_utils import reduce_masked_mean | ||||||
|  | from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics | ||||||
|  |  | ||||||
|  | import logging | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Evaluator: | ||||||
|  |     """ | ||||||
|  |     A class defining the CoTracker evaluator. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, exp_dir) -> None: | ||||||
|  |         # Visualization | ||||||
|  |         self.exp_dir = exp_dir | ||||||
|  |         os.makedirs(exp_dir, exist_ok=True) | ||||||
|  |         self.visualization_filepaths = defaultdict(lambda: defaultdict(list)) | ||||||
|  |         self.visualize_dir = os.path.join(exp_dir, "visualisations") | ||||||
|  |  | ||||||
|  |     def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name): | ||||||
|  |         if isinstance(pred_trajectory, tuple): | ||||||
|  |             pred_trajectory, pred_visibility = pred_trajectory | ||||||
|  |         else: | ||||||
|  |             pred_visibility = None | ||||||
|  |         if dataset_name == "badja": | ||||||
|  |             sample.segmentation = (sample.segmentation > 0).float() | ||||||
|  |             *_, N, _ = sample.trajectory.shape | ||||||
|  |             accs = [] | ||||||
|  |             accs_3px = [] | ||||||
|  |             for s1 in range(1, sample.video.shape[1]):  # target frame | ||||||
|  |                 for n in range(N): | ||||||
|  |                     vis = sample.visibility[0, s1, n] | ||||||
|  |                     if vis > 0: | ||||||
|  |                         coord_e = pred_trajectory[0, s1, n]  # 2 | ||||||
|  |                         coord_g = sample.trajectory[0, s1, n]  # 2 | ||||||
|  |                         dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0)) | ||||||
|  |                         area = torch.sum(sample.segmentation[0, s1]) | ||||||
|  |                         # print_('0.2*sqrt(area)', 0.2*torch.sqrt(area)) | ||||||
|  |                         thr = 0.2 * torch.sqrt(area) | ||||||
|  |                         # correct = | ||||||
|  |                         accs.append((dist < thr).float()) | ||||||
|  |                         # print('thr',thr) | ||||||
|  |                         accs_3px.append((dist < 3.0).float()) | ||||||
|  |  | ||||||
|  |             res = torch.mean(torch.stack(accs)) * 100.0 | ||||||
|  |             res_3px = torch.mean(torch.stack(accs_3px)) * 100.0 | ||||||
|  |             metrics[sample.seq_name[0]] = res.item() | ||||||
|  |             metrics[sample.seq_name[0] + "_accuracy"] = res_3px.item() | ||||||
|  |             print(metrics) | ||||||
|  |             print( | ||||||
|  |                 "avg", np.mean([v for k, v in metrics.items() if "accuracy" not in k]) | ||||||
|  |             ) | ||||||
|  |             print( | ||||||
|  |                 "avg acc 3px", | ||||||
|  |                 np.mean([v for k, v in metrics.items() if "accuracy" in k]), | ||||||
|  |             ) | ||||||
|  |         elif dataset_name == "fastcapture" or ("kubric" in dataset_name): | ||||||
|  |             *_, N, _ = sample.trajectory.shape | ||||||
|  |             accs = [] | ||||||
|  |             for s1 in range(1, sample.video.shape[1]):  # target frame | ||||||
|  |                 for n in range(N): | ||||||
|  |                     vis = sample.visibility[0, s1, n] | ||||||
|  |                     if vis > 0: | ||||||
|  |                         coord_e = pred_trajectory[0, s1, n]  # 2 | ||||||
|  |                         coord_g = sample.trajectory[0, s1, n]  # 2 | ||||||
|  |                         dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0)) | ||||||
|  |                         thr = 3 | ||||||
|  |                         correct = (dist < thr).float() | ||||||
|  |                         accs.append(correct) | ||||||
|  |  | ||||||
|  |             res = torch.mean(torch.stack(accs)) * 100.0 | ||||||
|  |             metrics[sample.seq_name[0] + "_accuracy"] = res.item() | ||||||
|  |             print(metrics) | ||||||
|  |             print("avg", np.mean([v for v in metrics.values()])) | ||||||
|  |         elif "tapvid" in dataset_name: | ||||||
|  |             B, T, N, D = sample.trajectory.shape | ||||||
|  |             traj = sample.trajectory.clone() | ||||||
|  |             thr = 0.9 | ||||||
|  |  | ||||||
|  |             if pred_visibility is None: | ||||||
|  |                 logging.warning("visibility is NONE") | ||||||
|  |                 pred_visibility = torch.zeros_like(sample.visibility) | ||||||
|  |  | ||||||
|  |             if not pred_visibility.dtype == torch.bool: | ||||||
|  |                 pred_visibility = pred_visibility > thr | ||||||
|  |  | ||||||
|  |             # pred_trajectory | ||||||
|  |             query_points = sample.query_points.clone().cpu().numpy() | ||||||
|  |  | ||||||
|  |             pred_visibility = pred_visibility[:, :, :N] | ||||||
|  |             pred_trajectory = pred_trajectory[:, :, :N] | ||||||
|  |  | ||||||
|  |             gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy() | ||||||
|  |             gt_occluded = ( | ||||||
|  |                 torch.logical_not(sample.visibility.clone().permute(0, 2, 1)) | ||||||
|  |                 .cpu() | ||||||
|  |                 .numpy() | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             pred_occluded = ( | ||||||
|  |                 torch.logical_not(pred_visibility.clone().permute(0, 2, 1)) | ||||||
|  |                 .cpu() | ||||||
|  |                 .numpy() | ||||||
|  |             ) | ||||||
|  |             pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy() | ||||||
|  |  | ||||||
|  |             out_metrics = compute_tapvid_metrics( | ||||||
|  |                 query_points, | ||||||
|  |                 gt_occluded, | ||||||
|  |                 gt_tracks, | ||||||
|  |                 pred_occluded, | ||||||
|  |                 pred_tracks, | ||||||
|  |                 query_mode="strided" if "strided" in dataset_name else "first", | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             metrics[sample.seq_name[0]] = out_metrics | ||||||
|  |             for metric_name in out_metrics.keys(): | ||||||
|  |                 if "avg" not in metrics: | ||||||
|  |                     metrics["avg"] = {} | ||||||
|  |                 metrics["avg"][metric_name] = np.mean( | ||||||
|  |                     [v[metric_name] for k, v in metrics.items() if k != "avg"] | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |             logging.info(f"Metrics: {out_metrics}") | ||||||
|  |             logging.info(f"avg: {metrics['avg']}") | ||||||
|  |             print("metrics", out_metrics) | ||||||
|  |             print("avg", metrics["avg"]) | ||||||
|  |         else: | ||||||
|  |             rgbs = sample.video | ||||||
|  |             trajs_g = sample.trajectory | ||||||
|  |             valids = sample.valid | ||||||
|  |             vis_g = sample.visibility | ||||||
|  |  | ||||||
|  |             B, S, C, H, W = rgbs.shape | ||||||
|  |             assert C == 3 | ||||||
|  |             B, S, N, D = trajs_g.shape | ||||||
|  |  | ||||||
|  |             assert torch.sum(valids) == B * S * N | ||||||
|  |  | ||||||
|  |             vis_g = (torch.sum(vis_g, dim=1, keepdim=True) >= 4).float().repeat(1, S, 1) | ||||||
|  |  | ||||||
|  |             ate = torch.norm(pred_trajectory - trajs_g, dim=-1)  # B, S, N | ||||||
|  |  | ||||||
|  |             metrics["things_all"] = reduce_masked_mean(ate, valids).item() | ||||||
|  |             metrics["things_vis"] = reduce_masked_mean(ate, valids * vis_g).item() | ||||||
|  |             metrics["things_occ"] = reduce_masked_mean( | ||||||
|  |                 ate, valids * (1.0 - vis_g) | ||||||
|  |             ).item() | ||||||
|  |  | ||||||
|  |     @torch.no_grad() | ||||||
|  |     def evaluate_sequence( | ||||||
|  |         self, | ||||||
|  |         model, | ||||||
|  |         test_dataloader: torch.utils.data.DataLoader, | ||||||
|  |         dataset_name: str, | ||||||
|  |         train_mode=False, | ||||||
|  |         writer: Optional[SummaryWriter] = None, | ||||||
|  |         step: Optional[int] = 0, | ||||||
|  |     ): | ||||||
|  |         metrics = {} | ||||||
|  |  | ||||||
|  |         vis = Visualizer( | ||||||
|  |             save_dir=self.exp_dir, | ||||||
|  |             fps=7, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         for ind, sample in enumerate(tqdm(test_dataloader)): | ||||||
|  |             if isinstance(sample, tuple): | ||||||
|  |                 sample, gotit = sample | ||||||
|  |                 if not all(gotit): | ||||||
|  |                     print("batch is None") | ||||||
|  |                     continue | ||||||
|  |             dataclass_to_cuda_(sample) | ||||||
|  |  | ||||||
|  |             if ( | ||||||
|  |                 not train_mode | ||||||
|  |                 and hasattr(model, "sequence_len") | ||||||
|  |                 and (sample.visibility[:, : model.sequence_len].sum() == 0) | ||||||
|  |             ): | ||||||
|  |                 print(f"skipping batch {ind}") | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             if "tapvid" in dataset_name: | ||||||
|  |                 queries = sample.query_points.clone().float() | ||||||
|  |  | ||||||
|  |                 queries = torch.stack( | ||||||
|  |                     [ | ||||||
|  |                         queries[:, :, 0], | ||||||
|  |                         queries[:, :, 2], | ||||||
|  |                         queries[:, :, 1], | ||||||
|  |                     ], | ||||||
|  |                     dim=2, | ||||||
|  |                 ) | ||||||
|  |             else: | ||||||
|  |                 queries = torch.cat( | ||||||
|  |                     [ | ||||||
|  |                         torch.zeros_like(sample.trajectory[:, 0, :, :1]), | ||||||
|  |                         sample.trajectory[:, 0], | ||||||
|  |                     ], | ||||||
|  |                     dim=2, | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |             pred_tracks = model(sample.video, queries) | ||||||
|  |             if "strided" in dataset_name: | ||||||
|  |  | ||||||
|  |                 inv_video = sample.video.flip(1).clone() | ||||||
|  |                 inv_queries = queries.clone() | ||||||
|  |                 inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1 | ||||||
|  |  | ||||||
|  |                 pred_trj, pred_vsb = pred_tracks | ||||||
|  |                 inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries) | ||||||
|  |  | ||||||
|  |                 inv_pred_trj = inv_pred_trj.flip(1) | ||||||
|  |                 inv_pred_vsb = inv_pred_vsb.flip(1) | ||||||
|  |  | ||||||
|  |                 mask = pred_trj == 0 | ||||||
|  |  | ||||||
|  |                 pred_trj[mask] = inv_pred_trj[mask] | ||||||
|  |                 pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]] | ||||||
|  |  | ||||||
|  |                 pred_tracks = pred_trj, pred_vsb | ||||||
|  |  | ||||||
|  |             if dataset_name == "badja" or dataset_name == "fastcapture": | ||||||
|  |                 seq_name = sample.seq_name[0] | ||||||
|  |             else: | ||||||
|  |                 seq_name = str(ind) | ||||||
|  |  | ||||||
|  |             vis.visualize( | ||||||
|  |                 sample.video, | ||||||
|  |                 pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks, | ||||||
|  |                 filename=dataset_name + "_" + seq_name, | ||||||
|  |                 writer=writer, | ||||||
|  |                 step=step, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             self.compute_metrics(metrics, sample, pred_tracks, dataset_name) | ||||||
|  |         return metrics | ||||||
							
								
								
									
										179
									
								
								cotracker/evaluation/evaluate.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										179
									
								
								cotracker/evaluation/evaluate.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,179 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
|  | import json | ||||||
|  | import os | ||||||
|  | from dataclasses import dataclass, field | ||||||
|  |  | ||||||
|  | import hydra | ||||||
|  | import numpy as np | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | from omegaconf import OmegaConf | ||||||
|  |  | ||||||
|  | from cotracker.datasets.badja_dataset import BadjaDataset | ||||||
|  | from cotracker.datasets.fast_capture_dataset import FastCaptureDataset | ||||||
|  | from cotracker.datasets.tap_vid_datasets import TapVidDataset | ||||||
|  | from cotracker.datasets.utils import collate_fn | ||||||
|  |  | ||||||
|  | from cotracker.models.evaluation_predictor import EvaluationPredictor | ||||||
|  |  | ||||||
|  | from cotracker.evaluation.core.evaluator import Evaluator | ||||||
|  | from cotracker.models.build_cotracker import ( | ||||||
|  |     build_cotracker, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @dataclass(eq=False) | ||||||
|  | class DefaultConfig: | ||||||
|  |     # Directory where all outputs of the experiment will be saved. | ||||||
|  |     exp_dir: str = "./outputs" | ||||||
|  |  | ||||||
|  |     # Name of the dataset to be used for the evaluation. | ||||||
|  |     dataset_name: str = "badja" | ||||||
|  |     # The root directory of the dataset. | ||||||
|  |     dataset_root: str = "./" | ||||||
|  |  | ||||||
|  |     # Path to the pre-trained model checkpoint to be used for the evaluation. | ||||||
|  |     # The default value is the path to a specific CoTracker model checkpoint. | ||||||
|  |     # Other available options are commented. | ||||||
|  |     checkpoint: str = "./checkpoints/cotracker_stride_4_wind_8.pth" | ||||||
|  |     # cotracker_stride_4_wind_12 | ||||||
|  |     # cotracker_stride_8_wind_16 | ||||||
|  |  | ||||||
|  |     # EvaluationPredictor parameters | ||||||
|  |     # The size (N) of the support grid used in the predictor. | ||||||
|  |     # The total number of points is (N*N). | ||||||
|  |     grid_size: int = 6 | ||||||
|  |     # The size (N) of the local support grid. | ||||||
|  |     local_grid_size: int = 6 | ||||||
|  |     # A flag indicating whether to evaluate one ground truth point at a time. | ||||||
|  |     single_point: bool = True | ||||||
|  |     # The number of iterative updates for each sliding window. | ||||||
|  |     n_iters: int = 6 | ||||||
|  |  | ||||||
|  |     seed: int = 0 | ||||||
|  |     gpu_idx: int = 0 | ||||||
|  |  | ||||||
|  |     # Override hydra's working directory to current working dir, | ||||||
|  |     # also disable storing the .hydra logs: | ||||||
|  |     hydra: dict = field( | ||||||
|  |         default_factory=lambda: { | ||||||
|  |             "run": {"dir": "."}, | ||||||
|  |             "output_subdir": None, | ||||||
|  |         } | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def run_eval(cfg: DefaultConfig): | ||||||
|  |     """ | ||||||
|  |     The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         cfg (DefaultConfig): An instance of DefaultConfig class which includes: | ||||||
|  |             - exp_dir (str): The directory path for the experiment. | ||||||
|  |             - dataset_name (str): The name of the dataset to be used. | ||||||
|  |             - dataset_root (str): The root directory of the dataset. | ||||||
|  |             - checkpoint (str): The path to the CoTracker model's checkpoint. | ||||||
|  |             - single_point (bool): A flag indicating whether to evaluate one ground truth point at a time. | ||||||
|  |             - n_iters (int): The number of iterative updates for each sliding window. | ||||||
|  |             - seed (int): The seed for setting the random state for reproducibility. | ||||||
|  |             - gpu_idx (int): The index of the GPU to be used. | ||||||
|  |     """ | ||||||
|  |     # Creating the experiment directory if it doesn't exist | ||||||
|  |     os.makedirs(cfg.exp_dir, exist_ok=True) | ||||||
|  |  | ||||||
|  |     # Saving the experiment configuration to a .yaml file in the experiment directory | ||||||
|  |     cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml") | ||||||
|  |     with open(cfg_file, "w") as f: | ||||||
|  |         OmegaConf.save(config=cfg, f=f) | ||||||
|  |  | ||||||
|  |     evaluator = Evaluator(cfg.exp_dir) | ||||||
|  |     cotracker_model = build_cotracker(cfg.checkpoint) | ||||||
|  |  | ||||||
|  |     # Creating the EvaluationPredictor object | ||||||
|  |     predictor = EvaluationPredictor( | ||||||
|  |         cotracker_model, | ||||||
|  |         grid_size=cfg.grid_size, | ||||||
|  |         local_grid_size=cfg.local_grid_size, | ||||||
|  |         single_point=cfg.single_point, | ||||||
|  |         n_iters=cfg.n_iters, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # Setting the random seeds | ||||||
|  |     torch.manual_seed(cfg.seed) | ||||||
|  |     np.random.seed(cfg.seed) | ||||||
|  |  | ||||||
|  |     # Constructing the specified dataset | ||||||
|  |     curr_collate_fn = collate_fn | ||||||
|  |     if cfg.dataset_name == "badja": | ||||||
|  |         test_dataset = BadjaDataset(data_root=os.path.join(cfg.dataset_root, "BADJA")) | ||||||
|  |     elif cfg.dataset_name == "fastcapture": | ||||||
|  |         test_dataset = FastCaptureDataset( | ||||||
|  |             data_root=os.path.join(cfg.dataset_root, "fastcapture"), | ||||||
|  |             max_seq_len=100, | ||||||
|  |             max_num_points=20, | ||||||
|  |         ) | ||||||
|  |     elif "tapvid" in cfg.dataset_name: | ||||||
|  |         dataset_type = cfg.dataset_name.split("_")[1] | ||||||
|  |         if dataset_type == "davis": | ||||||
|  |             data_root = os.path.join(cfg.dataset_root, "/tapvid_davis/tapvid_davis.pkl") | ||||||
|  |         elif dataset_type == "kinetics": | ||||||
|  |             data_root = os.path.join( | ||||||
|  |                 cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics" | ||||||
|  |             ) | ||||||
|  |         test_dataset = TapVidDataset( | ||||||
|  |             dataset_type=dataset_type, | ||||||
|  |             data_root=data_root, | ||||||
|  |             queried_first=not "strided" in cfg.dataset_name, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     # Creating the DataLoader object | ||||||
|  |     test_dataloader = torch.utils.data.DataLoader( | ||||||
|  |         test_dataset, | ||||||
|  |         batch_size=1, | ||||||
|  |         shuffle=False, | ||||||
|  |         num_workers=14, | ||||||
|  |         collate_fn=curr_collate_fn, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # Timing and conducting the evaluation | ||||||
|  |     import time | ||||||
|  |  | ||||||
|  |     start = time.time() | ||||||
|  |     evaluate_result = evaluator.evaluate_sequence( | ||||||
|  |         predictor, | ||||||
|  |         test_dataloader, | ||||||
|  |         dataset_name=cfg.dataset_name, | ||||||
|  |     ) | ||||||
|  |     end = time.time() | ||||||
|  |     print(end - start) | ||||||
|  |  | ||||||
|  |     # Saving the evaluation results to a .json file | ||||||
|  |     if not "tapvid" in cfg.dataset_name: | ||||||
|  |         print("evaluate_result", evaluate_result) | ||||||
|  |     else: | ||||||
|  |         evaluate_result = evaluate_result["avg"] | ||||||
|  |     result_file = os.path.join(cfg.exp_dir, f"result_eval_.json") | ||||||
|  |     evaluate_result["time"] = end - start | ||||||
|  |     print(f"Dumping eval results to {result_file}.") | ||||||
|  |     with open(result_file, "w") as f: | ||||||
|  |         json.dump(evaluate_result, f) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | cs = hydra.core.config_store.ConfigStore.instance() | ||||||
|  | cs.store(name="default_config_eval", node=DefaultConfig) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @hydra.main(config_path="./configs/", config_name="default_config_eval") | ||||||
|  | def evaluate(cfg: DefaultConfig) -> None: | ||||||
|  |     os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | ||||||
|  |     os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx) | ||||||
|  |     run_eval(cfg) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     evaluate() | ||||||
							
								
								
									
										5
									
								
								cotracker/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
							
								
								
									
										70
									
								
								cotracker/models/build_cotracker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								cotracker/models/build_cotracker.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,70 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  | from cotracker.models.core.cotracker.cotracker import CoTracker | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def build_cotracker( | ||||||
|  |     checkpoint: str, | ||||||
|  | ): | ||||||
|  |     model_name = checkpoint.split("/")[-1].split(".")[0] | ||||||
|  |     if model_name == "cotracker_stride_4_wind_8": | ||||||
|  |         return build_cotracker_stride_4_wind_8(checkpoint=checkpoint) | ||||||
|  |     elif model_name == "cotracker_stride_4_wind_12": | ||||||
|  |         return build_cotracker_stride_4_wind_12(checkpoint=checkpoint) | ||||||
|  |     elif model_name == "cotracker_stride_8_wind_16": | ||||||
|  |         return build_cotracker_stride_8_wind_16(checkpoint=checkpoint) | ||||||
|  |     else: | ||||||
|  |         raise ValueError(f"Unknown model name {model_name}") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # model used to produce the results in the paper | ||||||
|  | def build_cotracker_stride_4_wind_8(checkpoint=None): | ||||||
|  |     return _build_cotracker( | ||||||
|  |         stride=4, | ||||||
|  |         sequence_len=8, | ||||||
|  |         checkpoint=checkpoint, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def build_cotracker_stride_4_wind_12(checkpoint=None): | ||||||
|  |     return _build_cotracker( | ||||||
|  |         stride=4, | ||||||
|  |         sequence_len=12, | ||||||
|  |         checkpoint=checkpoint, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # the fastest model | ||||||
|  | def build_cotracker_stride_8_wind_16(checkpoint=None): | ||||||
|  |     return _build_cotracker( | ||||||
|  |         stride=8, | ||||||
|  |         sequence_len=16, | ||||||
|  |         checkpoint=checkpoint, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _build_cotracker( | ||||||
|  |     stride, | ||||||
|  |     sequence_len, | ||||||
|  |     checkpoint=None, | ||||||
|  | ): | ||||||
|  |     cotracker = CoTracker( | ||||||
|  |         stride=stride, | ||||||
|  |         S=sequence_len, | ||||||
|  |         add_space_attn=True, | ||||||
|  |         space_depth=6, | ||||||
|  |         time_depth=6, | ||||||
|  |     ) | ||||||
|  |     if checkpoint is not None: | ||||||
|  |         with open(checkpoint, "rb") as f: | ||||||
|  |             state_dict = torch.load(f, map_location="cpu") | ||||||
|  |             if "model" in state_dict: | ||||||
|  |                 state_dict = state_dict["model"] | ||||||
|  |         cotracker.load_state_dict(state_dict) | ||||||
|  |     return cotracker | ||||||
							
								
								
									
										5
									
								
								cotracker/models/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/models/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
							
								
								
									
										5
									
								
								cotracker/models/core/cotracker/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/models/core/cotracker/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
							
								
								
									
										400
									
								
								cotracker/models/core/cotracker/blocks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										400
									
								
								cotracker/models/core/cotracker/blocks.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,400 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | import torch.nn.functional as F | ||||||
|  |  | ||||||
|  | from einops import rearrange | ||||||
|  | from timm.models.vision_transformer import Attention, Mlp | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ResidualBlock(nn.Module): | ||||||
|  |     def __init__(self, in_planes, planes, norm_fn="group", stride=1): | ||||||
|  |         super(ResidualBlock, self).__init__() | ||||||
|  |  | ||||||
|  |         self.conv1 = nn.Conv2d( | ||||||
|  |             in_planes, | ||||||
|  |             planes, | ||||||
|  |             kernel_size=3, | ||||||
|  |             padding=1, | ||||||
|  |             stride=stride, | ||||||
|  |             padding_mode="zeros", | ||||||
|  |         ) | ||||||
|  |         self.conv2 = nn.Conv2d( | ||||||
|  |             planes, planes, kernel_size=3, padding=1, padding_mode="zeros" | ||||||
|  |         ) | ||||||
|  |         self.relu = nn.ReLU(inplace=True) | ||||||
|  |  | ||||||
|  |         num_groups = planes // 8 | ||||||
|  |  | ||||||
|  |         if norm_fn == "group": | ||||||
|  |             self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) | ||||||
|  |             self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) | ||||||
|  |             if not stride == 1: | ||||||
|  |                 self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) | ||||||
|  |  | ||||||
|  |         elif norm_fn == "batch": | ||||||
|  |             self.norm1 = nn.BatchNorm2d(planes) | ||||||
|  |             self.norm2 = nn.BatchNorm2d(planes) | ||||||
|  |             if not stride == 1: | ||||||
|  |                 self.norm3 = nn.BatchNorm2d(planes) | ||||||
|  |  | ||||||
|  |         elif norm_fn == "instance": | ||||||
|  |             self.norm1 = nn.InstanceNorm2d(planes) | ||||||
|  |             self.norm2 = nn.InstanceNorm2d(planes) | ||||||
|  |             if not stride == 1: | ||||||
|  |                 self.norm3 = nn.InstanceNorm2d(planes) | ||||||
|  |  | ||||||
|  |         elif norm_fn == "none": | ||||||
|  |             self.norm1 = nn.Sequential() | ||||||
|  |             self.norm2 = nn.Sequential() | ||||||
|  |             if not stride == 1: | ||||||
|  |                 self.norm3 = nn.Sequential() | ||||||
|  |  | ||||||
|  |         if stride == 1: | ||||||
|  |             self.downsample = None | ||||||
|  |  | ||||||
|  |         else: | ||||||
|  |             self.downsample = nn.Sequential( | ||||||
|  |                 nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def forward(self, x): | ||||||
|  |         y = x | ||||||
|  |         y = self.relu(self.norm1(self.conv1(y))) | ||||||
|  |         y = self.relu(self.norm2(self.conv2(y))) | ||||||
|  |  | ||||||
|  |         if self.downsample is not None: | ||||||
|  |             x = self.downsample(x) | ||||||
|  |  | ||||||
|  |         return self.relu(x + y) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class BasicEncoder(nn.Module): | ||||||
|  |     def __init__( | ||||||
|  |         self, input_dim=3, output_dim=128, stride=8, norm_fn="batch", dropout=0.0 | ||||||
|  |     ): | ||||||
|  |         super(BasicEncoder, self).__init__() | ||||||
|  |         self.stride = stride | ||||||
|  |         self.norm_fn = norm_fn | ||||||
|  |         self.in_planes = 64 | ||||||
|  |  | ||||||
|  |         if self.norm_fn == "group": | ||||||
|  |             self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes) | ||||||
|  |             self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2) | ||||||
|  |  | ||||||
|  |         elif self.norm_fn == "batch": | ||||||
|  |             self.norm1 = nn.BatchNorm2d(self.in_planes) | ||||||
|  |             self.norm2 = nn.BatchNorm2d(output_dim * 2) | ||||||
|  |  | ||||||
|  |         elif self.norm_fn == "instance": | ||||||
|  |             self.norm1 = nn.InstanceNorm2d(self.in_planes) | ||||||
|  |             self.norm2 = nn.InstanceNorm2d(output_dim * 2) | ||||||
|  |  | ||||||
|  |         elif self.norm_fn == "none": | ||||||
|  |             self.norm1 = nn.Sequential() | ||||||
|  |  | ||||||
|  |         self.conv1 = nn.Conv2d( | ||||||
|  |             input_dim, | ||||||
|  |             self.in_planes, | ||||||
|  |             kernel_size=7, | ||||||
|  |             stride=2, | ||||||
|  |             padding=3, | ||||||
|  |             padding_mode="zeros", | ||||||
|  |         ) | ||||||
|  |         self.relu1 = nn.ReLU(inplace=True) | ||||||
|  |  | ||||||
|  |         self.shallow = False | ||||||
|  |         if self.shallow: | ||||||
|  |             self.layer1 = self._make_layer(64, stride=1) | ||||||
|  |             self.layer2 = self._make_layer(96, stride=2) | ||||||
|  |             self.layer3 = self._make_layer(128, stride=2) | ||||||
|  |             self.conv2 = nn.Conv2d(128 + 96 + 64, output_dim, kernel_size=1) | ||||||
|  |         else: | ||||||
|  |             self.layer1 = self._make_layer(64, stride=1) | ||||||
|  |             self.layer2 = self._make_layer(96, stride=2) | ||||||
|  |             self.layer3 = self._make_layer(128, stride=2) | ||||||
|  |             self.layer4 = self._make_layer(128, stride=2) | ||||||
|  |  | ||||||
|  |             self.conv2 = nn.Conv2d( | ||||||
|  |                 128 + 128 + 96 + 64, | ||||||
|  |                 output_dim * 2, | ||||||
|  |                 kernel_size=3, | ||||||
|  |                 padding=1, | ||||||
|  |                 padding_mode="zeros", | ||||||
|  |             ) | ||||||
|  |             self.relu2 = nn.ReLU(inplace=True) | ||||||
|  |             self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) | ||||||
|  |  | ||||||
|  |         self.dropout = None | ||||||
|  |         if dropout > 0: | ||||||
|  |             self.dropout = nn.Dropout2d(p=dropout) | ||||||
|  |  | ||||||
|  |         for m in self.modules(): | ||||||
|  |             if isinstance(m, nn.Conv2d): | ||||||
|  |                 nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | ||||||
|  |             elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): | ||||||
|  |                 if m.weight is not None: | ||||||
|  |                     nn.init.constant_(m.weight, 1) | ||||||
|  |                 if m.bias is not None: | ||||||
|  |                     nn.init.constant_(m.bias, 0) | ||||||
|  |  | ||||||
|  |     def _make_layer(self, dim, stride=1): | ||||||
|  |         layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) | ||||||
|  |         layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) | ||||||
|  |         layers = (layer1, layer2) | ||||||
|  |  | ||||||
|  |         self.in_planes = dim | ||||||
|  |         return nn.Sequential(*layers) | ||||||
|  |  | ||||||
|  |     def forward(self, x): | ||||||
|  |         _, _, H, W = x.shape | ||||||
|  |  | ||||||
|  |         x = self.conv1(x) | ||||||
|  |         x = self.norm1(x) | ||||||
|  |         x = self.relu1(x) | ||||||
|  |  | ||||||
|  |         if self.shallow: | ||||||
|  |             a = self.layer1(x) | ||||||
|  |             b = self.layer2(a) | ||||||
|  |             c = self.layer3(b) | ||||||
|  |             a = F.interpolate( | ||||||
|  |                 a, | ||||||
|  |                 (H // self.stride, W // self.stride), | ||||||
|  |                 mode="bilinear", | ||||||
|  |                 align_corners=True, | ||||||
|  |             ) | ||||||
|  |             b = F.interpolate( | ||||||
|  |                 b, | ||||||
|  |                 (H // self.stride, W // self.stride), | ||||||
|  |                 mode="bilinear", | ||||||
|  |                 align_corners=True, | ||||||
|  |             ) | ||||||
|  |             c = F.interpolate( | ||||||
|  |                 c, | ||||||
|  |                 (H // self.stride, W // self.stride), | ||||||
|  |                 mode="bilinear", | ||||||
|  |                 align_corners=True, | ||||||
|  |             ) | ||||||
|  |             x = self.conv2(torch.cat([a, b, c], dim=1)) | ||||||
|  |         else: | ||||||
|  |             a = self.layer1(x) | ||||||
|  |             b = self.layer2(a) | ||||||
|  |             c = self.layer3(b) | ||||||
|  |             d = self.layer4(c) | ||||||
|  |             a = F.interpolate( | ||||||
|  |                 a, | ||||||
|  |                 (H // self.stride, W // self.stride), | ||||||
|  |                 mode="bilinear", | ||||||
|  |                 align_corners=True, | ||||||
|  |             ) | ||||||
|  |             b = F.interpolate( | ||||||
|  |                 b, | ||||||
|  |                 (H // self.stride, W // self.stride), | ||||||
|  |                 mode="bilinear", | ||||||
|  |                 align_corners=True, | ||||||
|  |             ) | ||||||
|  |             c = F.interpolate( | ||||||
|  |                 c, | ||||||
|  |                 (H // self.stride, W // self.stride), | ||||||
|  |                 mode="bilinear", | ||||||
|  |                 align_corners=True, | ||||||
|  |             ) | ||||||
|  |             d = F.interpolate( | ||||||
|  |                 d, | ||||||
|  |                 (H // self.stride, W // self.stride), | ||||||
|  |                 mode="bilinear", | ||||||
|  |                 align_corners=True, | ||||||
|  |             ) | ||||||
|  |             x = self.conv2(torch.cat([a, b, c, d], dim=1)) | ||||||
|  |             x = self.norm2(x) | ||||||
|  |             x = self.relu2(x) | ||||||
|  |             x = self.conv3(x) | ||||||
|  |  | ||||||
|  |         if self.training and self.dropout is not None: | ||||||
|  |             x = self.dropout(x) | ||||||
|  |         return x | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AttnBlock(nn.Module): | ||||||
|  |     """ | ||||||
|  |     A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): | ||||||
|  |         super().__init__() | ||||||
|  |         self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | ||||||
|  |         self.attn = Attention( | ||||||
|  |             hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | ||||||
|  |         mlp_hidden_dim = int(hidden_size * mlp_ratio) | ||||||
|  |         approx_gelu = lambda: nn.GELU(approximate="tanh") | ||||||
|  |         self.mlp = Mlp( | ||||||
|  |             in_features=hidden_size, | ||||||
|  |             hidden_features=mlp_hidden_dim, | ||||||
|  |             act_layer=approx_gelu, | ||||||
|  |             drop=0, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def forward(self, x): | ||||||
|  |         x = x + self.attn(self.norm1(x)) | ||||||
|  |         x = x + self.mlp(self.norm2(x)) | ||||||
|  |         return x | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def bilinear_sampler(img, coords, mode="bilinear", mask=False): | ||||||
|  |     """Wrapper for grid_sample, uses pixel coordinates""" | ||||||
|  |     H, W = img.shape[-2:] | ||||||
|  |     xgrid, ygrid = coords.split([1, 1], dim=-1) | ||||||
|  |     # go to 0,1 then 0,2 then -1,1 | ||||||
|  |     xgrid = 2 * xgrid / (W - 1) - 1 | ||||||
|  |     ygrid = 2 * ygrid / (H - 1) - 1 | ||||||
|  |  | ||||||
|  |     grid = torch.cat([xgrid, ygrid], dim=-1) | ||||||
|  |     img = F.grid_sample(img, grid, align_corners=True) | ||||||
|  |  | ||||||
|  |     if mask: | ||||||
|  |         mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) | ||||||
|  |         return img, mask.float() | ||||||
|  |  | ||||||
|  |     return img | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CorrBlock: | ||||||
|  |     def __init__(self, fmaps, num_levels=4, radius=4): | ||||||
|  |         B, S, C, H, W = fmaps.shape | ||||||
|  |         self.S, self.C, self.H, self.W = S, C, H, W | ||||||
|  |  | ||||||
|  |         self.num_levels = num_levels | ||||||
|  |         self.radius = radius | ||||||
|  |         self.fmaps_pyramid = [] | ||||||
|  |  | ||||||
|  |         self.fmaps_pyramid.append(fmaps) | ||||||
|  |         for i in range(self.num_levels - 1): | ||||||
|  |             fmaps_ = fmaps.reshape(B * S, C, H, W) | ||||||
|  |             fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2) | ||||||
|  |             _, _, H, W = fmaps_.shape | ||||||
|  |             fmaps = fmaps_.reshape(B, S, C, H, W) | ||||||
|  |             self.fmaps_pyramid.append(fmaps) | ||||||
|  |  | ||||||
|  |     def sample(self, coords): | ||||||
|  |         r = self.radius | ||||||
|  |         B, S, N, D = coords.shape | ||||||
|  |         assert D == 2 | ||||||
|  |  | ||||||
|  |         H, W = self.H, self.W | ||||||
|  |         out_pyramid = [] | ||||||
|  |         for i in range(self.num_levels): | ||||||
|  |             corrs = self.corrs_pyramid[i]  # B, S, N, H, W | ||||||
|  |             _, _, _, H, W = corrs.shape | ||||||
|  |  | ||||||
|  |             dx = torch.linspace(-r, r, 2 * r + 1) | ||||||
|  |             dy = torch.linspace(-r, r, 2 * r + 1) | ||||||
|  |             delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to( | ||||||
|  |                 coords.device | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i | ||||||
|  |             delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) | ||||||
|  |             coords_lvl = centroid_lvl + delta_lvl | ||||||
|  |  | ||||||
|  |             corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl) | ||||||
|  |             corrs = corrs.view(B, S, N, -1) | ||||||
|  |             out_pyramid.append(corrs) | ||||||
|  |  | ||||||
|  |         out = torch.cat(out_pyramid, dim=-1)  # B, S, N, LRR*2 | ||||||
|  |         return out.contiguous().float() | ||||||
|  |  | ||||||
|  |     def corr(self, targets): | ||||||
|  |         B, S, N, C = targets.shape | ||||||
|  |         assert C == self.C | ||||||
|  |         assert S == self.S | ||||||
|  |  | ||||||
|  |         fmap1 = targets | ||||||
|  |  | ||||||
|  |         self.corrs_pyramid = [] | ||||||
|  |         for fmaps in self.fmaps_pyramid: | ||||||
|  |             _, _, _, H, W = fmaps.shape | ||||||
|  |             fmap2s = fmaps.view(B, S, C, H * W) | ||||||
|  |             corrs = torch.matmul(fmap1, fmap2s) | ||||||
|  |             corrs = corrs.view(B, S, N, H, W) | ||||||
|  |             corrs = corrs / torch.sqrt(torch.tensor(C).float()) | ||||||
|  |             self.corrs_pyramid.append(corrs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class UpdateFormer(nn.Module): | ||||||
|  |     """ | ||||||
|  |     Transformer model that updates track estimates. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         space_depth=12, | ||||||
|  |         time_depth=12, | ||||||
|  |         input_dim=320, | ||||||
|  |         hidden_size=384, | ||||||
|  |         num_heads=8, | ||||||
|  |         output_dim=130, | ||||||
|  |         mlp_ratio=4.0, | ||||||
|  |         add_space_attn=True, | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.out_channels = 2 | ||||||
|  |         self.num_heads = num_heads | ||||||
|  |         self.hidden_size = hidden_size | ||||||
|  |         self.add_space_attn = add_space_attn | ||||||
|  |         self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) | ||||||
|  |         self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) | ||||||
|  |  | ||||||
|  |         self.time_blocks = nn.ModuleList( | ||||||
|  |             [ | ||||||
|  |                 AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) | ||||||
|  |                 for _ in range(time_depth) | ||||||
|  |             ] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         if add_space_attn: | ||||||
|  |             self.space_blocks = nn.ModuleList( | ||||||
|  |                 [ | ||||||
|  |                     AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) | ||||||
|  |                     for _ in range(space_depth) | ||||||
|  |                 ] | ||||||
|  |             ) | ||||||
|  |             assert len(self.time_blocks) >= len(self.space_blocks) | ||||||
|  |         self.initialize_weights() | ||||||
|  |  | ||||||
|  |     def initialize_weights(self): | ||||||
|  |         def _basic_init(module): | ||||||
|  |             if isinstance(module, nn.Linear): | ||||||
|  |                 torch.nn.init.xavier_uniform_(module.weight) | ||||||
|  |                 if module.bias is not None: | ||||||
|  |                     nn.init.constant_(module.bias, 0) | ||||||
|  |  | ||||||
|  |         self.apply(_basic_init) | ||||||
|  |  | ||||||
|  |     def forward(self, input_tensor): | ||||||
|  |         x = self.input_transform(input_tensor) | ||||||
|  |  | ||||||
|  |         j = 0 | ||||||
|  |         for i in range(len(self.time_blocks)): | ||||||
|  |             B, N, T, _ = x.shape | ||||||
|  |             x_time = rearrange(x, "b n t c -> (b n) t c", b=B, t=T, n=N) | ||||||
|  |             x_time = self.time_blocks[i](x_time) | ||||||
|  |  | ||||||
|  |             x = rearrange(x_time, "(b n) t c -> b n t c ", b=B, t=T, n=N) | ||||||
|  |             if self.add_space_attn and ( | ||||||
|  |                 i % (len(self.time_blocks) // len(self.space_blocks)) == 0 | ||||||
|  |             ): | ||||||
|  |                 x_space = rearrange(x, "b n t c -> (b t) n c ", b=B, t=T, n=N) | ||||||
|  |                 x_space = self.space_blocks[j](x_space) | ||||||
|  |                 x = rearrange(x_space, "(b t) n c -> b n t c  ", b=B, t=T, n=N) | ||||||
|  |                 j += 1 | ||||||
|  |  | ||||||
|  |         flow = self.flow_head(x) | ||||||
|  |         return flow | ||||||
							
								
								
									
										351
									
								
								cotracker/models/core/cotracker/cotracker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										351
									
								
								cotracker/models/core/cotracker/cotracker.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,351 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | from einops import rearrange | ||||||
|  |  | ||||||
|  | from cotracker.models.core.cotracker.blocks import ( | ||||||
|  |     BasicEncoder, | ||||||
|  |     CorrBlock, | ||||||
|  |     UpdateFormer, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | from cotracker.models.core.model_utils import meshgrid2d, bilinear_sample2d, smart_cat | ||||||
|  | from cotracker.models.core.embeddings import ( | ||||||
|  |     get_2d_embedding, | ||||||
|  |     get_1d_sincos_pos_embed_from_grid, | ||||||
|  |     get_2d_sincos_pos_embed, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | torch.manual_seed(0) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)): | ||||||
|  |     if grid_size == 1: | ||||||
|  |         return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[ | ||||||
|  |             None, None | ||||||
|  |         ].cuda() | ||||||
|  |  | ||||||
|  |     grid_y, grid_x = meshgrid2d( | ||||||
|  |         1, grid_size, grid_size, stack=False, norm=False, device="cuda" | ||||||
|  |     ) | ||||||
|  |     step = interp_shape[1] // 64 | ||||||
|  |     if grid_center[0] != 0 or grid_center[1] != 0: | ||||||
|  |         grid_y = grid_y - grid_size / 2.0 | ||||||
|  |         grid_x = grid_x - grid_size / 2.0 | ||||||
|  |     grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * ( | ||||||
|  |         interp_shape[0] - step * 2 | ||||||
|  |     ) | ||||||
|  |     grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * ( | ||||||
|  |         interp_shape[1] - step * 2 | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     grid_y = grid_y + grid_center[0] | ||||||
|  |     grid_x = grid_x + grid_center[1] | ||||||
|  |     xy = torch.stack([grid_x, grid_y], dim=-1).cuda() | ||||||
|  |     return xy | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def sample_pos_embed(grid_size, embed_dim, coords): | ||||||
|  |     pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim, grid_size=grid_size) | ||||||
|  |     pos_embed = ( | ||||||
|  |         torch.from_numpy(pos_embed) | ||||||
|  |         .reshape(grid_size[0], grid_size[1], embed_dim) | ||||||
|  |         .float() | ||||||
|  |         .unsqueeze(0) | ||||||
|  |         .to(coords.device) | ||||||
|  |     ) | ||||||
|  |     sampled_pos_embed = bilinear_sample2d( | ||||||
|  |         pos_embed.permute(0, 3, 1, 2), coords[:, 0, :, 0], coords[:, 0, :, 1] | ||||||
|  |     ) | ||||||
|  |     return sampled_pos_embed | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CoTracker(nn.Module): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         S=8, | ||||||
|  |         stride=8, | ||||||
|  |         add_space_attn=True, | ||||||
|  |         num_heads=8, | ||||||
|  |         hidden_size=384, | ||||||
|  |         space_depth=12, | ||||||
|  |         time_depth=12, | ||||||
|  |     ): | ||||||
|  |         super(CoTracker, self).__init__() | ||||||
|  |         self.S = S | ||||||
|  |         self.stride = stride | ||||||
|  |         self.hidden_dim = 256 | ||||||
|  |         self.latent_dim = latent_dim = 128 | ||||||
|  |         self.corr_levels = 4 | ||||||
|  |         self.corr_radius = 3 | ||||||
|  |         self.add_space_attn = add_space_attn | ||||||
|  |         self.fnet = BasicEncoder( | ||||||
|  |             output_dim=self.latent_dim, norm_fn="instance", dropout=0, stride=stride | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         self.updateformer = UpdateFormer( | ||||||
|  |             space_depth=space_depth, | ||||||
|  |             time_depth=time_depth, | ||||||
|  |             input_dim=456, | ||||||
|  |             hidden_size=hidden_size, | ||||||
|  |             num_heads=num_heads, | ||||||
|  |             output_dim=latent_dim + 2, | ||||||
|  |             mlp_ratio=4.0, | ||||||
|  |             add_space_attn=add_space_attn, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         self.norm = nn.GroupNorm(1, self.latent_dim) | ||||||
|  |         self.ffeat_updater = nn.Sequential( | ||||||
|  |             nn.Linear(self.latent_dim, self.latent_dim), | ||||||
|  |             nn.GELU(), | ||||||
|  |         ) | ||||||
|  |         self.vis_predictor = nn.Sequential( | ||||||
|  |             nn.Linear(self.latent_dim, 1), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def forward_iteration( | ||||||
|  |         self, | ||||||
|  |         fmaps, | ||||||
|  |         coords_init, | ||||||
|  |         feat_init=None, | ||||||
|  |         vis_init=None, | ||||||
|  |         track_mask=None, | ||||||
|  |         iters=4, | ||||||
|  |     ): | ||||||
|  |         B, S_init, N, D = coords_init.shape | ||||||
|  |         assert D == 2 | ||||||
|  |         assert B == 1 | ||||||
|  |  | ||||||
|  |         B, S, __, H8, W8 = fmaps.shape | ||||||
|  |  | ||||||
|  |         device = fmaps.device | ||||||
|  |  | ||||||
|  |         if S_init < S: | ||||||
|  |             coords = torch.cat( | ||||||
|  |                 [coords_init, coords_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1 | ||||||
|  |             ) | ||||||
|  |             vis_init = torch.cat( | ||||||
|  |                 [vis_init, vis_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1 | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             coords = coords_init.clone() | ||||||
|  |  | ||||||
|  |         fcorr_fn = CorrBlock( | ||||||
|  |             fmaps, num_levels=self.corr_levels, radius=self.corr_radius | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         ffeats = feat_init.clone() | ||||||
|  |  | ||||||
|  |         times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1) | ||||||
|  |  | ||||||
|  |         pos_embed = sample_pos_embed( | ||||||
|  |             grid_size=(H8, W8), | ||||||
|  |             embed_dim=456, | ||||||
|  |             coords=coords, | ||||||
|  |         ) | ||||||
|  |         pos_embed = rearrange(pos_embed, "b e n -> (b n) e").unsqueeze(1) | ||||||
|  |         times_embed = ( | ||||||
|  |             torch.from_numpy(get_1d_sincos_pos_embed_from_grid(456, times_[0]))[None] | ||||||
|  |             .repeat(B, 1, 1) | ||||||
|  |             .float() | ||||||
|  |             .to(device) | ||||||
|  |         ) | ||||||
|  |         coord_predictions = [] | ||||||
|  |  | ||||||
|  |         for __ in range(iters): | ||||||
|  |             coords = coords.detach() | ||||||
|  |             fcorr_fn.corr(ffeats) | ||||||
|  |  | ||||||
|  |             fcorrs = fcorr_fn.sample(coords)  # B, S, N, LRR | ||||||
|  |             LRR = fcorrs.shape[3] | ||||||
|  |  | ||||||
|  |             fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR) | ||||||
|  |             flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) | ||||||
|  |  | ||||||
|  |             flows_cat = get_2d_embedding(flows_, 64, cat_coords=True) | ||||||
|  |             ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) | ||||||
|  |  | ||||||
|  |             if track_mask.shape[1] < vis_init.shape[1]: | ||||||
|  |                 track_mask = torch.cat( | ||||||
|  |                     [ | ||||||
|  |                         track_mask, | ||||||
|  |                         torch.zeros_like(track_mask[:, 0]).repeat( | ||||||
|  |                             1, vis_init.shape[1] - track_mask.shape[1], 1, 1 | ||||||
|  |                         ), | ||||||
|  |                     ], | ||||||
|  |                     dim=1, | ||||||
|  |                 ) | ||||||
|  |             concat = ( | ||||||
|  |                 torch.cat([track_mask, vis_init], dim=2) | ||||||
|  |                 .permute(0, 2, 1, 3) | ||||||
|  |                 .reshape(B * N, S, 2) | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             transformer_input = torch.cat([flows_cat, fcorrs_, ffeats_, concat], dim=2) | ||||||
|  |             x = transformer_input + pos_embed + times_embed | ||||||
|  |  | ||||||
|  |             x = rearrange(x, "(b n) t d -> b n t d", b=B) | ||||||
|  |  | ||||||
|  |             delta = self.updateformer(x) | ||||||
|  |  | ||||||
|  |             delta = rearrange(delta, " b n t d -> (b n) t d") | ||||||
|  |  | ||||||
|  |             delta_coords_ = delta[:, :, :2] | ||||||
|  |             delta_feats_ = delta[:, :, 2:] | ||||||
|  |  | ||||||
|  |             delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) | ||||||
|  |             ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N * S, self.latent_dim) | ||||||
|  |  | ||||||
|  |             ffeats_ = self.ffeat_updater(self.norm(delta_feats_)) + ffeats_ | ||||||
|  |  | ||||||
|  |             ffeats = ffeats_.reshape(B, N, S, self.latent_dim).permute( | ||||||
|  |                 0, 2, 1, 3 | ||||||
|  |             )  # B,S,N,C | ||||||
|  |  | ||||||
|  |             coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) | ||||||
|  |             coord_predictions.append(coords * self.stride) | ||||||
|  |  | ||||||
|  |         vis_e = self.vis_predictor(ffeats.reshape(B * S * N, self.latent_dim)).reshape( | ||||||
|  |             B, S, N | ||||||
|  |         ) | ||||||
|  |         return coord_predictions, vis_e, feat_init | ||||||
|  |  | ||||||
|  |     def forward(self, rgbs, queries, iters=4, feat_init=None, is_train=False): | ||||||
|  |         B, T, C, H, W = rgbs.shape | ||||||
|  |         B, N, __ = queries.shape | ||||||
|  |  | ||||||
|  |         device = rgbs.device | ||||||
|  |         assert B == 1 | ||||||
|  |         # INIT for the first sequence | ||||||
|  |         # We want to sort points by the first frame they are visible to add them to the tensor of tracked points consequtively | ||||||
|  |         first_positive_inds = queries[:, :, 0].long() | ||||||
|  |  | ||||||
|  |         __, sort_inds = torch.sort(first_positive_inds[0], dim=0, descending=False) | ||||||
|  |         inv_sort_inds = torch.argsort(sort_inds, dim=0) | ||||||
|  |         first_positive_sorted_inds = first_positive_inds[0][sort_inds] | ||||||
|  |  | ||||||
|  |         assert torch.allclose( | ||||||
|  |             first_positive_inds[0], first_positive_inds[0][sort_inds][inv_sort_inds] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         coords_init = queries[:, :, 1:].reshape(B, 1, N, 2).repeat( | ||||||
|  |             1, self.S, 1, 1 | ||||||
|  |         ) / float(self.stride) | ||||||
|  |  | ||||||
|  |         rgbs = 2 * (rgbs / 255.0) - 1.0 | ||||||
|  |  | ||||||
|  |         traj_e = torch.zeros((B, T, N, 2), device=device) | ||||||
|  |         vis_e = torch.zeros((B, T, N), device=device) | ||||||
|  |  | ||||||
|  |         ind_array = torch.arange(T, device=device) | ||||||
|  |         ind_array = ind_array[None, :, None].repeat(B, 1, N) | ||||||
|  |  | ||||||
|  |         track_mask = (ind_array >= first_positive_inds[:, None, :]).unsqueeze(-1) | ||||||
|  |         # these are logits, so we initialize visibility with something that would give a value close to 1 after softmax | ||||||
|  |         vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10 | ||||||
|  |  | ||||||
|  |         ind = 0 | ||||||
|  |  | ||||||
|  |         track_mask_ = track_mask[:, :, sort_inds].clone() | ||||||
|  |         coords_init_ = coords_init[:, :, sort_inds].clone() | ||||||
|  |         vis_init_ = vis_init[:, :, sort_inds].clone() | ||||||
|  |  | ||||||
|  |         prev_wind_idx = 0 | ||||||
|  |         fmaps_ = None | ||||||
|  |         vis_predictions = [] | ||||||
|  |         coord_predictions = [] | ||||||
|  |         wind_inds = [] | ||||||
|  |         while ind < T - self.S // 2: | ||||||
|  |             rgbs_seq = rgbs[:, ind : ind + self.S] | ||||||
|  |  | ||||||
|  |             S = S_local = rgbs_seq.shape[1] | ||||||
|  |             if S < self.S: | ||||||
|  |                 rgbs_seq = torch.cat( | ||||||
|  |                     [rgbs_seq, rgbs_seq[:, -1, None].repeat(1, self.S - S, 1, 1, 1)], | ||||||
|  |                     dim=1, | ||||||
|  |                 ) | ||||||
|  |                 S = rgbs_seq.shape[1] | ||||||
|  |             rgbs_ = rgbs_seq.reshape(B * S, C, H, W) | ||||||
|  |  | ||||||
|  |             if fmaps_ is None: | ||||||
|  |                 fmaps_ = self.fnet(rgbs_) | ||||||
|  |             else: | ||||||
|  |                 fmaps_ = torch.cat( | ||||||
|  |                     [fmaps_[self.S // 2 :], self.fnet(rgbs_[self.S // 2 :])], dim=0 | ||||||
|  |                 ) | ||||||
|  |             fmaps = fmaps_.reshape( | ||||||
|  |                 B, S, self.latent_dim, H // self.stride, W // self.stride | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             curr_wind_points = torch.nonzero(first_positive_sorted_inds < ind + self.S) | ||||||
|  |             if curr_wind_points.shape[0] == 0: | ||||||
|  |                 ind = ind + self.S // 2 | ||||||
|  |                 continue | ||||||
|  |             wind_idx = curr_wind_points[-1] + 1 | ||||||
|  |  | ||||||
|  |             if wind_idx - prev_wind_idx > 0: | ||||||
|  |                 fmaps_sample = fmaps[ | ||||||
|  |                     :, first_positive_sorted_inds[prev_wind_idx:wind_idx] - ind | ||||||
|  |                 ] | ||||||
|  |  | ||||||
|  |                 feat_init_ = bilinear_sample2d( | ||||||
|  |                     fmaps_sample, | ||||||
|  |                     coords_init_[:, 0, prev_wind_idx:wind_idx, 0], | ||||||
|  |                     coords_init_[:, 0, prev_wind_idx:wind_idx, 1], | ||||||
|  |                 ).permute(0, 2, 1) | ||||||
|  |  | ||||||
|  |                 feat_init_ = feat_init_.unsqueeze(1).repeat(1, self.S, 1, 1) | ||||||
|  |                 feat_init = smart_cat(feat_init, feat_init_, dim=2) | ||||||
|  |  | ||||||
|  |             if prev_wind_idx > 0: | ||||||
|  |                 new_coords = coords[-1][:, self.S // 2 :] / float(self.stride) | ||||||
|  |  | ||||||
|  |                 coords_init_[:, : self.S // 2, :prev_wind_idx] = new_coords | ||||||
|  |                 coords_init_[:, self.S // 2 :, :prev_wind_idx] = new_coords[ | ||||||
|  |                     :, -1 | ||||||
|  |                 ].repeat(1, self.S // 2, 1, 1) | ||||||
|  |  | ||||||
|  |                 new_vis = vis[:, self.S // 2 :].unsqueeze(-1) | ||||||
|  |                 vis_init_[:, : self.S // 2, :prev_wind_idx] = new_vis | ||||||
|  |                 vis_init_[:, self.S // 2 :, :prev_wind_idx] = new_vis[:, -1].repeat( | ||||||
|  |                     1, self.S // 2, 1, 1 | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |             coords, vis, __ = self.forward_iteration( | ||||||
|  |                 fmaps=fmaps, | ||||||
|  |                 coords_init=coords_init_[:, :, :wind_idx], | ||||||
|  |                 feat_init=feat_init[:, :, :wind_idx], | ||||||
|  |                 vis_init=vis_init_[:, :, :wind_idx], | ||||||
|  |                 track_mask=track_mask_[:, ind : ind + self.S, :wind_idx], | ||||||
|  |                 iters=iters, | ||||||
|  |             ) | ||||||
|  |             if is_train: | ||||||
|  |                 vis_predictions.append(torch.sigmoid(vis[:, :S_local])) | ||||||
|  |                 coord_predictions.append([coord[:, :S_local] for coord in coords]) | ||||||
|  |                 wind_inds.append(wind_idx) | ||||||
|  |  | ||||||
|  |             traj_e[:, ind : ind + self.S, :wind_idx] = coords[-1][:, :S_local] | ||||||
|  |             vis_e[:, ind : ind + self.S, :wind_idx] = vis[:, :S_local] | ||||||
|  |  | ||||||
|  |             track_mask_[:, : ind + self.S, :wind_idx] = 0.0 | ||||||
|  |             ind = ind + self.S // 2 | ||||||
|  |  | ||||||
|  |             prev_wind_idx = wind_idx | ||||||
|  |  | ||||||
|  |         traj_e = traj_e[:, :, inv_sort_inds] | ||||||
|  |         vis_e = vis_e[:, :, inv_sort_inds] | ||||||
|  |  | ||||||
|  |         vis_e = torch.sigmoid(vis_e) | ||||||
|  |  | ||||||
|  |         train_data = ( | ||||||
|  |             (vis_predictions, coord_predictions, wind_inds, sort_inds) | ||||||
|  |             if is_train | ||||||
|  |             else None | ||||||
|  |         ) | ||||||
|  |         return traj_e, feat_init, vis_e, train_data | ||||||
							
								
								
									
										61
									
								
								cotracker/models/core/cotracker/losses.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								cotracker/models/core/cotracker/losses.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torch.nn.functional as F | ||||||
|  | from cotracker.models.core.model_utils import reduce_masked_mean | ||||||
|  |  | ||||||
|  | EPS = 1e-6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def balanced_ce_loss(pred, gt, valid=None): | ||||||
|  |     total_balanced_loss = 0.0 | ||||||
|  |     for j in range(len(gt)): | ||||||
|  |         B, S, N = gt[j].shape | ||||||
|  |         # pred and gt are the same shape | ||||||
|  |         for (a, b) in zip(pred[j].size(), gt[j].size()): | ||||||
|  |             assert a == b  # some shape mismatch! | ||||||
|  |         # if valid is not None: | ||||||
|  |         for (a, b) in zip(pred[j].size(), valid[j].size()): | ||||||
|  |             assert a == b  # some shape mismatch! | ||||||
|  |  | ||||||
|  |         pos = (gt[j] > 0.95).float() | ||||||
|  |         neg = (gt[j] < 0.05).float() | ||||||
|  |  | ||||||
|  |         label = pos * 2.0 - 1.0 | ||||||
|  |         a = -label * pred[j] | ||||||
|  |         b = F.relu(a) | ||||||
|  |         loss = b + torch.log(torch.exp(-b) + torch.exp(a - b)) | ||||||
|  |  | ||||||
|  |         pos_loss = reduce_masked_mean(loss, pos * valid[j]) | ||||||
|  |         neg_loss = reduce_masked_mean(loss, neg * valid[j]) | ||||||
|  |  | ||||||
|  |         balanced_loss = pos_loss + neg_loss | ||||||
|  |         total_balanced_loss += balanced_loss / float(N) | ||||||
|  |     return total_balanced_loss | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8): | ||||||
|  |     """Loss function defined over sequence of flow predictions""" | ||||||
|  |     total_flow_loss = 0.0 | ||||||
|  |     for j in range(len(flow_gt)): | ||||||
|  |         B, S, N, D = flow_gt[j].shape | ||||||
|  |         assert D == 2 | ||||||
|  |         B, S1, N = vis[j].shape | ||||||
|  |         B, S2, N = valids[j].shape | ||||||
|  |         assert S == S1 | ||||||
|  |         assert S == S2 | ||||||
|  |         n_predictions = len(flow_preds[j]) | ||||||
|  |         flow_loss = 0.0 | ||||||
|  |         for i in range(n_predictions): | ||||||
|  |             i_weight = gamma ** (n_predictions - i - 1) | ||||||
|  |             flow_pred = flow_preds[j][i] | ||||||
|  |             i_loss = (flow_pred - flow_gt[j]).abs()  # B, S, N, 2 | ||||||
|  |             i_loss = torch.mean(i_loss, dim=3)  # B, S, N | ||||||
|  |             flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j]) | ||||||
|  |         flow_loss = flow_loss / n_predictions | ||||||
|  |         total_flow_loss += flow_loss / float(N) | ||||||
|  |     return total_flow_loss | ||||||
							
								
								
									
										154
									
								
								cotracker/models/core/embeddings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								cotracker/models/core/embeddings.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,154 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import numpy as np | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): | ||||||
|  |     """ | ||||||
|  |     grid_size: int of the grid height and width | ||||||
|  |     return: | ||||||
|  |     pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | ||||||
|  |     """ | ||||||
|  |     if isinstance(grid_size, tuple): | ||||||
|  |         grid_size_h, grid_size_w = grid_size | ||||||
|  |     else: | ||||||
|  |         grid_size_h = grid_size_w = grid_size | ||||||
|  |     grid_h = np.arange(grid_size_h, dtype=np.float32) | ||||||
|  |     grid_w = np.arange(grid_size_w, dtype=np.float32) | ||||||
|  |     grid = np.meshgrid(grid_w, grid_h)  # here w goes first | ||||||
|  |     grid = np.stack(grid, axis=0) | ||||||
|  |  | ||||||
|  |     grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) | ||||||
|  |     pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | ||||||
|  |     if cls_token and extra_tokens > 0: | ||||||
|  |         pos_embed = np.concatenate( | ||||||
|  |             [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 | ||||||
|  |         ) | ||||||
|  |     return pos_embed | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | ||||||
|  |     assert embed_dim % 2 == 0 | ||||||
|  |  | ||||||
|  |     # use half of dimensions to encode grid_h | ||||||
|  |     emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2) | ||||||
|  |     emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2) | ||||||
|  |  | ||||||
|  |     emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D) | ||||||
|  |     return emb | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | ||||||
|  |     """ | ||||||
|  |     embed_dim: output dimension for each position | ||||||
|  |     pos: a list of positions to be encoded: size (M,) | ||||||
|  |     out: (M, D) | ||||||
|  |     """ | ||||||
|  |     assert embed_dim % 2 == 0 | ||||||
|  |     omega = np.arange(embed_dim // 2, dtype=np.float64) | ||||||
|  |     omega /= embed_dim / 2.0 | ||||||
|  |     omega = 1.0 / 10000 ** omega  # (D/2,) | ||||||
|  |  | ||||||
|  |     pos = pos.reshape(-1)  # (M,) | ||||||
|  |     out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product | ||||||
|  |  | ||||||
|  |     emb_sin = np.sin(out)  # (M, D/2) | ||||||
|  |     emb_cos = np.cos(out)  # (M, D/2) | ||||||
|  |  | ||||||
|  |     emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D) | ||||||
|  |     return emb | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_2d_embedding(xy, C, cat_coords=True): | ||||||
|  |     B, N, D = xy.shape | ||||||
|  |     assert D == 2 | ||||||
|  |  | ||||||
|  |     x = xy[:, :, 0:1] | ||||||
|  |     y = xy[:, :, 1:2] | ||||||
|  |     div_term = ( | ||||||
|  |         torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C) | ||||||
|  |     ).reshape(1, 1, int(C / 2)) | ||||||
|  |  | ||||||
|  |     pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) | ||||||
|  |     pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) | ||||||
|  |  | ||||||
|  |     pe_x[:, :, 0::2] = torch.sin(x * div_term) | ||||||
|  |     pe_x[:, :, 1::2] = torch.cos(x * div_term) | ||||||
|  |  | ||||||
|  |     pe_y[:, :, 0::2] = torch.sin(y * div_term) | ||||||
|  |     pe_y[:, :, 1::2] = torch.cos(y * div_term) | ||||||
|  |  | ||||||
|  |     pe = torch.cat([pe_x, pe_y], dim=2)  # B, N, C*3 | ||||||
|  |     if cat_coords: | ||||||
|  |         pe = torch.cat([xy, pe], dim=2)  # B, N, C*3+3 | ||||||
|  |     return pe | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_3d_embedding(xyz, C, cat_coords=True): | ||||||
|  |     B, N, D = xyz.shape | ||||||
|  |     assert D == 3 | ||||||
|  |  | ||||||
|  |     x = xyz[:, :, 0:1] | ||||||
|  |     y = xyz[:, :, 1:2] | ||||||
|  |     z = xyz[:, :, 2:3] | ||||||
|  |     div_term = ( | ||||||
|  |         torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C) | ||||||
|  |     ).reshape(1, 1, int(C / 2)) | ||||||
|  |  | ||||||
|  |     pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) | ||||||
|  |     pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) | ||||||
|  |     pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) | ||||||
|  |  | ||||||
|  |     pe_x[:, :, 0::2] = torch.sin(x * div_term) | ||||||
|  |     pe_x[:, :, 1::2] = torch.cos(x * div_term) | ||||||
|  |  | ||||||
|  |     pe_y[:, :, 0::2] = torch.sin(y * div_term) | ||||||
|  |     pe_y[:, :, 1::2] = torch.cos(y * div_term) | ||||||
|  |  | ||||||
|  |     pe_z[:, :, 0::2] = torch.sin(z * div_term) | ||||||
|  |     pe_z[:, :, 1::2] = torch.cos(z * div_term) | ||||||
|  |  | ||||||
|  |     pe = torch.cat([pe_x, pe_y, pe_z], dim=2)  # B, N, C*3 | ||||||
|  |     if cat_coords: | ||||||
|  |         pe = torch.cat([pe, xyz], dim=2)  # B, N, C*3+3 | ||||||
|  |     return pe | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_4d_embedding(xyzw, C, cat_coords=True): | ||||||
|  |     B, N, D = xyzw.shape | ||||||
|  |     assert D == 4 | ||||||
|  |  | ||||||
|  |     x = xyzw[:, :, 0:1] | ||||||
|  |     y = xyzw[:, :, 1:2] | ||||||
|  |     z = xyzw[:, :, 2:3] | ||||||
|  |     w = xyzw[:, :, 3:4] | ||||||
|  |     div_term = ( | ||||||
|  |         torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C) | ||||||
|  |     ).reshape(1, 1, int(C / 2)) | ||||||
|  |  | ||||||
|  |     pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) | ||||||
|  |     pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) | ||||||
|  |     pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) | ||||||
|  |     pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) | ||||||
|  |  | ||||||
|  |     pe_x[:, :, 0::2] = torch.sin(x * div_term) | ||||||
|  |     pe_x[:, :, 1::2] = torch.cos(x * div_term) | ||||||
|  |  | ||||||
|  |     pe_y[:, :, 0::2] = torch.sin(y * div_term) | ||||||
|  |     pe_y[:, :, 1::2] = torch.cos(y * div_term) | ||||||
|  |  | ||||||
|  |     pe_z[:, :, 0::2] = torch.sin(z * div_term) | ||||||
|  |     pe_z[:, :, 1::2] = torch.cos(z * div_term) | ||||||
|  |  | ||||||
|  |     pe_w[:, :, 0::2] = torch.sin(w * div_term) | ||||||
|  |     pe_w[:, :, 1::2] = torch.cos(w * div_term) | ||||||
|  |  | ||||||
|  |     pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2)  # B, N, C*3 | ||||||
|  |     if cat_coords: | ||||||
|  |         pe = torch.cat([pe, xyzw], dim=2)  # B, N, C*3+3 | ||||||
|  |     return pe | ||||||
							
								
								
									
										169
									
								
								cotracker/models/core/model_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										169
									
								
								cotracker/models/core/model_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,169 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  | EPS = 1e-6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def smart_cat(tensor1, tensor2, dim): | ||||||
|  |     if tensor1 is None: | ||||||
|  |         return tensor2 | ||||||
|  |     return torch.cat([tensor1, tensor2], dim=dim) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def normalize_single(d): | ||||||
|  |     # d is a whatever shape torch tensor | ||||||
|  |     dmin = torch.min(d) | ||||||
|  |     dmax = torch.max(d) | ||||||
|  |     d = (d - dmin) / (EPS + (dmax - dmin)) | ||||||
|  |     return d | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def normalize(d): | ||||||
|  |     # d is B x whatever. normalize within each element of the batch | ||||||
|  |     out = torch.zeros(d.size()) | ||||||
|  |     if d.is_cuda: | ||||||
|  |         out = out.cuda() | ||||||
|  |     B = list(d.size())[0] | ||||||
|  |     for b in list(range(B)): | ||||||
|  |         out[b] = normalize_single(d[b]) | ||||||
|  |     return out | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"): | ||||||
|  |     # returns a meshgrid sized B x Y x X | ||||||
|  |  | ||||||
|  |     grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device)) | ||||||
|  |     grid_y = torch.reshape(grid_y, [1, Y, 1]) | ||||||
|  |     grid_y = grid_y.repeat(B, 1, X) | ||||||
|  |  | ||||||
|  |     grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device)) | ||||||
|  |     grid_x = torch.reshape(grid_x, [1, 1, X]) | ||||||
|  |     grid_x = grid_x.repeat(B, Y, 1) | ||||||
|  |  | ||||||
|  |     if stack: | ||||||
|  |         # note we stack in xy order | ||||||
|  |         # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample) | ||||||
|  |         grid = torch.stack([grid_x, grid_y], dim=-1) | ||||||
|  |         return grid | ||||||
|  |     else: | ||||||
|  |         return grid_y, grid_x | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def reduce_masked_mean(x, mask, dim=None, keepdim=False): | ||||||
|  |     # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting | ||||||
|  |     # returns shape-1 | ||||||
|  |     # axis can be a list of axes | ||||||
|  |     for (a, b) in zip(x.size(), mask.size()): | ||||||
|  |         assert a == b  # some shape mismatch! | ||||||
|  |     prod = x * mask | ||||||
|  |     if dim is None: | ||||||
|  |         numer = torch.sum(prod) | ||||||
|  |         denom = EPS + torch.sum(mask) | ||||||
|  |     else: | ||||||
|  |         numer = torch.sum(prod, dim=dim, keepdim=keepdim) | ||||||
|  |         denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim) | ||||||
|  |  | ||||||
|  |     mean = numer / denom | ||||||
|  |     return mean | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def bilinear_sample2d(im, x, y, return_inbounds=False): | ||||||
|  |     # x and y are each B, N | ||||||
|  |     # output is B, C, N | ||||||
|  |     if len(im.shape) == 5: | ||||||
|  |         B, N, C, H, W = list(im.shape) | ||||||
|  |     else: | ||||||
|  |         B, C, H, W = list(im.shape) | ||||||
|  |     N = list(x.shape)[1] | ||||||
|  |  | ||||||
|  |     x = x.float() | ||||||
|  |     y = y.float() | ||||||
|  |     H_f = torch.tensor(H, dtype=torch.float32) | ||||||
|  |     W_f = torch.tensor(W, dtype=torch.float32) | ||||||
|  |  | ||||||
|  |     # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float() | ||||||
|  |  | ||||||
|  |     max_y = (H_f - 1).int() | ||||||
|  |     max_x = (W_f - 1).int() | ||||||
|  |  | ||||||
|  |     x0 = torch.floor(x).int() | ||||||
|  |     x1 = x0 + 1 | ||||||
|  |     y0 = torch.floor(y).int() | ||||||
|  |     y1 = y0 + 1 | ||||||
|  |  | ||||||
|  |     x0_clip = torch.clamp(x0, 0, max_x) | ||||||
|  |     x1_clip = torch.clamp(x1, 0, max_x) | ||||||
|  |     y0_clip = torch.clamp(y0, 0, max_y) | ||||||
|  |     y1_clip = torch.clamp(y1, 0, max_y) | ||||||
|  |     dim2 = W | ||||||
|  |     dim1 = W * H | ||||||
|  |  | ||||||
|  |     base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1 | ||||||
|  |     base = torch.reshape(base, [B, 1]).repeat([1, N]) | ||||||
|  |  | ||||||
|  |     base_y0 = base + y0_clip * dim2 | ||||||
|  |     base_y1 = base + y1_clip * dim2 | ||||||
|  |  | ||||||
|  |     idx_y0_x0 = base_y0 + x0_clip | ||||||
|  |     idx_y0_x1 = base_y0 + x1_clip | ||||||
|  |     idx_y1_x0 = base_y1 + x0_clip | ||||||
|  |     idx_y1_x1 = base_y1 + x1_clip | ||||||
|  |  | ||||||
|  |     # use the indices to lookup pixels in the flat image | ||||||
|  |     # im is B x C x H x W | ||||||
|  |     # move C out to last dim | ||||||
|  |     if len(im.shape) == 5: | ||||||
|  |         im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C) | ||||||
|  |         i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute( | ||||||
|  |             0, 2, 1 | ||||||
|  |         ) | ||||||
|  |         i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute( | ||||||
|  |             0, 2, 1 | ||||||
|  |         ) | ||||||
|  |         i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute( | ||||||
|  |             0, 2, 1 | ||||||
|  |         ) | ||||||
|  |         i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute( | ||||||
|  |             0, 2, 1 | ||||||
|  |         ) | ||||||
|  |     else: | ||||||
|  |         im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C) | ||||||
|  |         i_y0_x0 = im_flat[idx_y0_x0.long()] | ||||||
|  |         i_y0_x1 = im_flat[idx_y0_x1.long()] | ||||||
|  |         i_y1_x0 = im_flat[idx_y1_x0.long()] | ||||||
|  |         i_y1_x1 = im_flat[idx_y1_x1.long()] | ||||||
|  |  | ||||||
|  |     # Finally calculate interpolated values. | ||||||
|  |     x0_f = x0.float() | ||||||
|  |     x1_f = x1.float() | ||||||
|  |     y0_f = y0.float() | ||||||
|  |     y1_f = y1.float() | ||||||
|  |  | ||||||
|  |     w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2) | ||||||
|  |     w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2) | ||||||
|  |     w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2) | ||||||
|  |     w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2) | ||||||
|  |  | ||||||
|  |     output = ( | ||||||
|  |         w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1 | ||||||
|  |     ) | ||||||
|  |     # output is B*N x C | ||||||
|  |     output = output.view(B, -1, C) | ||||||
|  |     output = output.permute(0, 2, 1) | ||||||
|  |     # output is B x C x N | ||||||
|  |  | ||||||
|  |     if return_inbounds: | ||||||
|  |         x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte() | ||||||
|  |         y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte() | ||||||
|  |         inbounds = (x_valid & y_valid).float() | ||||||
|  |         inbounds = inbounds.reshape( | ||||||
|  |             B, N | ||||||
|  |         )  # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1) | ||||||
|  |         return output, inbounds | ||||||
|  |  | ||||||
|  |     return output  # B, C, N | ||||||
							
								
								
									
										103
									
								
								cotracker/models/evaluation_predictor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								cotracker/models/evaluation_predictor.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,103 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torch.nn.functional as F | ||||||
|  | from typing import Tuple | ||||||
|  |  | ||||||
|  | from cotracker.models.core.cotracker.cotracker import CoTracker, get_points_on_a_grid | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class EvaluationPredictor(torch.nn.Module): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         cotracker_model: CoTracker, | ||||||
|  |         interp_shape: Tuple[int, int] = (384, 512), | ||||||
|  |         grid_size: int = 6, | ||||||
|  |         local_grid_size: int = 6, | ||||||
|  |         single_point: bool = True, | ||||||
|  |         n_iters: int = 6, | ||||||
|  |     ) -> None: | ||||||
|  |         super(EvaluationPredictor, self).__init__() | ||||||
|  |         self.grid_size = grid_size | ||||||
|  |         self.local_grid_size = local_grid_size | ||||||
|  |         self.single_point = single_point | ||||||
|  |         self.interp_shape = interp_shape | ||||||
|  |         self.n_iters = n_iters | ||||||
|  |  | ||||||
|  |         self.model = cotracker_model | ||||||
|  |         self.model.to("cuda") | ||||||
|  |         self.model.eval() | ||||||
|  |  | ||||||
|  |     def forward(self, video, queries): | ||||||
|  |         queries = queries.clone().cuda() | ||||||
|  |         B, T, C, H, W = video.shape | ||||||
|  |         B, N, D = queries.shape | ||||||
|  |  | ||||||
|  |         assert D == 3 | ||||||
|  |         assert B == 1 | ||||||
|  |  | ||||||
|  |         rgbs = video.reshape(B * T, C, H, W) | ||||||
|  |         rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear") | ||||||
|  |         rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]).cuda() | ||||||
|  |  | ||||||
|  |         queries[:, :, 1] *= self.interp_shape[1] / W | ||||||
|  |         queries[:, :, 2] *= self.interp_shape[0] / H | ||||||
|  |  | ||||||
|  |         if self.single_point: | ||||||
|  |             traj_e = torch.zeros((B, T, N, 2)).cuda() | ||||||
|  |             vis_e = torch.zeros((B, T, N)).cuda() | ||||||
|  |             for pind in range((N)): | ||||||
|  |                 query = queries[:, pind : pind + 1] | ||||||
|  |  | ||||||
|  |                 t = query[0, 0, 0].long() | ||||||
|  |  | ||||||
|  |                 traj_e_pind, vis_e_pind = self._process_one_point(rgbs, query) | ||||||
|  |                 traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1] | ||||||
|  |                 vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1] | ||||||
|  |         else: | ||||||
|  |             if self.grid_size > 0: | ||||||
|  |                 xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:]) | ||||||
|  |                 xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda()  # | ||||||
|  |                 queries = torch.cat([queries, xy], dim=1)  # | ||||||
|  |  | ||||||
|  |             traj_e, __, vis_e, __ = self.model( | ||||||
|  |                 rgbs=rgbs, | ||||||
|  |                 queries=queries, | ||||||
|  |                 iters=self.n_iters, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         traj_e[:, :, :, 0] *= W / float(self.interp_shape[1]) | ||||||
|  |         traj_e[:, :, :, 1] *= H / float(self.interp_shape[0]) | ||||||
|  |         return traj_e, vis_e | ||||||
|  |  | ||||||
|  |     def _process_one_point(self, rgbs, query): | ||||||
|  |         t = query[0, 0, 0].long() | ||||||
|  |  | ||||||
|  |         device = rgbs.device | ||||||
|  |         if self.local_grid_size > 0: | ||||||
|  |             xy_target = get_points_on_a_grid( | ||||||
|  |                 self.local_grid_size, | ||||||
|  |                 (50, 50), | ||||||
|  |                 [query[0, 0, 2], query[0, 0, 1]], | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             xy_target = torch.cat( | ||||||
|  |                 [torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2 | ||||||
|  |             )  # | ||||||
|  |             query = torch.cat([query, xy_target], dim=1).to(device)  # | ||||||
|  |  | ||||||
|  |         if self.grid_size > 0: | ||||||
|  |             xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:]) | ||||||
|  |             xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda()  # | ||||||
|  |             query = torch.cat([query, xy], dim=1).to(device)  # | ||||||
|  |         # crop the video to start from the queried frame | ||||||
|  |         query[0, 0, 0] = 0 | ||||||
|  |         traj_e_pind, __, vis_e_pind, __ = self.model( | ||||||
|  |             rgbs=rgbs[:, t:], queries=query, iters=self.n_iters | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         return traj_e_pind, vis_e_pind | ||||||
							
								
								
									
										178
									
								
								cotracker/predictor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										178
									
								
								cotracker/predictor.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,178 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torch.nn.functional as F | ||||||
|  |  | ||||||
|  | from tqdm import tqdm | ||||||
|  | from cotracker.models.core.cotracker.cotracker import get_points_on_a_grid | ||||||
|  | from cotracker.models.core.model_utils import smart_cat | ||||||
|  | from cotracker.models.build_cotracker import ( | ||||||
|  |     build_cotracker, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CoTrackerPredictor(torch.nn.Module): | ||||||
|  |     def __init__( | ||||||
|  |         self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth" | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.interp_shape = (384, 512) | ||||||
|  |         self.support_grid_size = 6 | ||||||
|  |         model = build_cotracker(checkpoint) | ||||||
|  |  | ||||||
|  |         self.model = model | ||||||
|  |         self.model.to("cuda") | ||||||
|  |         self.model.eval() | ||||||
|  |  | ||||||
|  |     @torch.no_grad() | ||||||
|  |     def forward( | ||||||
|  |         self, | ||||||
|  |         video,  # (1, T, 3, H, W) | ||||||
|  |         # input prompt types: | ||||||
|  |         # - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame. | ||||||
|  |         # *backward_tracking=True* will compute tracks in both directions. | ||||||
|  |         # - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates. | ||||||
|  |         # - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask. | ||||||
|  |         # You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks. | ||||||
|  |         queries: torch.Tensor = None, | ||||||
|  |         segm_mask: torch.Tensor = None,  # Segmentation mask of shape (B, 1, H, W) | ||||||
|  |         grid_size: int = 0, | ||||||
|  |         grid_query_frame: int = 0,  # only for dense and regular grid tracks | ||||||
|  |         backward_tracking: bool = False, | ||||||
|  |     ): | ||||||
|  |  | ||||||
|  |         if queries is None and grid_size == 0: | ||||||
|  |             tracks, visibilities = self._compute_dense_tracks( | ||||||
|  |                 video, | ||||||
|  |                 grid_query_frame=grid_query_frame, | ||||||
|  |                 backward_tracking=backward_tracking, | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             tracks, visibilities = self._compute_sparse_tracks( | ||||||
|  |                 video, | ||||||
|  |                 queries, | ||||||
|  |                 segm_mask, | ||||||
|  |                 grid_size, | ||||||
|  |                 add_support_grid=(grid_size == 0 or segm_mask is not None), | ||||||
|  |                 grid_query_frame=grid_query_frame, | ||||||
|  |                 backward_tracking=backward_tracking, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         return tracks, visibilities | ||||||
|  |  | ||||||
|  |     def _compute_dense_tracks( | ||||||
|  |         self, video, grid_query_frame, grid_size=50, backward_tracking=False | ||||||
|  |     ): | ||||||
|  |         *_, H, W = video.shape | ||||||
|  |         grid_step = W // grid_size | ||||||
|  |         grid_width = W // grid_step | ||||||
|  |         grid_height = H // grid_step | ||||||
|  |         tracks = visibilities = None | ||||||
|  |         grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to("cuda") | ||||||
|  |         grid_pts[0, :, 0] = grid_query_frame | ||||||
|  |         for offset in tqdm(range(grid_step * grid_step)): | ||||||
|  |             ox = offset % grid_step | ||||||
|  |             oy = offset // grid_step | ||||||
|  |             grid_pts[0, :, 1] = ( | ||||||
|  |                 torch.arange(grid_width).repeat(grid_height) * grid_step + ox | ||||||
|  |             ) | ||||||
|  |             grid_pts[0, :, 2] = ( | ||||||
|  |                 torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy | ||||||
|  |             ) | ||||||
|  |             tracks_step, visibilities_step = self._compute_sparse_tracks( | ||||||
|  |                 video=video, | ||||||
|  |                 queries=grid_pts, | ||||||
|  |                 backward_tracking=backward_tracking, | ||||||
|  |             ) | ||||||
|  |             tracks = smart_cat(tracks, tracks_step, dim=2) | ||||||
|  |             visibilities = smart_cat(visibilities, visibilities_step, dim=2) | ||||||
|  |  | ||||||
|  |         return tracks, visibilities | ||||||
|  |  | ||||||
|  |     def _compute_sparse_tracks( | ||||||
|  |         self, | ||||||
|  |         video, | ||||||
|  |         queries, | ||||||
|  |         segm_mask=None, | ||||||
|  |         grid_size=0, | ||||||
|  |         add_support_grid=False, | ||||||
|  |         grid_query_frame=0, | ||||||
|  |         backward_tracking=False, | ||||||
|  |     ): | ||||||
|  |         B, T, C, H, W = video.shape | ||||||
|  |         assert B == 1 | ||||||
|  |  | ||||||
|  |         video = video.reshape(B * T, C, H, W) | ||||||
|  |         video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").cuda() | ||||||
|  |         video = video.reshape( | ||||||
|  |             B, T, 3, self.interp_shape[0], self.interp_shape[1] | ||||||
|  |         ).cuda() | ||||||
|  |  | ||||||
|  |         if queries is not None: | ||||||
|  |             queries = queries.clone() | ||||||
|  |             B, N, D = queries.shape | ||||||
|  |             assert D == 3 | ||||||
|  |             queries[:, :, 1] *= self.interp_shape[1] / W | ||||||
|  |             queries[:, :, 2] *= self.interp_shape[0] / H | ||||||
|  |         elif grid_size > 0: | ||||||
|  |             grid_pts = get_points_on_a_grid(grid_size, self.interp_shape) | ||||||
|  |             if segm_mask is not None: | ||||||
|  |                 segm_mask = F.interpolate( | ||||||
|  |                     segm_mask, tuple(self.interp_shape), mode="nearest" | ||||||
|  |                 ) | ||||||
|  |                 point_mask = segm_mask[0, 0][ | ||||||
|  |                     (grid_pts[0, :, 1]).round().long().cpu(), | ||||||
|  |                     (grid_pts[0, :, 0]).round().long().cpu(), | ||||||
|  |                 ].bool() | ||||||
|  |                 grid_pts = grid_pts[:, point_mask] | ||||||
|  |  | ||||||
|  |             queries = torch.cat( | ||||||
|  |                 [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts], | ||||||
|  |                 dim=2, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         if add_support_grid: | ||||||
|  |             grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape) | ||||||
|  |             grid_pts = torch.cat( | ||||||
|  |                 [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 | ||||||
|  |             ) | ||||||
|  |             queries = torch.cat([queries, grid_pts], dim=1) | ||||||
|  |  | ||||||
|  |         tracks, __, visibilities, __ = self.model(rgbs=video, queries=queries, iters=6) | ||||||
|  |  | ||||||
|  |         if backward_tracking: | ||||||
|  |             tracks, visibilities = self._compute_backward_tracks( | ||||||
|  |                 video, queries, tracks, visibilities | ||||||
|  |             ) | ||||||
|  |             if add_support_grid: | ||||||
|  |                 queries[:, -self.support_grid_size ** 2 :, 0] = T - 1 | ||||||
|  |         if add_support_grid: | ||||||
|  |             tracks = tracks[:, :, : -self.support_grid_size ** 2] | ||||||
|  |             visibilities = visibilities[:, :, : -self.support_grid_size ** 2] | ||||||
|  |         thr = 0.9 | ||||||
|  |         visibilities = visibilities > thr | ||||||
|  |         tracks[:, :, :, 0] *= W / float(self.interp_shape[1]) | ||||||
|  |         tracks[:, :, :, 1] *= H / float(self.interp_shape[0]) | ||||||
|  |         return tracks, visibilities | ||||||
|  |  | ||||||
|  |     def _compute_backward_tracks(self, video, queries, tracks, visibilities): | ||||||
|  |         inv_video = video.flip(1).clone() | ||||||
|  |         inv_queries = queries.clone() | ||||||
|  |         inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1 | ||||||
|  |  | ||||||
|  |         inv_tracks, __, inv_visibilities, __ = self.model( | ||||||
|  |             rgbs=inv_video, queries=inv_queries, iters=6 | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         inv_tracks = inv_tracks.flip(1) | ||||||
|  |         inv_visibilities = inv_visibilities.flip(1) | ||||||
|  |  | ||||||
|  |         mask = tracks == 0 | ||||||
|  |  | ||||||
|  |         tracks[mask] = inv_tracks[mask] | ||||||
|  |         visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]] | ||||||
|  |         return tracks, visibilities | ||||||
							
								
								
									
										5
									
								
								cotracker/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
							
								
								
									
										291
									
								
								cotracker/utils/visualizer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										291
									
								
								cotracker/utils/visualizer.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,291 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
|  | import os | ||||||
|  | import numpy as np | ||||||
|  | import cv2 | ||||||
|  | import torch | ||||||
|  | import flow_vis | ||||||
|  |  | ||||||
|  | from matplotlib import cm | ||||||
|  | import torch.nn.functional as F | ||||||
|  | import torchvision.transforms as transforms | ||||||
|  | from moviepy.editor import ImageSequenceClip | ||||||
|  | from torch.utils.tensorboard import SummaryWriter | ||||||
|  | import matplotlib.pyplot as plt | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Visualizer: | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         save_dir: str = "./results", | ||||||
|  |         grayscale: bool = False, | ||||||
|  |         pad_value: int = 0, | ||||||
|  |         fps: int = 10, | ||||||
|  |         mode: str = "rainbow",  # 'cool', 'optical_flow' | ||||||
|  |         linewidth: int = 2, | ||||||
|  |         show_first_frame: int = 10, | ||||||
|  |         tracks_leave_trace: int = 0,  # -1 for infinite | ||||||
|  |     ): | ||||||
|  |         self.mode = mode | ||||||
|  |         self.save_dir = save_dir | ||||||
|  |         if mode == "rainbow": | ||||||
|  |             self.color_map = cm.get_cmap("gist_rainbow") | ||||||
|  |         elif mode == "cool": | ||||||
|  |             self.color_map = cm.get_cmap(mode) | ||||||
|  |         self.show_first_frame = show_first_frame | ||||||
|  |         self.grayscale = grayscale | ||||||
|  |         self.tracks_leave_trace = tracks_leave_trace | ||||||
|  |         self.pad_value = pad_value | ||||||
|  |         self.linewidth = linewidth | ||||||
|  |         self.fps = fps | ||||||
|  |  | ||||||
|  |     def visualize( | ||||||
|  |         self, | ||||||
|  |         video: torch.Tensor,  # (B,T,C,H,W) | ||||||
|  |         tracks: torch.Tensor,  # (B,T,N,2) | ||||||
|  |         gt_tracks: torch.Tensor = None,  # (B,T,N,2) | ||||||
|  |         segm_mask: torch.Tensor = None,  # (B,1,H,W) | ||||||
|  |         filename: str = "video", | ||||||
|  |         writer: SummaryWriter = None, | ||||||
|  |         step: int = 0, | ||||||
|  |         query_frame: int = 0, | ||||||
|  |         save_video: bool = True, | ||||||
|  |         compensate_for_camera_motion: bool = False, | ||||||
|  |     ): | ||||||
|  |         if compensate_for_camera_motion: | ||||||
|  |             assert segm_mask is not None | ||||||
|  |         if segm_mask is not None: | ||||||
|  |             coords = tracks[0, query_frame].round().long() | ||||||
|  |             segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long() | ||||||
|  |  | ||||||
|  |         video = F.pad( | ||||||
|  |             video, | ||||||
|  |             (self.pad_value, self.pad_value, self.pad_value, self.pad_value), | ||||||
|  |             "constant", | ||||||
|  |             255, | ||||||
|  |         ) | ||||||
|  |         tracks = tracks + self.pad_value | ||||||
|  |  | ||||||
|  |         if self.grayscale: | ||||||
|  |             transform = transforms.Grayscale() | ||||||
|  |             video = transform(video) | ||||||
|  |             video = video.repeat(1, 1, 3, 1, 1) | ||||||
|  |  | ||||||
|  |         res_video = self.draw_tracks_on_video( | ||||||
|  |             video=video, | ||||||
|  |             tracks=tracks, | ||||||
|  |             segm_mask=segm_mask, | ||||||
|  |             gt_tracks=gt_tracks, | ||||||
|  |             query_frame=query_frame, | ||||||
|  |             compensate_for_camera_motion=compensate_for_camera_motion, | ||||||
|  |         ) | ||||||
|  |         if save_video: | ||||||
|  |             self.save_video(res_video, filename=filename, writer=writer, step=step) | ||||||
|  |         return res_video | ||||||
|  |  | ||||||
|  |     def save_video(self, video, filename, writer=None, step=0): | ||||||
|  |         if writer is not None: | ||||||
|  |             writer.add_video( | ||||||
|  |                 f"{filename}_pred_track", | ||||||
|  |                 video.to(torch.uint8), | ||||||
|  |                 global_step=step, | ||||||
|  |                 fps=self.fps, | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             os.makedirs(self.save_dir, exist_ok=True) | ||||||
|  |             wide_list = list(video.unbind(1)) | ||||||
|  |             wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list] | ||||||
|  |             clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps) | ||||||
|  |  | ||||||
|  |             # Write the video file | ||||||
|  |             save_path = os.path.join(self.save_dir, f"{filename}_pred_track.mp4") | ||||||
|  |             clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None) | ||||||
|  |  | ||||||
|  |             print(f"Video saved to {save_path}") | ||||||
|  |  | ||||||
|  |     def draw_tracks_on_video( | ||||||
|  |         self, | ||||||
|  |         video: torch.Tensor, | ||||||
|  |         tracks: torch.Tensor, | ||||||
|  |         segm_mask: torch.Tensor = None, | ||||||
|  |         gt_tracks=None, | ||||||
|  |         query_frame: int = 0, | ||||||
|  |         compensate_for_camera_motion=False, | ||||||
|  |     ): | ||||||
|  |         B, T, C, H, W = video.shape | ||||||
|  |         _, _, N, D = tracks.shape | ||||||
|  |  | ||||||
|  |         assert D == 2 | ||||||
|  |         assert C == 3 | ||||||
|  |         video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy()  # S, H, W, C | ||||||
|  |         tracks = tracks[0].long().detach().cpu().numpy()  # S, N, 2 | ||||||
|  |         if gt_tracks is not None: | ||||||
|  |             gt_tracks = gt_tracks[0].detach().cpu().numpy() | ||||||
|  |  | ||||||
|  |         res_video = [] | ||||||
|  |  | ||||||
|  |         # process input video | ||||||
|  |         for rgb in video: | ||||||
|  |             res_video.append(rgb.copy()) | ||||||
|  |  | ||||||
|  |         vector_colors = np.zeros((T, N, 3)) | ||||||
|  |         if self.mode == "optical_flow": | ||||||
|  |             vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None]) | ||||||
|  |         elif segm_mask is None: | ||||||
|  |             if self.mode == "rainbow": | ||||||
|  |                 y_min, y_max = ( | ||||||
|  |                     tracks[query_frame, :, 1].min(), | ||||||
|  |                     tracks[query_frame, :, 1].max(), | ||||||
|  |                 ) | ||||||
|  |                 norm = plt.Normalize(y_min, y_max) | ||||||
|  |                 for n in range(N): | ||||||
|  |                     color = self.color_map(norm(tracks[query_frame, n, 1])) | ||||||
|  |                     color = np.array(color[:3])[None] * 255 | ||||||
|  |                     vector_colors[:, n] = np.repeat(color, T, axis=0) | ||||||
|  |             else: | ||||||
|  |                 # color changes with time | ||||||
|  |                 for t in range(T): | ||||||
|  |                     color = np.array(self.color_map(t / T)[:3])[None] * 255 | ||||||
|  |                     vector_colors[t] = np.repeat(color, N, axis=0) | ||||||
|  |         else: | ||||||
|  |             if self.mode == "rainbow": | ||||||
|  |                 vector_colors[:, segm_mask <= 0, :] = 255 | ||||||
|  |  | ||||||
|  |                 y_min, y_max = ( | ||||||
|  |                     tracks[0, segm_mask > 0, 1].min(), | ||||||
|  |                     tracks[0, segm_mask > 0, 1].max(), | ||||||
|  |                 ) | ||||||
|  |                 norm = plt.Normalize(y_min, y_max) | ||||||
|  |                 for n in range(N): | ||||||
|  |                     if segm_mask[n] > 0: | ||||||
|  |                         color = self.color_map(norm(tracks[0, n, 1])) | ||||||
|  |                         color = np.array(color[:3])[None] * 255 | ||||||
|  |                         vector_colors[:, n] = np.repeat(color, T, axis=0) | ||||||
|  |  | ||||||
|  |             else: | ||||||
|  |                 # color changes with segm class | ||||||
|  |                 segm_mask = segm_mask.cpu() | ||||||
|  |                 color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32) | ||||||
|  |                 color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0 | ||||||
|  |                 color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0 | ||||||
|  |                 vector_colors = np.repeat(color[None], T, axis=0) | ||||||
|  |  | ||||||
|  |         #  draw tracks | ||||||
|  |         if self.tracks_leave_trace != 0: | ||||||
|  |             for t in range(1, T): | ||||||
|  |                 first_ind = ( | ||||||
|  |                     max(0, t - self.tracks_leave_trace) | ||||||
|  |                     if self.tracks_leave_trace >= 0 | ||||||
|  |                     else 0 | ||||||
|  |                 ) | ||||||
|  |                 curr_tracks = tracks[first_ind : t + 1] | ||||||
|  |                 curr_colors = vector_colors[first_ind : t + 1] | ||||||
|  |                 if compensate_for_camera_motion: | ||||||
|  |                     diff = ( | ||||||
|  |                         tracks[first_ind : t + 1, segm_mask <= 0] | ||||||
|  |                         - tracks[t : t + 1, segm_mask <= 0] | ||||||
|  |                     ).mean(1)[:, None] | ||||||
|  |  | ||||||
|  |                     curr_tracks = curr_tracks - diff | ||||||
|  |                     curr_tracks = curr_tracks[:, segm_mask > 0] | ||||||
|  |                     curr_colors = curr_colors[:, segm_mask > 0] | ||||||
|  |  | ||||||
|  |                 res_video[t] = self._draw_pred_tracks( | ||||||
|  |                     res_video[t], | ||||||
|  |                     curr_tracks, | ||||||
|  |                     curr_colors, | ||||||
|  |                 ) | ||||||
|  |                 if gt_tracks is not None: | ||||||
|  |                     res_video[t] = self._draw_gt_tracks( | ||||||
|  |                         res_video[t], gt_tracks[first_ind : t + 1] | ||||||
|  |                     ) | ||||||
|  |  | ||||||
|  |         #  draw points | ||||||
|  |         for t in range(T): | ||||||
|  |             for i in range(N): | ||||||
|  |                 coord = (tracks[t, i, 0], tracks[t, i, 1]) | ||||||
|  |                 if coord[0] != 0 and coord[1] != 0: | ||||||
|  |                     if not compensate_for_camera_motion or ( | ||||||
|  |                         compensate_for_camera_motion and segm_mask[i] > 0 | ||||||
|  |                     ): | ||||||
|  |                         cv2.circle( | ||||||
|  |                             res_video[t], | ||||||
|  |                             coord, | ||||||
|  |                             int(self.linewidth * 2), | ||||||
|  |                             vector_colors[t, i].tolist(), | ||||||
|  |                             -1, | ||||||
|  |                         ) | ||||||
|  |  | ||||||
|  |         #  construct the final rgb sequence | ||||||
|  |         if self.show_first_frame > 0: | ||||||
|  |             res_video = [res_video[0]] * self.show_first_frame + res_video[1:] | ||||||
|  |         return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte() | ||||||
|  |  | ||||||
|  |     def _draw_pred_tracks( | ||||||
|  |         self, | ||||||
|  |         rgb: np.ndarray,  # H x W x 3 | ||||||
|  |         tracks: np.ndarray,  # T x 2 | ||||||
|  |         vector_colors: np.ndarray, | ||||||
|  |         alpha: float = 0.5, | ||||||
|  |     ): | ||||||
|  |         T, N, _ = tracks.shape | ||||||
|  |  | ||||||
|  |         for s in range(T - 1): | ||||||
|  |             vector_color = vector_colors[s] | ||||||
|  |             original = rgb.copy() | ||||||
|  |             alpha = (s / T) ** 2 | ||||||
|  |             for i in range(N): | ||||||
|  |                 coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1])) | ||||||
|  |                 coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1])) | ||||||
|  |                 if coord_y[0] != 0 and coord_y[1] != 0: | ||||||
|  |                     cv2.line( | ||||||
|  |                         rgb, | ||||||
|  |                         coord_y, | ||||||
|  |                         coord_x, | ||||||
|  |                         vector_color[i].tolist(), | ||||||
|  |                         self.linewidth, | ||||||
|  |                         cv2.LINE_AA, | ||||||
|  |                     ) | ||||||
|  |             if self.tracks_leave_trace > 0: | ||||||
|  |                 rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0) | ||||||
|  |         return rgb | ||||||
|  |  | ||||||
|  |     def _draw_gt_tracks( | ||||||
|  |         self, | ||||||
|  |         rgb: np.ndarray,  # H x W x 3, | ||||||
|  |         gt_tracks: np.ndarray,  # T x 2 | ||||||
|  |     ): | ||||||
|  |         T, N, _ = gt_tracks.shape | ||||||
|  |         color = np.array((211.0, 0.0, 0.0)) | ||||||
|  |  | ||||||
|  |         for t in range(T): | ||||||
|  |             for i in range(N): | ||||||
|  |                 gt_tracks = gt_tracks[t][i] | ||||||
|  |                 #  draw a red cross | ||||||
|  |                 if gt_tracks[0] > 0 and gt_tracks[1] > 0: | ||||||
|  |                     length = self.linewidth * 3 | ||||||
|  |                     coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length) | ||||||
|  |                     coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length) | ||||||
|  |                     cv2.line( | ||||||
|  |                         rgb, | ||||||
|  |                         coord_y, | ||||||
|  |                         coord_x, | ||||||
|  |                         color, | ||||||
|  |                         self.linewidth, | ||||||
|  |                         cv2.LINE_AA, | ||||||
|  |                     ) | ||||||
|  |                     coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length) | ||||||
|  |                     coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length) | ||||||
|  |                     cv2.line( | ||||||
|  |                         rgb, | ||||||
|  |                         coord_y, | ||||||
|  |                         coord_x, | ||||||
|  |                         color, | ||||||
|  |                         self.linewidth, | ||||||
|  |                         cv2.LINE_AA, | ||||||
|  |                     ) | ||||||
|  |         return rgb | ||||||
							
								
								
									
										71
									
								
								demo.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								demo.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,71 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
|  | import os | ||||||
|  | import torch | ||||||
|  | import argparse | ||||||
|  | import numpy as np | ||||||
|  |  | ||||||
|  | from torchvision.io import read_video | ||||||
|  | from PIL import Image | ||||||
|  | from cotracker.utils.visualizer import Visualizer | ||||||
|  | from cotracker.predictor import CoTrackerPredictor | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = argparse.ArgumentParser() | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--video_path", | ||||||
|  |         default="./assets/apple.mp4", | ||||||
|  |         help="path to a video", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--mask_path", | ||||||
|  |         default="./assets/apple_mask.png", | ||||||
|  |         help="path to a segmentation mask", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--checkpoint", | ||||||
|  |         default="./checkpoints/cotracker_stride_4_wind_8.pth", | ||||||
|  |         help="cotracker model", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--grid_query_frame", | ||||||
|  |         type=int, | ||||||
|  |         default=0, | ||||||
|  |         help="Compute dense and grid tracks starting from this frame ", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--backward_tracking", | ||||||
|  |         action="store_true", | ||||||
|  |         help="Compute tracks in both directions, not only forward", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     args = parser.parse_args() | ||||||
|  |  | ||||||
|  |     # load the input video frame by frame | ||||||
|  |     video = read_video(args.video_path) | ||||||
|  |     video = video[0].permute(0, 3, 1, 2)[None].float() | ||||||
|  |     segm_mask = np.array(Image.open(os.path.join(args.mask_path))) | ||||||
|  |     segm_mask = torch.from_numpy(segm_mask)[None, None] | ||||||
|  |  | ||||||
|  |     model = CoTrackerPredictor(checkpoint=args.checkpoint) | ||||||
|  |  | ||||||
|  |     pred_tracks, pred_visibility = model( | ||||||
|  |         video, | ||||||
|  |         grid_size=args.grid_size, | ||||||
|  |         grid_query_frame=args.grid_query_frame, | ||||||
|  |         backward_tracking=args.backward_tracking, | ||||||
|  |         # segm_mask=segm_mask | ||||||
|  |     ) | ||||||
|  |     print("computed") | ||||||
|  |  | ||||||
|  |     # save a video with predicted tracks | ||||||
|  |     seq_name = args.video_path.split("/")[-1] | ||||||
|  |     vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) | ||||||
|  |     vis.visualize(video, pred_tracks, query_frame=args.grid_query_frame) | ||||||
							
								
								
									
										924
									
								
								notebooks/demo.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										924
									
								
								notebooks/demo.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										18
									
								
								setup.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								setup.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,18 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
|  | from setuptools import find_packages, setup | ||||||
|  |  | ||||||
|  | setup( | ||||||
|  |     name="cotracker", | ||||||
|  |     version="1.0", | ||||||
|  |     install_requires=[], | ||||||
|  |     packages=find_packages(exclude="notebooks"), | ||||||
|  |     extras_require={ | ||||||
|  |         "all": ["matplotlib", "opencv-python"], | ||||||
|  |         "dev": ["flake8", "black"], | ||||||
|  |     }, | ||||||
|  | ) | ||||||
							
								
								
									
										665
									
								
								train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										665
									
								
								train.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,665 @@ | |||||||
|  | # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  |  | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  |  | ||||||
|  | import os | ||||||
|  | import random | ||||||
|  | import torch | ||||||
|  | import signal | ||||||
|  | import socket | ||||||
|  | import sys | ||||||
|  | import json | ||||||
|  |  | ||||||
|  | import numpy as np | ||||||
|  | import argparse | ||||||
|  | import logging | ||||||
|  | from pathlib import Path | ||||||
|  | from tqdm import tqdm | ||||||
|  | import torch.optim as optim | ||||||
|  | from torch.utils.data import DataLoader | ||||||
|  | from torch.cuda.amp import GradScaler | ||||||
|  |  | ||||||
|  | from torch.utils.tensorboard import SummaryWriter | ||||||
|  | from pytorch_lightning.lite import LightningLite | ||||||
|  |  | ||||||
|  | from cotracker.models.evaluation_predictor import EvaluationPredictor | ||||||
|  | from cotracker.models.core.cotracker.cotracker import CoTracker | ||||||
|  | from cotracker.utils.visualizer import Visualizer | ||||||
|  | from cotracker.datasets.tap_vid_datasets import TapVidDataset | ||||||
|  | from cotracker.datasets.badja_dataset import BadjaDataset | ||||||
|  | from cotracker.datasets.fast_capture_dataset import FastCaptureDataset | ||||||
|  | from cotracker.evaluation.core.evaluator import Evaluator | ||||||
|  | from cotracker.datasets import kubric_movif_dataset | ||||||
|  | from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_cuda_ | ||||||
|  | from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # define the handler function | ||||||
|  | # for training on a slurm cluster | ||||||
|  | def sig_handler(signum, frame): | ||||||
|  |     print("caught signal", signum) | ||||||
|  |     print(socket.gethostname(), "USR1 signal caught.") | ||||||
|  |     # do other stuff to cleanup here | ||||||
|  |     print("requeuing job " + os.environ["SLURM_JOB_ID"]) | ||||||
|  |     os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"]) | ||||||
|  |     sys.exit(-1) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def term_handler(signum, frame): | ||||||
|  |     print("bypassing sigterm", flush=True) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def fetch_optimizer(args, model): | ||||||
|  |     """Create the optimizer and learning rate scheduler""" | ||||||
|  |     optimizer = optim.AdamW( | ||||||
|  |         model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8 | ||||||
|  |     ) | ||||||
|  |     scheduler = optim.lr_scheduler.OneCycleLR( | ||||||
|  |         optimizer, | ||||||
|  |         args.lr, | ||||||
|  |         args.num_steps + 100, | ||||||
|  |         pct_start=0.05, | ||||||
|  |         cycle_momentum=False, | ||||||
|  |         anneal_strategy="linear", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     return optimizer, scheduler | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def forward_batch(batch, model, args, loss_fn=None, writer=None, step=0): | ||||||
|  |     rgbs = batch.video | ||||||
|  |     trajs_g = batch.trajectory | ||||||
|  |     vis_g = batch.visibility | ||||||
|  |     valids = batch.valid | ||||||
|  |     B, T, C, H, W = rgbs.shape | ||||||
|  |     assert C == 3 | ||||||
|  |     B, T, N, D = trajs_g.shape | ||||||
|  |     device = rgbs.device | ||||||
|  |  | ||||||
|  |     __, first_positive_inds = torch.max(vis_g, dim=1) | ||||||
|  |     # We want to make sure that during training the model sees visible points | ||||||
|  |     # that it does not need to track just yet: they are visible but queried from a later frame | ||||||
|  |     N_rand = N // 4 | ||||||
|  |     # inds of visible points in the 1st frame | ||||||
|  |     nonzero_inds = [torch.nonzero(vis_g[0, :, i]) for i in range(N)] | ||||||
|  |     rand_vis_inds = torch.cat( | ||||||
|  |         [ | ||||||
|  |             nonzero_row[torch.randint(len(nonzero_row), size=(1,))] | ||||||
|  |             for nonzero_row in nonzero_inds | ||||||
|  |         ], | ||||||
|  |         dim=1, | ||||||
|  |     ) | ||||||
|  |     first_positive_inds = torch.cat( | ||||||
|  |         [rand_vis_inds[:, :N_rand], first_positive_inds[:, N_rand:]], dim=1 | ||||||
|  |     ) | ||||||
|  |     ind_array_ = torch.arange(T, device=device) | ||||||
|  |     ind_array_ = ind_array_[None, :, None].repeat(B, 1, N) | ||||||
|  |     assert torch.allclose( | ||||||
|  |         vis_g[ind_array_ == first_positive_inds[:, None, :]], | ||||||
|  |         torch.ones_like(vis_g), | ||||||
|  |     ) | ||||||
|  |     assert torch.allclose( | ||||||
|  |         vis_g[ind_array_ == rand_vis_inds[:, None, :]], torch.ones_like(vis_g) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     gather = torch.gather( | ||||||
|  |         trajs_g, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 2) | ||||||
|  |     ) | ||||||
|  |     xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1) | ||||||
|  |  | ||||||
|  |     queries = torch.cat([first_positive_inds[:, :, None], xys], dim=2) | ||||||
|  |  | ||||||
|  |     predictions, __, visibility, train_data = model( | ||||||
|  |         rgbs=rgbs, queries=queries, iters=args.train_iters, is_train=True | ||||||
|  |     ) | ||||||
|  |     vis_predictions, coord_predictions, wind_inds, sort_inds = train_data | ||||||
|  |  | ||||||
|  |     trajs_g = trajs_g[:, :, sort_inds] | ||||||
|  |     vis_g = vis_g[:, :, sort_inds] | ||||||
|  |     valids = valids[:, :, sort_inds] | ||||||
|  |  | ||||||
|  |     vis_gts = [] | ||||||
|  |     traj_gts = [] | ||||||
|  |     valids_gts = [] | ||||||
|  |  | ||||||
|  |     for i, wind_idx in enumerate(wind_inds): | ||||||
|  |         ind = i * (args.sliding_window_len // 2) | ||||||
|  |  | ||||||
|  |         vis_gts.append(vis_g[:, ind : ind + args.sliding_window_len, :wind_idx]) | ||||||
|  |         traj_gts.append(trajs_g[:, ind : ind + args.sliding_window_len, :wind_idx]) | ||||||
|  |         valids_gts.append(valids[:, ind : ind + args.sliding_window_len, :wind_idx]) | ||||||
|  |  | ||||||
|  |     seq_loss = sequence_loss(coord_predictions, traj_gts, vis_gts, valids_gts, 0.8) | ||||||
|  |     vis_loss = balanced_ce_loss(vis_predictions, vis_gts, valids_gts) | ||||||
|  |  | ||||||
|  |     output = {"flow": {"predictions": predictions[0].detach()}} | ||||||
|  |     output["flow"]["loss"] = seq_loss.mean() | ||||||
|  |     output["visibility"] = { | ||||||
|  |         "loss": vis_loss.mean() * 10.0, | ||||||
|  |         "predictions": visibility[0].detach(), | ||||||
|  |     } | ||||||
|  |     return output | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def run_test_eval(evaluator, model, dataloaders, writer, step): | ||||||
|  |     model.eval() | ||||||
|  |     for ds_name, dataloader in dataloaders: | ||||||
|  |         predictor = EvaluationPredictor( | ||||||
|  |             model.module.module, | ||||||
|  |             grid_size=6, | ||||||
|  |             local_grid_size=0, | ||||||
|  |             single_point=False, | ||||||
|  |             n_iters=6, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         metrics = evaluator.evaluate_sequence( | ||||||
|  |             model=predictor, | ||||||
|  |             test_dataloader=dataloader, | ||||||
|  |             dataset_name=ds_name, | ||||||
|  |             train_mode=True, | ||||||
|  |             writer=writer, | ||||||
|  |             step=step, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         if ds_name == "badja" or ds_name == "fastcapture" or ("kubric" in ds_name): | ||||||
|  |  | ||||||
|  |             metrics = { | ||||||
|  |                 **{ | ||||||
|  |                     f"{ds_name}_avg": np.mean( | ||||||
|  |                         [v for k, v in metrics.items() if "accuracy" not in k] | ||||||
|  |                     ) | ||||||
|  |                 }, | ||||||
|  |                 **{ | ||||||
|  |                     f"{ds_name}_avg_accuracy": np.mean( | ||||||
|  |                         [v for k, v in metrics.items() if "accuracy" in k] | ||||||
|  |                     ) | ||||||
|  |                 }, | ||||||
|  |             } | ||||||
|  |             print("avg", np.mean([v for v in metrics.values()])) | ||||||
|  |  | ||||||
|  |         if "tapvid" in ds_name: | ||||||
|  |             metrics = { | ||||||
|  |                 f"{ds_name}_avg_OA": metrics["avg"]["occlusion_accuracy"] * 100, | ||||||
|  |                 f"{ds_name}_avg_delta": metrics["avg"]["average_pts_within_thresh"] | ||||||
|  |                 * 100, | ||||||
|  |                 f"{ds_name}_avg_Jaccard": metrics["avg"]["average_jaccard"] * 100, | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |         writer.add_scalars(f"Eval", metrics, step) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Logger: | ||||||
|  |  | ||||||
|  |     SUM_FREQ = 100 | ||||||
|  |  | ||||||
|  |     def __init__(self, model, scheduler): | ||||||
|  |         self.model = model | ||||||
|  |         self.scheduler = scheduler | ||||||
|  |         self.total_steps = 0 | ||||||
|  |         self.running_loss = {} | ||||||
|  |         self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs")) | ||||||
|  |  | ||||||
|  |     def _print_training_status(self): | ||||||
|  |         metrics_data = [ | ||||||
|  |             self.running_loss[k] / Logger.SUM_FREQ | ||||||
|  |             for k in sorted(self.running_loss.keys()) | ||||||
|  |         ] | ||||||
|  |         training_str = "[{:6d}] ".format(self.total_steps + 1) | ||||||
|  |         metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data) | ||||||
|  |  | ||||||
|  |         # print the training status | ||||||
|  |         logging.info( | ||||||
|  |             f"Training Metrics ({self.total_steps}): {training_str + metrics_str}" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         if self.writer is None: | ||||||
|  |             self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs")) | ||||||
|  |  | ||||||
|  |         for k in self.running_loss: | ||||||
|  |             self.writer.add_scalar( | ||||||
|  |                 k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps | ||||||
|  |             ) | ||||||
|  |             self.running_loss[k] = 0.0 | ||||||
|  |  | ||||||
|  |     def push(self, metrics, task): | ||||||
|  |         self.total_steps += 1 | ||||||
|  |  | ||||||
|  |         for key in metrics: | ||||||
|  |             task_key = str(key) + "_" + task | ||||||
|  |             if task_key not in self.running_loss: | ||||||
|  |                 self.running_loss[task_key] = 0.0 | ||||||
|  |  | ||||||
|  |             self.running_loss[task_key] += metrics[key] | ||||||
|  |  | ||||||
|  |         if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ - 1: | ||||||
|  |             self._print_training_status() | ||||||
|  |             self.running_loss = {} | ||||||
|  |  | ||||||
|  |     def write_dict(self, results): | ||||||
|  |         if self.writer is None: | ||||||
|  |             self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs")) | ||||||
|  |  | ||||||
|  |         for key in results: | ||||||
|  |             self.writer.add_scalar(key, results[key], self.total_steps) | ||||||
|  |  | ||||||
|  |     def close(self): | ||||||
|  |         self.writer.close() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Lite(LightningLite): | ||||||
|  |     def run(self, args): | ||||||
|  |         def seed_everything(seed: int): | ||||||
|  |             random.seed(seed) | ||||||
|  |             os.environ["PYTHONHASHSEED"] = str(seed) | ||||||
|  |             np.random.seed(seed) | ||||||
|  |             torch.manual_seed(seed) | ||||||
|  |             torch.cuda.manual_seed(seed) | ||||||
|  |             torch.backends.cudnn.deterministic = True | ||||||
|  |             torch.backends.cudnn.benchmark = False | ||||||
|  |  | ||||||
|  |         seed_everything(0) | ||||||
|  |  | ||||||
|  |         def seed_worker(worker_id): | ||||||
|  |             worker_seed = torch.initial_seed() % 2 ** 32 | ||||||
|  |             np.random.seed(worker_seed) | ||||||
|  |             random.seed(worker_seed) | ||||||
|  |  | ||||||
|  |         g = torch.Generator() | ||||||
|  |         g.manual_seed(0) | ||||||
|  |  | ||||||
|  |         eval_dataloaders = [] | ||||||
|  |         if "badja" in args.eval_datasets: | ||||||
|  |             eval_dataset = BadjaDataset( | ||||||
|  |                 data_root=os.path.join(args.dataset_root, "BADJA"), | ||||||
|  |                 max_seq_len=args.eval_max_seq_len, | ||||||
|  |                 dataset_resolution=args.crop_size, | ||||||
|  |             ) | ||||||
|  |             eval_dataloader_badja = torch.utils.data.DataLoader( | ||||||
|  |                 eval_dataset, | ||||||
|  |                 batch_size=1, | ||||||
|  |                 shuffle=False, | ||||||
|  |                 num_workers=8, | ||||||
|  |                 collate_fn=collate_fn, | ||||||
|  |             ) | ||||||
|  |             eval_dataloaders.append(("badja", eval_dataloader_badja)) | ||||||
|  |  | ||||||
|  |         if "fastcapture" in args.eval_datasets: | ||||||
|  |             eval_dataset = FastCaptureDataset( | ||||||
|  |                 data_root=os.path.join(args.dataset_root, "fastcapture"), | ||||||
|  |                 max_seq_len=min(100, args.eval_max_seq_len), | ||||||
|  |                 max_num_points=40, | ||||||
|  |                 dataset_resolution=args.crop_size, | ||||||
|  |             ) | ||||||
|  |             eval_dataloader_fastcapture = torch.utils.data.DataLoader( | ||||||
|  |                 eval_dataset, | ||||||
|  |                 batch_size=1, | ||||||
|  |                 shuffle=False, | ||||||
|  |                 num_workers=1, | ||||||
|  |                 collate_fn=collate_fn, | ||||||
|  |             ) | ||||||
|  |             eval_dataloaders.append(("fastcapture", eval_dataloader_fastcapture)) | ||||||
|  |  | ||||||
|  |         if "tapvid_davis_first" in args.eval_datasets: | ||||||
|  |             data_root = os.path.join( | ||||||
|  |                 args.dataset_root, "/tapvid_davis/tapvid_davis.pkl" | ||||||
|  |             ) | ||||||
|  |             eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root) | ||||||
|  |             eval_dataloader_tapvid_davis = torch.utils.data.DataLoader( | ||||||
|  |                 eval_dataset, | ||||||
|  |                 batch_size=1, | ||||||
|  |                 shuffle=False, | ||||||
|  |                 num_workers=1, | ||||||
|  |                 collate_fn=collate_fn, | ||||||
|  |             ) | ||||||
|  |             eval_dataloaders.append(("tapvid_davis", eval_dataloader_tapvid_davis)) | ||||||
|  |  | ||||||
|  |         evaluator = Evaluator(args.ckpt_path) | ||||||
|  |  | ||||||
|  |         visualizer = Visualizer( | ||||||
|  |             save_dir=args.ckpt_path, | ||||||
|  |             pad_value=80, | ||||||
|  |             fps=1, | ||||||
|  |             show_first_frame=0, | ||||||
|  |             tracks_leave_trace=0, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         loss_fn = None | ||||||
|  |  | ||||||
|  |         if args.model_name == "cotracker": | ||||||
|  |  | ||||||
|  |             model = CoTracker( | ||||||
|  |                 stride=args.model_stride, | ||||||
|  |                 S=args.sliding_window_len, | ||||||
|  |                 add_space_attn=not args.remove_space_attn, | ||||||
|  |                 num_heads=args.updateformer_num_heads, | ||||||
|  |                 hidden_size=args.updateformer_hidden_size, | ||||||
|  |                 space_depth=args.updateformer_space_depth, | ||||||
|  |                 time_depth=args.updateformer_time_depth, | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             raise ValueError(f"Model {args.model_name} doesn't exist") | ||||||
|  |  | ||||||
|  |         with open(args.ckpt_path + "/meta.json", "w") as file: | ||||||
|  |             json.dump(vars(args), file, sort_keys=True, indent=4) | ||||||
|  |  | ||||||
|  |         model.cuda() | ||||||
|  |  | ||||||
|  |         train_dataset = kubric_movif_dataset.KubricMovifDataset( | ||||||
|  |             data_root=os.path.join(args.dataset_root, "kubric_movi_f"), | ||||||
|  |             crop_size=args.crop_size, | ||||||
|  |             seq_len=args.sequence_len, | ||||||
|  |             traj_per_sample=args.traj_per_sample, | ||||||
|  |             sample_vis_1st_frame=args.sample_vis_1st_frame, | ||||||
|  |             use_augs=not args.dont_use_augs, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         train_loader = DataLoader( | ||||||
|  |             train_dataset, | ||||||
|  |             batch_size=args.batch_size, | ||||||
|  |             shuffle=True, | ||||||
|  |             num_workers=args.num_workers, | ||||||
|  |             worker_init_fn=seed_worker, | ||||||
|  |             generator=g, | ||||||
|  |             pin_memory=True, | ||||||
|  |             collate_fn=collate_fn_train, | ||||||
|  |             drop_last=True, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         train_loader = self.setup_dataloaders(train_loader, move_to_device=False) | ||||||
|  |         print("LEN TRAIN LOADER", len(train_loader)) | ||||||
|  |         optimizer, scheduler = fetch_optimizer(args, model) | ||||||
|  |  | ||||||
|  |         total_steps = 0 | ||||||
|  |         logger = Logger(model, scheduler) | ||||||
|  |  | ||||||
|  |         folder_ckpts = [ | ||||||
|  |             f | ||||||
|  |             for f in os.listdir(args.ckpt_path) | ||||||
|  |             if not os.path.isdir(f) and f.endswith(".pth") and not "final" in f | ||||||
|  |         ] | ||||||
|  |         if len(folder_ckpts) > 0: | ||||||
|  |             ckpt_path = sorted(folder_ckpts)[-1] | ||||||
|  |             ckpt = self.load(os.path.join(args.ckpt_path, ckpt_path)) | ||||||
|  |             logging.info(f"Loading checkpoint {ckpt_path}") | ||||||
|  |             if "model" in ckpt: | ||||||
|  |                 model.load_state_dict(ckpt["model"]) | ||||||
|  |             else: | ||||||
|  |                 model.load_state_dict(ckpt) | ||||||
|  |             if "optimizer" in ckpt: | ||||||
|  |                 logging.info("Load optimizer") | ||||||
|  |                 optimizer.load_state_dict(ckpt["optimizer"]) | ||||||
|  |             if "scheduler" in ckpt: | ||||||
|  |                 logging.info("Load scheduler") | ||||||
|  |                 scheduler.load_state_dict(ckpt["scheduler"]) | ||||||
|  |             if "total_steps" in ckpt: | ||||||
|  |                 total_steps = ckpt["total_steps"] | ||||||
|  |                 logging.info(f"Load total_steps {total_steps}") | ||||||
|  |  | ||||||
|  |         elif args.restore_ckpt is not None: | ||||||
|  |             assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith( | ||||||
|  |                 ".pt" | ||||||
|  |             ) | ||||||
|  |             logging.info("Loading checkpoint...") | ||||||
|  |  | ||||||
|  |             strict = True | ||||||
|  |             state_dict = self.load(args.restore_ckpt) | ||||||
|  |             if "model" in state_dict: | ||||||
|  |                 state_dict = state_dict["model"] | ||||||
|  |  | ||||||
|  |             if list(state_dict.keys())[0].startswith("module."): | ||||||
|  |                 state_dict = { | ||||||
|  |                     k.replace("module.", ""): v for k, v in state_dict.items() | ||||||
|  |                 } | ||||||
|  |             model.load_state_dict(state_dict, strict=strict) | ||||||
|  |  | ||||||
|  |             logging.info(f"Done loading checkpoint") | ||||||
|  |         model, optimizer = self.setup(model, optimizer, move_to_device=False) | ||||||
|  |         # model.cuda() | ||||||
|  |         model.train() | ||||||
|  |  | ||||||
|  |         save_freq = args.save_freq | ||||||
|  |         scaler = GradScaler(enabled=args.mixed_precision) | ||||||
|  |  | ||||||
|  |         should_keep_training = True | ||||||
|  |         global_batch_num = 0 | ||||||
|  |         epoch = -1 | ||||||
|  |  | ||||||
|  |         while should_keep_training: | ||||||
|  |             epoch += 1 | ||||||
|  |             for i_batch, batch in enumerate(tqdm(train_loader)): | ||||||
|  |                 batch, gotit = batch | ||||||
|  |                 if not all(gotit): | ||||||
|  |                     print("batch is None") | ||||||
|  |                     continue | ||||||
|  |                 dataclass_to_cuda_(batch) | ||||||
|  |  | ||||||
|  |                 optimizer.zero_grad() | ||||||
|  |  | ||||||
|  |                 assert model.training | ||||||
|  |  | ||||||
|  |                 output = forward_batch( | ||||||
|  |                     batch, | ||||||
|  |                     model, | ||||||
|  |                     args, | ||||||
|  |                     loss_fn=loss_fn, | ||||||
|  |                     writer=logger.writer, | ||||||
|  |                     step=total_steps, | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |                 loss = 0 | ||||||
|  |                 for k, v in output.items(): | ||||||
|  |                     if "loss" in v: | ||||||
|  |                         loss += v["loss"] | ||||||
|  |                         logger.writer.add_scalar( | ||||||
|  |                             f"live_{k}_loss", v["loss"].item(), total_steps | ||||||
|  |                         ) | ||||||
|  |                     if "metrics" in v: | ||||||
|  |                         logger.push(v["metrics"], k) | ||||||
|  |  | ||||||
|  |                 if self.global_rank == 0: | ||||||
|  |                     if total_steps % save_freq == save_freq - 1: | ||||||
|  |                         if args.model_name == "motion_diffuser": | ||||||
|  |                             pred_coords = model.module.module.forward_batch_test( | ||||||
|  |                                 batch, interp_shape=args.crop_size | ||||||
|  |                             ) | ||||||
|  |  | ||||||
|  |                             output["flow"] = {"predictions": pred_coords[0].detach()} | ||||||
|  |                         visualizer.visualize( | ||||||
|  |                             video=batch.video.clone(), | ||||||
|  |                             tracks=batch.trajectory.clone(), | ||||||
|  |                             filename="train_gt_traj", | ||||||
|  |                             writer=logger.writer, | ||||||
|  |                             step=total_steps, | ||||||
|  |                         ) | ||||||
|  |  | ||||||
|  |                         visualizer.visualize( | ||||||
|  |                             video=batch.video.clone(), | ||||||
|  |                             tracks=output["flow"]["predictions"][None], | ||||||
|  |                             filename="train_pred_traj", | ||||||
|  |                             writer=logger.writer, | ||||||
|  |                             step=total_steps, | ||||||
|  |                         ) | ||||||
|  |  | ||||||
|  |                     if len(output) > 1: | ||||||
|  |                         logger.writer.add_scalar( | ||||||
|  |                             f"live_total_loss", loss.item(), total_steps | ||||||
|  |                         ) | ||||||
|  |                     logger.writer.add_scalar( | ||||||
|  |                         f"learning_rate", optimizer.param_groups[0]["lr"], total_steps | ||||||
|  |                     ) | ||||||
|  |                     global_batch_num += 1 | ||||||
|  |  | ||||||
|  |                 self.barrier() | ||||||
|  |  | ||||||
|  |                 self.backward(scaler.scale(loss)) | ||||||
|  |  | ||||||
|  |                 scaler.unscale_(optimizer) | ||||||
|  |                 torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0) | ||||||
|  |  | ||||||
|  |                 scaler.step(optimizer) | ||||||
|  |                 scheduler.step() | ||||||
|  |                 scaler.update() | ||||||
|  |                 total_steps += 1 | ||||||
|  |                 if self.global_rank == 0: | ||||||
|  |                     if (i_batch >= len(train_loader) - 1) or ( | ||||||
|  |                         total_steps == 1 and args.validate_at_start | ||||||
|  |                     ): | ||||||
|  |                         if (epoch + 1) % args.save_every_n_epoch == 0: | ||||||
|  |                             ckpt_iter = "0" * (6 - len(str(total_steps))) + str( | ||||||
|  |                                 total_steps | ||||||
|  |                             ) | ||||||
|  |                             save_path = Path( | ||||||
|  |                                 f"{args.ckpt_path}/model_{args.model_name}_{ckpt_iter}.pth" | ||||||
|  |                             ) | ||||||
|  |  | ||||||
|  |                             save_dict = { | ||||||
|  |                                 "model": model.module.module.state_dict(), | ||||||
|  |                                 "optimizer": optimizer.state_dict(), | ||||||
|  |                                 "scheduler": scheduler.state_dict(), | ||||||
|  |                                 "total_steps": total_steps, | ||||||
|  |                             } | ||||||
|  |  | ||||||
|  |                             logging.info(f"Saving file {save_path}") | ||||||
|  |                             self.save(save_dict, save_path) | ||||||
|  |  | ||||||
|  |                         if (epoch + 1) % args.evaluate_every_n_epoch == 0 or ( | ||||||
|  |                             args.validate_at_start and epoch == 0 | ||||||
|  |                         ): | ||||||
|  |                             run_test_eval( | ||||||
|  |                                 evaluator, | ||||||
|  |                                 model, | ||||||
|  |                                 eval_dataloaders, | ||||||
|  |                                 logger.writer, | ||||||
|  |                                 total_steps, | ||||||
|  |                             ) | ||||||
|  |                             model.train() | ||||||
|  |                             torch.cuda.empty_cache() | ||||||
|  |  | ||||||
|  |                 self.barrier() | ||||||
|  |                 if total_steps > args.num_steps: | ||||||
|  |                     should_keep_training = False | ||||||
|  |                     break | ||||||
|  |  | ||||||
|  |         print("FINISHED TRAINING") | ||||||
|  |  | ||||||
|  |         PATH = f"{args.ckpt_path}/{args.model_name}_final.pth" | ||||||
|  |         torch.save(model.module.module.state_dict(), PATH) | ||||||
|  |         run_test_eval(evaluator, model, eval_dataloaders, logger.writer, total_steps) | ||||||
|  |         logger.close() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     signal.signal(signal.SIGUSR1, sig_handler) | ||||||
|  |     signal.signal(signal.SIGTERM, term_handler) | ||||||
|  |     parser = argparse.ArgumentParser() | ||||||
|  |     parser.add_argument("--model_name", default="cotracker", help="model name") | ||||||
|  |     parser.add_argument("--restore_ckpt", help="restore checkpoint") | ||||||
|  |     parser.add_argument("--ckpt_path", help="restore checkpoint") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--batch_size", type=int, default=4, help="batch size used during training." | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--num_workers", type=int, default=6, help="left right consistency loss" | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--mixed_precision", action="store_true", help="use mixed precision" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument("--lr", type=float, default=0.0005, help="max learning rate.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--wdecay", type=float, default=0.00001, help="Weight decay in optimizer." | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--num_steps", type=int, default=200000, help="length of training schedule." | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--evaluate_every_n_epoch", | ||||||
|  |         type=int, | ||||||
|  |         default=1, | ||||||
|  |         help="number of flow-field updates during validation forward pass", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--save_every_n_epoch", | ||||||
|  |         type=int, | ||||||
|  |         default=1, | ||||||
|  |         help="number of flow-field updates during validation forward pass", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--validate_at_start", action="store_true", help="use mixed precision" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument("--save_freq", type=int, default=100, help="save_freq") | ||||||
|  |     parser.add_argument("--traj_per_sample", type=int, default=768, help="save_freq") | ||||||
|  |     parser.add_argument("--dataset_root", type=str, help="path lo all the datasets") | ||||||
|  |  | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--train_iters", | ||||||
|  |         type=int, | ||||||
|  |         default=4, | ||||||
|  |         help="number of updates to the disparity field in each forward pass.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--sequence_len", type=int, default=8, help="train sequence length" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--eval_datasets", | ||||||
|  |         nargs="+", | ||||||
|  |         default=["things", "badja", "fastcapture"], | ||||||
|  |         help="eval datasets.", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--remove_space_attn", action="store_true", help="use mixed precision" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--dont_use_augs", action="store_true", help="use mixed precision" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--sample_vis_1st_frame", action="store_true", help="use mixed precision" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--sliding_window_len", type=int, default=8, help="use mixed precision" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--updateformer_hidden_size", type=int, default=384, help="use mixed precision" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--updateformer_num_heads", type=int, default=8, help="use mixed precision" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--updateformer_space_depth", type=int, default=12, help="use mixed precision" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--updateformer_time_depth", type=int, default=12, help="use mixed precision" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--model_stride", type=int, default=8, help="use mixed precision" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--crop_size", | ||||||
|  |         type=int, | ||||||
|  |         nargs="+", | ||||||
|  |         default=[384, 512], | ||||||
|  |         help="use mixed precision", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--eval_max_seq_len", type=int, default=1000, help="use mixed precision" | ||||||
|  |     ) | ||||||
|  |     args = parser.parse_args() | ||||||
|  |  | ||||||
|  |     logging.basicConfig( | ||||||
|  |         level=logging.INFO, | ||||||
|  |         format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     Path(args.ckpt_path).mkdir(exist_ok=True, parents=True) | ||||||
|  |     from pytorch_lightning.strategies import DDPStrategy | ||||||
|  |  | ||||||
|  |     Lite( | ||||||
|  |         strategy=DDPStrategy(find_unused_parameters=True), | ||||||
|  |         devices="auto", | ||||||
|  |         accelerator="gpu", | ||||||
|  |         precision=32, | ||||||
|  |         num_nodes=4, | ||||||
|  |     ).run(args) | ||||||
		Reference in New Issue
	
	Block a user