├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── datasets ├── augmentor.py ├── dynamic_stereo_datasets.py └── frame_utils.py ├── evaluation ├── configs │ ├── eval_dynamic_replica_150_frames.yaml │ ├── eval_dynamic_replica_40_frames.yaml │ ├── eval_real_data.yaml │ ├── eval_sintel_clean.yaml │ └── eval_sintel_final.yaml ├── core │ └── evaluator.py ├── evaluate.py └── utils │ ├── eval_utils.py │ └── utils.py ├── models ├── core │ ├── attention.py │ ├── corr.py │ ├── dynamic_stereo.py │ ├── extractor.py │ ├── model_zoo.py │ ├── update.py │ └── utils │ │ ├── config.py │ │ └── utils.py ├── dynamic_stereo_model.py └── raft_stereo_model.py ├── notebooks └── Dynamic_Replica_demo.ipynb ├── requirements.txt ├── scripts ├── checksum_check.py ├── download_dynamic_replica.py ├── download_utils.py └── dr_sha256.json ├── train.py └── train_utils ├── logger.py ├── losses.py └── utils.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Dynamic Stereo 2 | We want to make contributing to this project as easy and transparent as possible. 3 | 4 | ## Pull Requests 5 | We actively welcome your pull requests. 6 | 7 | 1. Fork the repo and create your branch from `main`. 8 | 2. If you've changed APIs, update the documentation. 9 | 3. Make sure your code lints. 10 | 4. If you haven't already, complete the Contributor License Agreement ("CLA"). 11 | 12 | ## Contributor License Agreement ("CLA") 13 | In order to accept your pull request, we need you to submit a CLA. You only need 14 | to do this once to work on any of Meta's open source projects. 15 | 16 | Complete your CLA here: 17 | 18 | ## Issues 19 | We use GitHub issues to track public bugs. Please ensure your description is 20 | clear and has sufficient instructions to be able to reproduce the issue. 21 | 22 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 23 | disclosure of security bugs. In those cases, please go through the process 24 | outlined on that page and do not file a public issue. 25 | 26 | ## Coding Style 27 | * all files are processed with the `black` auto-formatter before pushing, e.g. 28 | ``` 29 | python -m black eval_demo.py 30 | ``` 31 | 32 | ## License 33 | By contributing to DynamicStereo, you agree that your contributions will be licensed 34 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [CVPR 2023] DynamicStereo: Consistent Dynamic Depth from Stereo Videos. 2 | 3 | **[Meta AI Research, FAIR](https://ai.facebook.com/research/)**; **[University of Oxford, VGG](https://www.robots.ox.ac.uk/~vgg/)** 4 | 5 | [Nikita Karaev](https://nikitakaraevv.github.io/), [Ignacio Rocco](https://www.irocco.info/), [Benjamin Graham](https://ai.facebook.com/people/benjamin-graham/), [Natalia Neverova](https://nneverova.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/) 6 | 7 | [[`Paper`](https://research.facebook.com/publications/dynamicstereo-consistent-dynamic-depth-from-stereo-videos/)] [[`Project`](https://dynamic-stereo.github.io/)] [[`BibTeX`](#citing-dynamicstereo)] 8 | 9 | ![nikita-reading](https://user-images.githubusercontent.com/37815420/236242052-e72d5605-1ab2-426c-ae8d-5c8a86d5252c.gif) 10 | 11 | **DynamicStereo** is a transformer-based architecture for temporally consistent depth estimation from stereo videos. It has been trained on a combination of two datasets: [SceneFlow](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) and **Dynamic Replica** that we present below. 12 | 13 | ## Dataset 14 | 15 | https://user-images.githubusercontent.com/37815420/236239579-7877623c-716b-4074-a14e-944d095f1419.mp4 16 | 17 | The dataset consists of 145200 *stereo* frames (524 videos) with humans and animals in motion. 18 | 19 | We provide annotations for both *left and right* views, see [this notebook](https://github.com/facebookresearch/dynamic_stereo/blob/main/notebooks/Dynamic_Replica_demo.ipynb): 20 | - camera intrinsics and extrinsics 21 | - image depth (can be converted to disparity with intrinsics) 22 | - instance segmentation masks 23 | - binary foreground / background segmentation masks 24 | - optical flow (released!) 25 | - long-range pixel trajectories (released!) 26 | 27 | 28 | ### Download the Dynamic Replica dataset 29 | Download `links.json` from the *data* tab on the [project website](https://dynamic-stereo.github.io/) after accepting the license agreement. 30 | ``` 31 | git clone https://github.com/facebookresearch/dynamic_stereo 32 | cd dynamic_stereo 33 | export PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH 34 | ``` 35 | Add the downloaded `links.json` file to the project folder. Use flag `download_splits` to choose dataset splits that you want to download: 36 | ``` 37 | python ./scripts/download_dynamic_replica.py --link_list_file links.json \ 38 | --download_folder ./dynamic_replica_data --download_splits real valid test train 39 | ``` 40 | 41 | Memory requirements for dataset splits after unpacking (with all the annotations): 42 | - train - 1.8T 43 | - test - 328G 44 | - valid - 106G 45 | - real - 152M 46 | 47 | You can use [this PyTorch dataset class](https://github.com/facebookresearch/dynamic_stereo/blob/dfe2907faf41b810e6bb0c146777d81cb48cb4f5/datasets/dynamic_stereo_datasets.py#L287) to iterate over the dataset. 48 | 49 | ## Installation 50 | 51 | Describes installation of DynamicStereo with the latest PyTorch3D, PyTorch 1.12.1 & cuda 11.3 52 | 53 | ### Setup the root for all source files: 54 | ``` 55 | git clone https://github.com/facebookresearch/dynamic_stereo 56 | cd dynamic_stereo 57 | export PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH 58 | ``` 59 | ### Create a conda env: 60 | ``` 61 | conda create -n dynamicstereo python=3.8 62 | conda activate dynamicstereo 63 | ``` 64 | ### Install requirements 65 | ``` 66 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch 67 | # It will require some time to install PyTorch3D. In the meantime, you may want to take a break and enjoy a cup of coffee. 68 | pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable" 69 | pip install -r requirements.txt 70 | ``` 71 | 72 | ### (Optional) Install RAFT-Stereo 73 | ``` 74 | mkdir third_party 75 | cd third_party 76 | git clone https://github.com/princeton-vl/RAFT-Stereo 77 | cd RAFT-Stereo 78 | bash download_models.sh 79 | cd ../.. 80 | ``` 81 | 82 | 83 | 84 | ## Evaluation 85 | To download the checkpoints, you can follow the below instructions: 86 | ``` 87 | mkdir checkpoints 88 | cd checkpoints 89 | wget https://dl.fbaipublicfiles.com/dynamic_replica_v1/dynamic_stereo_sf.pth 90 | wget https://dl.fbaipublicfiles.com/dynamic_replica_v1/dynamic_stereo_dr_sf.pth 91 | cd .. 92 | ``` 93 | You can also download the checkpoints manually by clicking the links below. Copy the checkpoints to `./dynamic_stereo/checkpoints`. 94 | 95 | - [DynamicStereo](https://dl.fbaipublicfiles.com/dynamic_replica_v1/dynamic_stereo_sf.pth) trained on SceneFlow 96 | - [DynamicStereo](https://dl.fbaipublicfiles.com/dynamic_replica_v1/dynamic_stereo_dr_sf.pth) trained on SceneFlow and *Dynamic Replica* 97 | 98 | To evaluate DynamicStereo: 99 | ``` 100 | python ./evaluation/evaluate.py --config-name eval_dynamic_replica_40_frames \ 101 | MODEL.model_name=DynamicStereoModel exp_dir=./outputs/test_dynamic_replica_ds \ 102 | MODEL.DynamicStereoModel.model_weights=./checkpoints/dynamic_stereo_sf.pth 103 | ``` 104 | Due to the high image resolution, evaluation on *Dynamic Replica* requires a 32GB GPU. If you don't have enough GPU memory, you can decrease `kernel_size` from 20 to 10 by adding `MODEL.DynamicStereoModel.kernel_size=10` to the above python command. Another option is to decrease the dataset resolution. 105 | 106 | As a result, you should see the numbers from *Table 5* in the [paper](https://arxiv.org/pdf/2305.02296.pdf). (for this, you need `kernel_size=20`) 107 | 108 | Reconstructions of all the *Dynamic Replica* splits (including *real*) will be visualized and saved to `exp_dir`. 109 | 110 | If you installed [RAFT-Stereo](https://github.com/princeton-vl/RAFT-Stereo), you can run: 111 | ``` 112 | python ./evaluation/evaluate.py --config-name eval_dynamic_replica_40_frames \ 113 | MODEL.model_name=RAFTStereoModel exp_dir=./outputs/test_dynamic_replica_raft 114 | ``` 115 | 116 | Other public datasets we use: 117 | - [SceneFlow](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 118 | - [Sintel](http://sintel.is.tue.mpg.de/stereo) 119 | - [Middlebury](https://vision.middlebury.edu/stereo/data/) 120 | - [ETH3D](https://www.eth3d.net/datasets#low-res-two-view-training-data) 121 | - [KITTI 2015](http://www.cvlibs.net/datasets/kitti/eval_stereo.php) 122 | 123 | ## Training 124 | Training requires a 32GB GPU. You can decrease `image_size` and / or `sample_len` if you don't have enough GPU memory. 125 | You need to donwload SceneFlow before training. Alternatively, you can only train on *Dynamic Replica*. 126 | ``` 127 | python train.py --batch_size 1 \ 128 | --spatial_scale -0.2 0.4 --image_size 384 512 --saturation_range 0 1.4 --num_steps 200000 \ 129 | --ckpt_path dynamicstereo_sf_dr \ 130 | --sample_len 5 --lr 0.0003 --train_iters 10 --valid_iters 20 \ 131 | --num_workers 28 --save_freq 100 --update_block_3d --different_update_blocks \ 132 | --attention_type self_stereo_temporal_update_time_update_space --train_datasets dynamic_replica things monkaa driving 133 | ``` 134 | If you want to train on SceneFlow only, remove the flag `dynamic_replica` from `train_datasets`. 135 | 136 | 137 | 138 | ## License 139 | The majority of dynamic_stereo is licensed under CC-BY-NC, however portions of the project are available under separate license terms: [RAFT-Stereo](https://github.com/princeton-vl/RAFT-Stereo) is licensed under the MIT license, [LoFTR](https://github.com/zju3dv/LoFTR) and [CREStereo](https://github.com/megvii-research/CREStereo) are licensed under the Apache 2.0 license. 140 | 141 | 142 | ## Citing DynamicStereo 143 | If you use DynamicStereo or Dynamic Replica in your research, please use the following BibTeX entry. 144 | ``` 145 | @article{karaev2023dynamicstereo, 146 | title={DynamicStereo: Consistent Dynamic Depth from Stereo Videos}, 147 | author={Nikita Karaev and Ignacio Rocco and Benjamin Graham and Natalia Neverova and Andrea Vedaldi and Christian Rupprecht}, 148 | journal={CVPR}, 149 | year={2023} 150 | } 151 | ``` 152 | -------------------------------------------------------------------------------- /datasets/augmentor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import random 9 | from PIL import Image 10 | 11 | import cv2 12 | 13 | cv2.setNumThreads(0) 14 | cv2.ocl.setUseOpenCL(False) 15 | 16 | from torchvision.transforms import ColorJitter, functional, Compose 17 | 18 | 19 | class AdjustGamma(object): 20 | def __init__(self, gamma_min, gamma_max, gain_min=1.0, gain_max=1.0): 21 | self.gamma_min, self.gamma_max, self.gain_min, self.gain_max = ( 22 | gamma_min, 23 | gamma_max, 24 | gain_min, 25 | gain_max, 26 | ) 27 | 28 | def __call__(self, sample): 29 | gain = random.uniform(self.gain_min, self.gain_max) 30 | gamma = random.uniform(self.gamma_min, self.gamma_max) 31 | return functional.adjust_gamma(sample, gamma, gain) 32 | 33 | def __repr__(self): 34 | return f"Adjust Gamma {self.gamma_min}, ({self.gamma_max}) and Gain ({self.gain_min}, {self.gain_max})" 35 | 36 | 37 | class SequenceDispFlowAugmentor: 38 | def __init__( 39 | self, 40 | crop_size, 41 | min_scale=-0.2, 42 | max_scale=0.5, 43 | do_flip=True, 44 | yjitter=False, 45 | saturation_range=[0.6, 1.4], 46 | gamma=[1, 1, 1, 1], 47 | ): 48 | # spatial augmentation params 49 | self.crop_size = crop_size 50 | self.min_scale = min_scale 51 | self.max_scale = max_scale 52 | self.spatial_aug_prob = 1.0 53 | self.stretch_prob = 0.8 54 | self.max_stretch = 0.2 55 | 56 | # flip augmentation params 57 | self.yjitter = yjitter 58 | self.do_flip = do_flip 59 | self.h_flip_prob = 0.5 60 | self.v_flip_prob = 0.1 61 | 62 | # photometric augmentation params 63 | self.photo_aug = Compose( 64 | [ 65 | ColorJitter( 66 | brightness=0.4, 67 | contrast=0.4, 68 | saturation=saturation_range, 69 | hue=0.5 / 3.14, 70 | ), 71 | AdjustGamma(*gamma), 72 | ] 73 | ) 74 | self.asymmetric_color_aug_prob = 0.2 75 | self.eraser_aug_prob = 0.5 76 | 77 | def color_transform(self, seq): 78 | """Photometric augmentation""" 79 | 80 | # asymmetric 81 | if np.random.rand() < self.asymmetric_color_aug_prob: 82 | for i in range(len(seq)): 83 | for cam in (0, 1): 84 | seq[i][cam] = np.array( 85 | self.photo_aug(Image.fromarray(seq[i][cam])), dtype=np.uint8 86 | ) 87 | # symmetric 88 | else: 89 | image_stack = np.concatenate( 90 | [seq[i][cam] for i in range(len(seq)) for cam in (0, 1)], axis=0 91 | ) 92 | image_stack = np.array( 93 | self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8 94 | ) 95 | split = np.split(image_stack, len(seq) * 2, axis=0) 96 | for i in range(len(seq)): 97 | seq[i][0] = split[2 * i] 98 | seq[i][1] = split[2 * i + 1] 99 | return seq 100 | 101 | def eraser_transform(self, seq, bounds=[50, 100]): 102 | """Occlusion augmentation""" 103 | ht, wd = seq[0][0].shape[:2] 104 | for i in range(len(seq)): 105 | for cam in (0, 1): 106 | if np.random.rand() < self.eraser_aug_prob: 107 | mean_color = np.mean(seq[0][0].reshape(-1, 3), axis=0) 108 | for _ in range(np.random.randint(1, 3)): 109 | x0 = np.random.randint(0, wd) 110 | y0 = np.random.randint(0, ht) 111 | dx = np.random.randint(bounds[0], bounds[1]) 112 | dy = np.random.randint(bounds[0], bounds[1]) 113 | seq[i][cam][y0 : y0 + dy, x0 : x0 + dx, :] = mean_color 114 | 115 | return seq 116 | 117 | def spatial_transform(self, img, disp): 118 | # randomly sample scale 119 | ht, wd = img[0][0].shape[:2] 120 | min_scale = np.maximum( 121 | (self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd) 122 | ) 123 | 124 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 125 | scale_x = scale 126 | scale_y = scale 127 | if np.random.rand() < self.stretch_prob: 128 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 129 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 130 | 131 | scale_x = np.clip(scale_x, min_scale, None) 132 | scale_y = np.clip(scale_y, min_scale, None) 133 | 134 | if np.random.rand() < self.spatial_aug_prob: 135 | # rescale the images 136 | for i in range(len(img)): 137 | for cam in (0, 1): 138 | img[i][cam] = cv2.resize( 139 | img[i][cam], 140 | None, 141 | fx=scale_x, 142 | fy=scale_y, 143 | interpolation=cv2.INTER_LINEAR, 144 | ) 145 | if len(disp[i]) > 0: 146 | disp[i][cam] = cv2.resize( 147 | disp[i][cam], 148 | None, 149 | fx=scale_x, 150 | fy=scale_y, 151 | interpolation=cv2.INTER_LINEAR, 152 | ) 153 | disp[i][cam] = disp[i][cam] * [scale_x, scale_y] 154 | 155 | if self.yjitter: 156 | y0 = np.random.randint(2, img[0][0].shape[0] - self.crop_size[0] - 2) 157 | x0 = np.random.randint(2, img[0][0].shape[1] - self.crop_size[1] - 2) 158 | 159 | for i in range(len(img)): 160 | y1 = y0 + np.random.randint(-2, 2 + 1) 161 | img[i][0] = img[i][0][ 162 | y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1] 163 | ] 164 | img[i][1] = img[i][1][ 165 | y1 : y1 + self.crop_size[0], x0 : x0 + self.crop_size[1] 166 | ] 167 | if len(disp[i]) > 0: 168 | disp[i][0] = disp[i][0][ 169 | y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1] 170 | ] 171 | disp[i][1] = disp[i][1][ 172 | y1 : y1 + self.crop_size[0], x0 : x0 + self.crop_size[1] 173 | ] 174 | else: 175 | y0 = np.random.randint(0, img[0][0].shape[0] - self.crop_size[0]) 176 | x0 = np.random.randint(0, img[0][0].shape[1] - self.crop_size[1]) 177 | for i in range(len(img)): 178 | for cam in (0, 1): 179 | img[i][cam] = img[i][cam][ 180 | y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1] 181 | ] 182 | if len(disp[i]) > 0: 183 | disp[i][cam] = disp[i][cam][ 184 | y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1] 185 | ] 186 | 187 | return img, disp 188 | 189 | def __call__(self, img, disp): 190 | img = self.color_transform(img) 191 | img = self.eraser_transform(img) 192 | img, disp = self.spatial_transform(img, disp) 193 | 194 | for i in range(len(img)): 195 | for cam in (0, 1): 196 | img[i][cam] = np.ascontiguousarray(img[i][cam]) 197 | if len(disp[i]) > 0: 198 | disp[i][cam] = np.ascontiguousarray(disp[i][cam]) 199 | 200 | return img, disp 201 | -------------------------------------------------------------------------------- /datasets/dynamic_stereo_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # # Data loading based on https://github.com/NVIDIA/flownet2-pytorch 7 | 8 | 9 | import os 10 | import copy 11 | import gzip 12 | import logging 13 | import torch 14 | import numpy as np 15 | import torch.utils.data as data 16 | import torch.nn.functional as F 17 | import os.path as osp 18 | from glob import glob 19 | 20 | from collections import defaultdict 21 | from PIL import Image 22 | from dataclasses import dataclass 23 | from typing import List, Optional 24 | from pytorch3d.renderer.cameras import PerspectiveCameras 25 | from pytorch3d.implicitron.dataset.types import ( 26 | FrameAnnotation as ImplicitronFrameAnnotation, 27 | load_dataclass, 28 | ) 29 | 30 | from dynamic_stereo.datasets import frame_utils 31 | from dynamic_stereo.evaluation.utils.eval_utils import depth2disparity_scale 32 | from dynamic_stereo.datasets.augmentor import SequenceDispFlowAugmentor 33 | 34 | 35 | @dataclass 36 | class DynamicReplicaFrameAnnotation(ImplicitronFrameAnnotation): 37 | """A dataclass used to load annotations from json.""" 38 | 39 | camera_name: Optional[str] = None 40 | 41 | 42 | class StereoSequenceDataset(data.Dataset): 43 | def __init__(self, aug_params=None, sparse=False, reader=None): 44 | self.augmentor = None 45 | self.sparse = sparse 46 | self.img_pad = ( 47 | aug_params.pop("img_pad", None) if aug_params is not None else None 48 | ) 49 | if aug_params is not None and "crop_size" in aug_params: 50 | if sparse: 51 | raise ValueError("Sparse augmentor is not implemented") 52 | else: 53 | self.augmentor = SequenceDispFlowAugmentor(**aug_params) 54 | 55 | if reader is None: 56 | self.disparity_reader = frame_utils.read_gen 57 | else: 58 | self.disparity_reader = reader 59 | self.depth_reader = self._load_16big_png_depth 60 | self.is_test = False 61 | self.sample_list = [] 62 | self.extra_info = [] 63 | self.depth_eps = 1e-5 64 | 65 | def _load_16big_png_depth(self, depth_png): 66 | with Image.open(depth_png) as depth_pil: 67 | # the image is stored with 16-bit depth but PIL reads it as I (32 bit). 68 | # we cast it to uint16, then reinterpret as float16, then cast to float32 69 | depth = ( 70 | np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16) 71 | .astype(np.float32) 72 | .reshape((depth_pil.size[1], depth_pil.size[0])) 73 | ) 74 | return depth 75 | 76 | def _get_pytorch3d_camera( 77 | self, entry_viewpoint, image_size, scale: float 78 | ) -> PerspectiveCameras: 79 | assert entry_viewpoint is not None 80 | # principal point and focal length 81 | principal_point = torch.tensor( 82 | entry_viewpoint.principal_point, dtype=torch.float 83 | ) 84 | focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float) 85 | 86 | half_image_size_wh_orig = ( 87 | torch.tensor(list(reversed(image_size)), dtype=torch.float) / 2.0 88 | ) 89 | 90 | # first, we convert from the dataset's NDC convention to pixels 91 | format = entry_viewpoint.intrinsics_format 92 | if format.lower() == "ndc_norm_image_bounds": 93 | # this is e.g. currently used in CO3D for storing intrinsics 94 | rescale = half_image_size_wh_orig 95 | elif format.lower() == "ndc_isotropic": 96 | rescale = half_image_size_wh_orig.min() 97 | else: 98 | raise ValueError(f"Unknown intrinsics format: {format}") 99 | 100 | # principal point and focal length in pixels 101 | principal_point_px = half_image_size_wh_orig - principal_point * rescale 102 | focal_length_px = focal_length * rescale 103 | 104 | # now, convert from pixels to PyTorch3D v0.5+ NDC convention 105 | # if self.image_height is None or self.image_width is None: 106 | out_size = list(reversed(image_size)) 107 | 108 | half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0 109 | half_min_image_size_output = half_image_size_output.min() 110 | 111 | # rescaled principal point and focal length in ndc 112 | principal_point = ( 113 | half_image_size_output - principal_point_px * scale 114 | ) / half_min_image_size_output 115 | focal_length = focal_length_px * scale / half_min_image_size_output 116 | 117 | return PerspectiveCameras( 118 | focal_length=focal_length[None], 119 | principal_point=principal_point[None], 120 | R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None], 121 | T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None], 122 | ) 123 | 124 | def _get_output_tensor(self, sample): 125 | output_tensor = defaultdict(list) 126 | sample_size = len(sample["image"]["left"]) 127 | output_tensor_keys = ["img", "disp", "valid_disp", "mask"] 128 | add_keys = ["viewpoint", "metadata"] 129 | for add_key in add_keys: 130 | if add_key in sample: 131 | output_tensor_keys.append(add_key) 132 | 133 | for key in output_tensor_keys: 134 | output_tensor[key] = [[] for _ in range(sample_size)] 135 | 136 | if "viewpoint" in sample: 137 | viewpoint_left = self._get_pytorch3d_camera( 138 | sample["viewpoint"]["left"][0], 139 | sample["metadata"]["left"][0][1], 140 | scale=1.0, 141 | ) 142 | viewpoint_right = self._get_pytorch3d_camera( 143 | sample["viewpoint"]["right"][0], 144 | sample["metadata"]["right"][0][1], 145 | scale=1.0, 146 | ) 147 | depth2disp_scale = depth2disparity_scale( 148 | viewpoint_left, 149 | viewpoint_right, 150 | torch.Tensor(sample["metadata"]["left"][0][1])[None], 151 | ) 152 | 153 | for i in range(sample_size): 154 | for cam in ["left", "right"]: 155 | if "mask" in sample and cam in sample["mask"]: 156 | mask = frame_utils.read_gen(sample["mask"][cam][i]) 157 | mask = np.array(mask) / 255.0 158 | output_tensor["mask"][i].append(mask) 159 | 160 | if "viewpoint" in sample and cam in sample["viewpoint"]: 161 | viewpoint = self._get_pytorch3d_camera( 162 | sample["viewpoint"][cam][i], 163 | sample["metadata"][cam][i][1], 164 | scale=1.0, 165 | ) 166 | output_tensor["viewpoint"][i].append(viewpoint) 167 | 168 | if "metadata" in sample and cam in sample["metadata"]: 169 | metadata = sample["metadata"][cam][i] 170 | output_tensor["metadata"][i].append(metadata) 171 | 172 | if cam in sample["image"]: 173 | 174 | img = frame_utils.read_gen(sample["image"][cam][i]) 175 | img = np.array(img).astype(np.uint8) 176 | 177 | # grayscale images 178 | if len(img.shape) == 2: 179 | img = np.tile(img[..., None], (1, 1, 3)) 180 | else: 181 | img = img[..., :3] 182 | output_tensor["img"][i].append(img) 183 | 184 | if cam in sample["disparity"]: 185 | disp = self.disparity_reader(sample["disparity"][cam][i]) 186 | if isinstance(disp, tuple): 187 | disp, valid_disp = disp 188 | else: 189 | valid_disp = disp < 512 190 | disp = np.array(disp).astype(np.float32) 191 | 192 | disp = np.stack([-disp, np.zeros_like(disp)], axis=-1) 193 | 194 | output_tensor["disp"][i].append(disp) 195 | output_tensor["valid_disp"][i].append(valid_disp) 196 | 197 | elif "depth" in sample and cam in sample["depth"]: 198 | depth = self.depth_reader(sample["depth"][cam][i]) 199 | 200 | depth_mask = depth < self.depth_eps 201 | depth[depth_mask] = self.depth_eps 202 | 203 | disp = depth2disp_scale / depth 204 | disp[depth_mask] = 0 205 | valid_disp = (disp < 512) * (1 - depth_mask) 206 | 207 | disp = np.array(disp).astype(np.float32) 208 | disp = np.stack([-disp, np.zeros_like(disp)], axis=-1) 209 | output_tensor["disp"][i].append(disp) 210 | output_tensor["valid_disp"][i].append(valid_disp) 211 | 212 | return output_tensor 213 | 214 | def __getitem__(self, index): 215 | im_tensor = {"img"} 216 | sample = self.sample_list[index] 217 | if self.is_test: 218 | sample_size = len(sample["image"]["left"]) 219 | im_tensor["img"] = [[] for _ in range(sample_size)] 220 | for i in range(sample_size): 221 | for cam in ["left", "right"]: 222 | img = frame_utils.read_gen(sample["image"][cam][i]) 223 | img = np.array(img).astype(np.uint8)[..., :3] 224 | img = torch.from_numpy(img).permute(2, 0, 1).float() 225 | im_tensor["img"][i].append(img) 226 | im_tensor["img"] = torch.stack(im_tensor["img"]) 227 | return im_tensor, self.extra_info[index] 228 | 229 | index = index % len(self.sample_list) 230 | 231 | try: 232 | output_tensor = self._get_output_tensor(sample) 233 | except: 234 | logging.warning(f"Exception in loading sample {index}!") 235 | index = np.random.randint(len(self.sample_list)) 236 | logging.info(f"New index is {index}") 237 | sample = self.sample_list[index] 238 | output_tensor = self._get_output_tensor(sample) 239 | sample_size = len(sample["image"]["left"]) 240 | 241 | if self.augmentor is not None: 242 | output_tensor["img"], output_tensor["disp"] = self.augmentor( 243 | output_tensor["img"], output_tensor["disp"] 244 | ) 245 | for i in range(sample_size): 246 | for cam in (0, 1): 247 | if cam < len(output_tensor["img"][i]): 248 | img = ( 249 | torch.from_numpy(output_tensor["img"][i][cam]) 250 | .permute(2, 0, 1) 251 | .float() 252 | ) 253 | if self.img_pad is not None: 254 | padH, padW = self.img_pad 255 | img = F.pad(img, [padW] * 2 + [padH] * 2) 256 | output_tensor["img"][i][cam] = img 257 | 258 | if cam < len(output_tensor["disp"][i]): 259 | disp = ( 260 | torch.from_numpy(output_tensor["disp"][i][cam]) 261 | .permute(2, 0, 1) 262 | .float() 263 | ) 264 | 265 | if self.sparse: 266 | valid_disp = torch.from_numpy( 267 | output_tensor["valid_disp"][i][cam] 268 | ) 269 | else: 270 | valid_disp = ( 271 | (disp[0].abs() < 512) 272 | & (disp[1].abs() < 512) 273 | & (disp[0].abs() != 0) 274 | ) 275 | disp = disp[:1] 276 | 277 | output_tensor["disp"][i][cam] = disp 278 | output_tensor["valid_disp"][i][cam] = valid_disp.float() 279 | 280 | if "mask" in output_tensor and cam < len(output_tensor["mask"][i]): 281 | mask = torch.from_numpy(output_tensor["mask"][i][cam]).float() 282 | output_tensor["mask"][i][cam] = mask 283 | 284 | if "viewpoint" in output_tensor and cam < len( 285 | output_tensor["viewpoint"][i] 286 | ): 287 | viewpoint = output_tensor["viewpoint"][i][cam] 288 | output_tensor["viewpoint"][i][cam] = viewpoint 289 | 290 | res = {} 291 | if "viewpoint" in output_tensor and self.split != "train": 292 | res["viewpoint"] = output_tensor["viewpoint"] 293 | if "metadata" in output_tensor and self.split != "train": 294 | res["metadata"] = output_tensor["metadata"] 295 | 296 | for k, v in output_tensor.items(): 297 | if k != "viewpoint" and k != "metadata": 298 | for i in range(len(v)): 299 | if len(v[i]) > 0: 300 | v[i] = torch.stack(v[i]) 301 | if len(v) > 0 and (len(v[0]) > 0): 302 | res[k] = torch.stack(v) 303 | return res 304 | 305 | def __mul__(self, v): 306 | copy_of_self = copy.deepcopy(self) 307 | copy_of_self.sample_list = v * copy_of_self.sample_list 308 | copy_of_self.extra_info = v * copy_of_self.extra_info 309 | return copy_of_self 310 | 311 | def __len__(self): 312 | return len(self.sample_list) 313 | 314 | 315 | class DynamicReplicaDataset(StereoSequenceDataset): 316 | def __init__( 317 | self, 318 | aug_params=None, 319 | root="./dynamic_replica_data", 320 | split="train", 321 | sample_len=-1, 322 | only_first_n_samples=-1, 323 | ): 324 | super(DynamicReplicaDataset, self).__init__(aug_params) 325 | self.root = root 326 | self.sample_len = sample_len 327 | self.split = split 328 | 329 | frame_annotations_file = f"frame_annotations_{split}.jgz" 330 | 331 | with gzip.open( 332 | osp.join(root, split, frame_annotations_file), "rt", encoding="utf8" 333 | ) as zipfile: 334 | frame_annots_list = load_dataclass( 335 | zipfile, List[DynamicReplicaFrameAnnotation] 336 | ) 337 | seq_annot = defaultdict(lambda: defaultdict(list)) 338 | for frame_annot in frame_annots_list: 339 | seq_annot[frame_annot.sequence_name][frame_annot.camera_name].append( 340 | frame_annot 341 | ) 342 | 343 | for seq_name in seq_annot.keys(): 344 | try: 345 | filenames = defaultdict(lambda: defaultdict(list)) 346 | for cam in ["left", "right"]: 347 | for framedata in seq_annot[seq_name][cam]: 348 | im_path = osp.join(root, split, framedata.image.path) 349 | depth_path = osp.join(root, split, framedata.depth.path) 350 | mask_path = osp.join(root, split, framedata.mask.path) 351 | 352 | assert os.path.isfile(im_path), im_path 353 | assert os.path.isfile(depth_path), depth_path 354 | assert os.path.isfile(mask_path), mask_path 355 | 356 | filenames["image"][cam].append(im_path) 357 | filenames["depth"][cam].append(depth_path) 358 | filenames["mask"][cam].append(mask_path) 359 | 360 | filenames["viewpoint"][cam].append(framedata.viewpoint) 361 | filenames["metadata"][cam].append( 362 | [framedata.sequence_name, framedata.image.size] 363 | ) 364 | 365 | for k in filenames.keys(): 366 | assert ( 367 | len(filenames[k][cam]) 368 | == len(filenames["image"][cam]) 369 | > 0 370 | ), framedata.sequence_name 371 | 372 | seq_len = len(filenames["image"][cam]) 373 | 374 | print("seq_len", seq_name, seq_len) 375 | if split == "train": 376 | for ref_idx in range(0, seq_len, 3): 377 | step = 1 if self.sample_len == 1 else np.random.randint(1, 6) 378 | if ref_idx + step * self.sample_len < seq_len: 379 | sample = defaultdict(lambda: defaultdict(list)) 380 | for cam in ["left", "right"]: 381 | for idx in range( 382 | ref_idx, ref_idx + step * self.sample_len, step 383 | ): 384 | for k in filenames.keys(): 385 | if "mask" not in k: 386 | sample[k][cam].append( 387 | filenames[k][cam][idx] 388 | ) 389 | 390 | self.sample_list.append(sample) 391 | else: 392 | step = self.sample_len if self.sample_len > 0 else seq_len 393 | counter = 0 394 | 395 | for ref_idx in range(0, seq_len, step): 396 | sample = defaultdict(lambda: defaultdict(list)) 397 | for cam in ["left", "right"]: 398 | for idx in range(ref_idx, ref_idx + step): 399 | for k in filenames.keys(): 400 | sample[k][cam].append(filenames[k][cam][idx]) 401 | 402 | self.sample_list.append(sample) 403 | counter += 1 404 | if only_first_n_samples > 0 and counter >= only_first_n_samples: 405 | break 406 | except Exception as e: 407 | print(e) 408 | print("Skipping sequence", seq_name) 409 | 410 | assert len(self.sample_list) > 0, "No samples found" 411 | print(f"Added {len(self.sample_list)} from Dynamic Replica {split}") 412 | logging.info(f"Added {len(self.sample_list)} from Dynamic Replica {split}") 413 | 414 | 415 | class SequenceSceneFlowDataset(StereoSequenceDataset): 416 | def __init__( 417 | self, 418 | aug_params=None, 419 | root="./datasets", 420 | dstype="frames_cleanpass", 421 | sample_len=1, 422 | things_test=False, 423 | add_things=True, 424 | add_monkaa=True, 425 | add_driving=True, 426 | ): 427 | super(SequenceSceneFlowDataset, self).__init__(aug_params) 428 | self.root = root 429 | self.dstype = dstype 430 | self.sample_len = sample_len 431 | if things_test: 432 | self._add_things("TEST") 433 | else: 434 | if add_things: 435 | self._add_things("TRAIN") 436 | if add_monkaa: 437 | self._add_monkaa() 438 | if add_driving: 439 | self._add_driving() 440 | 441 | def _add_things(self, split="TRAIN"): 442 | """Add FlyingThings3D data""" 443 | 444 | original_length = len(self.sample_list) 445 | root = osp.join(self.root, "FlyingThings3D") 446 | image_paths = defaultdict(list) 447 | disparity_paths = defaultdict(list) 448 | 449 | for cam in ["left", "right"]: 450 | image_paths[cam] = sorted( 451 | glob(osp.join(root, self.dstype, split, f"*/*/{cam}/")) 452 | ) 453 | disparity_paths[cam] = [ 454 | path.replace(self.dstype, "disparity") for path in image_paths[cam] 455 | ] 456 | 457 | # Choose a random subset of 400 images for validation 458 | state = np.random.get_state() 459 | np.random.seed(1000) 460 | val_idxs = set(np.random.permutation(len(image_paths["left"]))[:40]) 461 | np.random.set_state(state) 462 | np.random.seed(0) 463 | num_seq = len(image_paths["left"]) 464 | 465 | for seq_idx in range(num_seq): 466 | if (split == "TEST" and seq_idx in val_idxs) or ( 467 | split == "TRAIN" and not seq_idx in val_idxs 468 | ): 469 | images, disparities = defaultdict(list), defaultdict(list) 470 | for cam in ["left", "right"]: 471 | images[cam] = sorted( 472 | glob(osp.join(image_paths[cam][seq_idx], "*.png")) 473 | ) 474 | disparities[cam] = sorted( 475 | glob(osp.join(disparity_paths[cam][seq_idx], "*.pfm")) 476 | ) 477 | 478 | self._append_sample(images, disparities) 479 | 480 | assert len(self.sample_list) > 0, "No samples found" 481 | print( 482 | f"Added {len(self.sample_list) - original_length} from FlyingThings {self.dstype}" 483 | ) 484 | logging.info( 485 | f"Added {len(self.sample_list) - original_length} from FlyingThings {self.dstype}" 486 | ) 487 | 488 | def _add_monkaa(self): 489 | """Add FlyingThings3D data""" 490 | 491 | original_length = len(self.sample_list) 492 | root = osp.join(self.root, "Monkaa") 493 | image_paths = defaultdict(list) 494 | disparity_paths = defaultdict(list) 495 | 496 | for cam in ["left", "right"]: 497 | image_paths[cam] = sorted(glob(osp.join(root, self.dstype, f"*/{cam}/"))) 498 | disparity_paths[cam] = [ 499 | path.replace(self.dstype, "disparity") for path in image_paths[cam] 500 | ] 501 | 502 | num_seq = len(image_paths["left"]) 503 | 504 | for seq_idx in range(num_seq): 505 | images, disparities = defaultdict(list), defaultdict(list) 506 | for cam in ["left", "right"]: 507 | images[cam] = sorted(glob(osp.join(image_paths[cam][seq_idx], "*.png"))) 508 | disparities[cam] = sorted( 509 | glob(osp.join(disparity_paths[cam][seq_idx], "*.pfm")) 510 | ) 511 | 512 | self._append_sample(images, disparities) 513 | 514 | assert len(self.sample_list) > 0, "No samples found" 515 | print( 516 | f"Added {len(self.sample_list) - original_length} from Monkaa {self.dstype}" 517 | ) 518 | logging.info( 519 | f"Added {len(self.sample_list) - original_length} from Monkaa {self.dstype}" 520 | ) 521 | 522 | def _add_driving(self): 523 | """Add FlyingThings3D data""" 524 | 525 | original_length = len(self.sample_list) 526 | root = osp.join(self.root, "Driving") 527 | image_paths = defaultdict(list) 528 | disparity_paths = defaultdict(list) 529 | 530 | for cam in ["left", "right"]: 531 | image_paths[cam] = sorted( 532 | glob(osp.join(root, self.dstype, f"*/*/*/{cam}/")) 533 | ) 534 | disparity_paths[cam] = [ 535 | path.replace(self.dstype, "disparity") for path in image_paths[cam] 536 | ] 537 | 538 | num_seq = len(image_paths["left"]) 539 | for seq_idx in range(num_seq): 540 | images, disparities = defaultdict(list), defaultdict(list) 541 | for cam in ["left", "right"]: 542 | images[cam] = sorted(glob(osp.join(image_paths[cam][seq_idx], "*.png"))) 543 | disparities[cam] = sorted( 544 | glob(osp.join(disparity_paths[cam][seq_idx], "*.pfm")) 545 | ) 546 | 547 | self._append_sample(images, disparities) 548 | 549 | assert len(self.sample_list) > 0, "No samples found" 550 | print( 551 | f"Added {len(self.sample_list) - original_length} from Driving {self.dstype}" 552 | ) 553 | logging.info( 554 | f"Added {len(self.sample_list) - original_length} from Driving {self.dstype}" 555 | ) 556 | 557 | def _append_sample(self, images, disparities): 558 | seq_len = len(images["left"]) 559 | for ref_idx in range(0, seq_len - self.sample_len): 560 | sample = defaultdict(lambda: defaultdict(list)) 561 | for cam in ["left", "right"]: 562 | for idx in range(ref_idx, ref_idx + self.sample_len): 563 | sample["image"][cam].append(images[cam][idx]) 564 | sample["disparity"][cam].append(disparities[cam][idx]) 565 | self.sample_list.append(sample) 566 | 567 | sample = defaultdict(lambda: defaultdict(list)) 568 | for cam in ["left", "right"]: 569 | for idx in range(ref_idx, ref_idx + self.sample_len): 570 | sample["image"][cam].append(images[cam][seq_len - idx - 1]) 571 | sample["disparity"][cam].append(disparities[cam][seq_len - idx - 1]) 572 | self.sample_list.append(sample) 573 | 574 | 575 | class SequenceSintelStereo(StereoSequenceDataset): 576 | def __init__( 577 | self, 578 | dstype="clean", 579 | aug_params=None, 580 | root="./datasets", 581 | ): 582 | super().__init__( 583 | aug_params, sparse=True, reader=frame_utils.readDispSintelStereo 584 | ) 585 | self.dstype = dstype 586 | original_length = len(self.sample_list) 587 | image_root = osp.join(root, "sintel_stereo", "training") 588 | 589 | image_paths = defaultdict(list) 590 | disparity_paths = defaultdict(list) 591 | 592 | for cam in ["left", "right"]: 593 | image_paths[cam] = sorted( 594 | glob(osp.join(image_root, f"{self.dstype}_{cam}/*")) 595 | ) 596 | 597 | cam = "left" 598 | disparity_paths[cam] = [ 599 | path.replace(f"{self.dstype}_{cam}", "disparities") 600 | for path in image_paths[cam] 601 | ] 602 | 603 | num_seq = len(image_paths["left"]) 604 | # for each sequence 605 | for seq_idx in range(num_seq): 606 | sample = defaultdict(lambda: defaultdict(list)) 607 | for cam in ["left", "right"]: 608 | sample["image"][cam] = sorted( 609 | glob(osp.join(image_paths[cam][seq_idx], "*.png")) 610 | ) 611 | cam = "left" 612 | sample["disparity"][cam] = sorted( 613 | glob(osp.join(disparity_paths[cam][seq_idx], "*.png")) 614 | ) 615 | for im1, disp in zip(sample["image"][cam], sample["disparity"][cam]): 616 | assert ( 617 | im1.split("/")[-1].split(".")[0] 618 | == disp.split("/")[-1].split(".")[0] 619 | ), (im1.split("/")[-1].split(".")[0], disp.split("/")[-1].split(".")[0]) 620 | self.sample_list.append(sample) 621 | 622 | logging.info( 623 | f"Added {len(self.sample_list) - original_length} from SintelStereo {self.dstype}" 624 | ) 625 | 626 | 627 | def fetch_dataloader(args): 628 | """Create the data loader for the corresponding trainign set""" 629 | 630 | aug_params = { 631 | "crop_size": args.image_size, 632 | "min_scale": args.spatial_scale[0], 633 | "max_scale": args.spatial_scale[1], 634 | "do_flip": False, 635 | "yjitter": not args.noyjitter, 636 | } 637 | if hasattr(args, "saturation_range") and args.saturation_range is not None: 638 | aug_params["saturation_range"] = args.saturation_range 639 | if hasattr(args, "img_gamma") and args.img_gamma is not None: 640 | aug_params["gamma"] = args.img_gamma 641 | if hasattr(args, "do_flip") and args.do_flip is not None: 642 | aug_params["do_flip"] = args.do_flip 643 | 644 | train_dataset = None 645 | 646 | add_monkaa = "monkaa" in args.train_datasets 647 | add_driving = "driving" in args.train_datasets 648 | add_things = "things" in args.train_datasets 649 | add_dynamic_replica = "dynamic_replica" in args.train_datasets 650 | 651 | new_dataset = None 652 | 653 | if add_monkaa or add_driving or add_things: 654 | clean_dataset = SequenceSceneFlowDataset( 655 | aug_params, 656 | dstype="frames_cleanpass", 657 | sample_len=args.sample_len, 658 | add_monkaa=add_monkaa, 659 | add_driving=add_driving, 660 | add_things=add_things, 661 | ) 662 | 663 | final_dataset = SequenceSceneFlowDataset( 664 | aug_params, 665 | dstype="frames_finalpass", 666 | sample_len=args.sample_len, 667 | add_monkaa=add_monkaa, 668 | add_driving=add_driving, 669 | add_things=add_things, 670 | ) 671 | 672 | new_dataset = clean_dataset + final_dataset 673 | 674 | if add_dynamic_replica: 675 | dr_dataset = DynamicReplicaDataset( 676 | aug_params, split="train", sample_len=args.sample_len 677 | ) 678 | if new_dataset is None: 679 | new_dataset = dr_dataset 680 | else: 681 | new_dataset = new_dataset + dr_dataset 682 | 683 | logging.info(f"Adding {len(new_dataset)} samples from SceneFlow") 684 | train_dataset = ( 685 | new_dataset if train_dataset is None else train_dataset + new_dataset 686 | ) 687 | 688 | train_loader = data.DataLoader( 689 | train_dataset, 690 | batch_size=args.batch_size, 691 | pin_memory=True, 692 | shuffle=True, 693 | num_workers=args.num_workers, 694 | drop_last=True, 695 | ) 696 | 697 | logging.info("Training with %d image pairs" % len(train_dataset)) 698 | return train_loader 699 | -------------------------------------------------------------------------------- /datasets/frame_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | from PIL import Image 9 | from os.path import * 10 | import re 11 | import imageio 12 | import cv2 13 | 14 | cv2.setNumThreads(0) 15 | cv2.ocl.setUseOpenCL(False) 16 | 17 | TAG_CHAR = np.array([202021.25], np.float32) 18 | 19 | 20 | def readFlow(fn): 21 | """Read .flo file in Middlebury format""" 22 | # Code adapted from: 23 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 24 | 25 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 26 | # print 'fn = %s'%(fn) 27 | with open(fn, "rb") as f: 28 | magic = np.fromfile(f, np.float32, count=1) 29 | if 202021.25 != magic: 30 | print("Magic number incorrect. Invalid .flo file") 31 | return None 32 | else: 33 | w = np.fromfile(f, np.int32, count=1) 34 | h = np.fromfile(f, np.int32, count=1) 35 | # print 'Reading %d x %d flo file\n' % (w, h) 36 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) 37 | # Reshape data into 3D array (columns, rows, bands) 38 | # The reshape here is for visualization, the original code is (w,h,2) 39 | return np.resize(data, (int(h), int(w), 2)) 40 | 41 | 42 | def readPFM(file): 43 | file = open(file, "rb") 44 | 45 | color = None 46 | width = None 47 | height = None 48 | scale = None 49 | endian = None 50 | 51 | header = file.readline().rstrip() 52 | if header == b"PF": 53 | color = True 54 | elif header == b"Pf": 55 | color = False 56 | else: 57 | raise Exception("Not a PFM file.") 58 | 59 | dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline()) 60 | if dim_match: 61 | width, height = map(int, dim_match.groups()) 62 | else: 63 | raise Exception("Malformed PFM header.") 64 | 65 | scale = float(file.readline().rstrip()) 66 | if scale < 0: # little-endian 67 | endian = "<" 68 | scale = -scale 69 | else: 70 | endian = ">" # big-endian 71 | 72 | data = np.fromfile(file, endian + "f") 73 | shape = (height, width, 3) if color else (height, width) 74 | 75 | data = np.reshape(data, shape) 76 | data = np.flipud(data) 77 | return data 78 | 79 | 80 | def readDispSintelStereo(file_name): 81 | """Return disparity read from filename.""" 82 | f_in = np.array(Image.open(file_name)) 83 | d_r = f_in[:, :, 0].astype("float64") 84 | d_g = f_in[:, :, 1].astype("float64") 85 | d_b = f_in[:, :, 2].astype("float64") 86 | 87 | disp = d_r * 4 + d_g / (2 ** 6) + d_b / (2 ** 14) 88 | mask = np.array(Image.open(file_name.replace("disparities", "occlusions"))) 89 | valid = (mask == 0) & (disp > 0) 90 | return disp, valid 91 | 92 | 93 | def readDispMiddlebury(file_name): 94 | assert basename(file_name) == "disp0GT.pfm" 95 | disp = readPFM(file_name).astype(np.float32) 96 | assert len(disp.shape) == 2 97 | nocc_pix = file_name.replace("disp0GT.pfm", "mask0nocc.png") 98 | assert exists(nocc_pix) 99 | nocc_pix = imageio.imread(nocc_pix) == 255 100 | assert np.any(nocc_pix) 101 | return disp, nocc_pix 102 | 103 | 104 | def read_gen(file_name, pil=False): 105 | ext = splitext(file_name)[-1] 106 | if ext == ".png" or ext == ".jpeg" or ext == ".ppm" or ext == ".jpg": 107 | return Image.open(file_name) 108 | elif ext == ".bin" or ext == ".raw": 109 | return np.load(file_name) 110 | elif ext == ".flo": 111 | return readFlow(file_name).astype(np.float32) 112 | elif ext == ".pfm": 113 | flow = readPFM(file_name).astype(np.float32) 114 | if len(flow.shape) == 2: 115 | return flow 116 | else: 117 | return flow[:, :, :-1] 118 | return [] 119 | -------------------------------------------------------------------------------- /evaluation/configs/eval_dynamic_replica_150_frames.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default_config_eval 3 | visualize_interval: 0 4 | exp_dir: ./outputs/dynamic_stereo_DR 5 | sample_len: 150 6 | MODEL: 7 | model_name: DynamicStereoModel 8 | 9 | -------------------------------------------------------------------------------- /evaluation/configs/eval_dynamic_replica_40_frames.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default_config_eval 3 | visualize_interval: 0 4 | exp_dir: ./outputs/dynamic_stereo_DR 5 | sample_len: 40 6 | MODEL: 7 | model_name: DynamicStereoModel 8 | 9 | -------------------------------------------------------------------------------- /evaluation/configs/eval_real_data.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default_config_eval 3 | visualize_interval: 1 4 | exp_dir: ./outputs/dynamic_stereo_real 5 | dataset_name: real 6 | sample_len: 40 7 | MODEL: 8 | model_name: DynamicStereoModel 9 | 10 | -------------------------------------------------------------------------------- /evaluation/configs/eval_sintel_clean.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default_config_eval 3 | visualize_interval: -1 4 | exp_dir: ./outputs/dynamic_stereo_sintel_clean 5 | sample_len: 30 6 | dataset_name: sintel 7 | dstype: clean 8 | MODEL: 9 | model_name: DynamicStereoModel 10 | -------------------------------------------------------------------------------- /evaluation/configs/eval_sintel_final.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default_config_eval 3 | visualize_interval: -1 4 | exp_dir: ./outputs/dynamic_stereo_sintel_final 5 | sample_len: 30 6 | dataset_name: sintel 7 | dstype: final 8 | MODEL: 9 | model_name: DynamicStereoModel -------------------------------------------------------------------------------- /evaluation/core/evaluator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from collections import defaultdict 9 | import torch.nn.functional as F 10 | import torch 11 | from tqdm import tqdm 12 | from omegaconf import DictConfig 13 | from pytorch3d.implicitron.tools.config import Configurable 14 | 15 | from dynamic_stereo.evaluation.utils.eval_utils import depth2disparity_scale, eval_batch 16 | from dynamic_stereo.evaluation.utils.utils import ( 17 | PerceptionPrediction, 18 | pretty_print_perception_metrics, 19 | visualize_batch, 20 | ) 21 | 22 | 23 | class Evaluator(Configurable): 24 | """ 25 | A class defining the DynamicStereo evaluator. 26 | 27 | Args: 28 | eps: Threshold for converting disparity to depth. 29 | """ 30 | 31 | eps = 1e-5 32 | 33 | def setup_visualization(self, cfg: DictConfig) -> None: 34 | # Visualization 35 | self.visualize_interval = cfg.visualize_interval 36 | self.exp_dir = cfg.exp_dir 37 | if self.visualize_interval > 0: 38 | self.visualize_dir = os.path.join(cfg.exp_dir, "visualisations") 39 | 40 | @torch.no_grad() 41 | def evaluate_sequence( 42 | self, 43 | model, 44 | test_dataloader: torch.utils.data.DataLoader, 45 | is_real_data: bool = False, 46 | step=None, 47 | writer=None, 48 | train_mode=False, 49 | interp_shape=None, 50 | ): 51 | model.eval() 52 | per_batch_eval_results = [] 53 | 54 | if self.visualize_interval > 0: 55 | os.makedirs(self.visualize_dir, exist_ok=True) 56 | 57 | for batch_idx, sequence in enumerate(tqdm(test_dataloader)): 58 | batch_dict = defaultdict(list) 59 | batch_dict["stereo_video"] = sequence["img"] 60 | if not is_real_data: 61 | batch_dict["disparity"] = sequence["disp"][:, 0].abs() 62 | batch_dict["disparity_mask"] = sequence["valid_disp"][:, :1] 63 | 64 | if "mask" in sequence: 65 | batch_dict["fg_mask"] = sequence["mask"][:, :1] 66 | else: 67 | batch_dict["fg_mask"] = torch.ones_like( 68 | batch_dict["disparity_mask"] 69 | ) 70 | elif interp_shape is not None: 71 | left_video = batch_dict["stereo_video"][:, 0] 72 | left_video = F.interpolate( 73 | left_video, tuple(interp_shape), mode="bilinear" 74 | ) 75 | right_video = batch_dict["stereo_video"][:, 1] 76 | right_video = F.interpolate( 77 | right_video, tuple(interp_shape), mode="bilinear" 78 | ) 79 | batch_dict["stereo_video"] = torch.stack([left_video, right_video], 1) 80 | 81 | if train_mode: 82 | predictions = model.forward_batch_test(batch_dict) 83 | else: 84 | predictions = model(batch_dict) 85 | 86 | assert "disparity" in predictions 87 | predictions["disparity"] = predictions["disparity"][:, :1].clone().cpu() 88 | 89 | if not is_real_data: 90 | predictions["disparity"] = predictions["disparity"] * ( 91 | batch_dict["disparity_mask"].round() 92 | ) 93 | 94 | batch_eval_result, seq_length = eval_batch(batch_dict, predictions) 95 | 96 | per_batch_eval_results.append((batch_eval_result, seq_length)) 97 | pretty_print_perception_metrics(batch_eval_result) 98 | 99 | if (self.visualize_interval > 0) and ( 100 | batch_idx % self.visualize_interval == 0 101 | ): 102 | perception_prediction = PerceptionPrediction() 103 | 104 | pred_disp = predictions["disparity"] 105 | pred_disp[pred_disp < self.eps] = self.eps 106 | 107 | scale = depth2disparity_scale( 108 | sequence["viewpoint"][0][0], 109 | sequence["viewpoint"][0][1], 110 | torch.tensor([pred_disp.shape[2], pred_disp.shape[3]])[None], 111 | ) 112 | 113 | perception_prediction.depth_map = (scale / pred_disp).cuda() 114 | perspective_cameras = [] 115 | for cam in sequence["viewpoint"]: 116 | perspective_cameras.append(cam[0]) 117 | 118 | perception_prediction.perspective_cameras = perspective_cameras 119 | 120 | if "stereo_original_video" in batch_dict: 121 | batch_dict["stereo_video"] = batch_dict[ 122 | "stereo_original_video" 123 | ].clone() 124 | 125 | for k, v in batch_dict.items(): 126 | if isinstance(v, torch.Tensor): 127 | batch_dict[k] = v.cuda() 128 | 129 | visualize_batch( 130 | batch_dict, 131 | perception_prediction, 132 | output_dir=self.visualize_dir, 133 | sequence_name=sequence["metadata"][0][0][0], 134 | step=step, 135 | writer=writer, 136 | ) 137 | return per_batch_eval_results 138 | -------------------------------------------------------------------------------- /evaluation/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import os 9 | from dataclasses import dataclass, field 10 | from typing import Any, Dict, Optional 11 | 12 | import hydra 13 | import numpy as np 14 | 15 | import torch 16 | from omegaconf import OmegaConf 17 | 18 | from dynamic_stereo.evaluation.utils.utils import aggregate_and_print_results 19 | 20 | import dynamic_stereo.datasets.dynamic_stereo_datasets as datasets 21 | 22 | from dynamic_stereo.models.core.model_zoo import ( 23 | get_all_model_default_configs, 24 | model_zoo, 25 | ) 26 | from pytorch3d.implicitron.tools.config import get_default_args_field 27 | from dynamic_stereo.evaluation.core.evaluator import Evaluator 28 | 29 | 30 | @dataclass(eq=False) 31 | class DefaultConfig: 32 | exp_dir: str = "./outputs" 33 | 34 | # one of [sintel, dynamicreplica, things] 35 | dataset_name: str = "dynamicreplica" 36 | 37 | sample_len: int = -1 38 | dstype: Optional[str] = None 39 | # clean, final 40 | MODEL: Dict[str, Any] = field( 41 | default_factory=lambda: get_all_model_default_configs() 42 | ) 43 | EVALUATOR: Dict[str, Any] = get_default_args_field(Evaluator) 44 | 45 | seed: int = 42 46 | gpu_idx: int = 0 47 | 48 | visualize_interval: int = 0 # Use 0 for no visualization 49 | 50 | # Override hydra's working directory to current working dir, 51 | # also disable storing the .hydra logs: 52 | hydra: dict = field( 53 | default_factory=lambda: { 54 | "run": {"dir": "."}, 55 | "output_subdir": None, 56 | } 57 | ) 58 | 59 | 60 | def run_eval(cfg: DefaultConfig): 61 | """ 62 | Evaluates new view synthesis metrics of a specified model 63 | on a benchmark dataset. 64 | """ 65 | # make the experiment directory 66 | os.makedirs(cfg.exp_dir, exist_ok=True) 67 | 68 | # dump the exp cofig to the exp_dir 69 | cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml") 70 | with open(cfg_file, "w") as f: 71 | OmegaConf.save(config=cfg, f=f) 72 | 73 | torch.manual_seed(cfg.seed) 74 | np.random.seed(cfg.seed) 75 | evaluator = Evaluator(**cfg.EVALUATOR) 76 | 77 | model = model_zoo(**cfg.MODEL) 78 | model.cuda(0) 79 | evaluator.setup_visualization(cfg) 80 | 81 | if cfg.dataset_name == "dynamicreplica": 82 | test_dataloader = datasets.DynamicReplicaDataset( 83 | split="valid", sample_len=cfg.sample_len, only_first_n_samples=1 84 | ) 85 | elif cfg.dataset_name == "sintel": 86 | test_dataloader = datasets.SequenceSintelStereo(dstype=cfg.dstype) 87 | elif cfg.dataset_name == "things": 88 | test_dataloader = datasets.SequenceSceneFlowDatasets( 89 | {}, 90 | dstype=cfg.dstype, 91 | sample_len=cfg.sample_len, 92 | add_monkaa=False, 93 | add_driving=False, 94 | things_test=True, 95 | ) 96 | elif cfg.dataset_name == "real": 97 | for real_sequence_name in ["teddy_static", "ignacio_waving", "nikita_reading"]: 98 | ds_path = f"./dynamic_replica_data/real/{real_sequence_name}" 99 | # seq_len_real = 20 100 | real_dataset = datasets.DynamicReplicaDataset( 101 | split="test", 102 | sample_len=cfg.sample_len, 103 | root=ds_path, 104 | only_first_n_samples=1, 105 | ) 106 | 107 | evaluator.evaluate_sequence( 108 | model=model, 109 | test_dataloader=real_dataset, 110 | is_real_data=True, 111 | train_mode=False, 112 | ) 113 | return 114 | 115 | print() 116 | 117 | evaluate_result = evaluator.evaluate_sequence( 118 | model, 119 | test_dataloader, 120 | ) 121 | 122 | aggreegate_result = aggregate_and_print_results(evaluate_result) 123 | 124 | result_file = os.path.join(cfg.exp_dir, f"result_eval.json") 125 | 126 | print(f"Dumping eval results to {result_file}.") 127 | with open(result_file, "w") as f: 128 | json.dump(aggreegate_result, f) 129 | 130 | 131 | cs = hydra.core.config_store.ConfigStore.instance() 132 | cs.store(name="default_config_eval", node=DefaultConfig) 133 | 134 | 135 | @hydra.main(config_path="./configs/", config_name="default_config_eval") 136 | def evaluate(cfg: DefaultConfig) -> None: 137 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 138 | os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx) 139 | run_eval(cfg) 140 | 141 | 142 | if __name__ == "__main__": 143 | evaluate() 144 | -------------------------------------------------------------------------------- /evaluation/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass 8 | from typing import Dict, Optional, Union 9 | 10 | import torch 11 | from pytorch3d.utils import opencv_from_cameras_projection 12 | 13 | 14 | @dataclass(eq=True, frozen=True) 15 | class PerceptionMetric: 16 | metric: str 17 | depth_scaling_norm: Optional[str] = None 18 | suffix: str = "" 19 | index: str = "" 20 | 21 | def __str__(self): 22 | return ( 23 | self.metric 24 | + self.index 25 | + ( 26 | ("_norm_" + self.depth_scaling_norm) 27 | if self.depth_scaling_norm is not None 28 | else "" 29 | ) 30 | + self.suffix 31 | ) 32 | 33 | 34 | def eval_endpoint_error_sequence( 35 | x: torch.Tensor, 36 | y: torch.Tensor, 37 | mask: torch.Tensor, 38 | crop: int = 0, 39 | mask_thr: float = 0.5, 40 | clamp_thr: float = 1e-5, 41 | ) -> Dict[str, torch.Tensor]: 42 | 43 | assert len(x.shape) == len(y.shape) == len(mask.shape) == 4, ( 44 | x.shape, 45 | y.shape, 46 | mask.shape, 47 | ) 48 | assert x.shape[0] == y.shape[0] == mask.shape[0], (x.shape, y.shape, mask.shape) 49 | 50 | # chuck out the border 51 | if crop > 0: 52 | if crop > min(y.shape[2:]) - crop: 53 | raise ValueError("Incorrect crop size.") 54 | y = y[:, :, crop:-crop, crop:-crop] 55 | x = x[:, :, crop:-crop, crop:-crop] 56 | mask = mask[:, :, crop:-crop, crop:-crop] 57 | 58 | y = y * (mask > mask_thr).float() 59 | x = x * (mask > mask_thr).float() 60 | y[torch.isnan(y)] = 0 61 | 62 | results = {} 63 | for epe_name in ("epe", "temp_epe"): 64 | if epe_name == "epe": 65 | endpoint_error = (mask * (x - y) ** 2).sum(dim=1).sqrt() 66 | elif epe_name == "temp_epe": 67 | delta_mask = mask[:-1] * mask[1:] 68 | endpoint_error = ( 69 | (delta_mask * ((x[:-1] - x[1:]) - (y[:-1] - y[1:])) ** 2) 70 | .sum(dim=1) 71 | .sqrt() 72 | ) 73 | 74 | # epe_nonzero = endpoint_error != 0 75 | nonzero = torch.count_nonzero(endpoint_error) 76 | 77 | epe_mean = endpoint_error.sum() / torch.clamp( 78 | nonzero, clamp_thr 79 | ) # average error for all the sequence pixels 80 | epe_inv_accuracy_05px = (endpoint_error > 0.5).sum() / torch.clamp( 81 | nonzero, clamp_thr 82 | ) 83 | epe_inv_accuracy_1px = (endpoint_error > 1).sum() / torch.clamp( 84 | nonzero, clamp_thr 85 | ) 86 | epe_inv_accuracy_2px = (endpoint_error > 2).sum() / torch.clamp( 87 | nonzero, clamp_thr 88 | ) 89 | epe_inv_accuracy_3px = (endpoint_error > 3).sum() / torch.clamp( 90 | nonzero, clamp_thr 91 | ) 92 | 93 | results[f"{epe_name}_mean"] = epe_mean[None] 94 | results[f"{epe_name}_bad_0.5px"] = epe_inv_accuracy_05px[None] * 100 95 | results[f"{epe_name}_bad_1px"] = epe_inv_accuracy_1px[None] * 100 96 | results[f"{epe_name}_bad_2px"] = epe_inv_accuracy_2px[None] * 100 97 | results[f"{epe_name}_bad_3px"] = epe_inv_accuracy_3px[None] * 100 98 | return results 99 | 100 | 101 | def depth2disparity_scale(left_camera, right_camera, image_size_tensor): 102 | # # opencv camera matrices 103 | (_, T1, K1), (_, T2, _) = [ 104 | opencv_from_cameras_projection( 105 | f, 106 | image_size_tensor, 107 | ) 108 | for f in (left_camera, right_camera) 109 | ] 110 | fix_baseline = T1[0][0] - T2[0][0] 111 | focal_length_px = K1[0][0][0] 112 | # following this https://github.com/princeton-vl/RAFT-Stereo#converting-disparity-to-depth 113 | return focal_length_px * fix_baseline 114 | 115 | 116 | def depth_to_pcd( 117 | depth_map, 118 | img, 119 | focal_length, 120 | cx, 121 | cy, 122 | step: int = None, 123 | inv_extrinsic=None, 124 | mask=None, 125 | filter=False, 126 | ): 127 | __, w, __ = img.shape 128 | if step is None: 129 | step = int(w / 100) 130 | Z = depth_map[::step, ::step] 131 | colors = img[::step, ::step, :] 132 | 133 | Pixels_Y = torch.arange(Z.shape[0]).to(Z.device) * step 134 | Pixels_X = torch.arange(Z.shape[1]).to(Z.device) * step 135 | 136 | X = (Pixels_X[None, :] - cx) * Z / focal_length 137 | Y = (Pixels_Y[:, None] - cy) * Z / focal_length 138 | 139 | inds = Z > 0 140 | 141 | if mask is not None: 142 | inds = inds * (mask[::step, ::step] > 0) 143 | 144 | X = X[inds].reshape(-1) 145 | Y = Y[inds].reshape(-1) 146 | Z = Z[inds].reshape(-1) 147 | colors = colors[inds] 148 | pcd = torch.stack([X, Y, Z]).T 149 | 150 | if inv_extrinsic is not None: 151 | pcd_ext = torch.vstack([pcd.T, torch.ones((1, pcd.shape[0])).to(Z.device)]) 152 | pcd = (inv_extrinsic @ pcd_ext)[:3, :].T 153 | 154 | if filter: 155 | pcd, filt_inds = filter_outliers(pcd) 156 | colors = colors[filt_inds] 157 | return pcd, colors 158 | 159 | 160 | def filter_outliers(pcd, sigma=3): 161 | mean = pcd.mean(0) 162 | std = pcd.std(0) 163 | inds = ((pcd - mean).abs() < sigma * std)[:, 2] 164 | pcd = pcd[inds] 165 | return pcd, inds 166 | 167 | 168 | def eval_batch(batch_dict, predictions) -> Dict[str, Union[float, torch.Tensor]]: 169 | """ 170 | Produce performance metrics for a single batch of perception 171 | predictions. 172 | Args: 173 | frame_data: A PixarFrameData object containing the input to the new view 174 | synthesis method. 175 | preds: A PerceptionPrediction object with the predicted data. 176 | Returns: 177 | results: A dictionary holding evaluation metrics. 178 | """ 179 | results = {} 180 | 181 | assert "disparity" in predictions 182 | mask_now = torch.ones_like(batch_dict["fg_mask"]) 183 | 184 | mask_now = mask_now * batch_dict["disparity_mask"] 185 | 186 | eval_flow_traj_output = eval_endpoint_error_sequence( 187 | predictions["disparity"], batch_dict["disparity"], mask_now 188 | ) 189 | for epe_name in ("epe", "temp_epe"): 190 | results[PerceptionMetric(f"disp_{epe_name}_mean")] = eval_flow_traj_output[ 191 | f"{epe_name}_mean" 192 | ] 193 | 194 | results[PerceptionMetric(f"disp_{epe_name}_bad_3px")] = eval_flow_traj_output[ 195 | f"{epe_name}_bad_3px" 196 | ] 197 | 198 | results[PerceptionMetric(f"disp_{epe_name}_bad_2px")] = eval_flow_traj_output[ 199 | f"{epe_name}_bad_2px" 200 | ] 201 | 202 | results[PerceptionMetric(f"disp_{epe_name}_bad_1px")] = eval_flow_traj_output[ 203 | f"{epe_name}_bad_1px" 204 | ] 205 | 206 | results[PerceptionMetric(f"disp_{epe_name}_bad_0.5px")] = eval_flow_traj_output[ 207 | f"{epe_name}_bad_0.5px" 208 | ] 209 | if "endpoint_error_per_pixel" in eval_flow_traj_output: 210 | results["disp_endpoint_error_per_pixel"] = eval_flow_traj_output[ 211 | "endpoint_error_per_pixel" 212 | ] 213 | return (results, len(predictions["disparity"])) 214 | -------------------------------------------------------------------------------- /evaluation/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from collections import defaultdict 8 | import configparser 9 | import os 10 | import math 11 | from typing import Optional, List 12 | import torch 13 | import cv2 14 | import numpy as np 15 | from dataclasses import dataclass 16 | from tabulate import tabulate 17 | 18 | 19 | from pytorch3d.structures import Pointclouds 20 | from pytorch3d.transforms import RotateAxisAngle 21 | from pytorch3d.utils import ( 22 | opencv_from_cameras_projection, 23 | ) 24 | from pytorch3d.renderer import ( 25 | AlphaCompositor, 26 | PointsRasterizationSettings, 27 | PointsRasterizer, 28 | PointsRenderer, 29 | ) 30 | from dynamic_stereo.evaluation.utils.eval_utils import depth_to_pcd 31 | 32 | 33 | @dataclass 34 | class PerceptionPrediction: 35 | """ 36 | Holds the tensors that describe a result of any perception module. 37 | """ 38 | 39 | depth_map: Optional[torch.Tensor] = None 40 | disparity: Optional[torch.Tensor] = None 41 | image_rgb: Optional[torch.Tensor] = None 42 | fg_probability: Optional[torch.Tensor] = None 43 | 44 | 45 | def aggregate_eval_results(per_batch_eval_results, reduction="mean"): 46 | 47 | total_length = 0 48 | aggregate_results = defaultdict(list) 49 | for result in per_batch_eval_results: 50 | if isinstance(result, tuple): 51 | reduction = "sum" 52 | length = result[1] 53 | total_length += length 54 | result = result[0] 55 | for metric, val in result.items(): 56 | if reduction == "sum": 57 | aggregate_results[metric].append(val * length) 58 | 59 | if reduction == "mean": 60 | return {k: torch.cat(v).mean().item() for k, v in aggregate_results.items()} 61 | elif reduction == "sum": 62 | return { 63 | k: torch.cat(v).sum().item() / float(total_length) 64 | for k, v in aggregate_results.items() 65 | } 66 | 67 | 68 | def aggregate_and_print_results( 69 | per_batch_eval_results: List[dict], 70 | ): 71 | print("") 72 | result = aggregate_eval_results( 73 | per_batch_eval_results, 74 | ) 75 | pretty_print_perception_metrics(result) 76 | result = {str(k): v for k, v in result.items()} 77 | 78 | print("") 79 | return result 80 | 81 | 82 | def pretty_print_perception_metrics(results): 83 | 84 | metrics = sorted(list(results.keys()), key=lambda x: x.metric) 85 | 86 | print("===== Perception results =====") 87 | print( 88 | tabulate( 89 | [[metric, results[metric]] for metric in metrics], 90 | ) 91 | ) 92 | 93 | 94 | def read_calibration(calibration_file, resolution_string): 95 | # ported from https://github.com/stereolabs/zed-open-capture/ 96 | # blob/dfa0aee51ccd2297782230a05ca59e697df496b2/examples/include/calibration.hpp#L4172 97 | 98 | zed_resolutions = { 99 | "2K": (1242, 2208), 100 | "FHD": (1080, 1920), 101 | "HD": (720, 1280), 102 | # "qHD": (540, 960), 103 | "VGA": (376, 672), 104 | } 105 | assert resolution_string in zed_resolutions.keys() 106 | image_height, image_width = zed_resolutions[resolution_string] 107 | 108 | # Open camera configuration file 109 | assert os.path.isfile(calibration_file) 110 | calib = configparser.ConfigParser() 111 | calib.read(calibration_file) 112 | 113 | # Get translations 114 | T = np.zeros((3, 1)) 115 | T[0] = float(calib["STEREO"]["baseline"]) 116 | T[1] = float(calib["STEREO"]["ty"]) 117 | T[2] = float(calib["STEREO"]["tz"]) 118 | 119 | baseline = T[0] 120 | 121 | # Get left parameters 122 | left_cam_cx = float(calib[f"LEFT_CAM_{resolution_string}"]["cx"]) 123 | left_cam_cy = float(calib[f"LEFT_CAM_{resolution_string}"]["cy"]) 124 | left_cam_fx = float(calib[f"LEFT_CAM_{resolution_string}"]["fx"]) 125 | left_cam_fy = float(calib[f"LEFT_CAM_{resolution_string}"]["fy"]) 126 | left_cam_k1 = float(calib[f"LEFT_CAM_{resolution_string}"]["k1"]) 127 | left_cam_k2 = float(calib[f"LEFT_CAM_{resolution_string}"]["k2"]) 128 | left_cam_p1 = float(calib[f"LEFT_CAM_{resolution_string}"]["p1"]) 129 | left_cam_p2 = float(calib[f"LEFT_CAM_{resolution_string}"]["p2"]) 130 | left_cam_k3 = float(calib[f"LEFT_CAM_{resolution_string}"]["k3"]) 131 | 132 | # Get right parameters 133 | right_cam_cx = float(calib[f"RIGHT_CAM_{resolution_string}"]["cx"]) 134 | right_cam_cy = float(calib[f"RIGHT_CAM_{resolution_string}"]["cy"]) 135 | right_cam_fx = float(calib[f"RIGHT_CAM_{resolution_string}"]["fx"]) 136 | right_cam_fy = float(calib[f"RIGHT_CAM_{resolution_string}"]["fy"]) 137 | right_cam_k1 = float(calib[f"RIGHT_CAM_{resolution_string}"]["k1"]) 138 | right_cam_k2 = float(calib[f"RIGHT_CAM_{resolution_string}"]["k2"]) 139 | right_cam_p1 = float(calib[f"RIGHT_CAM_{resolution_string}"]["p1"]) 140 | right_cam_p2 = float(calib[f"RIGHT_CAM_{resolution_string}"]["p2"]) 141 | right_cam_k3 = float(calib[f"RIGHT_CAM_{resolution_string}"]["k3"]) 142 | 143 | # Get rotations 144 | R_zed = np.zeros(3) 145 | R_zed[0] = float(calib["STEREO"][f"rx_{resolution_string.lower()}"]) 146 | R_zed[1] = float(calib["STEREO"][f"cv_{resolution_string.lower()}"]) 147 | R_zed[2] = float(calib["STEREO"][f"rz_{resolution_string.lower()}"]) 148 | 149 | R = cv2.Rodrigues(R_zed)[0] 150 | 151 | # Left 152 | cameraMatrix_left = np.array( 153 | [[left_cam_fx, 0, left_cam_cx], [0, left_cam_fy, left_cam_cy], [0, 0, 1]] 154 | ) 155 | distCoeffs_left = np.array( 156 | [left_cam_k1, left_cam_k2, left_cam_p1, left_cam_p2, left_cam_k3] 157 | ) 158 | 159 | # Right 160 | cameraMatrix_right = np.array( 161 | [ 162 | [right_cam_fx, 0, right_cam_cx], 163 | [0, right_cam_fy, right_cam_cy], 164 | [0, 0, 1], 165 | ] 166 | ) 167 | distCoeffs_right = np.array( 168 | [right_cam_k1, right_cam_k2, right_cam_p1, right_cam_p2, right_cam_k3] 169 | ) 170 | 171 | # Stereo 172 | R1, R2, P1, P2, Q = cv2.stereoRectify( 173 | cameraMatrix1=cameraMatrix_left, 174 | distCoeffs1=distCoeffs_left, 175 | cameraMatrix2=cameraMatrix_right, 176 | distCoeffs2=distCoeffs_right, 177 | imageSize=(image_width, image_height), 178 | R=R, 179 | T=T, 180 | flags=cv2.CALIB_ZERO_DISPARITY, 181 | newImageSize=(image_width, image_height), 182 | alpha=0, 183 | )[:5] 184 | 185 | # Precompute maps for cv::remap() 186 | map_left_x, map_left_y = cv2.initUndistortRectifyMap( 187 | cameraMatrix_left, 188 | distCoeffs_left, 189 | R1, 190 | P1, 191 | (image_width, image_height), 192 | cv2.CV_32FC1, 193 | ) 194 | map_right_x, map_right_y = cv2.initUndistortRectifyMap( 195 | cameraMatrix_right, 196 | distCoeffs_right, 197 | R2, 198 | P2, 199 | (image_width, image_height), 200 | cv2.CV_32FC1, 201 | ) 202 | 203 | zed_calib = { 204 | "map_left_x": map_left_x, 205 | "map_left_y": map_left_y, 206 | "map_right_x": map_right_x, 207 | "map_right_y": map_right_y, 208 | "pose_left": P1, 209 | "pose_right": P2, 210 | "baseline": baseline, 211 | "image_width": image_width, 212 | "image_height": image_height, 213 | } 214 | 215 | return zed_calib 216 | 217 | 218 | def visualize_batch( 219 | batch_dict: dict, 220 | preds: PerceptionPrediction, 221 | output_dir: str, 222 | ref_frame: int = 0, 223 | only_foreground=False, 224 | step=0, 225 | sequence_name=None, 226 | writer=None, 227 | ): 228 | os.makedirs(output_dir, exist_ok=True) 229 | 230 | outputs = {} 231 | 232 | if preds.depth_map is not None: 233 | device = preds.depth_map.device 234 | 235 | pcd_global_seq = [] 236 | H, W = batch_dict["stereo_video"].shape[3:] 237 | 238 | for i in range(len(batch_dict["stereo_video"])): 239 | R, T, K = opencv_from_cameras_projection( 240 | preds.perspective_cameras[i], 241 | torch.tensor([H, W])[None].to(device), 242 | ) 243 | 244 | extrinsic_3x4_0 = torch.cat([R[0], T[0, :, None]], dim=1) 245 | 246 | extr_matrix = torch.cat( 247 | [ 248 | extrinsic_3x4_0, 249 | torch.Tensor([[0, 0, 0, 1]]).to(extrinsic_3x4_0.device), 250 | ], 251 | dim=0, 252 | ) 253 | 254 | inv_extr_matrix = extr_matrix.inverse().to(device) 255 | pcd, colors = depth_to_pcd( 256 | preds.depth_map[i, 0], 257 | batch_dict["stereo_video"][i][0].permute(1, 2, 0), 258 | K[0][0][0], 259 | K[0][0][2], 260 | K[0][1][2], 261 | step=1, 262 | inv_extrinsic=inv_extr_matrix, 263 | mask=batch_dict["fg_mask"][i, 0] if only_foreground else None, 264 | filter=False, 265 | ) 266 | 267 | R, T = inv_extr_matrix[None, :3, :3], inv_extr_matrix[None, :3, 3] 268 | pcd_global_seq.append((pcd, colors, (R, T, preds.perspective_cameras[i]))) 269 | 270 | raster_settings = PointsRasterizationSettings( 271 | image_size=[H, W], radius=0.003, points_per_pixel=10 272 | ) 273 | R, T, cam_ = pcd_global_seq[ref_frame][2] 274 | 275 | median_depth = preds.depth_map.median() 276 | cam_.cuda() 277 | 278 | for mode in ["angle_15", "angle_-15", "changing_angle"]: 279 | res = [] 280 | 281 | for t, (pcd, color, __) in enumerate(pcd_global_seq): 282 | 283 | if mode == "changing_angle": 284 | angle = math.cos((math.pi) * (t / 15)) * 15 285 | elif mode == "angle_15": 286 | angle = 15 287 | elif mode == "angle_-15": 288 | angle = -15 289 | 290 | delta_x = median_depth * math.sin(math.radians(angle)) 291 | delta_z = median_depth * (1 - math.cos(math.radians(angle))) 292 | 293 | cam = cam_.clone() 294 | cam.R = torch.bmm( 295 | cam.R, 296 | RotateAxisAngle(angle=angle, axis="Y", device=device).get_matrix()[ 297 | :, :3, :3 298 | ], 299 | ) 300 | cam.T[0, 0] = cam.T[0, 0] - delta_x 301 | cam.T[0, 2] = cam.T[0, 2] - delta_z + median_depth / 2.0 302 | 303 | rasterizer = PointsRasterizer( 304 | cameras=cam, raster_settings=raster_settings 305 | ) 306 | renderer = PointsRenderer( 307 | rasterizer=rasterizer, 308 | compositor=AlphaCompositor(background_color=(1, 1, 1)), 309 | ) 310 | pcd_copy = pcd.clone() 311 | 312 | point_cloud = Pointclouds(points=[pcd_copy], features=[color / 255.0]) 313 | images = renderer(point_cloud) 314 | res.append(images[0, ..., :3].cpu()) 315 | res = torch.stack(res) 316 | 317 | video = (res * 255).numpy().astype(np.uint8) 318 | save_name = f"{sequence_name}_reconstruction_{step}_mode_{mode}_" 319 | if writer is None: 320 | outputs[mode] = video 321 | if only_foreground: 322 | save_name += "fg_only" 323 | else: 324 | save_name += "full_scene" 325 | video_out = cv2.VideoWriter( 326 | os.path.join( 327 | output_dir, 328 | f"{save_name}.mp4", 329 | ), 330 | cv2.VideoWriter_fourcc(*"mp4v"), 331 | fps=10, 332 | frameSize=(res.shape[2], res.shape[1]), 333 | isColor=True, 334 | ) 335 | 336 | for i in range(len(video)): 337 | video_out.write(cv2.cvtColor(video[i], cv2.COLOR_BGR2RGB)) 338 | video_out.release() 339 | 340 | if writer is not None: 341 | writer.add_video( 342 | f"{sequence_name}_reconstruction_mode_{mode}", 343 | (res * 255).permute(0, 3, 1, 2).to(torch.uint8)[None], 344 | global_step=step, 345 | fps=8, 346 | ) 347 | 348 | return outputs 349 | -------------------------------------------------------------------------------- /models/core/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | import copy 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import Module, Dropout 12 | 13 | """ 14 | Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" 15 | Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py 16 | """ 17 | 18 | 19 | def elu_feature_map(x): 20 | return torch.nn.functional.elu(x) + 1 21 | 22 | 23 | class PositionEncodingSine(nn.Module): 24 | """ 25 | This is a sinusoidal position encoding that generalized to 2-dimensional images 26 | """ 27 | 28 | def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True): 29 | """ 30 | Args: 31 | max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels 32 | temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41), 33 | the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact 34 | on the final performance. For now, we keep both impls for backward compatability. 35 | We will remove the buggy impl after re-training all variants of our released models. 36 | """ 37 | super().__init__() 38 | pe = torch.zeros((d_model, *max_shape)) 39 | y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) 40 | x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) 41 | if temp_bug_fix: 42 | div_term = torch.exp( 43 | torch.arange(0, d_model // 2, 2).float() 44 | * (-math.log(10000.0) / (d_model // 2)) 45 | ) 46 | else: # a buggy implementation (for backward compatability only) 47 | div_term = torch.exp( 48 | torch.arange(0, d_model // 2, 2).float() 49 | * (-math.log(10000.0) / d_model // 2) 50 | ) 51 | div_term = div_term[:, None, None] # [C//4, 1, 1] 52 | pe[0::4, :, :] = torch.sin(x_position * div_term) 53 | pe[1::4, :, :] = torch.cos(x_position * div_term) 54 | pe[2::4, :, :] = torch.sin(y_position * div_term) 55 | pe[3::4, :, :] = torch.cos(y_position * div_term) 56 | 57 | self.register_buffer("pe", pe.unsqueeze(0), persistent=False) # [1, C, H, W] 58 | 59 | def forward(self, x): 60 | """ 61 | Args: 62 | x: [N, C, H, W] 63 | """ 64 | return x + self.pe[:, :, : x.size(2), : x.size(3)].to(x.device) 65 | 66 | 67 | class LinearAttention(Module): 68 | def __init__(self, eps=1e-6): 69 | super().__init__() 70 | self.feature_map = elu_feature_map 71 | self.eps = eps 72 | 73 | def forward(self, queries, keys, values, q_mask=None, kv_mask=None): 74 | """Multi-Head linear attention proposed in "Transformers are RNNs" 75 | Args: 76 | queries: [N, L, H, D] 77 | keys: [N, S, H, D] 78 | values: [N, S, H, D] 79 | q_mask: [N, L] 80 | kv_mask: [N, S] 81 | Returns: 82 | queried_values: (N, L, H, D) 83 | """ 84 | Q = self.feature_map(queries) 85 | K = self.feature_map(keys) 86 | 87 | # set padded position to zero 88 | if q_mask is not None: 89 | Q = Q * q_mask[:, :, None, None] 90 | if kv_mask is not None: 91 | K = K * kv_mask[:, :, None, None] 92 | values = values * kv_mask[:, :, None, None] 93 | 94 | v_length = values.size(1) 95 | values = values / v_length # prevent fp16 overflow 96 | KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V 97 | Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) 98 | queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length 99 | 100 | return queried_values.contiguous() 101 | 102 | 103 | class FullAttention(Module): 104 | def __init__(self, use_dropout=False, attention_dropout=0.1): 105 | super().__init__() 106 | self.use_dropout = use_dropout 107 | self.dropout = Dropout(attention_dropout) 108 | 109 | def forward(self, queries, keys, values, q_mask=None, kv_mask=None): 110 | """Multi-head scaled dot-product attention, a.k.a full attention. 111 | Args: 112 | queries: [N, L, H, D] 113 | keys: [N, S, H, D] 114 | values: [N, S, H, D] 115 | q_mask: [N, L] 116 | kv_mask: [N, S] 117 | Returns: 118 | queried_values: (N, L, H, D) 119 | """ 120 | 121 | # Compute the unnormalized attention and apply the masks 122 | QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) 123 | if kv_mask is not None: 124 | QK.masked_fill_( 125 | ~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float("-inf") 126 | ) 127 | 128 | # Compute the attention and the weighted average 129 | softmax_temp = 1.0 / queries.size(3) ** 0.5 # sqrt(D) 130 | A = torch.softmax(softmax_temp * QK, dim=2) 131 | if self.use_dropout: 132 | A = self.dropout(A) 133 | 134 | queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) 135 | 136 | return queried_values.contiguous() 137 | 138 | 139 | # Ref: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py 140 | class LoFTREncoderLayer(nn.Module): 141 | def __init__(self, d_model, nhead, attention="linear"): 142 | super(LoFTREncoderLayer, self).__init__() 143 | 144 | self.dim = d_model // nhead 145 | self.nhead = nhead 146 | 147 | # multi-head attention 148 | self.q_proj = nn.Linear(d_model, d_model, bias=False) 149 | self.k_proj = nn.Linear(d_model, d_model, bias=False) 150 | self.v_proj = nn.Linear(d_model, d_model, bias=False) 151 | self.attention = LinearAttention() if attention == "linear" else FullAttention() 152 | self.merge = nn.Linear(d_model, d_model, bias=False) 153 | 154 | # feed-forward network 155 | self.mlp = nn.Sequential( 156 | nn.Linear(d_model * 2, d_model * 2, bias=False), 157 | nn.ReLU(), 158 | nn.Linear(d_model * 2, d_model, bias=False), 159 | ) 160 | 161 | # norm and dropout 162 | self.norm1 = nn.LayerNorm(d_model) 163 | self.norm2 = nn.LayerNorm(d_model) 164 | 165 | def forward(self, x, source, x_mask=None, source_mask=None): 166 | """ 167 | Args: 168 | x (torch.Tensor): [N, L, C] 169 | source (torch.Tensor): [N, S, C] 170 | x_mask (torch.Tensor): [N, L] (optional) 171 | source_mask (torch.Tensor): [N, S] (optional) 172 | """ 173 | bs = x.size(0) 174 | query, key, value = x, source, source 175 | 176 | # multi-head attention 177 | query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] 178 | key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] 179 | value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) 180 | message = self.attention( 181 | query, key, value, q_mask=x_mask, kv_mask=source_mask 182 | ) # [N, L, (H, D)] 183 | message = self.merge(message.view(bs, -1, self.nhead * self.dim)) # [N, L, C] 184 | message = self.norm1(message) 185 | 186 | # feed-forward network 187 | message = self.mlp(torch.cat([x, message], dim=2)) 188 | message = self.norm2(message) 189 | 190 | return x + message 191 | 192 | 193 | class LocalFeatureTransformer(nn.Module): 194 | """A Local Feature Transformer (LoFTR) module.""" 195 | 196 | def __init__(self, d_model, nhead, layer_names, attention): 197 | super(LocalFeatureTransformer, self).__init__() 198 | 199 | self.d_model = d_model 200 | self.nhead = nhead 201 | self.layer_names = layer_names 202 | encoder_layer = LoFTREncoderLayer(d_model, nhead, attention) 203 | self.layers = nn.ModuleList( 204 | [copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))] 205 | ) 206 | self._reset_parameters() 207 | 208 | def _reset_parameters(self): 209 | for p in self.parameters(): 210 | if p.dim() > 1: 211 | nn.init.xavier_uniform_(p) 212 | 213 | def forward(self, feat0, feat1, mask0=None, mask1=None): 214 | """ 215 | Args: 216 | feat0 (torch.Tensor): [N, L, C] 217 | feat1 (torch.Tensor): [N, S, C] 218 | mask0 (torch.Tensor): [N, L] (optional) 219 | mask1 (torch.Tensor): [N, S] (optional) 220 | """ 221 | assert self.d_model == feat0.size( 222 | 2 223 | ), "the feature number of src and transformer must be equal" 224 | 225 | for layer, name in zip(self.layers, self.layer_names): 226 | 227 | if name == "self": 228 | feat0 = layer(feat0, feat0, mask0, mask0) 229 | feat1 = layer(feat1, feat1, mask1, mask1) 230 | elif name == "cross": 231 | feat0 = layer(feat0, feat1, mask0, mask1) 232 | feat1 = layer(feat1, feat0, mask1, mask0) 233 | else: 234 | raise KeyError 235 | 236 | return feat0, feat1 237 | -------------------------------------------------------------------------------- /models/core/corr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | 11 | def bilinear_sampler(img, coords, mode="bilinear", mask=False, stereo=True): 12 | """Wrapper for grid_sample, uses pixel coordinates""" 13 | H, W = img.shape[-2:] 14 | xgrid, ygrid = coords.split([1, 1], dim=-1) 15 | xgrid = 2 * xgrid / (W - 1) - 1 16 | if not stereo: 17 | ygrid = 2 * ygrid / (H - 1) - 1 18 | else: 19 | assert torch.unique(ygrid).numel() == 1 and H == 1 # This is a stereo problem 20 | img = img.contiguous() 21 | grid = torch.cat([xgrid, ygrid], dim=-1).contiguous() 22 | img = F.grid_sample(img, grid, align_corners=True) 23 | 24 | if mask: 25 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 26 | return img, mask.float() 27 | 28 | return img 29 | 30 | 31 | def coords_grid(batch, ht, wd, device): 32 | coords = torch.meshgrid( 33 | torch.arange(ht, device=device), torch.arange(wd, device=device), indexing="ij" 34 | ) 35 | coords = torch.stack(coords[::-1], dim=0).float() 36 | return coords[None].repeat(batch, 1, 1, 1) 37 | 38 | 39 | class CorrBlock1D: 40 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 41 | self.num_levels = num_levels 42 | self.radius = radius 43 | self.corr_pyramid = [] 44 | self.coords = coords_grid( 45 | fmap1.shape[0], fmap1.shape[2], fmap1.shape[3], fmap1.device 46 | ) 47 | # all pairs correlation 48 | corr = CorrBlock1D.corr(fmap1, fmap2) 49 | 50 | batch, h1, w1, dim, w2 = corr.shape 51 | corr = corr.reshape(batch * h1 * w1, dim, 1, w2) 52 | 53 | self.corr_pyramid.append(corr) 54 | for i in range(self.num_levels): 55 | corr = F.avg_pool2d(corr, [1, 2], stride=[1, 2]) 56 | self.corr_pyramid.append(corr) 57 | 58 | def __call__(self, flow): 59 | r = self.radius 60 | coords = self.coords + flow 61 | coords = coords[:, :1].permute(0, 2, 3, 1) 62 | batch, h1, w1, _ = coords.shape 63 | 64 | out_pyramid = [] 65 | for i in range(self.num_levels): 66 | corr = self.corr_pyramid[i] 67 | dx = torch.linspace(-r, r, 2 * r + 1) 68 | dx = dx.view(1, 1, 2 * r + 1, 1).to(coords.device) 69 | x0 = dx + coords.reshape(batch * h1 * w1, 1, 1, 1) / 2 ** i 70 | y0 = torch.zeros_like(x0) 71 | 72 | coords_lvl = torch.cat([x0, y0], dim=-1) 73 | corr = bilinear_sampler(corr, coords_lvl) 74 | corr = corr.view(batch, h1, w1, -1) 75 | out_pyramid.append(corr) 76 | 77 | out = torch.cat(out_pyramid, dim=-1) 78 | return out.permute(0, 3, 1, 2).contiguous().float() 79 | 80 | @staticmethod 81 | def corr(fmap1, fmap2): 82 | B, D, H, W1 = fmap1.shape 83 | _, _, _, W2 = fmap2.shape 84 | fmap1 = fmap1.view(B, D, H, W1) 85 | fmap2 = fmap2.view(B, D, H, W2) 86 | corr = torch.einsum("aijk,aijh->ajkh", fmap1, fmap2) 87 | corr = corr.reshape(B, H, W1, 1, W2).contiguous() 88 | return corr / torch.sqrt(torch.tensor(D).float()) 89 | -------------------------------------------------------------------------------- /models/core/dynamic_stereo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Dict, List 8 | from einops import rearrange 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from collections import defaultdict 13 | 14 | 15 | from dynamic_stereo.models.core.update import ( 16 | BasicUpdateBlock, 17 | SequenceUpdateBlock3D, 18 | TimeAttnBlock, 19 | ) 20 | from dynamic_stereo.models.core.extractor import BasicEncoder 21 | from dynamic_stereo.models.core.corr import CorrBlock1D 22 | 23 | from dynamic_stereo.models.core.attention import ( 24 | PositionEncodingSine, 25 | LocalFeatureTransformer, 26 | ) 27 | from dynamic_stereo.models.core.utils.utils import InputPadder, interp 28 | 29 | autocast = torch.cuda.amp.autocast 30 | 31 | 32 | class DynamicStereo(nn.Module): 33 | def __init__( 34 | self, 35 | max_disp: int = 192, 36 | mixed_precision: bool = False, 37 | num_frames: int = 5, 38 | attention_type: str = None, 39 | use_3d_update_block: bool = False, 40 | different_update_blocks: bool = False, 41 | ): 42 | super(DynamicStereo, self).__init__() 43 | 44 | self.max_flow = max_disp 45 | self.mixed_precision = mixed_precision 46 | 47 | self.hidden_dim = 128 48 | self.context_dim = 128 49 | dim = 256 50 | self.dim = dim 51 | self.dropout = 0 52 | self.use_3d_update_block = use_3d_update_block 53 | self.fnet = BasicEncoder( 54 | output_dim=dim, norm_fn="instance", dropout=self.dropout 55 | ) 56 | self.different_update_blocks = different_update_blocks 57 | cor_planes = 4 * 9 58 | self.depth = 4 59 | self.attention_type = attention_type 60 | # attention_type is a combination of the following attention types: 61 | # self_stereo, temporal, update_time, update_space 62 | # for example, self_stereo_temporal_update_time_update_space 63 | 64 | if self.use_3d_update_block: 65 | if self.different_update_blocks: 66 | self.update_block08 = SequenceUpdateBlock3D( 67 | hidden_dim=self.hidden_dim, cor_planes=cor_planes, mask_size=4 68 | ) 69 | self.update_block16 = SequenceUpdateBlock3D( 70 | hidden_dim=self.hidden_dim, 71 | cor_planes=cor_planes, 72 | mask_size=4, 73 | attention_type=attention_type, 74 | ) 75 | self.update_block04 = SequenceUpdateBlock3D( 76 | hidden_dim=self.hidden_dim, cor_planes=cor_planes, mask_size=4 77 | ) 78 | else: 79 | self.update_block = SequenceUpdateBlock3D( 80 | hidden_dim=self.hidden_dim, cor_planes=cor_planes, mask_size=4 81 | ) 82 | else: 83 | if self.different_update_blocks: 84 | self.update_block08 = BasicUpdateBlock( 85 | hidden_dim=self.hidden_dim, cor_planes=cor_planes, mask_size=4 86 | ) 87 | self.update_block16 = BasicUpdateBlock( 88 | hidden_dim=self.hidden_dim, 89 | cor_planes=cor_planes, 90 | mask_size=4, 91 | attention_type=attention_type, 92 | ) 93 | self.update_block04 = BasicUpdateBlock( 94 | hidden_dim=self.hidden_dim, cor_planes=cor_planes, mask_size=4 95 | ) 96 | else: 97 | self.update_block = BasicUpdateBlock( 98 | hidden_dim=self.hidden_dim, cor_planes=cor_planes, mask_size=4 99 | ) 100 | 101 | if attention_type is not None: 102 | if ("update_time" in attention_type) or ("temporal" in attention_type): 103 | self.time_embed = nn.Parameter(torch.zeros(1, num_frames, dim)) 104 | if "temporal" in attention_type: 105 | self.time_attn_blocks = nn.ModuleList( 106 | [TimeAttnBlock(dim=dim, num_heads=8) for _ in range(self.depth)] 107 | ) 108 | 109 | if "self_stereo" in attention_type: 110 | self.self_attn_blocks = nn.ModuleList( 111 | [ 112 | LocalFeatureTransformer( 113 | d_model=dim, 114 | nhead=8, 115 | layer_names=["self"] * 1, 116 | attention="linear", 117 | ) 118 | for _ in range(self.depth) 119 | ] 120 | ) 121 | 122 | self.cross_attn_blocks = nn.ModuleList( 123 | [ 124 | LocalFeatureTransformer( 125 | d_model=dim, 126 | nhead=8, 127 | layer_names=["cross"] * 1, 128 | attention="linear", 129 | ) 130 | for _ in range(self.depth) 131 | ] 132 | ) 133 | 134 | self.num_frames = num_frames 135 | 136 | @torch.jit.ignore 137 | def no_weight_decay(self): 138 | return {"time_embed"} 139 | 140 | def freeze_bn(self): 141 | for m in self.modules(): 142 | if isinstance(m, nn.BatchNorm2d): 143 | m.eval() 144 | 145 | def convex_upsample(self, flow: torch.Tensor, mask: torch.Tensor, rate: int = 4): 146 | """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" 147 | N, _, H, W = flow.shape 148 | mask = mask.view(N, 1, 9, rate, rate, H, W) 149 | mask = torch.softmax(mask, dim=2) 150 | 151 | up_flow = F.unfold(rate * flow, [3, 3], padding=1) 152 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 153 | 154 | up_flow = torch.sum(mask * up_flow, dim=2) 155 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 156 | return up_flow.reshape(N, 2, rate * H, rate * W) 157 | 158 | def zero_init(self, fmap: torch.Tensor): 159 | N, _, H, W = fmap.shape 160 | _x = torch.zeros([N, 1, H, W], dtype=torch.float32) 161 | _y = torch.zeros([N, 1, H, W], dtype=torch.float32) 162 | zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device) 163 | return zero_flow 164 | 165 | def forward_batch_test( 166 | self, batch_dict: Dict, kernel_size: int = 14, iters: int = 20 167 | ): 168 | stride = kernel_size // 2 169 | predictions = defaultdict(list) 170 | 171 | disp_preds = [] 172 | video = batch_dict["stereo_video"] 173 | num_ims = len(video) 174 | print("video", video.shape) 175 | 176 | for i in range(0, num_ims, stride): 177 | left_ims = video[i : min(i + kernel_size, num_ims), 0] 178 | padder = InputPadder(left_ims.shape, divis_by=32) 179 | 180 | right_ims = video[i : min(i + kernel_size, num_ims), 1] 181 | left_ims, right_ims = padder.pad(left_ims, right_ims) 182 | 183 | with autocast(enabled=self.mixed_precision): 184 | disparities_forw = self.forward( 185 | left_ims[None].cuda(), 186 | right_ims[None].cuda(), 187 | iters=iters, 188 | test_mode=True, 189 | ) 190 | 191 | disparities_forw = padder.unpad(disparities_forw[:, 0])[:, None].cpu() 192 | 193 | if len(disp_preds) > 0 and len(disparities_forw) >= stride: 194 | 195 | if len(disparities_forw) < kernel_size: 196 | disp_preds.append(disparities_forw[stride // 2 :]) 197 | else: 198 | disp_preds.append(disparities_forw[stride // 2 : -stride // 2]) 199 | 200 | elif len(disp_preds) == 0: 201 | disp_preds.append(disparities_forw[: -stride // 2]) 202 | 203 | predictions["disparity"] = (torch.cat(disp_preds).squeeze(1).abs())[:, :1] 204 | print(predictions["disparity"].shape) 205 | 206 | return predictions 207 | 208 | def forward_sst_block( 209 | self, fmap1_dw16: torch.Tensor, fmap2_dw16: torch.Tensor, T: int 210 | ): 211 | *_, h, w = fmap1_dw16.shape 212 | 213 | # positional encoding and self-attention 214 | pos_encoding_fn_small = PositionEncodingSine(d_model=self.dim, max_shape=(h, w)) 215 | # 'n c h w -> n (h w) c' 216 | fmap1_dw16 = pos_encoding_fn_small(fmap1_dw16) 217 | # 'n c h w -> n (h w) c' 218 | fmap2_dw16 = pos_encoding_fn_small(fmap2_dw16) 219 | 220 | if self.attention_type is not None: 221 | # add time embeddings 222 | if ( 223 | "temporal" in self.attention_type 224 | or "update_time" in self.attention_type 225 | ): 226 | fmap1_dw16 = rearrange( 227 | fmap1_dw16, "(b t) m h w -> (b h w) t m", t=T, h=h, w=w 228 | ) 229 | fmap2_dw16 = rearrange( 230 | fmap2_dw16, "(b t) m h w -> (b h w) t m", t=T, h=h, w=w 231 | ) 232 | 233 | # interpolate if video length doesn't match 234 | if T != self.num_frames: 235 | time_embed = self.time_embed.transpose(1, 2) 236 | new_time_embed = F.interpolate(time_embed, size=(T), mode="nearest") 237 | new_time_embed = new_time_embed.transpose(1, 2).contiguous() 238 | else: 239 | new_time_embed = self.time_embed 240 | 241 | fmap1_dw16 = fmap1_dw16 + new_time_embed 242 | fmap2_dw16 = fmap2_dw16 + new_time_embed 243 | 244 | fmap1_dw16 = rearrange( 245 | fmap1_dw16, "(b h w) t m -> (b t) m h w", t=T, h=h, w=w 246 | ) 247 | fmap2_dw16 = rearrange( 248 | fmap2_dw16, "(b h w) t m -> (b t) m h w", t=T, h=h, w=w 249 | ) 250 | 251 | if ("self_stereo" in self.attention_type) or ( 252 | "temporal" in self.attention_type 253 | ): 254 | for att_ind in range(self.depth): 255 | if "self_stereo" in self.attention_type: 256 | fmap1_dw16 = rearrange( 257 | fmap1_dw16, "(b t) m h w -> (b t) (h w) m", t=T, h=h, w=w 258 | ) 259 | fmap2_dw16 = rearrange( 260 | fmap2_dw16, "(b t) m h w -> (b t) (h w) m", t=T, h=h, w=w 261 | ) 262 | 263 | fmap1_dw16, fmap2_dw16 = self.self_attn_blocks[att_ind]( 264 | fmap1_dw16, fmap2_dw16 265 | ) 266 | fmap1_dw16, fmap2_dw16 = self.cross_attn_blocks[att_ind]( 267 | fmap1_dw16, fmap2_dw16 268 | ) 269 | 270 | fmap1_dw16 = rearrange( 271 | fmap1_dw16, "(b t) (h w) m -> (b t) m h w ", t=T, h=h, w=w 272 | ) 273 | fmap2_dw16 = rearrange( 274 | fmap2_dw16, "(b t) (h w) m -> (b t) m h w ", t=T, h=h, w=w 275 | ) 276 | 277 | if "temporal" in self.attention_type: 278 | fmap1_dw16 = self.time_attn_blocks[att_ind](fmap1_dw16, T=T) 279 | fmap2_dw16 = self.time_attn_blocks[att_ind](fmap2_dw16, T=T) 280 | return fmap1_dw16, fmap2_dw16 281 | 282 | def forward_update_block( 283 | self, 284 | update_block: nn.Module, 285 | corr_fn: CorrBlock1D, 286 | flow: torch.Tensor, 287 | net: torch.Tensor, 288 | inp: torch.Tensor, 289 | predictions: List, 290 | iters: int, 291 | interp_scale: float, 292 | t: int, 293 | ): 294 | for _ in range(iters): 295 | flow = flow.detach() 296 | out_corrs = corr_fn(flow) 297 | with autocast(enabled=self.mixed_precision): 298 | net, up_mask, delta_flow = update_block(net, inp, out_corrs, flow, t=t) 299 | 300 | flow = flow + delta_flow 301 | flow_up = flow_out = self.convex_upsample(flow, up_mask, rate=4) 302 | if interp_scale > 1: 303 | flow_up = interp_scale * interp( 304 | flow_out, 305 | ( 306 | interp_scale * flow_out.shape[2], 307 | interp_scale * flow_out.shape[3], 308 | ), 309 | ) 310 | flow_up = flow_up[:, :1] 311 | predictions.append(flow_up) 312 | return flow_out, net 313 | 314 | def forward(self, image1, image2, flow_init=None, iters=10, test_mode=False): 315 | """Estimate optical flow between pair of frames""" 316 | # if input is list, 317 | image1 = 2 * (image1 / 255.0) - 1.0 318 | image2 = 2 * (image2 / 255.0) - 1.0 319 | 320 | b, T, *_ = image1.shape 321 | 322 | image1 = image1.contiguous() 323 | image2 = image2.contiguous() 324 | 325 | hdim = self.hidden_dim 326 | 327 | image1 = rearrange(image1, "b t c h w -> (b t) c h w") 328 | image2 = rearrange(image2, "b t c h w -> (b t) c h w") 329 | 330 | with autocast(enabled=self.mixed_precision): 331 | fmap1, fmap2 = self.fnet([image1, image2]) 332 | 333 | net, inp = torch.split(fmap1, [hdim, hdim], dim=1) 334 | net = torch.tanh(net) 335 | inp = F.relu(inp) 336 | *_, h, w = fmap1.shape 337 | # 1/4 -> 1/16 338 | # feature 339 | fmap1_dw16 = F.avg_pool2d(fmap1, 4, stride=4) 340 | fmap2_dw16 = F.avg_pool2d(fmap2, 4, stride=4) 341 | 342 | fmap1_dw16, fmap2_dw16 = self.forward_sst_block(fmap1_dw16, fmap2_dw16, T=T) 343 | 344 | net_dw16, inp_dw16 = torch.split(fmap1_dw16, [hdim, hdim], dim=1) 345 | net_dw16 = torch.tanh(net_dw16) 346 | inp_dw16 = F.relu(inp_dw16) 347 | 348 | fmap1_dw8 = ( 349 | F.avg_pool2d(fmap1, 2, stride=2) + interp(fmap1_dw16, (h // 2, w // 2)) 350 | ) / 2.0 351 | fmap2_dw8 = ( 352 | F.avg_pool2d(fmap2, 2, stride=2) + interp(fmap2_dw16, (h // 2, w // 2)) 353 | ) / 2.0 354 | 355 | net_dw8, inp_dw8 = torch.split(fmap1_dw8, [hdim, hdim], dim=1) 356 | net_dw8 = torch.tanh(net_dw8) 357 | inp_dw8 = F.relu(inp_dw8) 358 | # Cascaded refinement (1/16 + 1/8 + 1/4) 359 | predictions = [] 360 | flow = None 361 | flow_up = None 362 | if flow_init is not None: 363 | scale = h / flow_init.shape[2] 364 | flow = -scale * interp(flow_init, (h, w)) 365 | else: 366 | # zero initialization 367 | flow_dw16 = self.zero_init(fmap1_dw16) 368 | 369 | # Recurrent Update Module 370 | # Update 1/16 371 | update_block = ( 372 | self.update_block16 373 | if self.different_update_blocks 374 | else self.update_block 375 | ) 376 | 377 | corr_fn_att_dw16 = CorrBlock1D(fmap1_dw16, fmap2_dw16) 378 | flow, net_dw16 = self.forward_update_block( 379 | update_block=update_block, 380 | corr_fn=corr_fn_att_dw16, 381 | flow=flow_dw16, 382 | net=net_dw16, 383 | inp=inp_dw16, 384 | predictions=predictions, 385 | iters=iters // 2, 386 | interp_scale=4, 387 | t=T, 388 | ) 389 | 390 | scale = fmap1_dw8.shape[2] / flow.shape[2] 391 | flow_dw8 = -scale * interp(flow, (fmap1_dw8.shape[2], fmap1_dw8.shape[3])) 392 | 393 | net_dw8 = ( 394 | net_dw8 395 | + interp(net_dw16, (2 * net_dw16.shape[2], 2 * net_dw16.shape[3])) 396 | ) / 2.0 397 | # Update 1/8 398 | 399 | update_block = ( 400 | self.update_block08 401 | if self.different_update_blocks 402 | else self.update_block 403 | ) 404 | 405 | corr_fn_dw8 = CorrBlock1D(fmap1_dw8, fmap2_dw8) 406 | flow, net_dw8 = self.forward_update_block( 407 | update_block=update_block, 408 | corr_fn=corr_fn_dw8, 409 | flow=flow_dw8, 410 | net=net_dw8, 411 | inp=inp_dw8, 412 | predictions=predictions, 413 | iters=iters // 2, 414 | interp_scale=2, 415 | t=T, 416 | ) 417 | 418 | scale = h / flow.shape[2] 419 | flow = -scale * interp(flow, (h, w)) 420 | 421 | net = ( 422 | net + interp(net_dw8, (2 * net_dw8.shape[2], 2 * net_dw8.shape[3])) 423 | ) / 2.0 424 | # Update 1/4 425 | update_block = ( 426 | self.update_block04 if self.different_update_blocks else self.update_block 427 | ) 428 | corr_fn = CorrBlock1D(fmap1, fmap2) 429 | flow, __ = self.forward_update_block( 430 | update_block=update_block, 431 | corr_fn=corr_fn, 432 | flow=flow, 433 | net=net, 434 | inp=inp, 435 | predictions=predictions, 436 | iters=iters, 437 | interp_scale=1, 438 | t=T, 439 | ) 440 | 441 | predictions = torch.stack(predictions) 442 | 443 | predictions = rearrange(predictions, "d (b t) c h w -> d t b c h w", b=b, t=T) 444 | flow_up = predictions[-1] 445 | 446 | if test_mode: 447 | return flow_up 448 | 449 | return predictions 450 | -------------------------------------------------------------------------------- /models/core/extractor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class ResidualBlock(nn.Module): 12 | def __init__(self, in_planes, planes, norm_fn="group", stride=1): 13 | super(ResidualBlock, self).__init__() 14 | 15 | self.conv1 = nn.Conv2d( 16 | in_planes, planes, kernel_size=3, padding=1, stride=stride 17 | ) 18 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 19 | self.relu = nn.ReLU(inplace=True) 20 | 21 | num_groups = planes // 8 22 | 23 | if norm_fn == "group": 24 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 25 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 26 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 27 | 28 | elif norm_fn == "batch": 29 | self.norm1 = nn.BatchNorm2d(planes) 30 | self.norm2 = nn.BatchNorm2d(planes) 31 | self.norm3 = nn.BatchNorm2d(planes) 32 | 33 | elif norm_fn == "instance": 34 | self.norm1 = nn.InstanceNorm2d(planes, affine=False) 35 | self.norm2 = nn.InstanceNorm2d(planes, affine=False) 36 | self.norm3 = nn.InstanceNorm2d(planes, affine=False) 37 | 38 | elif norm_fn == "none": 39 | self.norm1 = nn.Sequential() 40 | self.norm2 = nn.Sequential() 41 | self.norm3 = nn.Sequential() 42 | 43 | self.downsample = nn.Sequential( 44 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 45 | ) 46 | 47 | def forward(self, x): 48 | y = x 49 | y = self.relu(self.norm1(self.conv1(y))) 50 | y = self.relu(self.norm2(self.conv2(y))) 51 | 52 | x = self.downsample(x) 53 | 54 | return self.relu(x + y) 55 | 56 | 57 | class BasicEncoder(nn.Module): 58 | def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): 59 | super(BasicEncoder, self).__init__() 60 | self.norm_fn = norm_fn 61 | 62 | if self.norm_fn == "group": 63 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 64 | 65 | elif self.norm_fn == "batch": 66 | self.norm1 = nn.BatchNorm2d(64) 67 | 68 | elif self.norm_fn == "instance": 69 | self.norm1 = nn.InstanceNorm2d(64, affine=False) 70 | 71 | elif self.norm_fn == "none": 72 | self.norm1 = nn.Sequential() 73 | 74 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 75 | self.relu1 = nn.ReLU(inplace=True) 76 | 77 | self.in_planes = 64 78 | self.layer1 = self._make_layer(64, stride=1) 79 | self.layer2 = self._make_layer(96, stride=2) 80 | self.layer3 = self._make_layer(128, stride=1) 81 | 82 | # output convolution 83 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 84 | 85 | self.dropout = None 86 | if dropout > 0: 87 | self.dropout = nn.Dropout2d(p=dropout) 88 | 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 92 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 93 | if m.weight is not None: 94 | nn.init.constant_(m.weight, 1) 95 | if m.bias is not None: 96 | nn.init.constant_(m.bias, 0) 97 | 98 | def _make_layer(self, dim, stride=1): 99 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 100 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 101 | layers = (layer1, layer2) 102 | 103 | self.in_planes = dim 104 | return nn.Sequential(*layers) 105 | 106 | def forward(self, x): 107 | 108 | # if input is list, combine batch dimension 109 | is_list = isinstance(x, tuple) or isinstance(x, list) 110 | if is_list: 111 | batch_dim = x[0].shape[0] 112 | x = torch.cat(x, dim=0) 113 | 114 | x = self.conv1(x) 115 | x = self.norm1(x) 116 | x = self.relu1(x) 117 | 118 | x = self.layer1(x) 119 | x = self.layer2(x) 120 | x = self.layer3(x) 121 | 122 | x = self.conv2(x) 123 | 124 | if self.dropout is not None: 125 | x = self.dropout(x) 126 | 127 | if is_list: 128 | x = torch.split(x, x.shape[0] // 2, dim=0) 129 | 130 | return x 131 | -------------------------------------------------------------------------------- /models/core/model_zoo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import copy 8 | from dynamic_stereo.models.dynamic_stereo_model import DynamicStereoModel 9 | 10 | from pytorch3d.implicitron.tools.config import get_default_args 11 | 12 | try: 13 | from dynamic_stereo.models.raft_stereo_model import RAFTStereoModel 14 | 15 | MODELS = [RAFTStereoModel, DynamicStereoModel] 16 | except: 17 | MODELS = [DynamicStereoModel] 18 | 19 | _MODEL_NAME_TO_MODEL = {model_cls.__name__: model_cls for model_cls in MODELS} 20 | _MODEL_CONFIG_NAME_TO_DEFAULT_CONFIG = {} 21 | for model_cls in MODELS: 22 | _MODEL_CONFIG_NAME_TO_DEFAULT_CONFIG[ 23 | model_cls.MODEL_CONFIG_NAME 24 | ] = get_default_args(model_cls) 25 | MODEL_NAME_NONE = "NONE" 26 | 27 | 28 | def model_zoo(model_name: str, **kwargs): 29 | if model_name.upper() == MODEL_NAME_NONE: 30 | return None 31 | 32 | model_cls = _MODEL_NAME_TO_MODEL.get(model_name) 33 | 34 | if model_cls is None: 35 | raise ValueError(f"No such model name: {model_name}") 36 | 37 | model_cls_params = {} 38 | if "model_zoo" in getattr(model_cls, "__dataclass_fields__", []): 39 | model_cls_params["model_zoo"] = model_zoo 40 | print( 41 | f"{model_cls.MODEL_CONFIG_NAME} model configs:", 42 | kwargs.get(model_cls.MODEL_CONFIG_NAME), 43 | ) 44 | return model_cls(**model_cls_params, **kwargs.get(model_cls.MODEL_CONFIG_NAME, {})) 45 | 46 | 47 | def get_all_model_default_configs(): 48 | return copy.deepcopy(_MODEL_CONFIG_NAME_TO_DEFAULT_CONFIG) 49 | -------------------------------------------------------------------------------- /models/core/update.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from einops import rearrange 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from dynamic_stereo.models.core.attention import LoFTREncoderLayer 13 | 14 | 15 | # Ref: https://github.com/princeton-vl/RAFT/blob/master/core/update.py 16 | class FlowHead(nn.Module): 17 | def __init__(self, input_dim=128, hidden_dim=256): 18 | super(FlowHead, self).__init__() 19 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 20 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 21 | self.relu = nn.ReLU(inplace=True) 22 | 23 | def forward(self, x): 24 | return self.conv2(self.relu(self.conv1(x))) 25 | 26 | 27 | class SepConvGRU(nn.Module): 28 | def __init__(self, hidden_dim=128, input_dim=192 + 128): 29 | super(SepConvGRU, self).__init__() 30 | self.convz1 = nn.Conv2d( 31 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) 32 | ) 33 | self.convr1 = nn.Conv2d( 34 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) 35 | ) 36 | self.convq1 = nn.Conv2d( 37 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) 38 | ) 39 | 40 | self.convz2 = nn.Conv2d( 41 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) 42 | ) 43 | self.convr2 = nn.Conv2d( 44 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) 45 | ) 46 | self.convq2 = nn.Conv2d( 47 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) 48 | ) 49 | 50 | def forward(self, h, x): 51 | # horizontal 52 | hx = torch.cat([h, x], dim=1) 53 | z = torch.sigmoid(self.convz1(hx)) 54 | r = torch.sigmoid(self.convr1(hx)) 55 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) 56 | h = (1 - z) * h + z * q 57 | 58 | # vertical 59 | hx = torch.cat([h, x], dim=1) 60 | z = torch.sigmoid(self.convz2(hx)) 61 | r = torch.sigmoid(self.convr2(hx)) 62 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) 63 | h = (1 - z) * h + z * q 64 | 65 | return h 66 | 67 | 68 | class ConvGRU(nn.Module): 69 | def __init__(self, hidden_dim, input_dim, kernel_size=3): 70 | super(ConvGRU, self).__init__() 71 | self.convz = nn.Conv2d( 72 | hidden_dim + input_dim, hidden_dim, kernel_size, padding=kernel_size // 2 73 | ) 74 | self.convr = nn.Conv2d( 75 | hidden_dim + input_dim, hidden_dim, kernel_size, padding=kernel_size // 2 76 | ) 77 | self.convq = nn.Conv2d( 78 | hidden_dim + input_dim, hidden_dim, kernel_size, padding=kernel_size // 2 79 | ) 80 | 81 | def forward(self, h, x): 82 | hx = torch.cat([h, x], dim=1) 83 | 84 | z = torch.sigmoid(self.convz(hx)) 85 | r = torch.sigmoid(self.convr(hx)) 86 | q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) 87 | 88 | h = (1 - z) * h + z * q 89 | return h 90 | 91 | 92 | class SepConvGRU3D(nn.Module): 93 | def __init__(self, hidden_dim=128, input_dim=192 + 128): 94 | super(SepConvGRU3D, self).__init__() 95 | self.convz1 = nn.Conv3d( 96 | hidden_dim + input_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2) 97 | ) 98 | self.convr1 = nn.Conv3d( 99 | hidden_dim + input_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2) 100 | ) 101 | self.convq1 = nn.Conv3d( 102 | hidden_dim + input_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2) 103 | ) 104 | 105 | self.convz2 = nn.Conv3d( 106 | hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0) 107 | ) 108 | self.convr2 = nn.Conv3d( 109 | hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0) 110 | ) 111 | self.convq2 = nn.Conv3d( 112 | hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0) 113 | ) 114 | 115 | self.convz3 = nn.Conv3d( 116 | hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0) 117 | ) 118 | self.convr3 = nn.Conv3d( 119 | hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0) 120 | ) 121 | self.convq3 = nn.Conv3d( 122 | hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0) 123 | ) 124 | 125 | def forward(self, h, x): 126 | hx = torch.cat([h, x], dim=1) 127 | z = torch.sigmoid(self.convz1(hx)) 128 | r = torch.sigmoid(self.convr1(hx)) 129 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) 130 | h = (1 - z) * h + z * q 131 | 132 | # vertical 133 | hx = torch.cat([h, x], dim=1) 134 | z = torch.sigmoid(self.convz2(hx)) 135 | r = torch.sigmoid(self.convr2(hx)) 136 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) 137 | h = (1 - z) * h + z * q 138 | 139 | # time 140 | hx = torch.cat([h, x], dim=1) 141 | z = torch.sigmoid(self.convz3(hx)) 142 | r = torch.sigmoid(self.convr3(hx)) 143 | q = torch.tanh(self.convq3(torch.cat([r * h, x], dim=1))) 144 | h = (1 - z) * h + z * q 145 | 146 | return h 147 | 148 | 149 | class BasicMotionEncoder(nn.Module): 150 | def __init__(self, cor_planes): 151 | super(BasicMotionEncoder, self).__init__() 152 | 153 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 154 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 155 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 156 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 157 | self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) 158 | 159 | def forward(self, flow, corr): 160 | cor = F.relu(self.convc1(corr)) 161 | cor = F.relu(self.convc2(cor)) 162 | flo = F.relu(self.convf1(flow)) 163 | flo = F.relu(self.convf2(flo)) 164 | 165 | cor_flo = torch.cat([cor, flo], dim=1) 166 | out = F.relu(self.conv(cor_flo)) 167 | return torch.cat([out, flow], dim=1) 168 | 169 | 170 | class Attention(nn.Module): 171 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None): 172 | super().__init__() 173 | self.num_heads = num_heads 174 | head_dim = dim // num_heads 175 | self.scale = qk_scale or head_dim ** -0.5 176 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 177 | self.proj = nn.Linear(dim, dim) 178 | 179 | def forward(self, x): 180 | B, N, C = x.shape 181 | qkv = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 182 | q, k, v = qkv, qkv, qkv 183 | 184 | attn = (q @ k.transpose(-2, -1)) * self.scale 185 | 186 | attn = attn.softmax(dim=-1) 187 | 188 | x = (attn @ v).transpose(1, 2).reshape(B, N, C).contiguous() 189 | x = self.proj(x) 190 | return x 191 | 192 | 193 | class Mlp(nn.Module): 194 | def __init__( 195 | self, 196 | in_features, 197 | hidden_features=None, 198 | out_features=None, 199 | act_layer=nn.GELU, 200 | drop=0.0, 201 | ): 202 | super().__init__() 203 | out_features = out_features or in_features 204 | hidden_features = hidden_features or in_features 205 | self.fc1 = nn.Linear(in_features, hidden_features) 206 | self.act = act_layer() 207 | self.fc2 = nn.Linear(hidden_features, out_features) 208 | self.drop = nn.Dropout(drop) 209 | 210 | def forward(self, x): 211 | x = self.fc1(x) 212 | x = self.act(x) 213 | x = self.drop(x) 214 | x = self.fc2(x) 215 | x = self.drop(x) 216 | return x 217 | 218 | 219 | class TimeAttnBlock(nn.Module): 220 | def __init__(self, dim=256, num_heads=8): 221 | super(TimeAttnBlock, self).__init__() 222 | self.temporal_attn = Attention(dim, num_heads=8, qkv_bias=False, qk_scale=None) 223 | self.temporal_fc = nn.Linear(dim, dim) 224 | self.temporal_norm1 = nn.LayerNorm(dim) 225 | 226 | nn.init.constant_(self.temporal_fc.weight, 0) 227 | nn.init.constant_(self.temporal_fc.bias, 0) 228 | 229 | def forward(self, x, T=1): 230 | _, _, h, w = x.shape 231 | 232 | x = rearrange(x, "(b t) m h w -> (b h w) t m", h=h, w=w, t=T) 233 | res_temporal1 = self.temporal_attn(self.temporal_norm1(x)) 234 | res_temporal1 = rearrange( 235 | res_temporal1, "(b h w) t m -> b (h w t) m", h=h, w=w, t=T 236 | ) 237 | res_temporal1 = self.temporal_fc(res_temporal1) 238 | res_temporal1 = rearrange( 239 | res_temporal1, " b (h w t) m -> b t m h w", h=h, w=w, t=T 240 | ) 241 | x = rearrange(x, "(b h w) t m -> b t m h w", h=h, w=w, t=T) 242 | x = x + res_temporal1 243 | x = rearrange(x, "b t m h w -> (b t) m h w", h=h, w=w, t=T) 244 | return x 245 | 246 | 247 | class SpaceAttnBlock(nn.Module): 248 | def __init__(self, dim=256, num_heads=8): 249 | super(SpaceAttnBlock, self).__init__() 250 | self.encoder_layer = LoFTREncoderLayer(dim, nhead=num_heads, attention="linear") 251 | 252 | def forward(self, x, T=1): 253 | _, _, h, w = x.shape 254 | x = rearrange(x, "(b t) m h w -> (b t) (h w) m", h=h, w=w, t=T) 255 | x = self.encoder_layer(x, x) 256 | x = rearrange(x, "(b t) (h w) m -> (b t) m h w", h=h, w=w, t=T) 257 | return x 258 | 259 | 260 | class BasicUpdateBlock(nn.Module): 261 | def __init__(self, hidden_dim, cor_planes, mask_size=8, attention_type=None): 262 | super(BasicUpdateBlock, self).__init__() 263 | self.attention_type = attention_type 264 | if attention_type is not None: 265 | if "update_time" in attention_type: 266 | self.time_attn = TimeAttnBlock(dim=256, num_heads=8) 267 | 268 | if "update_space" in attention_type: 269 | self.space_attn = SpaceAttnBlock(dim=256, num_heads=8) 270 | 271 | self.encoder = BasicMotionEncoder(cor_planes) 272 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim) 273 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 274 | 275 | self.mask = nn.Sequential( 276 | nn.Conv2d(128, 256, 3, padding=1), 277 | nn.ReLU(inplace=True), 278 | nn.Conv2d(256, mask_size ** 2 * 9, 1, padding=0), 279 | ) 280 | 281 | def forward(self, net, inp, corr, flow, upsample=True, t=1): 282 | motion_features = self.encoder(flow, corr) 283 | inp = torch.cat((inp, motion_features), dim=1) 284 | if self.attention_type is not None: 285 | if "update_time" in self.attention_type: 286 | inp = self.time_attn(inp, T=t) 287 | if "update_space" in self.attention_type: 288 | inp = self.space_attn(inp, T=t) 289 | net = self.gru(net, inp) 290 | delta_flow = self.flow_head(net) 291 | 292 | # scale mask to balence gradients 293 | mask = 0.25 * self.mask(net) 294 | return net, mask, delta_flow 295 | 296 | 297 | class FlowHead3D(nn.Module): 298 | def __init__(self, input_dim=128, hidden_dim=256): 299 | super(FlowHead3D, self).__init__() 300 | self.conv1 = nn.Conv3d(input_dim, hidden_dim, 3, padding=1) 301 | self.conv2 = nn.Conv3d(hidden_dim, 2, 3, padding=1) 302 | self.relu = nn.ReLU(inplace=True) 303 | 304 | def forward(self, x): 305 | return self.conv2(self.relu(self.conv1(x))) 306 | 307 | 308 | class SequenceUpdateBlock3D(nn.Module): 309 | def __init__(self, hidden_dim, cor_planes, mask_size=8, attention_type=None): 310 | super(SequenceUpdateBlock3D, self).__init__() 311 | 312 | self.encoder = BasicMotionEncoder(cor_planes) 313 | self.gru = SepConvGRU3D(hidden_dim=hidden_dim, input_dim=128 + hidden_dim) 314 | self.flow_head = FlowHead3D(hidden_dim, hidden_dim=256) 315 | self.mask = nn.Sequential( 316 | nn.Conv2d(hidden_dim, hidden_dim + 128, 3, padding=1), 317 | nn.ReLU(inplace=True), 318 | nn.Conv2d(hidden_dim + 128, (mask_size ** 2) * 9, 1, padding=0), 319 | ) 320 | self.attention_type = attention_type 321 | if attention_type is not None: 322 | if "update_time" in attention_type: 323 | self.time_attn = TimeAttnBlock(dim=256, num_heads=8) 324 | if "update_space" in attention_type: 325 | self.space_attn = SpaceAttnBlock(dim=256, num_heads=8) 326 | 327 | def forward(self, net, inp, corrs, flows, t, upsample=True): 328 | inp_tensor = [] 329 | 330 | motion_features = self.encoder(flows, corrs) 331 | inp_tensor = torch.cat([inp, motion_features], dim=1) 332 | 333 | if self.attention_type is not None: 334 | if "update_time" in self.attention_type: 335 | inp_tensor = self.time_attn(inp_tensor, T=t) 336 | if "update_space" in self.attention_type: 337 | inp_tensor = self.space_attn(inp_tensor, T=t) 338 | 339 | net = rearrange(net, "(b t) c h w -> b c t h w", t=t) 340 | inp_tensor = rearrange(inp_tensor, "(b t) c h w -> b c t h w", t=t) 341 | 342 | net = self.gru(net, inp_tensor) 343 | 344 | delta_flow = self.flow_head(net) 345 | 346 | # scale mask to balance gradients 347 | net = rearrange(net, " b c t h w -> (b t) c h w") 348 | mask = 0.25 * self.mask(net) 349 | 350 | delta_flow = rearrange(delta_flow, " b c t h w -> (b t) c h w") 351 | return net, mask, delta_flow 352 | -------------------------------------------------------------------------------- /models/core/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch.nn.functional as F 8 | 9 | 10 | def interp(tensor, size): 11 | return F.interpolate( 12 | tensor, 13 | size=size, 14 | mode="bilinear", 15 | align_corners=True, 16 | ) 17 | 18 | 19 | class InputPadder: 20 | """Pads images such that dimensions are divisible by 8""" 21 | 22 | def __init__(self, dims, mode="sintel", divis_by=8): 23 | self.ht, self.wd = dims[-2:] 24 | pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by 25 | pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by 26 | if mode == "sintel": 27 | self._pad = [ 28 | pad_wd // 2, 29 | pad_wd - pad_wd // 2, 30 | pad_ht // 2, 31 | pad_ht - pad_ht // 2, 32 | ] 33 | else: 34 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] 35 | 36 | def pad(self, *inputs): 37 | assert all((x.ndim == 4) for x in inputs) 38 | return [F.pad(x, self._pad, mode="replicate") for x in inputs] 39 | 40 | def unpad(self, x): 41 | assert x.ndim == 4 42 | ht, wd = x.shape[-2:] 43 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] 44 | return x[..., c[0] : c[1], c[2] : c[3]] 45 | -------------------------------------------------------------------------------- /models/dynamic_stereo_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import ClassVar 8 | 9 | import torch 10 | from pytorch3d.implicitron.tools.config import Configurable 11 | 12 | from dynamic_stereo.models.core.dynamic_stereo import DynamicStereo 13 | 14 | 15 | class DynamicStereoModel(Configurable, torch.nn.Module): 16 | 17 | MODEL_CONFIG_NAME: ClassVar[str] = "DynamicStereoModel" 18 | 19 | # model_weights: str = "./checkpoints/dynamic_stereo_sf.pth" 20 | model_weights: str = "./checkpoints/dynamic_stereo_dr_sf.pth" 21 | kernel_size: int = 20 22 | 23 | def __post_init__(self): 24 | super().__init__() 25 | 26 | self.mixed_precision = False 27 | model = DynamicStereo( 28 | mixed_precision=self.mixed_precision, 29 | num_frames=5, 30 | attention_type="self_stereo_temporal_update_time_update_space", 31 | use_3d_update_block=True, 32 | different_update_blocks=True, 33 | ) 34 | 35 | state_dict = torch.load(self.model_weights, map_location="cpu") 36 | if "model" in state_dict: 37 | state_dict = state_dict["model"] 38 | if "state_dict" in state_dict: 39 | state_dict = state_dict["state_dict"] 40 | state_dict = {"module." + k: v for k, v in state_dict.items()} 41 | model.load_state_dict(state_dict, strict=False) 42 | 43 | self.model = model 44 | self.model.to("cuda") 45 | self.model.eval() 46 | 47 | def forward(self, batch_dict, iters=20): 48 | return self.model.forward_batch_test( 49 | batch_dict, kernel_size=self.kernel_size, iters=iters 50 | ) 51 | -------------------------------------------------------------------------------- /models/raft_stereo_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from collections import defaultdict 8 | from types import SimpleNamespace 9 | from typing import ClassVar 10 | 11 | import torch 12 | from pytorch3d.implicitron.tools.config import Configurable 13 | 14 | import importlib 15 | import sys 16 | 17 | sys.path.append("third_party/RAFT-Stereo") 18 | raft_stereo = importlib.import_module( 19 | "dynamic_stereo.third_party.RAFT-Stereo.core.raft_stereo" 20 | ) 21 | raft_stereo_utils = importlib.import_module( 22 | "dynamic_stereo.third_party.RAFT-Stereo.core.utils.utils" 23 | ) 24 | autocast = torch.cuda.amp.autocast 25 | 26 | 27 | class RAFTStereoModel(Configurable, torch.nn.Module): 28 | MODEL_CONFIG_NAME: ClassVar[str] = "RAFTStereoModel" 29 | model_weights: str = "./third_party/RAFT-Stereo/models/raftstereo-middlebury.pth" 30 | 31 | def __post_init__(self): 32 | super().__init__() 33 | 34 | model_args = SimpleNamespace( 35 | hidden_dims=[128] * 3, 36 | corr_implementation="reg", 37 | shared_backbone=False, 38 | corr_levels=4, 39 | corr_radius=4, 40 | n_downsample=2, 41 | slow_fast_gru=False, 42 | n_gru_layers=3, 43 | mixed_precision=False, 44 | context_norm="batch", 45 | ) 46 | self.args = model_args 47 | model = torch.nn.DataParallel( 48 | raft_stereo.RAFTStereo(model_args), device_ids=[0] 49 | ) 50 | 51 | state_dict = torch.load(self.model_weights, map_location="cpu") 52 | if "state_dict" in state_dict: 53 | state_dict = state_dict["state_dict"] 54 | state_dict = {"module." + k: v for k, v in state_dict.items()} 55 | model.load_state_dict(state_dict) 56 | 57 | self.model = model.module 58 | self.model.to("cuda") 59 | self.model.eval() 60 | 61 | def forward(self, batch_dict, iters=32): 62 | predictions = defaultdict(list) 63 | for stereo_pair in batch_dict["stereo_video"]: 64 | left_image_rgb = stereo_pair[None, 0].cuda() 65 | right_image_rgb = stereo_pair[None, 1].cuda() 66 | 67 | padder = raft_stereo_utils.InputPadder(left_image_rgb.shape, divis_by=32) 68 | left_image_rgb, right_image_rgb = padder.pad( 69 | left_image_rgb, right_image_rgb 70 | ) 71 | 72 | with autocast(enabled=self.args.mixed_precision): 73 | _, flow_up = self.model.forward( 74 | left_image_rgb, 75 | right_image_rgb, 76 | iters=iters, 77 | test_mode=True, 78 | ) 79 | flow_up = padder.unpad(flow_up) 80 | predictions["disparity"].append(flow_up) 81 | predictions["disparity"] = ( 82 | torch.stack(predictions["disparity"]).squeeze(1).abs() 83 | ) 84 | return predictions 85 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core==1.1 2 | einops==0.4.1 3 | flow_vis==0.1 4 | imageio==2.21.1 5 | matplotlib==3.5.3 6 | munch==2.5.0 7 | numpy==1.23.5 8 | omegaconf==2.1.0 9 | opencv_python==4.6.0.66 10 | opt_einsum==3.3.0 11 | Pillow==9.5.0 12 | pytorch_lightning==1.6.0 13 | requests 14 | scikit_image==0.19.2 15 | scipy==1.10.0 16 | setuptools==65.6.3 17 | tabulate==0.8.10 18 | tqdm==4.64.1 19 | moviepy 20 | -------------------------------------------------------------------------------- /scripts/checksum_check.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import glob 9 | import argparse 10 | import hashlib 11 | import json 12 | 13 | from typing import Optional 14 | from multiprocessing import Pool 15 | from tqdm import tqdm 16 | 17 | 18 | DEFAULT_SHA256S_FILE = os.path.join(__file__.rsplit(os.sep, 2)[0], "dr_sha256.json") 19 | BLOCKSIZE = 65536 20 | 21 | 22 | def main( 23 | download_folder: str, 24 | sha256s_file: str, 25 | dump: bool = False, 26 | n_sha256_workers: int = 4 27 | ): 28 | if not os.path.isfile(sha256s_file): 29 | raise ValueError(f"The SHA256 file does not exist ({sha256s_file}).") 30 | 31 | expected_sha256s = get_expected_sha256s( 32 | sha256s_file=sha256s_file 33 | ) 34 | 35 | zipfiles = sorted(glob.glob(os.path.join(download_folder, "*.zip"))) 36 | print(f"Extracting SHA256 hashes for {len(zipfiles)} files in {download_folder}.") 37 | extracted_sha256s_list = [] 38 | with Pool(processes=n_sha256_workers) as sha_pool: 39 | for extracted_hash in tqdm( 40 | sha_pool.imap(_sha256_file_and_print, zipfiles), 41 | total=len(zipfiles), 42 | ): 43 | extracted_sha256s_list.append(extracted_hash) 44 | pass 45 | 46 | extracted_sha256s = dict( 47 | zip([os.path.split(z)[-1] for z in zipfiles], extracted_sha256s_list) 48 | ) 49 | 50 | if dump: 51 | print(extracted_sha256s) 52 | with open(sha256s_file, "w") as f: 53 | json.dump(extracted_sha256s, f, indent=2) 54 | 55 | 56 | missing_keys, invalid_keys = [], [] 57 | for k in expected_sha256s.keys(): 58 | if k not in extracted_sha256s: 59 | print(f"{k} missing!") 60 | missing_keys.append(k) 61 | elif expected_sha256s[k] != extracted_sha256s[k]: 62 | print( 63 | f"'{k}' does not match!" 64 | + f" ({expected_sha256s[k]} != {extracted_sha256s[k]})" 65 | ) 66 | invalid_keys.append(k) 67 | if len(invalid_keys) + len(missing_keys) > 0: 68 | raise ValueError( 69 | f"Checksum checker failed!" 70 | + f" Non-matching checksums: {str(invalid_keys)};" 71 | + f" missing files: {str(missing_keys)}." 72 | ) 73 | 74 | 75 | def get_expected_sha256s( 76 | sha256s_file: str 77 | ): 78 | with open(sha256s_file, "r") as f: 79 | expected_sha256s = json.load(f) 80 | return expected_sha256s 81 | 82 | 83 | def check_dr_sha256( 84 | path: str, 85 | sha256s_file: str, 86 | expected_sha256s: Optional[dict] = None, 87 | do_assertion: bool = True, 88 | ): 89 | zipname = os.path.split(path)[-1] 90 | if expected_sha256s is None: 91 | expected_sha256s = get_expected_sha256s( 92 | sha256s_file=sha256s_file, 93 | ) 94 | extracted_hash = sha256_file(path) 95 | if do_assertion: 96 | assert ( 97 | extracted_hash == expected_sha256s[zipname] 98 | ), f"{zipname}: ({extracted_hash} != {expected_sha256s[zipname]})" 99 | else: 100 | return extracted_hash == expected_sha256s[zipname] 101 | 102 | 103 | def sha256_file(path: str): 104 | sha256_hash = hashlib.sha256() 105 | with open(path, "rb") as f: 106 | file_buffer = f.read(BLOCKSIZE) 107 | while len(file_buffer) > 0: 108 | sha256_hash.update(file_buffer) 109 | file_buffer = f.read(BLOCKSIZE) 110 | digest_ = sha256_hash.hexdigest() 111 | return digest_ 112 | 113 | 114 | def _sha256_file_and_print(path: str): 115 | digest_ = sha256_file(path) 116 | print(f"{path}: {digest_}") 117 | return digest_ 118 | 119 | 120 | 121 | if __name__ == "__main__": 122 | parser = argparse.ArgumentParser( 123 | description="Check SHA256 hashes of the Dynamic Replica dataset." 124 | ) 125 | parser.add_argument( 126 | "--download_folder", 127 | type=str, 128 | help="A local target folder for downloading the the dataset files.", 129 | ) 130 | parser.add_argument( 131 | "--sha256s_file", 132 | type=str, 133 | help="A local target folder for downloading the the dataset files.", 134 | default=DEFAULT_SHA256S_FILE, 135 | ) 136 | parser.add_argument( 137 | "--num_workers", 138 | type=int, 139 | default=4, 140 | help="The number of sha256 extraction workers.", 141 | ) 142 | parser.add_argument( 143 | "--dump_sha256s", 144 | action="store_true", 145 | help="Store sha256s hashes.", 146 | ) 147 | 148 | args = parser.parse_args() 149 | main( 150 | str(args.download_folder), 151 | dump=bool(args.dump_sha256s), 152 | n_sha256_workers=int(args.num_workers), 153 | sha256s_file=str(args.sha256s_file), 154 | ) -------------------------------------------------------------------------------- /scripts/download_dynamic_replica.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | 9 | from dynamic_stereo.scripts.download_utils import build_arg_parser, download_dataset 10 | 11 | 12 | DEFAULT_LINK_LIST_FILE = os.path.join(os.path.dirname(__file__), "links.json") 13 | DEFAULT_SHA256S_FILE = os.path.join(os.path.dirname(__file__), "dr_sha256.json") 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = build_arg_parser( 18 | "dynamic_replica", DEFAULT_LINK_LIST_FILE, DEFAULT_SHA256S_FILE 19 | ) 20 | 21 | args = parser.parse_args() 22 | os.makedirs(args.download_folder, exist_ok=True) 23 | download_dataset( 24 | str(args.link_list_file), 25 | str(args.download_folder), 26 | n_download_workers=int(args.n_download_workers), 27 | n_extract_workers=int(args.n_extract_workers), 28 | download_splits=args.download_splits, 29 | checksum_check=bool(args.checksum_check), 30 | clear_archives_after_unpacking=bool(args.clear_archives_after_unpacking), 31 | sha256s_file=str(args.sha256_file), 32 | skip_downloaded_archives=not bool(args.redownload_existing_archives), 33 | ) 34 | -------------------------------------------------------------------------------- /scripts/download_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import shutil 9 | import requests 10 | import functools 11 | import json 12 | import warnings 13 | 14 | from argparse import ArgumentParser 15 | from typing import List, Optional 16 | from multiprocessing import Pool 17 | from tqdm import tqdm 18 | 19 | 20 | from dynamic_stereo.scripts.checksum_check import check_dr_sha256 21 | 22 | 23 | def download_dataset( 24 | link_list_file: str, 25 | download_folder: str, 26 | n_download_workers: int = 4, 27 | n_extract_workers: int = 4, 28 | download_splits: List[str] = ['real', 'valid', 'test', 'train'], 29 | checksum_check: bool = False, 30 | clear_archives_after_unpacking: bool = False, 31 | skip_downloaded_archives: bool = True, 32 | sha256s_file: Optional[str] = None, 33 | ): 34 | """ 35 | Downloads and unpacks the dataset in CO3D format. 36 | Note: The script will make a folder `/_in_progress`, which 37 | stores files whose download is in progress. The folder can be safely deleted 38 | the download is finished. 39 | Args: 40 | link_list_file: A text file with the list of zip file download links. 41 | download_folder: A local target folder for downloading the 42 | the dataset files. 43 | n_download_workers: The number of parallel workers 44 | for downloading the dataset files. 45 | n_extract_workers: The number of parallel workers 46 | for extracting the dataset files. 47 | download_splits: A list of data splits to download. 48 | Must be in ['real', 'valid', 'test', 'train']. 49 | checksum_check: Enable validation of the downloaded file's checksum before 50 | extraction. 51 | clear_archives_after_unpacking: Delete the unnecessary downloaded archive files 52 | after unpacking. 53 | skip_downloaded_archives: Skip re-downloading already downloaded archives. 54 | """ 55 | 56 | if checksum_check and not sha256s_file: 57 | raise ValueError( 58 | "checksum_check is requested but ground-truth SHA256 file not provided!" 59 | ) 60 | 61 | if not os.path.isfile(link_list_file): 62 | raise ValueError( 63 | "Please specify `link_list_file` with a valid path to a json" 64 | " with zip file download links." 65 | # " The file is stored in the DynamicStereo github:" 66 | # " https://github.com/facebookresearch/dynamic_stereo/blob/main/dynamic_stereo/links.json" 67 | ) 68 | 69 | if not os.path.isdir(download_folder): 70 | raise ValueError( 71 | "Please specify `download_folder` with a valid path to a target folder" 72 | + " for downloading the dataset." 73 | + f" {download_folder} does not exist." 74 | ) 75 | 76 | # read the link file 77 | with open(link_list_file, "r") as f: 78 | links = json.load(f) 79 | 80 | for split in download_splits: 81 | if split not in ['real', 'valid', 'test', 'train']: 82 | raise ValueError( 83 | f"Download split {str(split)} is not valid" 84 | ) 85 | 86 | data_links = [] 87 | for split_name, urls in links.items(): 88 | if split_name in download_splits: 89 | for url in urls: 90 | link_name = os.path.split(url)[-1] 91 | data_links.append((split_name, link_name, url)) 92 | 93 | 94 | with Pool(processes=n_download_workers) as download_pool: 95 | download_ok = {} 96 | for link_name, ok in tqdm( 97 | download_pool.imap( 98 | functools.partial( 99 | _download_split_file, 100 | download_folder, 101 | checksum_check, 102 | sha256s_file, 103 | skip_downloaded_archives, 104 | ), 105 | data_links, 106 | ), 107 | total=len(data_links), 108 | ): 109 | download_ok[link_name] = ok 110 | 111 | with Pool(processes=n_extract_workers) as extract_pool: 112 | for _ in tqdm( 113 | extract_pool.imap( 114 | functools.partial( 115 | _unpack_split_file, 116 | download_folder, 117 | clear_archives_after_unpacking, 118 | ), 119 | data_links, 120 | ), 121 | total=len(data_links), 122 | ): 123 | pass 124 | print("Done") 125 | 126 | 127 | 128 | def build_arg_parser( 129 | dataset_name: str, 130 | default_link_list_file: str, 131 | default_sha256_file: str, 132 | ) -> ArgumentParser: 133 | parser = ArgumentParser(description=f"Download the {dataset_name} dataset.") 134 | parser.add_argument( 135 | "--download_folder", 136 | type=str, 137 | required=True, 138 | help="A local target folder for downloading the the dataset files.", 139 | ) 140 | parser.add_argument( 141 | "--n_download_workers", 142 | type=int, 143 | default=4, 144 | help="The number of parallel workers for downloading the dataset files.", 145 | ) 146 | parser.add_argument( 147 | "--n_extract_workers", 148 | type=int, 149 | default=4, 150 | help="The number of parallel workers for extracting the dataset files.", 151 | ) 152 | parser.add_argument( 153 | "--download_splits", 154 | default=['real', 'valid', 'test', 'train'], 155 | nargs='+', 156 | help=f"A comma-separated list of {dataset_name} splits to download.", 157 | ) 158 | parser.add_argument( 159 | "--link_list_file", 160 | type=str, 161 | default=default_link_list_file, 162 | help=( 163 | f"The file with html links to the {dataset_name} dataset files." 164 | + " In most cases the default local file `links.json` should be used." 165 | ), 166 | ) 167 | parser.add_argument( 168 | "--sha256_file", 169 | type=str, 170 | default=default_sha256_file, 171 | help=( 172 | f"The file with SHA256 hashes of {dataset_name} dataset files." 173 | + " In most cases the default local file `dr_sha256.json` should be used." 174 | ), 175 | ) 176 | parser.add_argument( 177 | "--checksum_check", 178 | action="store_true", 179 | default=True, 180 | help="Check the SHA256 checksum of each downloaded file before extraction.", 181 | ) 182 | parser.add_argument( 183 | "--no_checksum_check", 184 | action="store_false", 185 | dest="checksum_check", 186 | default=False, 187 | help="Does not check the SHA256 checksum of each downloaded file before extraction.", 188 | ) 189 | parser.set_defaults(checksum_check=True) 190 | parser.add_argument( 191 | "--clear_archives_after_unpacking", 192 | action="store_true", 193 | default=False, 194 | help="Delete the unnecessary downloaded archive files after unpacking.", 195 | ) 196 | parser.add_argument( 197 | "--redownload_existing_archives", 198 | action="store_true", 199 | default=False, 200 | help="Redownload the already-downloaded archives.", 201 | ) 202 | 203 | return parser 204 | 205 | def _unpack_split_file( 206 | download_folder: str, 207 | clear_archive: bool, 208 | link: str, 209 | ): 210 | split, link_name, url = link 211 | local_fl = os.path.join(download_folder, link_name) 212 | print(f"Unpacking dataset file {local_fl} ({link_name}) to {download_folder}.") 213 | 214 | download_folder_split = os.path.join(download_folder, split) 215 | # os.makedirs(download_folder_split, exist_ok=True) 216 | shutil.unpack_archive(local_fl, download_folder_split) 217 | if clear_archive: 218 | os.remove(local_fl) 219 | 220 | def _download_split_file( 221 | download_folder: str, 222 | checksum_check: bool, 223 | sha256s_file: Optional[str], 224 | skip_downloaded_files: bool, 225 | link: str, 226 | ): 227 | __, link_name, url = link 228 | local_fl_final = os.path.join(download_folder, link_name) 229 | 230 | if skip_downloaded_files and os.path.isfile(local_fl_final): 231 | print(f"Skipping {local_fl_final}, already downloaded!") 232 | return link_name, True 233 | 234 | in_progress_folder = os.path.join(download_folder, "_in_progress") 235 | os.makedirs(in_progress_folder, exist_ok=True) 236 | local_fl = os.path.join(in_progress_folder, link_name) 237 | 238 | print(f"Downloading dataset file {link_name} ({url}) to {local_fl}.") 239 | _download_with_progress_bar(url, local_fl, link_name) 240 | if checksum_check: 241 | print(f"Checking SHA256 for {local_fl}.") 242 | try: 243 | check_dr_sha256( 244 | local_fl, 245 | sha256s_file=sha256s_file, 246 | ) 247 | except AssertionError: 248 | warnings.warn( 249 | f"Checksums for {local_fl} did not match!" 250 | + " This is likely due to a network failure," 251 | + " please restart the download script." 252 | ) 253 | return link_name, False 254 | 255 | os.rename(local_fl, local_fl_final) 256 | return link_name, True 257 | 258 | 259 | def _download_with_progress_bar(url: str, fname: str, filename: str): 260 | 261 | # taken from https://stackoverflow.com/a/62113293/986477 262 | resp = requests.get(url, stream=True) 263 | print(url) 264 | total = int(resp.headers.get("content-length", 0)) 265 | with open(fname, "wb") as file, tqdm( 266 | desc=fname, 267 | total=total, 268 | unit="iB", 269 | unit_scale=True, 270 | unit_divisor=1024, 271 | ) as bar: 272 | for datai, data in enumerate(resp.iter_content(chunk_size=1024)): 273 | size = file.write(data) 274 | bar.update(size) 275 | if datai % max((max(total // 1024, 1) // 20), 1) == 0: 276 | print(f"{filename}: Downloaded {100.0*(float(bar.n)/max(total, 1)):3.1f}%.") 277 | print(bar) -------------------------------------------------------------------------------- /scripts/dr_sha256.json: -------------------------------------------------------------------------------- 1 | { 2 | "real_000.zip": "e5c2aac04146d783c64f76d0ef7a9e8d49d80ffac99d2a795563517f15943a6f", 3 | "valid_000.zip": "0f35bee47030ae1a30289beb92ba69c5336491e0f07aab0a05cb5505173d1faf", 4 | "valid_001.zip": "cb37d3b1f643118ae22840b4212b00c00a8fe137099d3730a07796a5fefab24a", 5 | "valid_002.zip": "5535f2a98e06c68cf97e3259e962e3d44465a1820369e4425c4ef2a719b01ad0", 6 | "valid_003.zip": "e19db94514d22829743aa363698f407ecfd98d8f08eab037289a420939ef5143", 7 | "valid_004.zip": "953328f24ba0c3e8709df3829cce238305a8998bf7ae938c80069fab6f513862", 8 | "valid_005.zip": "27ce4c7424292dcf3e8e0b370fbbc848bd6d73ae28ea5832fddfa8e9c17d6011", 9 | "test_000.zip": "a56fa676a7a3dc52b33f1571d41fb0221e289735acccb7b9ad42dfb13fdac68c", 10 | "test_001.zip": "43580e89331826182f41d2ce9f06f62da46617fea9e612a16b2610de8ffdc10b", 11 | "test_002.zip": "33551fb68979d3d2f20e1976d9169a84ad58658c459aba4d7a2671c8d66904b9", 12 | "test_003.zip": "45ad28d7555e3579225d26dfcb8244b65de0d1ee749560cc6dd84f121b4b40de", 13 | "test_004.zip": "d736b56fe15410525deda1c16c0b8da4497383480a4328da92bc0ddb64a62d52", 14 | "test_005.zip": "3ae331047019a39c6306a17407c72e40dc15b5113f6f9ef72aba2da0b859ea7d", 15 | "test_006.zip": "94341c8ac8ed1d7f11816ad121e6c5821a751fdc3d3122a653c86f7b5845ca80", 16 | "test_007.zip": "4e18facbd507e16fc41d90d5c2ce1b44c511d3e2986e1ccdf5d264748d5d7e15", 17 | "test_008.zip": "e4d5aa0c25eb01863bbced477e17fddd9d8217d23d238bb06b0b688a0f6ed8e3", 18 | "test_009.zip": "5a413411cfc376078ed0357708084f71949159c13119aabb5c9ae1ffde33b6b7", 19 | "test_010.zip": "82ea42c7544385aa2d41271e63399534a398dbbef8a06cb990c8bb34296928c8", 20 | "train_000.zip": "e9fd9af579b0d08d538551c0ab6f7231a1fd162139667803e672cc0dc8b98b03", 21 | "train_001.zip": "65cb438c7a48567f85db8e54109db6c25d2a621fcbd3267c542a8a640e1dad56", 22 | "train_002.zip": "c3d9a76a955dd9feb0275837a17133a1d7ee76c963f1c6fa7630deb0ca8209b2", 23 | "train_003.zip": "13e108f78c7da1f1c1469dd87fab55a6e4ec79f1fcb6d7d16cc9006a933979f4", 24 | "train_004.zip": "171b92a62b46a68f1d89c2326ba67b6433faf087bc1eecc7a947c19d0f90d3e6", 25 | "train_005.zip": "75461ffe13cfbd87b4f0f9ffc83002b8381f5a0a212ece18b8961012f865a46e", 26 | "train_006.zip": "7546f94817814031a738082e6b30858d0057710af052a88fa505a961b6699886", 27 | "train_007.zip": "371dd100b215bcd41129def1c8fd07f974af11a9b3d3b9966ce5d9700b9929ad", 28 | "train_008.zip": "313f5c2089c6afc1691edf054e8b2af9eb8b2d91f791153763758c8d91abee48", 29 | "train_009.zip": "9cbb9f44bb6b7dcc74f00a51db4d2a8797c95a0d880d63ef1612d3883b16b995", 30 | "train_010.zip": "eb158fccc23a4b41358ec94be203f49a677f86626af7a88f0e649454c409c706", 31 | "train_011.zip": "f8b3f8c738cdcdbbdf346a4dd78b99883b5d4ab74c11b64ec7b4f8ccd3b68ffc", 32 | "train_012.zip": "b364ba9d35d7e55019d3554cf65b295d2358859c222b3b847b0f2cced948cfce", 33 | "train_013.zip": "c8a50efbd93e6e422eabf1846dac2d75e81dfcfcd4d785fe18b01526af9695f6", 34 | "train_014.zip": "52a768ce76310861cf1fc990ebb8d16f0c27fceff02c12b11638d36ca1c3a927", 35 | "train_015.zip": "67bf0ba775948997f5ab3cc810b6d0e8149758334210ace6f5cdfc529fe7d26e", 36 | "train_016.zip": "d5b9a26736421d8f330fd5e531d26071531501a88609d29d580b9d56b6bc17a3", 37 | "train_017.zip": "5f2d2c93e7944baf1e6d3dee671b12abb7476a75cbd6f572af86fe5c22472fa6", 38 | "train_018.zip": "77aa801b6b0359b970466329e4a05b937df94b650228cf4797a2a029606b8e5b", 39 | "train_019.zip": "30934c91cc0ae69acef6a89e4a5180686bd04080e2384a8bde5877cbaaadc575", 40 | "train_020.zip": "901d5c08705a70053a3e865354a4e7149c35f026b6ed166fee029d829d88c124", 41 | "train_021.zip": "f27019ff58e54a004ed2cf2106ed459a31b010ed82d32028b0e196dd365b8b0e", 42 | "train_022.zip": "0600346a2ce162f7e9824e90c553b69a656d4731c86d903e300d932ec8ba7600", 43 | "train_023.zip": "660d768e4b1bfe742a42ae6ee84f5e91c930789488a7c7f118e5d0edd1f1a010", 44 | "train_024.zip": "1f8792002baceaba8f93f93be1bee7c83a48c677e4b2d025b6f0047a796e94cd", 45 | "train_025.zip": "0b92b3f41c18fded8fcb7aba44e7d8738750b8155c907924200fdf4dc1718794", 46 | "train_026.zip": "4dc401639317527231abfef07221b8d7db2d0950008828104cd1f72092325d05", 47 | "train_027.zip": "e8313eaa21163f9dd2ff4558d16b1c9cf4962c2e4c0403d6a315955660a98b14", 48 | "train_028.zip": "d73edf1c500b4311795aaae0a03b3bc04a2c266e2a20b27ba9b6e72fb27fd277", 49 | "train_029.zip": "c5e4d302c62e693626445aba19638711108049235b0075558e7949b189050c56", 50 | "train_030.zip": "506b9ba7a740b0bf84159546f797437a48a24e468cb949f2189e51cf404c6170", 51 | "train_031.zip": "f36bb4b77fdb255dae2050884cf59cd3f8e46e77ea2984b4b219b799c4aac089", 52 | "train_032.zip": "fddca4efc40ed8d05adf9d519e4fb5b486ac77e8fa08c98d5c4be15867fda8a0", 53 | "train_033.zip": "c24d2b5c04f3e90b265fd0762e7ae19fb01a7c1948a4c09451383a9eec9f640f", 54 | "train_034.zip": "5828fbd615c4476f6107fe844cbf81632eff2f9c75194cb84d749630d9359e14", 55 | "train_035.zip": "7b60fe125fd1a9ba7991e2accd0f2b212968983b4631d43eccff9836a0c35ba8", 56 | "train_036.zip": "0f4eaf464a2afc62447a802159b3844487b80e9d1c9b0a7d324b0d5914514d60", 57 | "train_037.zip": "ba85a6692d86e48c4c787b334d4384c08b914e4cee7f3d2692dcae1bbac55878", 58 | "train_038.zip": "c67b0f5305560d8089bdc2f6212c05256c044e50a715d59b864fbef705bc6b5c", 59 | "train_039.zip": "f4b66c9e1360a8d6d8337c94eefb1132d865c2735c6b78ba726a590073174aad", 60 | "train_040.zip": "2c64b76d028fcc153f267925b79a24cf3bb0e42cc7716773df2139f5cec5e319", 61 | "train_041.zip": "22b1c0ab99a7f8bd0d36c2d2511d3d469cc390776c38132d1e8f1ad7aae5d4ff", 62 | "train_042.zip": "8f2afaecb9f90947c9071111fde9c015acfceb432ae0bf94deff3ecd581b26c8", 63 | "train_043.zip": "adf7ea7c356339b10b797c49163252704b4e6b0cebcc741d3374f8c9467f6b43", 64 | "train_044.zip": "3d0fe4a85fd22ff9c8ed468ca8173d93406a72fadf800d9e6bbf209348cf8965", 65 | "train_045.zip": "70874eca6bce66cb7681092755d066968e9c8fc32a266d7c0d2f29c01b2b2669", 66 | "train_046.zip": "01adcdbba0a25383e2281ce02a946f6bc824e1b8e16cf88e85a4ad275203884c", 67 | "train_047.zip": "50ed632ae330acf60c1b2e22b28fbfab5ccf0e8f23320b2911dcc2d43db048b6", 68 | "train_048.zip": "f302984f486df60d7a281e2b0a9b6d32456fc6042eb596cb5ef54ee919ccd7bb", 69 | "train_049.zip": "8e8e0a426796f76dfb2d29cb855894fd01cc954b017aa1d06ae1a121fb310088", 70 | "train_050.zip": "051f0dd8e612e7073dd20585c42681daeff853a6ee0de6f2e8ff4581cdf4f83b", 71 | "train_051.zip": "3f39b3732c32b960aef4bf3f152b1a72195dc4ab4bbc10116a05875ca8d40417", 72 | "train_052.zip": "361b9bcd3364c63c8f2814dfacf91489b79c9cedf03ffcb03b3dacfb77cee3a1", 73 | "train_053.zip": "f6afe23b3005b1889f76ea9c10ac42f7c4f07cefbe737781229640b834f8ede2", 74 | "train_054.zip": "ef993bd657104770df8e07a9d7c8ac1d1c3ac57b91f66796bea97f03e5a01df2", 75 | "train_055.zip": "ec0dea8199e1db7bd8e19f85b0d1a9ab9e8fc2be2c5da5b3455f96e074ad7f22", 76 | "train_056.zip": "44259829f6832c3dc14b893d5f5b7b6f784a09570f26e9cc9749807a1b05b21e", 77 | "train_057.zip": "263b712fe2ded353cb248324305f831d8b14aa0858f005067bb27e88decd7f32", 78 | "train_058.zip": "c44fb44365bc4cd8c4c9bb13d70fa9bb290708b7d3fe44fd79c6eed42702ed70", 79 | "train_059.zip": "43dd65609afb3992273f914b4d0108187f85eaf1f252f85556f10e40816d5e6c", 80 | "train_060.zip": "97b2abe90259f4629d7c1c1cec2427f155252403f5dcfea563e2d1338ae63150", 81 | "train_061.zip": "9d8c790d1806659617ddd6dd99ae56388b5eb9f311c47a079ac8fa5df8f44f57", 82 | "train_062.zip": "5b4398d6a8709ddf1b050b03b19dfe8aacf3378a4879402f457f12bd97ab99df", 83 | "train_063.zip": "05024f1b0671cb3026db0b9e801c9aab000b828784839f970a8ad0bc23125435", 84 | "train_064.zip": "b9bba3999971745ea2cdce69c00c49b109ba02c9f3169614d1d229e468bebc68", 85 | "train_065.zip": "ff4084dd7c017478b872fd7c9152df5271a7088489d3b86cc21968db272356ef", 86 | "train_066.zip": "9d8158fd6691065c1cb76ac36c3be90b065e8848856a66b10475b11e1261dd4d", 87 | "train_067.zip": "3e4b9ebef2bdecab5774a72037d9f1f7c40359e6a2d00851c0c40bdd686373c5", 88 | "train_068.zip": "a89d53ce7c79af32a659a2a59138568ada1395c56c6063f4f49c1d4e052cf9cd", 89 | "train_069.zip": "3f66206486af3f0bfa04ce8f664b6af6aa7fd2ad8ebadd5c75039de8c5ffea91", 90 | "train_070.zip": "e8a95aad5f81e7185a7dacb9031a5c27010ec17302e2e35f7f1de3dc88e02a7b", 91 | "train_071.zip": "677bf42f8d576c79189cd5af2abf420990368d9c7d768a21a10fc0939dde121f", 92 | "train_072.zip": "f8d5ea223dc13663bbaae6c5bbd732db15f1c249e7fe2da44b5a6ba5b7dbf505", 93 | "train_073.zip": "3057bda88ebd5bffb0da030d1126e1fb4fed4b5fbfc547dc0be669ece39979c1", 94 | "train_074.zip": "f3a01d19e6fedd44679d76ee93051b91b616a55b6b22861db126b8d2bfdba7ce", 95 | "train_075.zip": "0faa29f3f712f744e003da29b249896cc770fb9b357e8a4c447eeb6ad2798ce2", 96 | "train_076.zip": "d9943f9b72be89dd8f1273bd02133ab24b81e3c3f794e13362a96b0826518696", 97 | "train_077.zip": "cfab28d27c1532a91980b65baa4d40c8e13144788b9ae7a4c36ce8b909e51e55", 98 | "train_078.zip": "b06277baadbe60b2019d0f7b6ed637b23957b6320797bf4b6b9099dc4df0cc7e", 99 | "train_079.zip": "2163ef05752f7a8813fa9cd5661547bc280239fd3bd903b94a8aef37182e9645", 100 | "train_080.zip": "13ae6b86afe4aa00ce19f4f7a8df24d11742340c5775fca02f6e1f70cd9a3be7", 101 | "train_081.zip": "a2512084c16220e0acd207f5e330dd319a30c3445b5034f2c14f9a65111628a3", 102 | "train_082.zip": "d9615ac989465bc85cf990167ce176af55b8affeebb58d5021c215c1f7235c8a", 103 | "train_083.zip": "539710fcc33b043dd24499d3987852a35c8a1c5fb75f7530a9caebf57fd5f324", 104 | "train_084.zip": "33232eb1d68e493a25126f22e31326b7c1195ea511c332a1413e83a0245bdae6", 105 | "train_085.zip": "13e575f24a77278b7de25e3d186f6201692b3e45ed4701b071d5a770c0e1d590" 106 | } -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import logging 9 | from pathlib import Path 10 | from tqdm import tqdm 11 | import os 12 | import torch 13 | import torch.optim as optim 14 | 15 | from munch import DefaultMunch 16 | import json 17 | from pytorch_lightning.lite import LightningLite 18 | from torch.cuda.amp import GradScaler 19 | 20 | from dynamic_stereo.train_utils.utils import ( 21 | run_test_eval, 22 | save_ims_to_tb, 23 | count_parameters, 24 | ) 25 | from dynamic_stereo.train_utils.logger import Logger 26 | from dynamic_stereo.models.core.dynamic_stereo import DynamicStereo 27 | from dynamic_stereo.evaluation.core.evaluator import Evaluator 28 | from dynamic_stereo.train_utils.losses import sequence_loss 29 | import dynamic_stereo.datasets.dynamic_stereo_datasets as datasets 30 | 31 | 32 | def fetch_optimizer(args, model): 33 | """Create the optimizer and learning rate scheduler""" 34 | optimizer = optim.AdamW( 35 | model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8 36 | ) 37 | scheduler = optim.lr_scheduler.OneCycleLR( 38 | optimizer, 39 | args.lr, 40 | args.num_steps + 100, 41 | pct_start=0.01, 42 | cycle_momentum=False, 43 | anneal_strategy="linear", 44 | ) 45 | return optimizer, scheduler 46 | 47 | 48 | def forward_batch(batch, model, args): 49 | output = {} 50 | disparities = model( 51 | batch["img"][:, :, 0], 52 | batch["img"][:, :, 1], 53 | iters=args.train_iters, 54 | test_mode=False, 55 | ) 56 | num_traj = len(batch["disp"][0]) 57 | for i in range(num_traj): 58 | seq_loss, metrics = sequence_loss( 59 | disparities[:, i], batch["disp"][:, i, 0], batch["valid_disp"][:, i, 0] 60 | ) 61 | 62 | output[f"disp_{i}"] = {"loss": seq_loss / num_traj, "metrics": metrics} 63 | output["disparity"] = { 64 | "predictions": torch.cat( 65 | [disparities[-1, i, 0] for i in range(num_traj)], dim=1 66 | ).detach(), 67 | } 68 | return output 69 | 70 | 71 | class Lite(LightningLite): 72 | def run(self, args): 73 | self.seed_everything(0) 74 | eval_dataloader_dr = datasets.DynamicReplicaDataset( 75 | split="valid", sample_len=40, only_first_n_samples=1 76 | ) 77 | eval_dataloader_sintel_clean = datasets.SequenceSintelStereo(dstype="clean") 78 | eval_dataloader_sintel_final = datasets.SequenceSintelStereo(dstype="final") 79 | 80 | eval_dataloaders = [ 81 | ("sintel_clean", eval_dataloader_sintel_clean), 82 | ("sintel_final", eval_dataloader_sintel_final), 83 | ("dynamic_replica", eval_dataloader_dr), 84 | ] 85 | 86 | evaluator = Evaluator() 87 | 88 | eval_vis_cfg = { 89 | "visualize_interval": 1, # Use 0 for no visualization 90 | "exp_dir": args.ckpt_path, 91 | } 92 | eval_vis_cfg = DefaultMunch.fromDict(eval_vis_cfg, object()) 93 | evaluator.setup_visualization(eval_vis_cfg) 94 | 95 | model = DynamicStereo( 96 | max_disp=256, 97 | mixed_precision=args.mixed_precision, 98 | num_frames=args.sample_len, 99 | attention_type=args.attention_type, 100 | use_3d_update_block=args.update_block_3d, 101 | different_update_blocks=args.different_update_blocks, 102 | ) 103 | 104 | with open(args.ckpt_path + "/meta.json", "w") as file: 105 | json.dump(vars(args), file, sort_keys=True, indent=4) 106 | 107 | model.cuda() 108 | 109 | logging.info(f"Parameter Count: {count_parameters(model)}") 110 | 111 | train_loader = datasets.fetch_dataloader(args) 112 | train_loader = self.setup_dataloaders(train_loader, move_to_device=False) 113 | 114 | logging.info(f"Train loader size: {len(train_loader)}") 115 | 116 | optimizer, scheduler = fetch_optimizer(args, model) 117 | 118 | total_steps = 0 119 | logger = Logger(model, scheduler, args.ckpt_path) 120 | 121 | folder_ckpts = [ 122 | f 123 | for f in os.listdir(args.ckpt_path) 124 | if not os.path.isdir(f) and f.endswith(".pth") and not "final" in f 125 | ] 126 | if len(folder_ckpts) > 0: 127 | ckpt_path = sorted(folder_ckpts)[-1] 128 | ckpt = self.load(os.path.join(args.ckpt_path, ckpt_path)) 129 | logging.info(f"Loading checkpoint {ckpt_path}") 130 | if "model" in ckpt: 131 | model.load_state_dict(ckpt["model"]) 132 | else: 133 | model.load_state_dict(ckpt) 134 | if "optimizer" in ckpt: 135 | logging.info("Load optimizer") 136 | optimizer.load_state_dict(ckpt["optimizer"]) 137 | if "scheduler" in ckpt: 138 | logging.info("Load scheduler") 139 | scheduler.load_state_dict(ckpt["scheduler"]) 140 | if "total_steps" in ckpt: 141 | total_steps = ckpt["total_steps"] 142 | logging.info(f"Load total_steps {total_steps}") 143 | 144 | elif args.restore_ckpt is not None: 145 | assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith( 146 | ".pt" 147 | ) 148 | logging.info("Loading checkpoint...") 149 | strict = True 150 | 151 | state_dict = self.load(args.restore_ckpt) 152 | if "model" in state_dict: 153 | state_dict = state_dict["model"] 154 | if list(state_dict.keys())[0].startswith("module."): 155 | state_dict = { 156 | k.replace("module.", ""): v for k, v in state_dict.items() 157 | } 158 | model.load_state_dict(state_dict, strict=strict) 159 | 160 | logging.info(f"Done loading checkpoint") 161 | model, optimizer = self.setup(model, optimizer, move_to_device=False) 162 | model.cuda() 163 | model.train() 164 | model.module.module.freeze_bn() # We keep BatchNorm frozen 165 | 166 | save_freq = args.save_freq 167 | scaler = GradScaler(enabled=args.mixed_precision) 168 | 169 | should_keep_training = True 170 | global_batch_num = 0 171 | epoch = -1 172 | while should_keep_training: 173 | epoch += 1 174 | 175 | for i_batch, batch in enumerate(tqdm(train_loader)): 176 | optimizer.zero_grad() 177 | if batch is None: 178 | print("batch is None") 179 | continue 180 | for k, v in batch.items(): 181 | batch[k] = v.cuda() 182 | 183 | assert model.training 184 | 185 | output = forward_batch(batch, model, args) 186 | 187 | loss = 0 188 | logger.update() 189 | for k, v in output.items(): 190 | if "loss" in v: 191 | loss += v["loss"] 192 | logger.writer.add_scalar( 193 | f"live_{k}_loss", v["loss"].item(), total_steps 194 | ) 195 | if "metrics" in v: 196 | logger.push(v["metrics"], k) 197 | 198 | if self.global_rank == 0: 199 | if total_steps % save_freq == save_freq - 1: 200 | save_ims_to_tb(logger.writer, batch, output, total_steps) 201 | if len(output) > 1: 202 | logger.writer.add_scalar( 203 | f"live_total_loss", loss.item(), total_steps 204 | ) 205 | logger.writer.add_scalar( 206 | f"learning_rate", optimizer.param_groups[0]["lr"], total_steps 207 | ) 208 | global_batch_num += 1 209 | self.barrier() 210 | 211 | self.backward(scaler.scale(loss)) 212 | scaler.unscale_(optimizer) 213 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 214 | 215 | scaler.step(optimizer) 216 | scheduler.step() 217 | scaler.update() 218 | total_steps += 1 219 | 220 | if self.global_rank == 0: 221 | 222 | if (i_batch >= len(train_loader) - 1) or ( 223 | total_steps == 1 and args.validate_at_start 224 | ): 225 | ckpt_iter = "0" * (6 - len(str(total_steps))) + str(total_steps) 226 | save_path = Path( 227 | f"{args.ckpt_path}/model_{args.name}_{ckpt_iter}.pth" 228 | ) 229 | 230 | save_dict = { 231 | "model": model.module.module.state_dict(), 232 | "optimizer": optimizer.state_dict(), 233 | "scheduler": scheduler.state_dict(), 234 | "total_steps": total_steps, 235 | } 236 | 237 | logging.info(f"Saving file {save_path}") 238 | self.save(save_dict, save_path) 239 | 240 | if epoch % args.evaluate_every_n_epoch == 0: 241 | logging.info(f"Evaluation at epoch {epoch}") 242 | run_test_eval( 243 | args.ckpt_path, 244 | "valid", 245 | evaluator, 246 | model, 247 | eval_dataloaders, 248 | logger.writer, 249 | total_steps, 250 | ) 251 | 252 | model.train() 253 | model.module.module.freeze_bn() 254 | 255 | self.barrier() 256 | if total_steps > args.num_steps: 257 | should_keep_training = False 258 | break 259 | 260 | logger.close() 261 | PATH = f"{args.ckpt_path}/{args.name}_final.pth" 262 | torch.save(model.module.module.state_dict(), PATH) 263 | 264 | test_dataloader_dr = datasets.DynamicStereoDataset( 265 | split="test", sample_len=150, only_first_n_samples=1 266 | ) 267 | test_dataloaders = [ 268 | ("sintel_clean", eval_dataloader_sintel_clean), 269 | ("sintel_final", eval_dataloader_sintel_final), 270 | ("dynamic_replica", test_dataloader_dr), 271 | ] 272 | run_test_eval( 273 | args.ckpt_path, 274 | "test", 275 | evaluator, 276 | model, 277 | test_dataloaders, 278 | logger.writer, 279 | total_steps, 280 | ) 281 | 282 | 283 | if __name__ == "__main__": 284 | parser = argparse.ArgumentParser() 285 | parser.add_argument("--name", default="dynamic-stereo", help="name your experiment") 286 | parser.add_argument("--restore_ckpt", help="restore checkpoint") 287 | parser.add_argument("--ckpt_path", help="path to save checkpoints") 288 | parser.add_argument( 289 | "--mixed_precision", action="store_true", help="use mixed precision" 290 | ) 291 | 292 | # Training parameters 293 | parser.add_argument( 294 | "--batch_size", type=int, default=6, help="batch size used during training." 295 | ) 296 | parser.add_argument( 297 | "--train_datasets", 298 | nargs="+", 299 | default=["things", "monkaa", "driving"], 300 | help="training datasets.", 301 | ) 302 | parser.add_argument("--lr", type=float, default=0.0002, help="max learning rate.") 303 | 304 | parser.add_argument( 305 | "--num_steps", type=int, default=100000, help="length of training schedule." 306 | ) 307 | parser.add_argument( 308 | "--image_size", 309 | type=int, 310 | nargs="+", 311 | default=[320, 720], 312 | help="size of the random image crops used during training.", 313 | ) 314 | parser.add_argument( 315 | "--train_iters", 316 | type=int, 317 | default=16, 318 | help="number of updates to the disparity field in each forward pass.", 319 | ) 320 | parser.add_argument( 321 | "--wdecay", type=float, default=0.00001, help="Weight decay in optimizer." 322 | ) 323 | 324 | parser.add_argument( 325 | "--sample_len", type=int, default=2, help="length of training video samples" 326 | ) 327 | parser.add_argument( 328 | "--validate_at_start", action="store_true", help="validate the model at start" 329 | ) 330 | parser.add_argument("--save_freq", type=int, default=100, help="save frequency") 331 | 332 | parser.add_argument( 333 | "--evaluate_every_n_epoch", 334 | type=int, 335 | default=1, 336 | help="evaluate every n epoch", 337 | ) 338 | 339 | parser.add_argument( 340 | "--num_workers", type=int, default=6, help="number of dataloader workers." 341 | ) 342 | # Validation parameters 343 | parser.add_argument( 344 | "--valid_iters", 345 | type=int, 346 | default=32, 347 | help="number of updates to the disparity field in each forward pass during validation.", 348 | ) 349 | # Architecure choices 350 | parser.add_argument( 351 | "--different_update_blocks", 352 | action="store_true", 353 | help="use different update blocks for each resolution", 354 | ) 355 | parser.add_argument( 356 | "--attention_type", 357 | type=str, 358 | help="attention type of the SST and update blocks. \ 359 | Any combination of 'self_stereo', 'temporal', 'update_time', 'update_space' connected by an underscore.", 360 | ) 361 | parser.add_argument( 362 | "--update_block_3d", action="store_true", help="use Conv3D update block" 363 | ) 364 | # Data augmentation 365 | parser.add_argument( 366 | "--img_gamma", type=float, nargs="+", default=None, help="gamma range" 367 | ) 368 | parser.add_argument( 369 | "--saturation_range", 370 | type=float, 371 | nargs="+", 372 | default=None, 373 | help="color saturation", 374 | ) 375 | parser.add_argument( 376 | "--do_flip", 377 | default=False, 378 | choices=["h", "v"], 379 | help="flip the images horizontally or vertically", 380 | ) 381 | parser.add_argument( 382 | "--spatial_scale", 383 | type=float, 384 | nargs="+", 385 | default=[0, 0], 386 | help="re-scale the images randomly", 387 | ) 388 | parser.add_argument( 389 | "--noyjitter", 390 | action="store_true", 391 | help="don't simulate imperfect rectification", 392 | ) 393 | args = parser.parse_args() 394 | 395 | logging.basicConfig( 396 | level=logging.INFO, 397 | format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", 398 | ) 399 | 400 | Path(args.ckpt_path).mkdir(exist_ok=True, parents=True) 401 | from pytorch_lightning.strategies import DDPStrategy 402 | 403 | Lite( 404 | strategy=DDPStrategy(find_unused_parameters=True), 405 | devices="auto", 406 | accelerator="gpu", 407 | precision=32, 408 | ).run(args) 409 | -------------------------------------------------------------------------------- /train_utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | 13 | class Logger: 14 | 15 | SUM_FREQ = 100 16 | 17 | def __init__(self, model, scheduler, ckpt_path): 18 | self.model = model 19 | self.scheduler = scheduler 20 | self.total_steps = 0 21 | self.running_loss = {} 22 | self.ckpt_path = ckpt_path 23 | self.writer = SummaryWriter(log_dir=os.path.join(self.ckpt_path, "runs")) 24 | 25 | def _print_training_status(self): 26 | metrics_data = [ 27 | self.running_loss[k] / Logger.SUM_FREQ 28 | for k in sorted(self.running_loss.keys()) 29 | ] 30 | training_str = "[{:6d}] ".format(self.total_steps + 1) 31 | metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data) 32 | 33 | # print the training status 34 | logging.info( 35 | f"Training Metrics ({self.total_steps}): {training_str + metrics_str}" 36 | ) 37 | 38 | if self.writer is None: 39 | self.writer = SummaryWriter(log_dir=os.path.join(self.ckpt_path, "runs")) 40 | for k in self.running_loss: 41 | self.writer.add_scalar( 42 | k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps 43 | ) 44 | self.running_loss[k] = 0.0 45 | 46 | def push(self, metrics, task): 47 | for key in metrics: 48 | task_key = str(key) + "_" + task 49 | if task_key not in self.running_loss: 50 | self.running_loss[task_key] = 0.0 51 | self.running_loss[task_key] += metrics[key] 52 | 53 | def update(self): 54 | self.total_steps += 1 55 | if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ - 1: 56 | self._print_training_status() 57 | self.running_loss = {} 58 | 59 | def write_dict(self, results): 60 | if self.writer is None: 61 | self.writer = SummaryWriter(log_dir=os.path.join(self.ckpt_path, "runs")) 62 | 63 | for key in results: 64 | self.writer.add_scalar(key, results[key], self.total_steps) 65 | 66 | def close(self): 67 | self.writer.close() 68 | -------------------------------------------------------------------------------- /train_utils/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | def sequence_loss(flow_preds, flow_gt, valid, loss_gamma=0.9, max_flow=700): 10 | """Loss function defined over sequence of flow predictions""" 11 | n_predictions = len(flow_preds) 12 | assert n_predictions >= 1 13 | flow_loss = 0.0 14 | # exlude invalid pixels and extremely large diplacements 15 | mag = torch.sum(flow_gt ** 2, dim=1).sqrt().unsqueeze(1) 16 | 17 | if len(valid.shape) != len(flow_gt.shape): 18 | valid = valid.unsqueeze(1) 19 | 20 | valid = (valid >= 0.5) & (mag < max_flow) 21 | 22 | if valid.shape != flow_gt.shape: 23 | valid = torch.cat([valid, valid], dim=1) 24 | assert valid.shape == flow_gt.shape, [valid.shape, flow_gt.shape] 25 | assert not torch.isinf(flow_gt[valid.bool()]).any() 26 | 27 | for i in range(n_predictions): 28 | assert ( 29 | not torch.isnan(flow_preds[i]).any() 30 | and not torch.isinf(flow_preds[i]).any() 31 | ) 32 | 33 | if n_predictions == 1: 34 | i_weight = 1 35 | else: 36 | # We adjust the loss_gamma so it is consistent for any number of iterations 37 | adjusted_loss_gamma = loss_gamma ** (15 / (n_predictions - 1)) 38 | i_weight = adjusted_loss_gamma ** (n_predictions - i - 1) 39 | 40 | flow_pred = flow_preds[i].clone() 41 | if valid.shape[1] == 1 and flow_preds[i].shape[1] == 2: 42 | flow_pred = flow_pred[:, :1] 43 | 44 | i_loss = (flow_pred - flow_gt).abs() 45 | 46 | assert i_loss.shape == valid.shape, [ 47 | i_loss.shape, 48 | valid.shape, 49 | flow_gt.shape, 50 | flow_pred.shape, 51 | ] 52 | flow_loss += i_weight * i_loss[valid.bool()].mean() 53 | 54 | epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt() 55 | 56 | valid = valid[:, 0] 57 | epe = epe.view(-1) 58 | epe = epe[valid.reshape(epe.shape)] 59 | 60 | metrics = { 61 | "epe": epe.mean().item(), 62 | "1px": (epe < 1).float().mean().item(), 63 | "3px": (epe < 3).float().mean().item(), 64 | "5px": (epe < 5).float().mean().item(), 65 | } 66 | return flow_loss, metrics 67 | -------------------------------------------------------------------------------- /train_utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import os 9 | import torch 10 | 11 | import json 12 | import flow_vis 13 | import matplotlib.pyplot as plt 14 | 15 | import dynamic_stereo.datasets.dynamic_stereo_datasets as datasets 16 | from dynamic_stereo.evaluation.utils.utils import aggregate_and_print_results 17 | 18 | 19 | def count_parameters(model): 20 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 21 | 22 | 23 | def run_test_eval(ckpt_path, eval_type, evaluator, model, dataloaders, writer, step): 24 | for real_sequence_name in ["teddy_static", "ignacio_waving", "nikita_reading"]: 25 | seq_len_real = 50 26 | ds_path = f"./dynamic_replica_data/real/{real_sequence_name}" 27 | real_dataset = datasets.DynamicReplicaDataset( 28 | split="test", root=ds_path, sample_len=seq_len_real, only_first_n_samples=1 29 | ) 30 | 31 | evaluator.evaluate_sequence( 32 | model=model.module.module, 33 | test_dataloader=real_dataset, 34 | writer=writer, 35 | step=step, 36 | train_mode=True, 37 | ) 38 | 39 | for ds_name, dataloader in dataloaders: 40 | evaluator.visualize_interval = 1 if not "sintel" in ds_name else 0 41 | 42 | evaluate_result = evaluator.evaluate_sequence( 43 | model=model.module.module, 44 | test_dataloader=dataloader, 45 | writer=writer if not "sintel" in ds_name else None, 46 | step=step, 47 | train_mode=True, 48 | ) 49 | 50 | aggregate_result = aggregate_and_print_results( 51 | evaluate_result, 52 | ) 53 | 54 | save_metrics = [ 55 | "flow_mean_accuracy_5px", 56 | "flow_mean_accuracy_3px", 57 | "flow_mean_accuracy_1px", 58 | "flow_epe_traj_mean", 59 | ] 60 | for epe_name in ("epe", "temp_epe", "temp_epe_r"): 61 | for m in [ 62 | f"disp_{epe_name}_bad_0.5px", 63 | f"disp_{epe_name}_bad_1px", 64 | f"disp_{epe_name}_bad_2px", 65 | f"disp_{epe_name}_bad_3px", 66 | f"disp_{epe_name}_mean", 67 | ]: 68 | save_metrics.append(m) 69 | 70 | for k, v in aggregate_result.items(): 71 | if k in save_metrics: 72 | writer.add_scalars( 73 | f"{ds_name}_{k.rsplit('_', 1)[0]}", 74 | {f"{ds_name}_{k}": v}, 75 | step, 76 | ) 77 | 78 | result_file = os.path.join( 79 | ckpt_path, 80 | f"result_{ds_name}_{eval_type}_{step}_mimo.json", 81 | ) 82 | print(f"Dumping {eval_type} results to {result_file}.") 83 | with open(result_file, "w") as f: 84 | json.dump(aggregate_result, f) 85 | 86 | 87 | def fig2data(fig): 88 | """ 89 | fig = plt.figure() 90 | image = fig2data(fig) 91 | @brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it 92 | @param fig a matplotlib figure 93 | @return a numpy 3D array of RGBA values 94 | """ 95 | import PIL.Image as Image 96 | 97 | # draw the renderer 98 | fig.canvas.draw() 99 | 100 | # Get the RGBA buffer from the figure 101 | w, h = fig.canvas.get_width_height() 102 | buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 103 | buf.shape = (w, h, 3) 104 | 105 | image = Image.frombytes("RGB", (w, h), buf.tobytes()) 106 | image = np.asarray(image) 107 | return image 108 | 109 | 110 | def save_ims_to_tb(writer, batch, output, total_steps): 111 | writer.add_image( 112 | "train_im", 113 | torch.cat([torch.cat([im[0], im[1]], dim=-1) for im in batch["img"][0]], dim=-2) 114 | / 255.0, 115 | total_steps, 116 | dataformats="CHW", 117 | ) 118 | if "disp" in batch and len(batch["disp"]) > 0: 119 | disp_im = [ 120 | (torch.cat([im[0], im[1]], dim=-1) * torch.cat([val[0], val[1]], dim=-1)) 121 | for im, val in zip(batch["disp"][0], batch["valid_disp"][0]) 122 | ] 123 | 124 | disp_im = torch.cat(disp_im, dim=1) 125 | 126 | figure = plt.figure() 127 | plt.imshow(disp_im.cpu()[0]) 128 | disp_im = fig2data(figure).copy() 129 | 130 | writer.add_image( 131 | "train_disp", 132 | disp_im, 133 | total_steps, 134 | dataformats="HWC", 135 | ) 136 | 137 | for k, v in output.items(): 138 | if "predictions" in v: 139 | pred = v["predictions"] 140 | if k == "disparity": 141 | figure = plt.figure() 142 | plt.imshow(pred.cpu()[0]) 143 | pred = fig2data(figure).copy() 144 | dataformat = "HWC" 145 | else: 146 | pred = torch.tensor( 147 | flow_vis.flow_to_color( 148 | pred.permute(1, 2, 0).cpu().numpy(), convert_to_bgr=False 149 | ) 150 | / 255.0 151 | ) 152 | dataformat = "HWC" 153 | writer.add_image( 154 | f"pred_{k}", 155 | pred, 156 | total_steps, 157 | dataformats=dataformat, 158 | ) 159 | if "gt" in v: 160 | gt = v["gt"] 161 | gt = torch.tensor( 162 | flow_vis.flow_to_color( 163 | gt.permute(1, 2, 0).cpu().numpy(), convert_to_bgr=False 164 | ) 165 | / 255.0 166 | ) 167 | dataformat = "HWC" 168 | writer.add_image( 169 | f"gt_{k}", 170 | gt, 171 | total_steps, 172 | dataformats=dataformat, 173 | ) 174 | --------------------------------------------------------------------------------