├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets └── teaser.png ├── requirements └── data_engine.yaml └── spatial_engine ├── camera_movement ├── TEMPLATES.py ├── calculate_frames_relations.py └── camera_movement_engine_train_val.py ├── depth_perception ├── depth_comparison_coor_engine.py ├── depth_comparison_dot_engine.py ├── depth_estimation_coor_engine.py └── depth_estimation_dot_engine.py ├── object_movement ├── single_object_movement_engine_coord.py └── single_object_movement_engine_dot.py ├── object_perception ├── compute_object_visibility.py ├── find_object_coverage.sh ├── merge_object_coverage.py ├── single_object_coverage_finder.py └── single_object_perception_engine.py ├── utils └── scannet_utils │ ├── README.md │ ├── batch_load_scannet_data.py │ ├── extract_posed_images.py │ ├── handler │ ├── info_handler.py │ └── ops.py │ ├── make_visibility_info.py │ ├── scannet_utils.py │ └── update_info_file_with_images.py └── visual_correspondence ├── visual_correspondence_qa_engine_coor_2_coor.py └── visual_correspondence_qa_engine_dot_2_multichoice.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | __pycache__ 3 | *.pkl 4 | *.parquet 5 | 6 | evaluation_data 7 | training_data 8 | tmp 9 | visualization_results 10 | Book2.csv 11 | group_acc_results.csv 12 | 13 | training_data 14 | evaluation_data 15 | 16 | .vscode 17 | .env 18 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Multi-SpatialMLLM 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to Multi-SpatialMLLM, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the "Licensor." The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

Multi-SpatialMLLM: Multi-Frame Spatial Understanding with Multi-Modal Large Language Models

3 |

4 | Runsen Xu  5 | Weiyao Wang  6 | Hao Tang  7 | Xingyu Chen  8 | Xiaodong Wang  9 | Fu-Jen Chu  10 |
11 | Dahua Lin  12 | Matt Feiszli  13 | Kevin J. Liang  14 |
15 | FAIR, Meta The Chinese University of Hong Kong  16 |

17 |

18 | 19 |

20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 |

33 | 34 | ## 🏠 About 35 | 36 |
37 | Dialogue_Teaser 38 |
39 | We present Multi-SpatialMLLM to equip MLLMs with robust multi-frame spatial understanding by integrating depth perception, visual correspondence, and dynamic perception. Central to our approach is the MultiSPA dataset, a novel, large-scale collection of more than 27 million samples spanning diverse 3D and 4D scenes. Alongside MultiSPA, we introduce a comprehensive benchmark that tests a wide spectrum of spatial tasks under uniform metrics. Our model achieves significant gains over baselines and proprietary systems, demonstrating scalable, generalizable multi-frame reasoning. We further observe multi-task benefits and early indications of emergent capabilities in challenging scenarios, and showcase our model can serve as a multi-frame reward annotator for robotics. 40 | 41 | 42 | ## 🔥 News 43 | - [2025-05-22] We release the [paper](https://arxiv.org/abs/2505.17015) of Multi-SpatialMLLM and the codes of our data engine. 🎉 44 | 45 | ## 📋 Contents 46 | - [📦 Data Engine](#-data-engine) 47 | - [🏋️‍♂️ Model Training](#-model-training) 48 | - [🔗 Citation](#-citation) 49 | - [📄 License](#-license) 50 | - [👥 Contributing](#-contributing) 51 | 52 | ## Data Engine 53 | ### Environment 54 | To set up the Conda environment for our data engine, please follow these steps: 55 | 56 | 1. Ensure you have [Anaconda](https://www.anaconda.com/products/distribution) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html) installed. 57 | 2. Clone this repository. 58 | ```bash 59 | git clone https://github.com/facebookresearch/Multi-SpatialMLLM.git 60 | cd Multi-SpatialMLLM 61 | ``` 62 | 3. Create the Conda environment using the provided YAML file: 63 | ```bash 64 | conda env create -f requirements/data_engine.yaml 65 | ``` 66 | 4. Activate the newly created environment: 67 | ```bash 68 | conda activate data_engine 69 | ``` 70 | 71 | ### Data Preparation 72 | #### ScanNet 73 | Please follow [spatial_engine/utils/scannet_utils/README.md](spatial_engine/utils/scannet_utils/README.md) to download and process the ScanNet data. 74 | 75 | #### TAPVid-3D 76 | Follow [TAPVid-3D](https://github.com/google-deepmind/tapnet/tree/main/tapnet/tapvid3d) to download the data. We only use the ADT and PStudio subsets. You need to download the version with camera extrinsics annotation according to [this](https://github.com/google-deepmind/tapnet/issues/115). 77 |
78 | Here are some notes for your reference, but simply following the official script is enough. 79 | 80 | - Change the download url in `tapnet/tapnet/tapvid3d/annotation_generation/gcs_utils.py` from `https://storage.googleapis.com/dm-tapnet/tapvid3d/release_files/rc4` to `"https://storage.googleapis.com/dm-tapnet/tapvid3d/release_files/rc5"`. Also, modify `tapnet/tapnet/tapvid3d/annotation_generation/adt_utils.py` to store the extrinsics_w2c in the npz file like below. 81 | ```python 82 | # also add this for warning 83 | sequence_path = os.path.join(input_adt_path, adt_v2_name) 84 | # if the sequence_path does not exist, write to a warning file and exit 85 | if not os.path.exists(sequence_path): 86 | with open(f"adt_warning.txt", "a") as f: 87 | f.write(f"Sequence {seq_name} does not exist.") 88 | return 89 | ... 90 | ... 91 | 92 | queries_xyt = in_npz["queries_xyt"] 93 | trajectories = in_npz["tracks_XYZ"] 94 | visibilities = in_npz["visibility"] 95 | extrinsics_w2c = in_npz["extrinsics_w2c"] # add this 96 | 97 | # Verify video means. 98 | video_means = np.stack([np.mean(x, axis=(0, 1)) for x in rgb_ims], axis=0) 99 | assert np.allclose(video_means, in_npz["video_means"], atol=1e-3) 100 | 101 | example = { 102 | "images_jpeg_bytes": rgb_jpegs, 103 | "queries_xyt": queries_xyt, 104 | "tracks_XYZ": trajectories, 105 | "visibility": visibilities, 106 | "fx_fy_cx_cy": np.array( 107 | [FOCAL_LENGTH, FOCAL_LENGTH, WIDTH / 2, HEIGHT / 2] 108 | ), 109 | "extrinsics_w2c": extrinsics_w2c # add this 110 | } 111 | ``` 112 | - Each npz file from TAPVid-3D contains the following keys: 113 | ```python 114 | """ 115 | Each *.npz file contains: 116 | 117 | images_jpeg_bytes: tensor of shape [# of frames, height, width, 3], each frame stored as JPEG bytes that must be decoded 118 | intrinsics: (fx, fy, cx, cy) camera intrinsics of the video 119 | tracks_xyz: tensor of shape (# of frames, # of point tracks, 3), representing the 3D point trajectories and the last dimension is the (x, y, z) point position in meters. They are in camera coordinates. 120 | visibility: tensor of shape (# of frames, # of point tracks), representing the visibility of each point along its trajectory 121 | queries_xyt: tensor of shape (# of point tracks, 3), representing the query point used in the benchmark as the initial given point to track. The last dimension is given in (x, y, t), where x,y are the pixel location of the query point and t is the query frame. 122 | extrinsics_w2c: tensor of shape (#, 4, 4) 123 | """ 124 | ``` 125 | - For Pstudio, after running the official script, you will have a `tmp` folder inside, which is used to store the original video (image) from the pstudio dataset. You can just omit this folder. 126 | - For ADT 127 | - Need to download the original files from the Project Aria website and place the data in `data/projectaria_tools_adt_data`. 128 | ``` 129 | pip install projectaria-tools'[all]' 130 | 131 | # get a adt_download_urls.json from the official website 132 | mkdir data/projectaria_tools_adt_data 133 | mv adt_download_urls.json data/projectaria_tools_adt_data 134 | 135 | # download the data with all the types, it costs 1.4 T in total. 136 | aria_dataset_downloader -c data/projectaria_tools_adt_data/adt_download_urls.json -o data/projectaria_tools_adt_data/ -l all 137 | ``` 138 | - Then run the official script to download the query points and postprocess them to store image info inside the npz file. 139 | ``` 140 | cd tapnet 141 | ADT_OUTPUT_DIRECTORY="tapvid3d_dataset/adt/"\nmkdir -p $ADT_OUTPUT_DIRECTORY 142 | PYTHON_DEBUG="False" 143 | conda activate projectaria # if applicable, use a new env 144 | python3 -m tapnet.tapvid3d.annotation_generation.generate_adt --output_dir=$ADT_OUTPUT_DIRECTORY --debug=$PYTHON_DEBUG --split=all --adt_base_path data/projectaria_tools_adt_data 145 | ``` 146 | Specifically, in `tapnet/tapnet/tapvid3d/annotation_generation/generate_adt.py`: 147 | ``` 148 | gcs_utils.download_tapvid3d_files(tmp_adt_dir, _SPLIT.value, "adt", _DEBUG.value) 149 | ``` 150 | This function will download the npz files from the given url to `tmp` (which costs about 11G), and the following function will merge the images/videos from `adt_base_path` to the npz file. 151 | ``` 152 | generate_adt_npz(_ADT_BASE_PATH.value, tmp_adt_dir, _OUTPUT_DIR.value) 153 | ``` 154 | 155 |
156 | 157 | Finally, we assume the structure of the data is like: 158 | ```bash 159 | data/tapvid3d_dataset 160 | ├── adt 161 | │ ├── "scene_id".npz 162 | ├── pstudio 163 | │ ├── "scene_id".npz 164 | ``` 165 | 166 | ### Data Generation 167 | We generate the data based on the conversation format of [InternVL](https://internvl.readthedocs.io/en/latest/get_started/chat_data_format.html#multi-image-data). You could easily change the generated jsonl file to the format of your own. 168 | #### Camera Movement 169 | 1. Run `python spatial_engine/camera_movement/calculate_frames_relations.py` to calculate the spatial relations between frames, e.g. their overlap ratios. After running this script, a parquet containing this spatial information will be generated in `training_data/camera_movement` and `evaluation_data/camera_movement`. 170 | 171 | 2. Then run `python spatial_engine/camera_movement/camera_movement_engine_train_val.py` to generate the training and evaluation data. 172 | 173 | #### Depth Perception 174 | 1. Run `python spatial_engine/utils/scannet_utils/make_visibility_info.py` to compute the visibility information for each frame. Note that when using this information, loading the file takes a long time, about several minutes. 175 | 2. Run `python spatial_engine/depth_perception/depth_estimation_dot_engine.py` to generate the training and evaluation data for visual-based depth estimation. 176 | 3. Run `python spatial_engine/depth_perception/depth_estimation_coor_engine.py` to generate the training and evaluation data for coordinate-based depth estimation. 177 | 4. Run `python spatial_engine/depth_perception/depth_comparison_dot_engine.py` to generate the training and evaluation data for visual-based depth comparison. 178 | 5. Run `python spatial_engine/depth_perception/depth_comparison_coor_engine.py` to generate the training and evaluation data for coordinate-based depth comparison. 179 | 180 | #### Visual Correspondence 181 | 1. Run `python spatial_engine/visual_correspondence/visual_correspondence_qa_engine_dot_2_multichoice.py` to generate the training and evaluation data for visual correspondence in dot-based multichoice format. 182 | 2. Run `python spatial_engine/visual_correspondence/visual_correspondence_qa_engine_coor_2_coor.py` to generate the training and evaluation data for visual correspondence in coordinate-based format. 183 | 184 | #### Object Perception 185 | 1. Run `python spatial_engine/object_perception/compute_object_visibility.py` to compute the visibility information for each object. After running this script, a pkl file containing this visibility information will be saved to `training_data/object_perception` and `evaluation_data/object_perception`. 186 | 2. Run `bash find_object_coverage.sh` to compute the coverage information for each object in each scene. You could check `spatial_engine/object_perception/single_object_coverage_finder.py` to modify the parameters and run it with several processes. 187 | 3. After generating all the coverage information for each scene, run `python spatial_engine/object_perception/merge_object_coverage.py` to merge the coverage information. 188 | 4. Run `python spatial_engine/object_perception/single_object_perception_engine.py` to generate the training and evaluation data for object perception. 189 | 190 | #### Object Movement 191 | 1. Run `python spatial_engine/object_movement/single_object_movement_engine_coord.py` to generate the training and evaluation data for object movement in coordinate-based format. After running this script, images will be extracted from the npz file and saved to `data/my_tapvid3d_images`. 192 | 2. Run `python spatial_engine/object_movement/single_object_movement_engine_dot.py` to generate the training and evaluation data for object movement in dot-based format. 193 | 194 | ## 🏋️‍♂️ Model Training 195 | 196 | We use the [InternVL-2](https://internvl.readthedocs.io/en/latest/internvl2.0/finetune.html#start-2nd-fine-tuning) models for experiments in our paper. You could follow their official instructions to easily fine-tune the models with the generated data and reproduce our results. Other VLMs can also be used. Below are some training details used in our experiments, and more can be found in our paper. 197 | - All images should be resized to `H*W=1296*968` for training. 198 | - Different from the original InternVL setting of dynamically allocating 12 image tiles to all images, we make sure each image can use up to 6 image tiles for training and evaluation. Please change this [line](https://github.com/OpenGVLab/InternVL/blob/dd3635206874c92386185d586fffeda1026d3a76/internvl_chat/internvl/train/internvl_chat_finetune.py#L488) to `max_num=self.max_dynamic_patch`. Pay attention to GPU OOM issues, and you may change the `--max_seq_length` to `8192`. 199 | - The training config used for our main paper is in `data/configs/mix3M.json`. Note that this config only uses 3M training samples, and we use LoRA training for research efficiency. You could use more data and fully fine-tune the whole model to get much better performance. 200 | - To preserve the original ability of the model, some general instruction-following data should be added to the training data. 201 | 202 | ## 🔗 Citation 203 | 204 | If you find our work and this codebase helpful, please consider starring this repo 🌟 and cite: 205 | 206 | ```bibtex 207 | @article{xu2025multi, 208 | title={Multi-SpatialMLLM: Multi-Frame Spatial Understanding with Multi-Modal Large Language Models}, 209 | author={Xu, Runsen and Wang, Weiyao and Tang, Hao and Chen, Xingyu and Wang, Xiaodong and Chu, Fu-Jen and Lin, Dahua and Feiszli, Matt and Liang, Kevin J.}, 210 | journal={arXiv preprint arXiv:2505.17015}, 211 | year={2025} 212 | } 213 | ``` 214 | 215 | ## 📄 License 216 | 217 | Shield: [![CC BY-NC 4.0][cc-by-nc-shield]][cc-by-nc] 218 | 219 | This work is licensed under a 220 | [Creative Commons Attribution-NonCommercial 4.0 International License][cc-by-nc]. 221 | 222 | [![CC BY-NC 4.0][cc-by-nc-image]][cc-by-nc] 223 | 224 | [cc-by-nc]: https://creativecommons.org/licenses/by-nc/4.0/ 225 | [cc-by-nc-image]: https://licensebuttons.net/l/by-nc/4.0/88x31.png 226 | [cc-by-nc-shield]: https://img.shields.io/badge/License-CC%20BY--NC%204.0-lightgrey.svg 227 | 228 | ## 👥 Contributing 229 | 230 | See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md). 231 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/Multi-SpatialMLLM/2f2eda329f9aa047428413d77e925271db5d4f33/assets/teaser.png -------------------------------------------------------------------------------- /requirements/data_engine.yaml: -------------------------------------------------------------------------------- 1 | name: data_engine 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=conda_forge 7 | - _openmp_mutex=4.5=2_gnu 8 | - asttokens=3.0.0=pyhd8ed1ab_1 9 | - bzip2=1.0.8=h5eee18b_6 10 | - ca-certificates=2024.12.31=h06a4308_0 11 | - comm=0.2.2=pyhd8ed1ab_1 12 | - debugpy=1.8.12=py311hfdbb021_0 13 | - decorator=5.1.1=pyhd8ed1ab_1 14 | - exceptiongroup=1.2.2=pyhd8ed1ab_1 15 | - executing=2.1.0=pyhd8ed1ab_1 16 | - importlib-metadata=8.6.1=pyha770c72_0 17 | - ipykernel=6.29.5=pyh3099207_0 18 | - ipython=8.31.0=pyh707e725_0 19 | - jedi=0.19.2=pyhd8ed1ab_1 20 | - jupyter_client=8.6.3=pyhd8ed1ab_1 21 | - jupyter_core=5.7.2=pyh31011fe_1 22 | - krb5=1.21.3=h143b758_0 23 | - ld_impl_linux-64=2.40=h12ee557_0 24 | - libedit=3.1.20230828=h5eee18b_0 25 | - libffi=3.4.4=h6a678d5_1 26 | - libgcc=14.2.0=h77fa898_1 27 | - libgcc-ng=14.2.0=h69a702a_1 28 | - libgomp=14.2.0=h77fa898_1 29 | - libsodium=1.0.20=h4ab18f5_0 30 | - libstdcxx=14.2.0=hc0a3c3a_1 31 | - libstdcxx-ng=11.2.0=h1234567_1 32 | - libuuid=1.41.5=h5eee18b_0 33 | - matplotlib-inline=0.1.7=pyhd8ed1ab_1 34 | - ncurses=6.4=h6a678d5_0 35 | - nest-asyncio=1.6.0=pyhd8ed1ab_1 36 | - openssl=3.4.0=h7b32b05_1 37 | - packaging=24.2=pyhd8ed1ab_2 38 | - parso=0.8.4=pyhd8ed1ab_1 39 | - pexpect=4.9.0=pyhd8ed1ab_1 40 | - pickleshare=0.7.5=pyhd8ed1ab_1004 41 | - pip=24.2=py311h06a4308_0 42 | - platformdirs=4.3.6=pyhd8ed1ab_1 43 | - prompt-toolkit=3.0.50=pyha770c72_0 44 | - psutil=6.1.1=py311h9ecbd09_0 45 | - ptyprocess=0.7.0=pyhd8ed1ab_1 46 | - pure_eval=0.2.3=pyhd8ed1ab_1 47 | - pygments=2.19.1=pyhd8ed1ab_0 48 | - python=3.11.11=he870216_0 49 | - python-dateutil=2.9.0.post0=pyhff2d567_1 50 | - python_abi=3.11=2_cp311 51 | - pyzmq=26.2.0=py311h7deb3e3_3 52 | - readline=8.2=h5eee18b_0 53 | - setuptools=75.1.0=py311h06a4308_0 54 | - six=1.17.0=pyhd8ed1ab_0 55 | - sqlite=3.45.3=h5eee18b_0 56 | - stack_data=0.6.3=pyhd8ed1ab_1 57 | - tk=8.6.14=h39e8969_0 58 | - tornado=6.4.2=py311h9ecbd09_0 59 | - traitlets=5.14.3=pyhd8ed1ab_1 60 | - typing_extensions=4.12.2=pyha770c72_1 61 | - wcwidth=0.2.13=pyhd8ed1ab_1 62 | - wheel=0.44.0=py311h06a4308_0 63 | - xz=5.4.6=h5eee18b_1 64 | - zeromq=4.3.5=h3b0a872_7 65 | - zipp=3.21.0=pyhd8ed1ab_1 66 | - zlib=1.2.13=h5eee18b_1 67 | - pip: 68 | - absl-py==2.1.0 69 | - accelerate==0.34.2 70 | - addict==2.4.0 71 | - aiofiles==24.1.0 72 | - aiohappyeyeballs==2.4.4 73 | - aiohttp==3.11.11 74 | - aiosignal==1.3.2 75 | - altair==5.5.0 76 | - annotated-types==0.7.0 77 | - anyio==4.8.0 78 | - attrs==24.3.0 79 | - beautifulsoup4==4.12.3 80 | - bitsandbytes==0.41.0 81 | - blinker==1.9.0 82 | - cachetools==5.5.1 83 | - certifi==2024.12.14 84 | - charset-normalizer==3.4.1 85 | - chex==0.1.88 86 | - click==8.1.8 87 | - colorama==0.4.6 88 | - contourpy==1.3.1 89 | - cycler==0.12.1 90 | - decord==0.6.0 91 | - deepspeed==0.13.5 92 | - defusedxml==0.7.1 93 | - distro==1.9.0 94 | - docker-pycreds==0.4.0 95 | - einops==0.6.1 96 | - einops-exts==0.0.4 97 | - einshape==1.0 98 | - et-xmlfile==2.0.0 99 | - fastapi==0.115.7 100 | - ffmpy==0.5.0 101 | - filelock==3.17.0 102 | - flow-vis==0.1 103 | - fonttools==4.55.5 104 | - frozenlist==1.5.0 105 | - fsspec==2024.12.0 106 | - future==1.0.0 107 | - gdown==5.2.0 108 | - gitdb==4.0.12 109 | - gitpython==3.1.44 110 | - gradio==3.35.2 111 | - gradio-client==0.2.9 112 | - greenlet==3.1.1 113 | - grpcio==1.70.0 114 | - h11==0.14.0 115 | - h5py==3.13.0 116 | - hjson==3.1.0 117 | - httpcore==0.17.3 118 | - httpx==0.24.0 119 | - huggingface-hub==0.27.1 120 | - idna==3.10 121 | - imageio==2.37.0 122 | - ipdb==0.13.13 123 | - ipywidgets==8.1.5 124 | - jax==0.5.0 125 | - jaxlib==0.5.0 126 | - jinja2==3.1.5 127 | - jiter==0.8.2 128 | - joblib==1.4.2 129 | - jsonlines==4.0.0 130 | - jsonpatch==1.33 131 | - jsonpointer==3.0.0 132 | - jsonschema==4.23.0 133 | - jsonschema-specifications==2024.10.1 134 | - jupyterlab-widgets==3.0.13 135 | - kiwisolver==1.4.8 136 | - langchain==0.3.0 137 | - langchain-core==0.3.6 138 | - langchain-openai==0.2.0 139 | - langchain-text-splitters==0.3.0 140 | - langsmith==0.1.129 141 | - latex2mathml==3.77.0 142 | - linkify-it-py==2.0.3 143 | - markdown==3.7 144 | - markdown-it-py==2.2.0 145 | - markdown2==2.5.2 146 | - markupsafe==3.0.2 147 | - matplotlib==3.10.0 148 | - mdit-py-plugins==0.3.3 149 | - mdurl==0.1.2 150 | - mediapy==1.2.2 151 | - ml-dtypes==0.5.1 152 | - mmcls==0.25.0 153 | - mmcv-full==1.6.2 154 | - mmengine==0.10.6 155 | - mmsegmentation==0.30.0 156 | - model-index==0.1.11 157 | - mpmath==1.3.0 158 | - multidict==6.1.0 159 | - narwhals==1.23.0 160 | - networkx==3.4.2 161 | - ninja==1.11.1.3 162 | - numpy==1.26.4 163 | - nvidia-cublas-cu11==11.11.3.6 164 | - nvidia-cublas-cu12==12.4.5.8 165 | - nvidia-cuda-cupti-cu11==11.8.87 166 | - nvidia-cuda-cupti-cu12==12.4.127 167 | - nvidia-cuda-nvrtc-cu11==11.8.89 168 | - nvidia-cuda-nvrtc-cu12==12.4.127 169 | - nvidia-cuda-runtime-cu11==11.8.89 170 | - nvidia-cuda-runtime-cu12==12.4.127 171 | - nvidia-cudnn-cu11==8.7.0.84 172 | - nvidia-cudnn-cu12==9.1.0.70 173 | - nvidia-cufft-cu11==10.9.0.58 174 | - nvidia-cufft-cu12==11.2.1.3 175 | - nvidia-curand-cu11==10.3.0.86 176 | - nvidia-curand-cu12==10.3.5.147 177 | - nvidia-cusolver-cu11==11.4.1.48 178 | - nvidia-cusolver-cu12==11.6.1.9 179 | - nvidia-cusparse-cu11==11.7.5.86 180 | - nvidia-cusparse-cu12==12.3.1.170 181 | - nvidia-ml-py==12.560.30 182 | - nvidia-nccl-cu11==2.20.5 183 | - nvidia-nccl-cu12==2.21.5 184 | - nvidia-nvjitlink-cu12==12.4.127 185 | - nvidia-nvtx-cu11==11.8.86 186 | - nvidia-nvtx-cu12==12.4.127 187 | - openai==1.60.1 188 | - opencv-python==4.11.0.86 189 | - opendatalab==0.0.10 190 | - openmim==0.3.9 191 | - openpyxl==3.1.5 192 | - openxlab==0.0.11 193 | - opt-einsum==3.4.0 194 | - ordered-set==4.1.0 195 | - orjson==3.10.15 196 | - pandas==2.2.3 197 | - peft==0.12.0 198 | - pillow==11.1.0 199 | - prettytable==3.12.0 200 | - propcache==0.2.1 201 | - protobuf==5.29.3 202 | - py-cpuinfo==9.0.0 203 | - pyarrow==19.0.0 204 | - pycocoevalcap==1.2 205 | - pycocotools==2.0.8 206 | - pycryptodome==3.21.0 207 | - pydantic==2.10.6 208 | - pydantic-core==2.27.2 209 | - pydeck==0.9.1 210 | - pydub==0.25.1 211 | - pynvml==12.0.0 212 | - pyparsing==3.2.1 213 | - pysocks==1.7.1 214 | - python-dotenv==1.0.1 215 | - python-multipart==0.0.20 216 | - pytz==2024.2 217 | - pyyaml==6.0.2 218 | - referencing==0.36.1 219 | - regex==2024.11.6 220 | - requests==2.32.3 221 | - requests-toolbelt==1.0.0 222 | - rich==13.9.4 223 | - rpds-py==0.22.3 224 | - safetensors==0.5.2 225 | - scenepic==1.1.1 226 | - scikit-learn==1.6.1 227 | - scipy==1.15.1 228 | - seaborn==0.13.2 229 | - semantic-version==2.10.0 230 | - sentencepiece==0.1.99 231 | - sentry-sdk==2.20.0 232 | - setproctitle==1.3.4 233 | - shortuuid==1.0.13 234 | - smmap==5.0.2 235 | - sniffio==1.3.1 236 | - soupsieve==2.6 237 | - sqlalchemy==2.0.37 238 | - starlette==0.45.2 239 | - streamlit==1.41.1 240 | - streamlit-image-select==0.6.0 241 | - supervision==0.25.1 242 | - svgwrite==1.4.3 243 | - sympy==1.13.1 244 | - tabulate==0.9.0 245 | - tenacity==8.5.0 246 | - tensorboard==2.18.0 247 | - tensorboard-data-server==0.7.2 248 | - tensorboardx==2.6.2.2 249 | - termcolor==2.5.0 250 | - threadpoolctl==3.5.0 251 | - tiktoken==0.8.0 252 | - timm==0.9.12 253 | - tokenizers==0.15.1 254 | - toml==0.10.2 255 | - tomli==2.2.1 256 | - toolz==1.0.0 257 | - torch==2.3.1+cu118 258 | - torchaudio==2.3.1+cu118 259 | - torchvision==0.18.1+cu118 260 | - tqdm==4.67.1 261 | - transformers==4.37.2 262 | - triton==2.3.1 263 | - tzdata==2025.1 264 | - uc-micro-py==1.0.3 265 | - ultralytics==8.3.67 266 | - ultralytics-thop==2.0.14 267 | - urllib3==2.3.0 268 | - uvicorn==0.34.0 269 | - wandb==0.19.4 270 | - watchdog==6.0.0 271 | - wavedrom==2.0.3.post3 272 | - websockets==14.2 273 | - werkzeug==3.1.3 274 | - widgetsnbextension==4.0.13 275 | - yacs==0.1.8 276 | - yapf==0.40.1 277 | - yarl==1.18.3 278 | - zstandard==0.23.0 279 | prefix: /home/runsenxu/miniconda3/envs/data_engine 280 | -------------------------------------------------------------------------------- /spatial_engine/camera_movement/calculate_frames_relations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | In total, with D5 scene info of ScanNet 9 | for train, we have 82M records, with 35M nonzero records. (82654914 and 35063709) 10 | for val, we have 24M records, with 10M nonzero records. (24117039 and 10156707) 11 | """ 12 | 13 | import numpy as np 14 | from tqdm import tqdm 15 | import torch 16 | import mmengine 17 | 18 | from spatial_engine.utils.scannet_utils.handler.info_handler import SceneInfoHandler 19 | from multiprocessing import Pool 20 | import pandas as pd 21 | import os 22 | 23 | # set random seed 24 | np.random.seed(0) 25 | 26 | DEBUG = False # Set to True for debug mode. 27 | 28 | def save_overlap_info(overlap_info, parquet_file): 29 | """ 30 | Convert a nested dictionary into a DataFrame and save it as Parquet. 31 | The nested dictionary is expected in the form: 32 | 33 | overlap_info[scene_id][(image_id1, image_id2)] = { 34 | 'overlap': ..., 35 | 'distance': ..., 36 | 'yaw': ..., 37 | 'pitch': ... 38 | } 39 | """ 40 | rows = [] 41 | for scene_id, pair_dict in overlap_info.items(): 42 | for (img1, img2), vals in pair_dict.items(): 43 | rows.append({ 44 | 'scene_id': scene_id, 45 | 'image_id1': img1, 46 | 'image_id2': img2, 47 | 'overlap': vals['overlap'], 48 | 'distance': vals['distance'], 49 | 'yaw': vals['yaw'], 50 | 'pitch': vals['pitch'] 51 | }) 52 | if not rows: 53 | print(f"[save_overlap_info] Nothing to save to {parquet_file}.") 54 | return 55 | df = pd.DataFrame(rows) 56 | df.to_parquet(parquet_file, index=False) 57 | print(f"[save_overlap_info] Saved {len(df)} records to {parquet_file}.") 58 | 59 | def save_overlap_info_nonzero(overlap_info, parquet_file_nonzero): 60 | """ 61 | Similar to save_overlap_info, but filter out any rows where overlap == 0. 62 | """ 63 | rows = [] 64 | for scene_id, pair_dict in overlap_info.items(): 65 | for (img1, img2), vals in pair_dict.items(): 66 | # Only collect if overlap != 0 67 | if vals['overlap'] != 0.0: 68 | rows.append({ 69 | 'scene_id': scene_id, 70 | 'image_id1': img1, 71 | 'image_id2': img2, 72 | 'overlap': vals['overlap'], 73 | 'distance': vals['distance'], 74 | 'yaw': vals['yaw'], 75 | 'pitch': vals['pitch'] 76 | }) 77 | 78 | if not rows: 79 | print(f"[save_overlap_info_nonzero] No nonzero-overlap pairs to save.") 80 | return 81 | 82 | df = pd.DataFrame(rows) 83 | df.to_parquet(parquet_file_nonzero, index=False) 84 | print(f"[save_overlap_info_nonzero] Saved {len(df)} records to {parquet_file_nonzero}.") 85 | 86 | def extract_yaw_pitch(R): 87 | """ 88 | Extract the yaw and pitch angles from a rotation matrix. 89 | 90 | :param R: A 4x4 or 3x3 matrix. If 4x4, only the top-left 3x3 is used. 91 | :return: A tuple (yaw, pitch) in degrees. 92 | """ 93 | R3 = R[:3, :3] if R.shape == (4, 4) else R 94 | rotated_z_axis = R3[:, 2] 95 | 96 | # Yaw: arctan2 of (y, x) 97 | yaw = np.degrees(np.arctan2(rotated_z_axis[1], rotated_z_axis[0])) 98 | # Pitch: arcsin of z component 99 | pitch = np.degrees(np.arcsin(rotated_z_axis[2] / np.linalg.norm(rotated_z_axis))) 100 | return yaw, pitch 101 | 102 | def calculate_camera_overlap(in_bounds_dict, image_id1, image_id2, use_cuda=False): 103 | """ 104 | Calculate the percentage of overlap in the field of view of two cameras using precomputed in-bounds points. 105 | 106 | :param in_bounds_dict: Dictionary containing in-bounds information for each image ID 107 | :param image_id1: Image ID of the first camera 108 | :param image_id2: Image ID of the second camera 109 | :return: Percentage of overlap 110 | """ 111 | in_bounds1 = in_bounds_dict[image_id1] 112 | in_bounds2 = in_bounds_dict[image_id2] 113 | 114 | if torch.cuda.is_available() and use_cuda: 115 | # Move data to GPU 116 | in_bounds1 = torch.from_numpy(in_bounds1).to('cuda') 117 | in_bounds2 = torch.from_numpy(in_bounds2).to('cuda') 118 | 119 | # Points that are visible in at least one of the cameras 120 | visible_points_union = torch.logical_or(in_bounds1, in_bounds2) 121 | 122 | # Points that are visible in both cameras 123 | overlap_points = torch.logical_and(in_bounds1, in_bounds2) 124 | 125 | # Calculate the overlap percentage 126 | overlap_percentage = torch.sum(overlap_points).float() / torch.sum(visible_points_union).float() * 100 127 | return overlap_percentage.item() # Move result back to CPU and convert to Python float 128 | else: 129 | # Points that are visible in at least one of the cameras 130 | visible_points_union = np.logical_or(in_bounds1, in_bounds2) 131 | 132 | # Points that are visible in both cameras 133 | overlap_points = np.logical_and(in_bounds1, in_bounds2) 134 | 135 | # Calculate the overlap percentage 136 | overlap_percentage = np.sum(overlap_points) / np.sum(visible_points_union) * 100 137 | return overlap_percentage 138 | 139 | def process_scene(scene_id, scene_infos: SceneInfoHandler, warning_file): 140 | print(f"Start processing {scene_id}.") 141 | image_ids = scene_infos.get_all_extrinsic_valid_image_ids(scene_id) 142 | 143 | # get all points in the scene 144 | scene_points = scene_infos.get_scene_points_align(scene_id) 145 | scene_points = scene_points[:, :3] 146 | 147 | # Precompute in-bounds points for each image 148 | in_bounds_dict = {} 149 | yaw_dict = {} 150 | pitch_dict = {} 151 | positions_dict = {} 152 | for image_id in image_ids: 153 | E = scene_infos.get_extrinsic_matrix_align(scene_id, image_id) 154 | 155 | # project points to camera 156 | scene_points_2d, scene_points_depth = scene_infos.project_3d_point_to_image(scene_id, image_id, scene_points) 157 | in_bounds_points = scene_infos.check_point_visibility(scene_id, image_id, scene_points_2d, scene_points_depth) # * a True or False mask 158 | 159 | if np.sum(in_bounds_points) == 0: 160 | with open(warning_file, 'a') as f: 161 | f.write(f"{scene_id}: {image_id} has no in bound points\n") 162 | 163 | in_bounds_dict[image_id] = in_bounds_points 164 | 165 | # get yaw and pitch 166 | yaw, pitch = extract_yaw_pitch(E) 167 | 168 | yaw_dict[image_id] = yaw 169 | pitch_dict[image_id] = pitch 170 | 171 | # positions 172 | positions_dict[image_id] = E[:3, 3] 173 | 174 | scene_overlap_info = {} 175 | 176 | for i, image_id1 in enumerate(image_ids): 177 | for j in range(i + 1, len(image_ids)): 178 | image_id2 = image_ids[j] 179 | overlap_percentage = calculate_camera_overlap(in_bounds_dict, image_id1, image_id2) 180 | 181 | yaw_diff = yaw_dict[image_id2] - yaw_dict[image_id1] 182 | pitch_diff = pitch_dict[image_id2] - pitch_dict[image_id1] 183 | distance = np.linalg.norm(positions_dict[image_id2] - positions_dict[image_id1]) 184 | 185 | scene_overlap_info[(image_id1, image_id2)] = {} 186 | scene_overlap_info[(image_id1, image_id2)]['overlap'] = overlap_percentage 187 | scene_overlap_info[(image_id1, image_id2)]['distance'] = distance 188 | scene_overlap_info[(image_id1, image_id2)]['yaw'] = yaw_diff 189 | scene_overlap_info[(image_id1, image_id2)]['pitch'] = pitch_diff 190 | 191 | if np.any(np.isnan(list(scene_overlap_info[(image_id1, image_id2)].values()))) or \ 192 | np.any(np.isinf(list(scene_overlap_info[(image_id1, image_id2)].values()))): 193 | with open(warning_file, 'a') as f: 194 | f.write(f"{scene_id}: {(image_id1, image_id2)} has something wrong {list(scene_overlap_info[(image_id1, image_id2)].values())}. \n") 195 | 196 | print(f"Finished scene {scene_id}.") 197 | return scene_id, scene_overlap_info 198 | 199 | 200 | def run_split(scene_info_path, output_parquet, warning_file, num_workers=15, save_interval=20): 201 | """ 202 | 1. Instantiate SceneInfoHandler for the given `scene_info_path`. 203 | 2. Load existing overlap info from `output_parquet`. 204 | 3. Process each scene in parallel and accumulate results in a dictionary. 205 | 4. Periodically save to parquet + nonzero parquet, then final save at the end. 206 | """ 207 | scene_infos = SceneInfoHandler(scene_info_path) 208 | overlap_info = {} 209 | 210 | # Gather all scenes 211 | all_scene_ids = scene_infos.get_all_scene_ids() 212 | print(f"[run_split] Found {len(all_scene_ids)} scenes in {scene_info_path}.") 213 | 214 | # If we're debugging, only process the first scene 215 | if DEBUG and len(all_scene_ids) > 1: 216 | all_scene_ids = all_scene_ids[:1] 217 | print("[run_split] DEBUG mode: processing only the first scene.") 218 | 219 | # Prepare arguments 220 | args = [(scene_id, scene_infos, warning_file) for scene_id in all_scene_ids] 221 | 222 | with Pool(num_workers) as pool: 223 | results = [] 224 | for arg in args: 225 | results.append(pool.apply_async(process_scene, arg)) 226 | 227 | for count, r in enumerate(tqdm(results, desc=f"Processing {scene_info_path}")): 228 | scene_id, scene_overlap_info = r.get() 229 | overlap_info[scene_id] = scene_overlap_info 230 | 231 | # Save partial results every 'save_interval' scenes 232 | if (count + 1) % save_interval == 0: 233 | save_overlap_info(overlap_info, output_parquet) 234 | 235 | # Also save nonzero 236 | nonzero_parquet = output_parquet.replace(".parquet", "_nonzero.parquet") 237 | save_overlap_info_nonzero(overlap_info, nonzero_parquet) 238 | 239 | print(f"[run_split] Saved partial results for {count + 1} scenes to {output_parquet}") 240 | 241 | # Final save 242 | save_overlap_info(overlap_info, output_parquet) 243 | nonzero_parquet = output_parquet.replace(".parquet", "_nonzero.parquet") 244 | save_overlap_info_nonzero(overlap_info, nonzero_parquet) 245 | print(f"[run_split] Final save to {output_parquet} complete.") 246 | # print total number of records 247 | total_records = sum(len(v) for v in overlap_info.values()) 248 | print(f"[run_split] Total number of records: {total_records}") 249 | 250 | print(f"[run_split] Nonzero overlap also saved to {nonzero_parquet}.") 251 | # need to iterate over the overlap_info to get the number of nonzero records 252 | nonzero_records = sum(1 for scene in overlap_info.values() for pair in scene.values() if pair['overlap'] != 0.0) 253 | print(f"[run_split] Total number of nonzero records: {nonzero_records}") 254 | 255 | def main(): 256 | # Adjust these paths if needed 257 | train_info_path = "data/scannet/scannet_instance_data/scenes_train_info_i_D5.pkl" 258 | val_info_path = "data/scannet/scannet_instance_data/scenes_val_info_i_D5.pkl" 259 | 260 | train_output_dir = "training_data/camera_movement" 261 | val_output_dir = "evaluation_data/camera_movement" 262 | 263 | mmengine.mkdir_or_exist(train_output_dir) 264 | mmengine.mkdir_or_exist(val_output_dir) 265 | 266 | train_output_file = os.path.join(train_output_dir, "train_camera_info_D5.parquet") 267 | val_output_file = os.path.join(val_output_dir, "val_camera_info_D5.parquet") 268 | 269 | # Warnings 270 | train_warning_file = os.path.join(train_output_dir, "train_warning_D5.txt") 271 | val_warning_file = os.path.join(val_output_dir, "val_warning_D5.txt") 272 | 273 | # If DEBUG is true, add a suffix to file names 274 | if DEBUG: 275 | train_output_file = train_output_file.replace(".parquet", "_debug.parquet") 276 | val_output_file = val_output_file.replace(".parquet", "_debug.parquet") 277 | train_warning_file = train_warning_file.replace(".txt", "_debug.txt") 278 | val_warning_file = val_warning_file.replace(".txt", "_debug.txt") 279 | 280 | num_workers = 25 281 | 282 | print(f"[main] DEBUG mode: {DEBUG}") 283 | print(f"[main] Processing train split -> {train_output_file}") 284 | run_split(train_info_path, train_output_file, train_warning_file, num_workers=num_workers, save_interval=20) 285 | 286 | print(f"[main] Processing val split -> {val_output_file}") 287 | run_split(val_info_path, val_output_file, val_warning_file, num_workers=num_workers, save_interval=20) 288 | 289 | if __name__ == "__main__": 290 | main() 291 | -------------------------------------------------------------------------------- /spatial_engine/camera_movement/camera_movement_engine_train_val.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Multi-round QA about qualitative, quantitative, and vector of movements (distance, and two angles). 9 | 10 | For questions, we have 30 templates. For answer and task description, we have 10 templates. 11 | """ 12 | 13 | import pandas as pd 14 | from tqdm import tqdm 15 | import numpy as np 16 | import random 17 | random.seed(0) 18 | np.random.seed(0) 19 | import json 20 | from spatial_engine.utils.scannet_utils.handler.info_handler import SceneInfoHandler 21 | import os 22 | import mmengine 23 | import sys 24 | 25 | sys.path.append("spatial_engine/camera_movement") 26 | from TEMPLATES import QUESTION_TEMPLATES, ANSWER_TEMPLATES, TASK_DESCRIPTION 27 | 28 | 29 | def sample_dataframe(df, all_overlap_samples, non_overlap_samples, 30 | overlap_min=0, overlap_max=100, interval=1): 31 | """ 32 | Sample from the input DataFrame, aiming to collect a total of all_overlap_samples 33 | from bins where (overlap != 0). If a bin doesn't have enough samples, 34 | the remaining quota will be passed to the next bin, and so on. 35 | 36 | :param df: Input DataFrame, must contain an 'overlap' column 37 | :param all_overlap_samples: Total desired samples from overlap>0 bins 38 | :param non_overlap_samples: Number of samples to take where overlap == 0 39 | :param overlap_min: Minimum overlap value, default is 0 40 | :param overlap_max: Maximum overlap value, default is 100 41 | :param interval: Bin interval step size, default is 1 42 | :return: Sampled DataFrame 43 | """ 44 | 45 | # ----------------------------------------------------------- 46 | # 1) overlap == 0: Sample these separately first 47 | # ----------------------------------------------------------- 48 | non_overlap_df = df[df["overlap"] == 0].copy() 49 | if len(non_overlap_df) <= non_overlap_samples: 50 | sampled_non_overlap_df = non_overlap_df 51 | else: 52 | sampled_non_overlap_df = non_overlap_df.sample(n=non_overlap_samples) 53 | 54 | # Remaining data (overlap != 0) 55 | remaining_df = df[df["overlap"] != 0].copy() 56 | 57 | # ----------------------------------------------------------- 58 | # 2) Group by overlap into bins 59 | # ----------------------------------------------------------- 60 | bins = np.arange(overlap_min, overlap_max + interval, interval) 61 | remaining_df["overlap_group"] = pd.cut( 62 | remaining_df["overlap"], 63 | bins=bins, 64 | include_lowest=True 65 | ) 66 | 67 | # Exclude rows not assigned to any bin (NaN) 68 | remaining_df = remaining_df.dropna(subset=["overlap_group"]) 69 | 70 | # Collect data for each bin 71 | bin_dfs = [] 72 | for interval_bin, group_df in remaining_df.groupby("overlap_group"): 73 | bin_dfs.append((interval_bin, group_df)) 74 | 75 | if len(bin_dfs) == 0: 76 | # If no bins exist, return only the overlap==0 samples 77 | final_sampled_df = sampled_non_overlap_df 78 | if "overlap_group" in final_sampled_df.columns: 79 | final_sampled_df.drop(columns=["overlap_group"], inplace=True) 80 | return final_sampled_df 81 | 82 | # ----------------------------------------------------------- 83 | # 3) Distribute all_overlap_samples evenly across bins 84 | # and handle the remainder 85 | # ----------------------------------------------------------- 86 | N = len(bin_dfs) 87 | base_quota = all_overlap_samples // N 88 | remainder = all_overlap_samples % N 89 | 90 | # bin_quotas[i] = base_quota, with first remainder bins getting +1 91 | bin_quotas = [base_quota] * N 92 | for i in range(remainder): 93 | bin_quotas[i] += 1 94 | 95 | # ----------------------------------------------------------- 96 | # 4) Sort bins by data volume (small to large), not by overlap interval 97 | # ----------------------------------------------------------- 98 | # Key for sorting bin_dfs: len(group_df) 99 | # bin_dfs[i] = (interval_bin, group_df) 100 | # bin_quotas[i] = base_quota +/- 1 101 | # To maintain one-to-one correspondence, we zip bin_dfs and bin_quotas together and sort 102 | bin_data = [] 103 | for i, (interval_bin, group_df) in enumerate(bin_dfs): 104 | bin_data.append({ 105 | "interval_bin": interval_bin, 106 | "group_df": group_df, 107 | "quota": bin_quotas[i], 108 | "size": len(group_df) # for sorting 109 | }) 110 | 111 | # Sort by size in ascending order 112 | bin_data.sort(key=lambda x: x["size"]) 113 | 114 | # ----------------------------------------------------------- 115 | # 5) Process each bin, using leftover_quota mechanism 116 | # ----------------------------------------------------------- 117 | sampled_df = pd.DataFrame() 118 | leftover_quota = 0 119 | 120 | for bin_info in bin_data: 121 | group_df = bin_info["group_df"] 122 | bin_quota = bin_info["quota"] 123 | 124 | current_quota = bin_quota + leftover_quota 125 | print(f"[sample_dataframe] current_quota v.s. group_df: {current_quota} v.s. {len(group_df)}") 126 | 127 | if len(group_df) <= current_quota: 128 | # Not enough for quota, take all 129 | sampled_df = pd.concat([sampled_df, group_df], ignore_index=True) 130 | # leftover = current_quota - actual amount taken 131 | leftover_quota = current_quota - len(group_df) 132 | else: 133 | # More than quota, randomly sample current_quota 134 | sampled_part = group_df.sample(n=current_quota) 135 | sampled_df = pd.concat([sampled_df, sampled_part], ignore_index=True) 136 | leftover_quota = 0 137 | 138 | # If leftover_quota > 0, data is insufficient 139 | if leftover_quota > 0: 140 | print(f"[sample_dataframe] Warning: bins not enough to reach {all_overlap_samples}; leftover {leftover_quota}") 141 | 142 | # ----------------------------------------------------------- 143 | # 6) Merge with overlap==0 samples 144 | # ----------------------------------------------------------- 145 | final_sampled_df = pd.concat([sampled_df, sampled_non_overlap_df], ignore_index=True) 146 | 147 | # Cleanup 148 | if "overlap_group" in final_sampled_df.columns: 149 | final_sampled_df.drop(columns=["overlap_group"], inplace=True, errors="ignore") 150 | 151 | return final_sampled_df 152 | 153 | def build_training_sample(scene_infos: SceneInfoHandler, row, idx: int, question_type: str): 154 | scene_id = row["scene_id"] 155 | image1 = row["image_id1"] 156 | image2 = row["image_id2"] 157 | 158 | overlap = float(row["overlap"]) 159 | yaw_angle = float(row["yaw"]) 160 | pitch_angle = float(row["pitch"]) 161 | 162 | # randomly terminate if to swap image1 and image2 163 | if random.random() < 0.5: 164 | yaw_angle = -yaw_angle 165 | pitch_angle = -pitch_angle 166 | image1, image2 = image2, image1 167 | 168 | if abs(yaw_angle) > 180: 169 | if yaw_angle > 0: 170 | yaw_angle = yaw_angle - 360 171 | else: 172 | yaw_angle = yaw_angle + 360 173 | 174 | 175 | images = [f"{scene_id}/{image1}.jpg", f"{scene_id}/{image2}.jpg"] 176 | 177 | E1 = scene_infos.get_extrinsic_matrix_align(scene_id, image1) # camera to world 178 | E2 = scene_infos.get_extrinsic_matrix_align(scene_id, image2) 179 | 180 | # assert no nan 181 | assert not np.isnan(E1).any(), f"E1 is nan for {scene_id} {image1}" 182 | assert not np.isnan(E2).any(), f"E2 is nan for {scene_id} {image2}" 183 | 184 | # Transform E2 into the coordinate system of E1 185 | E1_inv = np.linalg.inv(E1) 186 | E2_relative = E1_inv @ E2 187 | 188 | # calculate the displacement vector in the first frame's coordinate system 189 | displacement_vector = E2_relative[:3, 3] 190 | distance = np.linalg.norm(displacement_vector) 191 | 192 | # should close to the distance from df 193 | assert abs(distance - row['distance']) < 0.1, f"distance is not close to the distance from df for {scene_id} {image1} {image2}." 194 | 195 | # the output format should be one item: 196 | # {"id": 1358431, "image": ["scene0006_01/01550.jpg", "scene0006_01/01245.jpg"], "conversations": [{"from": "human", "value": "Image-1: \nImage-2: \nAssume the scene remains unchanged. Your task is to perceive the spatial information of the scene based on the captured images. Calculate the distance (in mm) separating the cameras of images and ."}, {"from": "gpt", "value": "The difference is about `1150`."}], "height_list": [968, 968], "width_list": [1296, 1296]} 197 | task_description = random.choice(TASK_DESCRIPTION) 198 | 199 | if overlap < 0.1: 200 | # random select from q1, q2, q3 201 | raise NotImplementedError("overlap < 0.1 is not supported yet.") 202 | else: 203 | question = random.choice(QUESTION_TEMPLATES[question_type]) 204 | answer_template = random.choice(ANSWER_TEMPLATES[question_type]) 205 | 206 | # replace the placeholder with the actual values 207 | # for movement, need to use > 0 to get left/right, 'forward'/'backward', 'up'/'down' 208 | # x -> left/right, y -> up/down, z -> forward/backward 209 | answer_values = { 210 | "x_movement": "right" if displacement_vector[0] > 0 else "left", 211 | "y_movement": "down" if displacement_vector[1] > 0 else "up", 212 | "z_movement": "forward" if displacement_vector[2] > 0 else "backward", 213 | "yaw_movement": "left" if yaw_angle > 0 else "right", 214 | "pitch_movement": "up" if pitch_angle > 0 else "down", 215 | "x_distance": int(abs(displacement_vector[0]) * 1000), 216 | "y_distance": int(abs(displacement_vector[1]) * 1000), 217 | "z_distance": int(abs(displacement_vector[2]) * 1000), 218 | "yaw_angle": int(abs(yaw_angle)), 219 | "pitch_angle": int(abs(pitch_angle)), 220 | "x_value": int(displacement_vector[0] * 1000), 221 | "y_value": int(displacement_vector[1] * 1000), 222 | "z_value": int(displacement_vector[2] * 1000), 223 | "total_distance": int(np.linalg.norm(displacement_vector) * 1000), 224 | "displacement_vector": displacement_vector.tolist(), 225 | } 226 | # map 227 | answer_text = answer_template.format(**answer_values) 228 | 229 | conversation = [ 230 | {"from": "human", "value": f"{task_description}\n{question}"}, 231 | {"from": "gpt", "value": answer_text}, 232 | ] 233 | 234 | train_sample = { 235 | "id": idx, 236 | "image": images, 237 | "conversations": conversation, 238 | "height_list": [scene_infos.get_image_shape(scene_id, image1)[0]] * len(images), 239 | "width_list": [scene_infos.get_image_shape(scene_id, image1)[1]] * len(images), 240 | "answer_values": answer_values, 241 | "question_type": question_type, 242 | "gt_value": answer_values[question_type], 243 | } 244 | 245 | return train_sample 246 | 247 | def convert_train_sample_to_eval_sample(train_sample): 248 | # a train sample looks like this: 249 | # { 250 | # "id": f"{scene_id}_{object_id}_{pair_id}", 251 | # "image": [f"{scene_id}/{image1}.jpg", f"{scene_id}/{image2}.jpg"], 252 | # "conversations": conversation, 253 | # "height_list": [image_height] * 2, 254 | # "width_list": [image_width] * 2, 255 | # "answer_values": answer_values, 256 | # "question_type": "visual_correspondence" 257 | # } 258 | # we need a eval sample like: 259 | # { 260 | # "id": f"{scene_id}_{object_id}_{pair_id}", 261 | # "image": [f"{scene_id}/{image1}.jpg", f"{scene_id}/{image2}.jpg"], 262 | # "text": question, 263 | # "gt_value": value, 264 | # "question_type": attr, 265 | # } 266 | conversation = train_sample.pop("conversations") 267 | train_sample['text'] = conversation[0]['value'] 268 | 269 | return train_sample 270 | 271 | def build_train_dataset( 272 | parquet_path, 273 | output_dir, 274 | scene_infos, 275 | qtype, 276 | desired_count, 277 | overlap_min, 278 | overlap_max, 279 | interval 280 | ): 281 | df = pd.read_parquet(parquet_path) 282 | print(f"[Train: {qtype}] Loaded DF with {len(df)} rows from {parquet_path}") 283 | 284 | print(f"[Train: {qtype}] sampling {desired_count} samples in overlap=[{overlap_min}..{overlap_max}]") 285 | 286 | df_sampled = sample_dataframe( 287 | df, 288 | all_overlap_samples=desired_count, 289 | non_overlap_samples=0, # or your chosen number 290 | overlap_min=overlap_min, 291 | overlap_max=overlap_max, 292 | interval=interval 293 | ) 294 | print(f"[Train: {qtype}] got {len(df_sampled)} sampled rows") 295 | 296 | # build samples 297 | out_samples = [] 298 | for idx in tqdm(range(len(df_sampled)), desc=f"{qtype}"): 299 | row = df_sampled.iloc[idx] 300 | s = build_training_sample(scene_infos, row, idx, qtype) 301 | out_samples.append(s) 302 | 303 | random.shuffle(out_samples) 304 | out_file = os.path.join(output_dir, f"{qtype}_train.jsonl") 305 | print(f"[Train: {qtype}] writing {len(out_samples)} items to {out_file}") 306 | with open(out_file, "w") as f: 307 | for item in out_samples: 308 | f.write(json.dumps(item)+"\n") 309 | 310 | ######################################################################## 311 | # Build val dataset for one question type 312 | ######################################################################## 313 | 314 | def build_val_dataset( 315 | parquet_path, 316 | output_dir, 317 | scene_infos, 318 | qtype, 319 | desired_count, 320 | overlap_min, 321 | overlap_max, 322 | interval 323 | ): 324 | df = pd.read_parquet(parquet_path) 325 | print(f"[Val: {qtype}] Loaded DF with {len(df)} rows from {parquet_path}") 326 | 327 | # same sampling logic 328 | print(f"[Val: {qtype}] sampling {desired_count} samples in overlap=[{overlap_min}..{overlap_max}]") 329 | 330 | df_sampled = sample_dataframe( 331 | df, 332 | all_overlap_samples=desired_count, 333 | non_overlap_samples=0, 334 | overlap_min=overlap_min, 335 | overlap_max=overlap_max, 336 | interval=interval 337 | ) 338 | print(f"[Val: {qtype}] got {len(df_sampled)} sampled rows") 339 | 340 | # build as "train" but then convert 341 | out_samples = [] 342 | for idx in tqdm(range(len(df_sampled)), desc=f"{qtype}_val"): 343 | row = df_sampled.iloc[idx] 344 | s_train = build_training_sample(scene_infos, row, idx, qtype) 345 | s_eval = convert_train_sample_to_eval_sample(s_train) 346 | out_samples.append(s_eval) 347 | 348 | random.shuffle(out_samples) 349 | out_file = os.path.join(output_dir, f"{qtype}_val.jsonl") 350 | print(f"[Val: {qtype}] writing {len(out_samples)} items to {out_file}") 351 | with open(out_file, "w") as f: 352 | for item in out_samples: 353 | f.write(json.dumps(item)+"\n") 354 | 355 | ######################################################################## 356 | # Main 357 | ######################################################################## 358 | DEBUG = False 359 | 360 | def main(): 361 | info_path = "data/scannet/scannet_instance_data/scenes_train_val_info_i_D5.pkl" 362 | overlap_min = 6 363 | overlap_max = 35 364 | interval = 1 365 | 366 | version = "v1_0" 367 | 368 | # Desired sample counts 369 | train_question_samples = { 370 | "x_movement": 1000000, 371 | "y_movement": 1000000, 372 | "z_movement": 1000000, 373 | "yaw_movement": 1000000, 374 | "pitch_movement": 1000000, 375 | "yaw_angle": 1000000, 376 | "pitch_angle": 1000000, 377 | "total_distance": 3000000, 378 | "displacement_vector": 3000000, 379 | } 380 | val_question_samples = { 381 | "x_movement": 300, 382 | "y_movement": 300, 383 | "z_movement": 300, 384 | "yaw_movement": 300, 385 | "pitch_movement": 300, 386 | "total_distance": 300, 387 | "yaw_angle": 300, 388 | "pitch_angle": 300, 389 | "displacement_vector": 300, 390 | } 391 | 392 | # Debug overrides 393 | global DEBUG 394 | if DEBUG: 395 | train_parquet_path = "training_data/camera_movement/train_camera_info_D5_debug_nonzero.parquet" 396 | val_parquet_path = "evaluation_data/camera_movement/val_camera_info_D5_debug_nonzero.parquet" 397 | for k in train_question_samples: 398 | train_question_samples[k] = 100 399 | for k in val_question_samples: 400 | val_question_samples[k] = 100 401 | version += "_debug" 402 | else: 403 | train_parquet_path = "training_data/camera_movement/train_camera_info_D5.parquet" 404 | val_parquet_path = "evaluation_data/camera_movement/val_camera_info_D5.parquet" 405 | 406 | train_output_dir = f"training_data/camera_movement/{version}" 407 | val_output_dir = f"evaluation_data/camera_movement/{version}" 408 | 409 | mmengine.mkdir_or_exist(train_output_dir) 410 | mmengine.mkdir_or_exist(val_output_dir) 411 | 412 | # Initialize SceneInfoHandler 413 | scene_infos = SceneInfoHandler(info_path) 414 | 415 | # Build train & val for each question type 416 | for qtype in train_question_samples.keys(): 417 | print(f"\n=== Processing question type: {qtype} ===") 418 | # Each type costs about 4 mins to generate 1M samples 419 | 420 | # Build val 421 | build_val_dataset( 422 | parquet_path=val_parquet_path, 423 | output_dir=val_output_dir, 424 | scene_infos=scene_infos, 425 | qtype=qtype, 426 | desired_count=val_question_samples[qtype], 427 | overlap_min=overlap_min, 428 | overlap_max=overlap_max, 429 | interval=interval 430 | ) 431 | 432 | # Build train 433 | build_train_dataset( 434 | parquet_path=train_parquet_path, 435 | output_dir=train_output_dir, 436 | scene_infos=scene_infos, 437 | qtype=qtype, 438 | desired_count=train_question_samples[qtype], 439 | overlap_min=overlap_min, 440 | overlap_max=overlap_max, 441 | interval=interval 442 | ) 443 | 444 | print("All question types processed. Done.") 445 | 446 | if __name__ == "__main__": 447 | main() -------------------------------------------------------------------------------- /spatial_engine/depth_perception/depth_estimation_dot_engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import random 9 | import os 10 | from tqdm import tqdm 11 | import mmengine 12 | import cv2 13 | import numpy 14 | # set seed 15 | numpy.random.seed(5) 16 | random.seed(5) 17 | 18 | from spatial_engine.utils.scannet_utils.handler.info_handler import SceneInfoHandler, VisibilityInfoHandler 19 | 20 | import argparse 21 | 22 | def generate_distinct_colors(n, max_retries=10): 23 | colors = [] 24 | retries = 0 25 | while len(colors) < n and retries < max_retries: 26 | color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) 27 | if all(sum(abs(c1 - c2) for c1, c2 in zip(color, existing_color)) > 300 for existing_color in colors): 28 | colors.append(color) 29 | retries += 1 30 | if len(colors) < n: 31 | predefined_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (0, 0, 0), (255, 255, 255)] # Red, Green, Blue, Black, White 32 | colors += random.sample(predefined_colors, n - len(colors)) 33 | return colors 34 | 35 | class DepthEstimationDotQAEngine: 36 | def __init__(self, scene_info_path, 37 | version_num="v1_0", 38 | all_max_samples=-1, 39 | image_output_dir=None, 40 | visibility_info_path=None, 41 | max_n_points_per_image=1, 42 | warning_file=None): 43 | self.scene_info = SceneInfoHandler(scene_info_path) 44 | self.version_num = version_num 45 | self.image_output_dir = image_output_dir 46 | self.all_max_samples = all_max_samples 47 | self.task_name = "depth_estimation_dot" 48 | # * note, in v1_0, max_n_points_per_image is 1 49 | # * and even if it is set to be larger than 1, it will generate different QA pairs, i.e., single-round QA and the total number of samples is max_n_points_per_image * num_images (all_max_samples * max_n_points_per_image) 50 | self.max_n_points_per_image = max_n_points_per_image 51 | self.warning_file = warning_file 52 | # read visibility infos 53 | self.visibility_info = VisibilityInfoHandler(visibility_info_path) 54 | 55 | self.task_description = [ 56 | "\nGiven an image with an annotated point, complete the question-answer task.", 57 | "\nFor an image with an annotated point, answer the depth-related questions.", 58 | "\nUsing the provided image with an annotated point, complete the QA task.", 59 | "\nGiven an image with a specific annotated point, perform the question-answer process.", 60 | "\nWork with the image that has an annotated point to answer the related questions.", 61 | "\nAnalyze the image with an annotated point and complete the QA task.", 62 | "\nGiven an image where a point is annotated, proceed with the question-answer task.", 63 | "\nFor the image with an annotated point, determine the depth-related answers.", 64 | "\nUsing the image with an annotated point, perform the QA task.", 65 | "\nGiven an image with a marked point, complete the question-answer process.", 66 | "\nWork with the image containing an annotated point to answer the questions.", 67 | "\nGiven an image with a highlighted point, complete the QA task.", 68 | "\nFor an image with a marked point, answer the depth-related questions.", 69 | "\nUsing the image with a highlighted point, perform the QA task.", 70 | "\nGiven an image with a designated point, complete the question-answer process.", 71 | "\nWork with the image that has a marked point to answer the related questions.", 72 | "\nAnalyze the image with a designated point and complete the QA task.", 73 | "\nGiven an image where a point is highlighted, proceed with the question-answer task.", 74 | "\nFor the image with a designated point, determine the depth-related answers.", 75 | "\nUsing the image with a marked point, perform the QA task.", 76 | "\nGiven an image with a pinpointed location, engage in the question-answer task.", 77 | "\nFor an image with a specified point, resolve the depth-related queries.", 78 | "\nUtilize the image with a pinpointed spot to complete the QA task.", 79 | "\nGiven an image with a noted point, carry out the question-answer process.", 80 | "\nWork with the image that has a pinpointed point to answer the related questions.", 81 | "\nExamine the image with a noted point and complete the QA task.", 82 | "\nGiven an image where a point is pinpointed, proceed with the question-answer task.", 83 | "\nFor the image with a noted point, ascertain the depth-related answers.", 84 | "\nUsing the image with a pinpointed point, perform the QA task.", 85 | "\nGiven an image with a specified point, complete the question-answer process." 86 | ] 87 | 88 | self.templates = { 89 | "questions": [ 90 | "What is the depth of the annotated point in the image (in mm)?", 91 | "How far is the annotated point from the camera in millimeters?", 92 | "Determine the depth value of the annotated point in the given image (mm).", 93 | "Find the distance from the observer to the annotated point in the image, in mm.", 94 | "What is the measured depth of the annotated point in mm?", 95 | "How far away is the annotated point from the viewer in the image (mm)?", 96 | "Identify the depth value for the annotated point in millimeters.", 97 | "Given the annotated point, what is its depth in the image (mm)?", 98 | "What is the distance between the camera and the annotated point in mm?", 99 | "How deep is the annotated point in the given image (in millimeters)?", 100 | "What is the distance to the annotated point in the image (in mm)?", 101 | "How far is the annotated point located from the camera in mm?", 102 | "Determine the distance of the annotated point from the observer in mm.", 103 | "What is the depth measurement of the annotated point in the image (mm)?", 104 | "How far is the annotated point from the camera lens in millimeters?", 105 | "What is the depth of the highlighted point in the image (in mm)?", 106 | "How far is the highlighted point from the camera in millimeters?", 107 | "Determine the depth value of the highlighted point in the given image (mm).", 108 | "Find the distance from the observer to the highlighted point in the image, in mm.", 109 | "What is the measured depth of the highlighted point in mm?", 110 | "How far away is the highlighted point from the viewer in the image (mm)?", 111 | "Identify the depth value for the highlighted point in millimeters.", 112 | "Given the highlighted point, what is its depth in the image (mm)?", 113 | "What is the distance between the camera and the highlighted point in mm?", 114 | "How deep is the highlighted point in the given image (in millimeters)?", 115 | "What is the distance to the highlighted point in the image (in mm)?", 116 | "How far is the highlighted point located from the camera in mm?", 117 | "Determine the distance of the highlighted point from the observer in mm.", 118 | "What is the depth measurement of the highlighted point in the image (mm)?", 119 | "How far is the highlighted point from the camera lens in millimeters?" 120 | ], 121 | 122 | "answers": [ 123 | "The depth of the annotated point is `{depth}` mm.", 124 | "It is `{depth}` mm away from the camera.", 125 | "The depth value of the annotated point is `{depth}` mm.", 126 | "The distance to the annotated point is `{depth}` mm.", 127 | "Measured depth of the annotated point is `{depth}` mm.", 128 | "The annotated point is located `{depth}` mm away from the viewer.", 129 | "The depth of the annotated point is `{depth}` mm.", 130 | "Its depth in the image is `{depth}` mm.", 131 | "The distance from the camera to the annotated point is `{depth}` mm.", 132 | "The annotated point is `{depth}` mm deep in the image.", 133 | "The depth of the highlighted point is `{depth}` mm.", 134 | "It is `{depth}` mm away from the camera.", 135 | "The depth value of the highlighted point is `{depth}` mm.", 136 | "The distance to the highlighted point is `{depth}` mm.", 137 | "Measured depth of the highlighted point is `{depth}` mm.", 138 | "The highlighted point is located `{depth}` mm away from the viewer.", 139 | "The depth of the highlighted point is `{depth}` mm.", 140 | "Its depth in the image is `{depth}` mm.", 141 | "The distance from the camera to the highlighted point is `{depth}` mm.", 142 | "The highlighted point is `{depth}` mm deep in the image.", 143 | "The depth of the marked point is `{depth}` mm.", 144 | "It is `{depth}` mm away from the camera.", 145 | "The depth value of the marked point is `{depth}` mm.", 146 | "The distance to the marked point is `{depth}` mm.", 147 | "Measured depth of the marked point is `{depth}` mm.", 148 | "The marked point is located `{depth}` mm away from the viewer.", 149 | "The depth of the marked point is `{depth}` mm.", 150 | "Its depth in the image is `{depth}` mm.", 151 | "The distance from the camera to the marked point is `{depth}` mm.", 152 | "The marked point is `{depth}` mm deep in the image." 153 | ] 154 | } 155 | 156 | def generate_random_point(self, height, width, margin=50): 157 | """Generate random point within image boundaries with margin""" 158 | x = random.randint(margin, width - margin) 159 | y = random.randint(margin, height - margin) 160 | return (x, y) 161 | 162 | def annotate_image(self, image, point): 163 | """Annotate image with a single point""" 164 | annotated_img = image.copy() 165 | x, y = point 166 | 167 | # Generate a random distinct color 168 | color = generate_distinct_colors(1)[0] 169 | 170 | # Draw circle 171 | cv2.circle(annotated_img, (x, y), 10, color, -1) 172 | 173 | return annotated_img 174 | 175 | def generate_qa_training_single_scene(self, scene_id): 176 | # Get all valid image IDs for the scene 177 | image_ids = self.scene_info.get_all_extrinsic_valid_image_ids(scene_id) 178 | scene_image_height, scene_image_width = self.scene_info.get_image_shape(scene_id) 179 | 180 | # Calculate how many images to sample from this scene 181 | if self.max_samples > 0: 182 | n_images = min(self.max_samples, len(image_ids)) 183 | else: 184 | n_images = len(image_ids) 185 | 186 | # Randomly sample images 187 | sampled_image_ids = random.sample(image_ids, n_images) 188 | 189 | all_samples = [] 190 | for image_id in sampled_image_ids: 191 | # * Load all the visible points in this image 192 | visible_points = self.visibility_info.get_image_to_points_info(scene_id, image_id) # [point_index, ...] 193 | 194 | # * Randomly sample a set of points from the visible points in the image 195 | # * need to consider the number of available points, if not enough, use put back sampling 196 | if len(visible_points) < self.max_n_points_per_image: 197 | sampled_points = random.choices(visible_points, k=self.max_n_points_per_image) 198 | else: 199 | sampled_points = random.sample(visible_points, self.max_n_points_per_image) 200 | 201 | for point in sampled_points: 202 | # Calculate the 2D coordinates of the point in the image 203 | point_2d, point_depth = self.scene_info.get_point_2d_coordinates_in_image( 204 | scene_id, image_id, point, align=True, check_visible=True, return_depth=True 205 | ) 206 | 207 | if len(point_2d) == 0: 208 | # If the point is not visible in the image, print a warning and skip it 209 | message = f"Warning: Point-Id {point} is not visible in image {image_id} in scene {scene_id}.\n" 210 | print(message.strip()) 211 | with open(self.warning_file, 'a') as wf: 212 | wf.write(message.strip()) 213 | continue 214 | 215 | # Convert the normalized coordinates to scaled values (0-1000 range) 216 | x = round((point_2d[0][0] / scene_image_width) * 1000) 217 | y = round((point_2d[0][1] / scene_image_height) * 1000) 218 | depth = round(point_depth[0] * 1000) # depth is in meters 219 | 220 | # read and annotate the image 221 | img_path = self.scene_info.get_image_path(scene_id, image_id) 222 | img = cv2.imread(img_path) 223 | annotated_img = self.annotate_image(img, (int(point_2d[0][0]), int(point_2d[0][1]))) 224 | 225 | # save the annotated image 226 | save_dir = os.path.join(self.image_output_dir, scene_id) 227 | mmengine.mkdir_or_exist(save_dir) 228 | save_path = os.path.join(save_dir, f"{image_id}_p{point}_annotated.jpg") 229 | cv2.imwrite(save_path, annotated_img) 230 | 231 | # Fill the question and answer templates with the coordinates and depth 232 | question_template = random.choice(self.templates["questions"]) 233 | question = question_template 234 | 235 | answer_template = random.choice(self.templates["answers"]) 236 | answer = answer_template.format(x1=x, y1=y, depth=depth) 237 | 238 | task_description = random.choice(self.task_description) 239 | 240 | # If it's the first round, add the task description 241 | conversation = [ 242 | { 243 | "from": "human", 244 | "value": f"{task_description}\n{question}" 245 | }, 246 | { 247 | "from": "gpt", 248 | "value": answer 249 | } 250 | ] 251 | 252 | # Complete the training sample for the current image 253 | training_sample = { 254 | "id": f"{scene_id}_{image_id}_point{point}", 255 | "image": [f"{scene_id}/{image_id}_p{point}_annotated.jpg"], 256 | "conversations": conversation, 257 | "height_list": [scene_image_height], 258 | "width_list": [scene_image_width], 259 | "question_type": "depth_estimation_dot", 260 | "gt_value": depth, 261 | "ori_coordinates": [int(point_2d[0][0]), int(point_2d[0][1])], 262 | } 263 | all_samples.append(training_sample) 264 | 265 | return all_samples 266 | 267 | def generate_qa_training_data(self, output_dir, save_file=True): 268 | scene_ids = self.scene_info.get_sorted_keys() 269 | 270 | # if max_samples is not -1, then we need to sample the scenes 271 | if self.all_max_samples > 0: 272 | # need to calculate how many samples for each scene 273 | self.max_samples = max(self.all_max_samples // len(scene_ids) + 1, 1) 274 | if self.max_samples == 1: 275 | scene_ids = random.sample(scene_ids, self.all_max_samples) 276 | else: 277 | self.max_samples = -1 278 | self.num_used_scenes = len(scene_ids) 279 | 280 | train_data = [] 281 | for scene_id in tqdm(scene_ids, desc="Generating QA Training Data"): 282 | train_data.extend(self.generate_qa_training_single_scene(scene_id)) 283 | 284 | if len(train_data) > self.all_max_samples: 285 | train_data = random.sample(train_data, self.all_max_samples) 286 | 287 | random.shuffle(train_data) 288 | 289 | if save_file: 290 | output_jsonl_filepath = f"{output_dir}/{self.task_name}.jsonl" 291 | 292 | mmengine.mkdir_or_exist(output_dir) 293 | with open(output_jsonl_filepath, 'w') as f: 294 | for entry in train_data: 295 | f.write(json.dumps(entry) + '\n') 296 | print(f"[Train] Training data saved to {output_jsonl_filepath}. Generated {len(train_data)} samples in total.") 297 | else: 298 | return train_data 299 | 300 | def convert_train_sample_to_eval_sample(self, train_sample): 301 | conversation = train_sample["conversations"] 302 | train_sample["text"] = conversation[0]["value"] 303 | return train_sample 304 | 305 | def generate_qa_eval_data(self, output_dir): 306 | assert self.max_n_points_per_image == 1, "max_n_points_per_image should be 1 for evaluation" 307 | train_data = self.generate_qa_training_data(output_dir, save_file=False) 308 | all_data = [self.convert_train_sample_to_eval_sample(train_sample) for train_sample in train_data] 309 | 310 | output_jsonl_filepath = f"{output_dir}/{self.task_name}.jsonl" 311 | 312 | mmengine.mkdir_or_exist(output_dir) 313 | with open(output_jsonl_filepath, 'w') as f: 314 | for entry in all_data: 315 | f.write(json.dumps(entry) + '\n') 316 | 317 | print(f"[Eval] Evaluation data saved to {output_jsonl_filepath}. Generated {len(all_data)} samples in total.") 318 | 319 | if __name__ == "__main__": 320 | parser = argparse.ArgumentParser() 321 | parser.add_argument("--output_suffix", type=str, default="") 322 | 323 | parser.add_argument("--train_scene_info_path", type=str, 324 | default=f"data/scannet/scannet_instance_data/scenes_train_info_i_D5.pkl") 325 | parser.add_argument("--val_scene_info_path", type=str, 326 | default=f"data/scannet/scannet_instance_data/scenes_val_info_i_D5.pkl") 327 | parser.add_argument("--train_all_max_samples", type=int, default=500000) 328 | parser.add_argument("--val_all_max_samples", type=int, default=300) 329 | 330 | parser.add_argument("--output_dir_train", type=str, 331 | default=f"training_data/depth_estimation_dot") 332 | parser.add_argument("--output_dir_val", type=str, 333 | default=f"evaluation_data/depth_estimation_dot") 334 | 335 | parser.add_argument("--version_num", type=str, default="v1_0") 336 | args = parser.parse_args() 337 | 338 | args.output_dir_train = os.path.join(args.output_dir_train, args.version_num, args.output_suffix.replace('_', '')) 339 | args.output_dir_val = os.path.join(args.output_dir_val, args.version_num, args.output_suffix.replace('_', '')) 340 | args.image_output_dir_train = os.path.join(args.output_dir_train, "images") 341 | args.image_output_dir_val = os.path.join(args.output_dir_val, "images") 342 | 343 | # read the visibility info files 344 | train_visibility_info_path = f"data/scannet/scannet_instance_data/train_visibility_info_D5.parquet" 345 | val_visibility_info_path = f"data/scannet/scannet_instance_data/val_visibility_info_D5.parquet" 346 | 347 | val_warning_file = f"{args.output_dir_val}/val_warning.txt" 348 | train_warning_file = f"{args.output_dir_train}/train_warning.txt" 349 | 350 | print("Generating evaluation data...") # cost 3s to generate 351 | qa_engine_eval = DepthEstimationDotQAEngine( 352 | scene_info_path=args.val_scene_info_path, 353 | version_num=args.version_num, 354 | all_max_samples=args.val_all_max_samples, 355 | image_output_dir=args.image_output_dir_val, 356 | visibility_info_path=val_visibility_info_path, 357 | warning_file=val_warning_file 358 | ) 359 | qa_engine_eval.generate_qa_eval_data(args.output_dir_val) 360 | 361 | print("Generating training data...") # cost 1.5 hours to generate, will generate 337523 samples 362 | qa_engine_train = DepthEstimationDotQAEngine( 363 | scene_info_path=args.train_scene_info_path, 364 | version_num=args.version_num, 365 | all_max_samples=args.train_all_max_samples, 366 | image_output_dir=args.image_output_dir_train, 367 | visibility_info_path=train_visibility_info_path, 368 | warning_file=train_warning_file 369 | ) 370 | qa_engine_train.generate_qa_training_data(args.output_dir_train) -------------------------------------------------------------------------------- /spatial_engine/object_perception/compute_object_visibility.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | This script uses the saved visibility parquet file along with SceneInfoHandler to 9 | compute the visibility of each object (excluding categories in NONINFORMATIVE_DESC) 10 | in all valid images for each scene. 11 | 12 | This script needs to be run after generating the point-level visibility parquet file and before running the object frame coverage script. 13 | 14 | The final saved result structure is: 15 | { 16 | scene_id: { 17 | "object_to_images": { 18 | object_id: [ 19 | { 20 | "image_id": image_id, 21 | "intersection_count": number of intersecting points, 22 | "visibility": visibility percentage 23 | }, 24 | ... 25 | ] 26 | }, 27 | "image_to_objects": { 28 | image_id: [ 29 | { 30 | "object_id": object_id, 31 | "intersection_count": number of intersecting points, 32 | "visibility": visibility percentage 33 | }, 34 | ... 35 | ] 36 | } 37 | } 38 | } 39 | 40 | The result saves both: 41 | 1. object_to_images: Maps each object to a list of images that see it, 42 | each entry contains image_id, intersection_count and visibility percentage. 43 | 2. image_to_objects: Maps each image to a list of objects it can see, 44 | also saving intersection_count and visibility. 45 | For skipped cases, warnings are printed and written to warning.txt in the output directory. 46 | """ 47 | 48 | import os 49 | import pickle 50 | import json 51 | import numpy as np 52 | import pandas as pd 53 | from tqdm import tqdm 54 | 55 | from spatial_engine.utils.scannet_utils.handler.info_handler import SceneInfoHandler 56 | 57 | # Non-informative categories (keep original strings, don't call lower()) 58 | NONINFORMATIVE_DESC = {"wall", "object", "floor", "ceiling", "window"} 59 | 60 | def load_visibility_dict(parquet_file): 61 | """ 62 | Load parquet file and convert to dict with structure: 63 | key: "scene_id:image_to_points:image_id" 64 | value: parsed point index list 65 | """ 66 | df = pd.read_parquet(parquet_file) 67 | keys = df["key"].tolist() 68 | values = df["values"].tolist() 69 | 70 | return dict(zip(keys, values)) 71 | 72 | def process_scene(scene_id, scene_info_handler, visibility_dict): 73 | """ 74 | Process a single scene: 75 | - Iterate through each object in the scene, 76 | Skip and print warning if object's raw category is in NONINFORMATIVE_DESC; 77 | Skip and print warning if object has no point indices. 78 | - For remaining objects, call get_object_point_index to get object point indices, 79 | Calculate threshold (5% of object points, minimum 1). 80 | - Iterate through all valid images (scene_info.get_all_extrinsic_valid_image_ids), 81 | Look up visible points for that image in visibility_dict, 82 | Calculate intersection with object points, record the image if intersection count meets threshold, 83 | Save intersection count and visibility percentage. 84 | - Build two mappings: 85 | object_to_images: { object_id: [ { "image_id": ..., "intersection_count": n, "visibility": v }, ... ] } 86 | image_to_objects: { image_id: [ { "object_id": ..., "intersection_count": n, "visibility": v }, ... ] } 87 | Returns: (scene_id, result, warnings) 88 | where result is dict of above two mappings, warnings is list of all warning messages for this scene. 89 | """ 90 | print(f"Processing scene {scene_id}.") 91 | warnings_list = [] 92 | result = { 93 | "object_to_images": {}, 94 | "image_to_objects": {} 95 | } 96 | 97 | # Use the scene_info_handler 98 | if scene_id not in scene_info_handler.infos: 99 | msg = f"[Warning] Scene {scene_id} not found in scene_info." 100 | warnings_list.append(msg) 101 | print(msg) 102 | return scene_id, result, warnings_list 103 | 104 | num_objects = scene_info_handler.get_num_objects(scene_id) 105 | for object_id in range(num_objects): 106 | raw_category = scene_info_handler.get_object_raw_category(scene_id, object_id) 107 | if raw_category in NONINFORMATIVE_DESC: 108 | continue 109 | 110 | object_points = scene_info_handler.get_object_point_index(scene_id, object_id) 111 | if len(object_points) == 0: 112 | msg = f"[Warning] Scene {scene_id}, object {object_id} has no point indices, skipping." 113 | warnings_list.append(msg) 114 | print(msg) 115 | continue 116 | 117 | if isinstance(object_points, np.ndarray): 118 | object_points_set = set(object_points.tolist()) 119 | else: 120 | object_points_set = set(object_points) 121 | total_points = len(object_points_set) 122 | threshold = max(1, int(0.05 * total_points)) 123 | 124 | valid_image_ids = scene_info_handler.get_all_extrinsic_valid_image_ids(scene_id) 125 | for image_id in valid_image_ids: 126 | key = f"{scene_id}:image_to_points:{image_id}" 127 | if key not in visibility_dict: 128 | msg = f"[Warning] Scene {scene_id}, image {image_id} not found in visibility dict." 129 | warnings_list.append(msg) 130 | print(msg) 131 | continue 132 | 133 | visible_points = set(json.loads(visibility_dict[key])) 134 | intersection_count = len(visible_points & object_points_set) 135 | if intersection_count >= threshold: 136 | visibility_percent = (intersection_count / total_points) * 100.0 137 | if object_id not in result["object_to_images"]: 138 | result["object_to_images"][object_id] = [] 139 | result["object_to_images"][object_id].append({ 140 | "image_id": image_id, 141 | "intersection_count": intersection_count, 142 | "visibility": visibility_percent 143 | }) 144 | if image_id not in result["image_to_objects"]: 145 | result["image_to_objects"][image_id] = [] 146 | result["image_to_objects"][image_id].append({ 147 | "object_id": object_id, 148 | "intersection_count": intersection_count, 149 | "visibility": visibility_percent 150 | }) 151 | 152 | return scene_id, result, warnings_list 153 | 154 | def process_split(split_name, scene_info_path, visibility_parquet_file, output_dir): 155 | """ 156 | Process one split (train or val): 157 | - Load visibility parquet file and convert to dict; 158 | - Load scene_info to get all scene_ids; 159 | - Process each scene sequentially (call process_scene); 160 | - Save combined results to pkl file and write all warnings to warning.txt in output_dir. 161 | """ 162 | if not os.path.exists(output_dir): 163 | os.makedirs(output_dir) 164 | 165 | output_pkl_file = os.path.join(output_dir, "object_visibility.pkl") 166 | warning_file = os.path.join(output_dir, "warning.txt") 167 | with open(warning_file, "w") as wf: 168 | wf.write("") 169 | 170 | # Get scene_ids in the main process 171 | scene_info_handler = SceneInfoHandler(scene_info_path) 172 | scene_ids = scene_info_handler.get_all_scene_ids() 173 | 174 | results = {} 175 | all_warnings = [] 176 | 177 | scene_info_handler = SceneInfoHandler(scene_info_path) 178 | print(f"Loading visibility dict from {visibility_parquet_file}.") 179 | visibility_dict = load_visibility_dict(visibility_parquet_file) 180 | print(f"Loaded visibility dict.") 181 | 182 | for scene_id in tqdm(scene_ids, desc=f"Processing {split_name} scenes"): 183 | scene_id, scene_result, warnings = process_scene(scene_id, scene_info_handler, visibility_dict) 184 | results[scene_id] = scene_result 185 | all_warnings.extend(warnings) 186 | 187 | with open(warning_file, "a") as wf: 188 | for w in all_warnings: 189 | wf.write(w + "\n") 190 | 191 | with open(output_pkl_file, "wb") as f: 192 | pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL) 193 | 194 | print(f"Finished processing split '{split_name}'.") 195 | print(f"Result saved to {output_pkl_file}") 196 | print(f"Warnings saved to {warning_file}") 197 | 198 | 199 | def main(): 200 | # Configure parameters for train and val 201 | splits = { 202 | "val": { # * take 15 mins 203 | "scene_info_path": "data/scannet/scannet_instance_data/scenes_val_info_i_D5.pkl", 204 | "visibility_parquet_file": "data/scannet/scannet_instance_data/val_visibility_info_D5.parquet", 205 | "output_dir": "evaluation_data/object_perception" 206 | }, 207 | "train": { # * take 1h 46 mins to generate, 94M size 208 | "scene_info_path": "data/scannet/scannet_instance_data/scenes_train_info_i_D5.pkl", 209 | "visibility_parquet_file": "data/scannet/scannet_instance_data/train_visibility_info_D5.parquet", 210 | "output_dir": "training_data/object_perception" 211 | } 212 | } 213 | 214 | for split_name, config in splits.items(): 215 | process_split(split_name, 216 | config["scene_info_path"], 217 | config["visibility_parquet_file"], 218 | config["output_dir"]) 219 | 220 | 221 | if __name__ == "__main__": 222 | main() -------------------------------------------------------------------------------- /spatial_engine/object_perception/find_object_coverage.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | #!/bin/bash 8 | # run_tasks.sh 9 | # Usage: ./run_tasks.sh [start] [chunk_size] 10 | # This script runs object coverage finding for both train and val splits 11 | 12 | # Set default values 13 | DEFAULT_START=0 14 | DEFAULT_CHUNK_SIZE=10 15 | 16 | # Get command line parameters or use default values 17 | START=${1:-$DEFAULT_START} 18 | CHUNK_SIZE=${2:-$DEFAULT_CHUNK_SIZE} 19 | 20 | echo "Launching tasks for both train and val splits from scene $START with chunk size $CHUNK_SIZE" 21 | 22 | # Process train split 23 | echo "Processing train split..." 24 | for ((current_start=START; ; current_start+=CHUNK_SIZE)); do 25 | current_end=$((current_start+CHUNK_SIZE)) 26 | echo " Starting task for train: scenes $current_start to $current_end" 27 | # Launch task in background, output files will include start and end suffixes 28 | python spatial_engine/object_perception/single_object_coverage_finder.py --split "train" --start "$current_start" --end "$current_end" & 29 | 30 | # Add a limit to prevent infinite loop - can be adjusted as needed 31 | if [ $current_start -ge 1000 ]; then 32 | break 33 | fi 34 | done 35 | 36 | # Process val split 37 | echo "Processing val split..." 38 | for ((current_start=START; ; current_start+=CHUNK_SIZE)); do 39 | current_end=$((current_start+CHUNK_SIZE)) 40 | echo " Starting task for val: scenes $current_start to $current_end" 41 | # Launch task in background, output files will include start and end suffixes 42 | python spatial_engine/object_perception/single_object_coverage_finder.py --split "val" --start "$current_start" --end "$current_end" & 43 | 44 | # Add a limit to prevent infinite loop - can be adjusted as needed 45 | if [ $current_start -ge 1000 ]; then 46 | break 47 | fi 48 | done 49 | 50 | # Wait for all background tasks to complete 51 | wait 52 | echo "All tasks completed for both train and val splits." -------------------------------------------------------------------------------- /spatial_engine/object_perception/merge_object_coverage.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | #!/usr/bin/env python3 8 | """ 9 | This script should be run after find_object_coverage.sh. It's used to merge the object coverage results of different scenes. 10 | """ 11 | 12 | import os 13 | import pickle 14 | import glob 15 | import re 16 | 17 | def merge_dimension(split, base_dir, dimension): 18 | """ 19 | For the specified split ("train" or "val"), base_dir, and dimension, 20 | traverse all subdirectories starting with split_ under base_dir, 21 | load pkl files with filenames matching f"{split}_object_coverage_{dimension}__.pkl", 22 | merge them and return the combined dictionary. 23 | """ 24 | merged = {} 25 | # List all subdirectories starting with split_ 26 | subdirs = [d for d in os.listdir(base_dir) 27 | if os.path.isdir(os.path.join(base_dir, d)) and d.startswith(f"{split}_")] 28 | if not subdirs: 29 | print(f"No subdirectories starting with {split}_ found in {base_dir}.") 30 | return merged 31 | 32 | # Use regex to parse directory names (e.g., "train_0_10") 33 | pattern = re.compile(fr"{split}_(\d+)_(\d+)") 34 | dir_ranges = [] 35 | for d in subdirs: 36 | m = pattern.match(d) 37 | if m: 38 | start = int(m.group(1)) 39 | end = int(m.group(2)) 40 | dir_ranges.append((d, start, end)) 41 | else: 42 | print(f"Directory name format does not match requirements, skipping: {d}") 43 | # Sort by start value 44 | dir_ranges.sort(key=lambda x: x[1]) 45 | print(f"Found {len(dir_ranges)} {split} subdirectories:") 46 | for d, s, e in dir_ranges: 47 | print(f" {d}: {s} ~ {e}") 48 | 49 | # Traverse each subdirectory, looking for pkl files corresponding to the dimension 50 | for d, s, e in dir_ranges: 51 | subdir_path = os.path.join(base_dir, d) 52 | # Filename format: "{split}_object_coverage_{dimension}_{s}_{e}.pkl" 53 | pattern_file = os.path.join(subdir_path, f"{split}_object_coverage_{dimension}_*_*.pkl") 54 | files = glob.glob(pattern_file) 55 | if not files: 56 | print(f"No {dimension} files found in subdirectory {d}, skipping.") 57 | continue 58 | for file in files: 59 | print(f"Loading file: {file}") 60 | with open(file, "rb") as f: 61 | data = pickle.load(f) 62 | # Assume each file contains a dict with scene_id as keys 63 | merged.update(data) 64 | return merged 65 | 66 | def merge_split(split, base_dir, output_dir): 67 | """ 68 | For the specified split and base_dir, merge files for height, length, width dimensions separately, 69 | and save to output_dir. 70 | """ 71 | dims = ["height", "length", "width"] 72 | merged_dict = {} 73 | for d in dims: 74 | print(f"\nMerging {split} {d} files ...") 75 | merged = merge_dimension(split, base_dir, d) 76 | merged_dict[d] = merged 77 | num_scene = len(merged) 78 | print(f"After merging {split} {d}, there are {num_scene} scene_ids in total.") 79 | # Save results 80 | output_file = os.path.join(output_dir, f"merged_{split}_object_coverage_{d}.pkl") 81 | with open(output_file, "wb") as f: 82 | pickle.dump(merged, f) 83 | print(f"Saved merged results to {output_file}") 84 | return merged_dict 85 | 86 | def main(): 87 | # Set base_dir for training and validation sets (please modify according to actual paths) 88 | train_base = "training_data/object_perception" 89 | val_base = "evaluation_data/object_perception" 90 | 91 | # Output directory can be the same as base_dir or set separately, here we use the respective base_dir directly 92 | # Note: Output files will be saved to the corresponding folders 93 | print("Starting to merge training set files ...") 94 | merge_split("train", train_base, train_base) 95 | 96 | print("\nStarting to merge validation set files ...") 97 | merge_split("val", val_base, val_base) 98 | 99 | if __name__ == "__main__": 100 | main() -------------------------------------------------------------------------------- /spatial_engine/object_perception/single_object_coverage_finder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | #!/usr/bin/env python 8 | """ 9 | This script processes object coverage in three dimensions (height, length, width) 10 | based on the precomputed object visibility information and SceneInfoHandler. 11 | For each object (excluding those with categories in NONINFORMATIVE_DESC), we aim to find 12 | a minimal combination of images such that the union of the visible 3D points in that 13 | dimension exactly (within a specified tolerance) covers the object’s target size. 14 | A combination is considered minimal if removing any image from it would result in incomplete coverage. 15 | The output for each object is stored as a dict keyed by the number of images in the combination, 16 | e.g.: 17 | { 18 | "height": { 1: [ [img1], [img3], ... ], 2: [ [img2, img5], ... ], ... }, 19 | "length": { ... }, 20 | "width": { ... } 21 | } 22 | Note: Theoretically, one could compute visible images on the fly by projecting the object's 3D points 23 | into each image and checking visibility; however, precomputing and saving the object visibility data 24 | (such as in our previous step) can save time in subsequent processing. 25 | """ 26 | 27 | import os 28 | import pickle 29 | import json 30 | import numpy as np 31 | import pandas as pd 32 | from tqdm import tqdm 33 | import mmengine 34 | import random 35 | import argparse 36 | random.seed(0) 37 | # Tolerance for dimension coverage (10% of target) 38 | TOLERANCE = 0.1 39 | 40 | from spatial_engine.utils.scannet_utils.handler.info_handler import SceneInfoHandler 41 | 42 | DEBUG = False 43 | 44 | 45 | def load_visibility_dict(parquet_file): 46 | """ 47 | Load parquet file and convert to dict with structure: 48 | key: "scene_id:image_to_points:image_id" 49 | value: parsed point index list (stored as JSON string) 50 | """ 51 | df = pd.read_parquet(parquet_file) 52 | keys = df["key"].tolist() 53 | values = df["values"].tolist() 54 | return dict(zip(keys, values)) 55 | 56 | 57 | def compute_coverage(scene_pts, indices_bool_mask, axis): 58 | """ 59 | Given a set of indices (subset of scene_pts) and the 3D points (scene_pts), 60 | compute the coverage along the specified axis. 61 | """ 62 | if not indices_bool_mask.any(): 63 | return None 64 | coords = scene_pts[indices_bool_mask][:, axis] 65 | return max(coords) - min(coords) 66 | 67 | 68 | def covers_dimension(coverage, target, tolerance): 69 | """ 70 | Check if the computed coverage is within tolerance of the target dimension. 71 | """ 72 | if coverage is None: 73 | return False 74 | return abs(coverage - target) <= tolerance * target 75 | 76 | def find_minimal_combinations( 77 | scene_id, 78 | scene_pts, 79 | object_points_indices, 80 | visible_images, 81 | images_to_visible_points_dict, 82 | axis, 83 | target_dim, 84 | tolerance, 85 | max_images=5 86 | ): 87 | """ 88 | Use two-phase breadth-first search (BFS) to find "all minimal combinations" layer by layer: 89 | 1. Phase A: Check coverage, record minimal solutions, and add uncovered combinations to expansion list; 90 | 2. Phase B: Save minimal solutions to global set for pruning, then expand the list to the next layer. 91 | Once a combination covers the target, it stops expanding; but this doesn't affect other combinations 92 | continuing to search, thus finding all mutually exclusive minimal solutions. 93 | 94 | Returns a dictionary: {k: [list of minimal combinations of size k]}. 95 | """ 96 | # ----------- Preprocessing ----------- 97 | # 1) Build boolean mask for object points 98 | object_points_indices_mask = np.zeros(len(scene_pts), dtype=bool) 99 | object_points_indices_mask[object_points_indices] = True 100 | 101 | # 2) Extract boolean mask for each image (only keeping object points) 102 | image_bool_masks = {} 103 | for img in visible_images: 104 | key = f"{scene_id}:image_to_points:{img}" 105 | if key not in images_to_visible_points_dict: 106 | print(f"[Warning] Scene {scene_id}, image {img} not found in visibility dict. Skip this combination.") 107 | continue # skip 108 | bool_mask = np.zeros(len(scene_pts), dtype=bool) 109 | bool_mask[json.loads(images_to_visible_points_dict[key])] = True 110 | bool_mask = np.logical_and(bool_mask, object_points_indices_mask) 111 | image_bool_masks[img] = bool_mask 112 | 113 | valid_images = list(image_bool_masks.keys()) 114 | # Sort by number of covered points in descending order, optional (sometimes finds smaller combinations faster) 115 | # valid_images.sort(key=lambda x: np.sum(image_bool_masks[x]), reverse=True) 116 | 117 | # only get 25 images 118 | if len(valid_images) > 25: 119 | valid_images = random.sample(valid_images, 25) 120 | 121 | # Pre-compute cumulative remaining coverage, the union mask of all images from index i to the end 122 | cumulative_union = [None] * len(valid_images) 123 | if valid_images: 124 | cumulative_union[-1] = image_bool_masks[valid_images[-1]].copy() 125 | for i in range(len(valid_images) - 2, -1, -1): 126 | cumulative_union[i] = np.logical_or(image_bool_masks[valid_images[i]], cumulative_union[i + 1]) 127 | 128 | # Initialize a list to store the minimal sets as boolean masks 129 | found_minimal_sets = [] 130 | 131 | def is_superset_of_any_minimal_bit(comb_bitmask): 132 | # Check if the combination is a superset of any known minimal set using numpy operations 133 | if not found_minimal_sets: 134 | return False 135 | # Perform a logical AND between the combination bitmask and each minimal set 136 | # Then check if any row in the result matches the minimal set itself 137 | for minimal_set in found_minimal_sets: 138 | if np.array_equal(np.logical_and(minimal_set, comb_bitmask), minimal_set): 139 | return True 140 | return False 141 | 142 | # Helper function: Check coverage 143 | def can_cover(union_mask): 144 | cov = compute_coverage(scene_pts, union_mask, axis) 145 | return covers_dimension(cov, target_dim, tolerance) 146 | 147 | # ----------- BFS Initialization ----------- 148 | # current_level stores (comb, union_mask, last_idx), representing all size=k combinations at the current level 149 | current_level = [] 150 | for i, img in enumerate(valid_images): 151 | comb_bitmask = np.zeros(len(valid_images), dtype=bool) 152 | comb_bitmask[i] = True 153 | current_level.append(([img], image_bool_masks[img], i, comb_bitmask)) 154 | 155 | # Store results: dictionary {combination size k: [list of minimal combinations of size k]} 156 | minimal_solutions = {} 157 | first_layer_to_expand_list = [] 158 | 159 | k = 1 160 | # Enter loop while k is within limit and current_level is not empty 161 | while k <= max_images and current_level: 162 | # ---------- Phase A: Check Coverage ----------- 163 | to_expand = [] # Combinations to expand to the next level 164 | new_minimal_sets = [] 165 | 166 | # Traverse the current level 167 | for comb, union_mask, last_idx, comb_bitmask in current_level: 168 | # for comb, union_mask, last_idx, comb_bitmask in tqdm(current_level, desc=f"Processing layer {k}"): 169 | # Prune: Skip if the combination is a superset of any known minimal set 170 | if is_superset_of_any_minimal_bit(comb_bitmask): 171 | continue 172 | 173 | # If not a superset, check coverage 174 | if can_cover(union_mask): 175 | # This is a new minimal combination 176 | new_minimal_sets.append(comb_bitmask) 177 | # Save to minimal_solutions 178 | minimal_solutions.setdefault(k, []) 179 | minimal_solutions[k].append(tuple(comb)) 180 | else: 181 | # Not yet covered, add to the list to expand 182 | # early prune: if the cumulative union is not covered, skip 183 | if last_idx < len(valid_images) - 1: 184 | possible_union_mask = np.logical_or(cumulative_union[last_idx], union_mask) 185 | if not can_cover(possible_union_mask): 186 | continue 187 | to_expand.append((comb, union_mask, last_idx, comb_bitmask)) 188 | if k == 1: 189 | first_layer_to_expand_list.append((comb, union_mask, last_idx, comb_bitmask)) 190 | 191 | # Update known minimal sets 192 | if new_minimal_sets: 193 | found_minimal_sets.extend(new_minimal_sets) 194 | 195 | # ---------- Phase B: Expand to Next Level ----------- 196 | # Generate next level combinations (k+1), only expand "not covered" combinations 197 | 198 | next_level = [] 199 | if k < max_images: # If further expansion is possible 200 | for comb, union_mask, last_idx, comb_bitmask in to_expand: 201 | # Only use those in first_layer_to_expand_list to expand 202 | # Need to find the first larger image index 203 | for potential_comb, potential_union_mask, potential_last_idx, potential_comb_bitmask in first_layer_to_expand_list: 204 | if potential_last_idx > last_idx: 205 | # Make a new combination 206 | assert len(potential_comb) == 1, f"Number of images in potential_comb should be 1. {potential_comb}." 207 | new_comb = comb + potential_comb 208 | new_union_mask = np.logical_or(union_mask, potential_union_mask) 209 | new_comb_bitmask = np.logical_or(potential_comb_bitmask, comb_bitmask) 210 | next_level.append((new_comb, new_union_mask, potential_last_idx, new_comb_bitmask)) 211 | 212 | if len(next_level) > 5000: 213 | # random sample 5000 214 | next_level = random.sample(next_level, 5000) 215 | 216 | # Update current_level to the next level 217 | current_level = next_level 218 | k += 1 219 | 220 | return minimal_solutions 221 | 222 | def process_object(scene_id, object_id, scene_info_handler: SceneInfoHandler, visible_images, images_to_visible_points_dict): 223 | """ 224 | For a given object in a scene: 225 | - Get its point indices and corresponding 3D coordinates (aligned). 226 | - Use the precomputed object visibility (from the visibility file) to obtain the set of images 227 | that see this object. 228 | - For each dimension, compute the target value and corresponding axis: 229 | Height: target = get_object_height, axis = 2. 230 | For length and width, use get_object_width_axis_aligned to determine: 231 | if width_axis == 0 then width is along x (axis=0) and length along y (axis=1); 232 | else width is along y (axis=1) and length along x (axis=0). 233 | - Use find_minimal_combinations to obtain minimal image combinations covering the target. 234 | - Group the combinations by the number of images. 235 | Returns a dict: { "height": {n: [...], ...}, "length": {n: [...], ...}, "width": {n: [...], ...} }. 236 | 237 | Note: Theoretically, one could compute the visible images on the fly by projecting the object's 3D points 238 | into each image and checking visibility; however, precomputing and saving the object visibility data 239 | (e.g., in object_visibility.pkl) can save time in subsequent processing. 240 | """ 241 | # Get full aligned scene points (first 3 coords) 242 | scene_pts = scene_info_handler.get_scene_points_align(scene_id)[:, :3] 243 | object_points_indices = scene_info_handler.get_object_point_index(scene_id, object_id) 244 | 245 | # Height: target from get_object_height, axis=2. 246 | height_target = scene_info_handler.get_object_height(scene_id, object_id) 247 | height_axis = 2 248 | 249 | # For length and width, determine axis using get_object_width_axis_aligned. 250 | width_axis = scene_info_handler.get_object_width_axis_aligned(scene_id, object_id) # 0 or 1 251 | length_axis = 1 if width_axis == 0 else 0 252 | length_target = scene_info_handler.get_object_length(scene_id, object_id) 253 | width_target = scene_info_handler.get_object_width(scene_id, object_id) 254 | 255 | height_combs = find_minimal_combinations(scene_id, scene_pts, object_points_indices, visible_images, images_to_visible_points_dict, height_axis, height_target, TOLERANCE) 256 | # for height, length, width, we do not consider those less than 0.02 m 257 | length_combs = find_minimal_combinations(scene_id, scene_pts, object_points_indices, visible_images, images_to_visible_points_dict, length_axis, length_target, TOLERANCE) 258 | width_combs = find_minimal_combinations(scene_id, scene_pts, object_points_indices, visible_images, images_to_visible_points_dict, width_axis, width_target, TOLERANCE) 259 | 260 | return { 261 | "height": height_combs, 262 | "length": length_combs, 263 | "width": width_combs 264 | } 265 | 266 | def process_scene_for_coverage(scene_id, scene_info_handler, images_to_visible_points_dict, object_visibility_dict): 267 | """ 268 | Process one scene: 269 | For each object (excluding non-informative ones), compute the minimal image combinations 270 | that cover the object's height, length, and width. 271 | Returns: (scene_id, scene_result) where scene_result maps object_id to 272 | { "height": {...}, "length": {...}, "width": {...} }. 273 | """ 274 | print(f"Processing scene {scene_id} for object coverage.") 275 | scene_result = {} 276 | # iterate through object_visibility_dict 277 | scene_object_visibility = object_visibility_dict[scene_id]["object_to_images"] 278 | for object_id, visibility_list in tqdm(scene_object_visibility.items(), desc=f"Processing scene {scene_id} for object coverage"): 279 | # * process each object 280 | visible_images = [img["image_id"] for img in visibility_list] 281 | res = process_object(scene_id, object_id, scene_info_handler, visible_images, images_to_visible_points_dict) 282 | if res is not None: 283 | scene_result[object_id] = res 284 | return scene_id, scene_result 285 | 286 | 287 | def process_split_objects(split_name, scene_info_path, visibility_parquet_file, object_visibility_file, output_dir): 288 | """ 289 | Process one split (train or val) for object coverage. 290 | For each scene, compute per-object minimal image combinations for height, length, and width. 291 | Save three separate result files in output_dir. 292 | """ 293 | if not os.path.exists(output_dir): 294 | os.makedirs(output_dir) 295 | output_height_file = os.path.join(output_dir, f"{split_name}_object_coverage_height.pkl") 296 | output_length_file = os.path.join(output_dir, f"{split_name}_object_coverage_length.pkl") 297 | output_width_file = os.path.join(output_dir, f"{split_name}_object_coverage_width.pkl") 298 | warning_file = os.path.join(output_dir, "coverage_finding_warning_objects.txt") 299 | with open(warning_file, "w") as wf: 300 | wf.write("") 301 | 302 | # Load necessary data once 303 | scene_info_handler = SceneInfoHandler(scene_info_path) 304 | images_to_visible_points_dict = load_visibility_dict(visibility_parquet_file) 305 | object_visibility_dict = mmengine.load(object_visibility_file) 306 | 307 | # Get all scene_ids from SceneInfoHandler in main process 308 | scene_ids = scene_info_handler.get_all_scene_ids() 309 | 310 | results_height = {} 311 | results_length = {} 312 | results_width = {} 313 | 314 | if DEBUG: 315 | scene_ids = ["scene0011_00"] 316 | 317 | for scene_id in tqdm(scene_ids, desc=f"Processing {split_name} scenes for object coverage"): 318 | scene_id, scene_result = process_scene_for_coverage(scene_id, scene_info_handler, images_to_visible_points_dict, object_visibility_dict) 319 | if scene_result: 320 | results_height[scene_id] = {} 321 | results_length[scene_id] = {} 322 | results_width[scene_id] = {} 323 | for object_id, res in scene_result.items(): 324 | results_height[scene_id][object_id] = res["height"] # possible to be empty 325 | results_length[scene_id][object_id] = res["length"] 326 | results_width[scene_id][object_id] = res["width"] 327 | 328 | with open(output_height_file, "wb") as f: 329 | pickle.dump(results_height, f) 330 | with open(output_length_file, "wb") as f: 331 | pickle.dump(results_length, f) 332 | with open(output_width_file, "wb") as f: 333 | pickle.dump(results_width, f) 334 | 335 | print(f"Finished processing split '{split_name}' for object coverage.") 336 | print(f"Height coverage saved to {output_height_file}") 337 | print(f"Length coverage saved to {output_length_file}") 338 | print(f"Width coverage saved to {output_width_file}") 339 | print(f"Warnings saved to {warning_file}") 340 | 341 | 342 | def main(): 343 | parser = argparse.ArgumentParser(description="Process object coverage for a given split and scene index range.") 344 | parser.add_argument("--split", type=str, required=True, help="Specify the split: train or val") 345 | parser.add_argument("--start", type=int, required=True, help="Start scene index (inclusive)") 346 | parser.add_argument("--end", type=int, default=None, help="End scene index (exclusive)") 347 | args = parser.parse_args() 348 | 349 | split = args.split 350 | start_index = args.start 351 | end_index = args.end 352 | 353 | # 配置各个split的参数(根据你实际情况修改路径) 354 | splits = { 355 | "val": { 356 | "scene_info_path": "data/scannet/scannet_instance_data/scenes_val_info_i_D5.pkl", 357 | "visibility_parquet_file": "data/scannet/scannet_instance_data/val_visibility_info_D5.parquet", 358 | "object_visibility_file": "evaluation_data/object_perception/object_visibility.pkl", 359 | "output_dir": "evaluation_data/object_perception" 360 | }, 361 | "train": { 362 | "scene_info_path": "data/scannet/scannet_instance_data/scenes_train_info_i_D5.pkl", 363 | "visibility_parquet_file": "data/scannet/scannet_instance_data/train_visibility_info_D5.parquet", 364 | "object_visibility_file": "training_data/object_perception/object_visibility.pkl", 365 | "output_dir": "training_data/object_perception" 366 | } 367 | } 368 | 369 | if DEBUG: 370 | split = "val" 371 | start_index = 0 372 | end_index = 1 373 | # output dir should have a subdir named "debug" 374 | splits["val"]["output_dir"] = os.path.join(splits["val"]["output_dir"], "debug") 375 | splits["val"]["visibility_parquet_file"] = "data/scannet/scannet_instance_data/val_visibility_info_D5_debug.parquet" 376 | 377 | if split not in splits: 378 | raise ValueError("Invalid split. Choose train or val.") 379 | 380 | config = splits[split] 381 | # 在输出目录后加上 start_end 后缀 382 | config["output_dir"] = os.path.join(config["output_dir"], f"{split}_{start_index}_{end_index}") 383 | if not os.path.exists(config["output_dir"]): 384 | os.makedirs(config["output_dir"]) 385 | 386 | scene_info_handler = SceneInfoHandler(config["scene_info_path"]) 387 | images_to_visible_points_dict = load_visibility_dict(config["visibility_parquet_file"]) 388 | object_visibility_dict = mmengine.load(config["object_visibility_file"]) 389 | 390 | all_scene_ids = scene_info_handler.get_all_scene_ids() 391 | selected_scene_ids = all_scene_ids[start_index:end_index] 392 | print(f"Processing scenes from index {start_index} to {end_index} (total {len(selected_scene_ids)}) in split {split}.") 393 | 394 | results_height = {} 395 | results_length = {} 396 | results_width = {} 397 | 398 | for scene_id in tqdm(selected_scene_ids, desc=f"Processing {split} scenes"): 399 | scene_id, scene_result = process_scene_for_coverage(scene_id, scene_info_handler, images_to_visible_points_dict, object_visibility_dict) 400 | if scene_result: 401 | results_height[scene_id] = {} 402 | results_length[scene_id] = {} 403 | results_width[scene_id] = {} 404 | for object_id, res in scene_result.items(): 405 | results_height[scene_id][object_id] = res["height"] 406 | results_length[scene_id][object_id] = res["length"] 407 | results_width[scene_id][object_id] = res["width"] 408 | 409 | output_height_file = os.path.join(config["output_dir"], f"{split}_object_coverage_height_{start_index}_{end_index}.pkl") 410 | output_length_file = os.path.join(config["output_dir"], f"{split}_object_coverage_length_{start_index}_{end_index}.pkl") 411 | output_width_file = os.path.join(config["output_dir"], f"{split}_object_coverage_width_{start_index}_{end_index}.pkl") 412 | 413 | with open(output_height_file, "wb") as f: 414 | pickle.dump(results_height, f) 415 | with open(output_length_file, "wb") as f: 416 | pickle.dump(results_length, f) 417 | with open(output_width_file, "wb") as f: 418 | pickle.dump(results_width, f) 419 | 420 | print(f"Finished processing split '{split}' for scenes {start_index} to {end_index}.") 421 | print(f"Height coverage saved to {output_height_file}") 422 | print(f"Length coverage saved to {output_length_file}") 423 | print(f"Width coverage saved to {output_width_file}") 424 | 425 | if __name__ == "__main__": 426 | main() -------------------------------------------------------------------------------- /spatial_engine/object_perception/single_object_perception_engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import json 9 | import pickle 10 | import random 11 | from tqdm import tqdm 12 | import numpy as np 13 | import random 14 | random.seed(1) 15 | np.random.seed(1) # * use a different seed from cam_cam 16 | import json 17 | from spatial_engine.utils.scannet_utils.handler.info_handler import SceneInfoHandler 18 | 19 | 20 | # ==================== Global Variables ==================== 21 | max_train_samples = -1 # -1 means no downsampling for training data; if positive, downsample to this number 22 | val_max_samples = 3000 # Validation data will be randomly sampled to 300 samples 23 | 24 | 25 | ASK_DESCRIPTION = [ 26 | "Assume the scene remains unchanged. Your task is to determine the spatial properties based on the images. You need to integrate and analyze information from all provided images to get the answer.", 27 | "Given the static scene, determine the spatial properties using the images. Synthesize and evaluate information from all provided images to derive the answer.", 28 | "Analyze the images to determine the spatial properties of the scene. You must combine and interpret data from all provided images to find the answer.", 29 | "Using the images, identify the spatial properties of the scene. Collate and assess information from all provided images to reach the answer.", 30 | "Examine the images to find the spatial properties of the scene. You need to merge and scrutinize information from all provided images to obtain the answer.", 31 | "Determine the spatial properties of the scene based on the images. Integrate and review information from all provided images to conclude the answer.", 32 | "Identify the spatial properties of the scene using the images. You must gather and analyze data from all provided images to ascertain the answer.", 33 | "Find the spatial properties of the scene by analyzing the images. Combine and evaluate information from all provided images to deduce the answer.", 34 | "Use the images to determine the spatial properties of the scene. You need to synthesize and interpret information from all provided images to get the answer.", 35 | "Analyze the provided images to identify the spatial properties of the scene. Collate and assess data from all provided images to derive the answer.", 36 | "Determine the spatial properties of the scene using the images. You must integrate and scrutinize information from all provided images to find the answer.", 37 | "Identify the spatial properties of the scene by examining the images. Merge and evaluate data from all provided images to reach the answer.", 38 | "Find the spatial properties of the scene using the images. You need to gather and interpret information from all provided images to obtain the answer.", 39 | "Analyze the images to determine the spatial properties of the scene. Combine and review data from all provided images to ascertain the answer.", 40 | "Using the images, identify the spatial properties of the scene. Synthesize and scrutinize information from all provided images to deduce the answer.", 41 | "Examine the images to find the spatial properties of the scene. You must integrate and assess information from all provided images to get the answer.", 42 | "Determine the spatial properties of the scene based on the images. Collate and interpret data from all provided images to conclude the answer.", 43 | "Identify the spatial properties of the scene using the images. You need to merge and evaluate information from all provided images to derive the answer.", 44 | "Find the spatial properties of the scene by analyzing the images. Gather and review data from all provided images to find the answer.", 45 | "Use the images to determine the spatial properties of the scene. You must synthesize and assess information from all provided images to reach the answer.", 46 | "Analyze the provided images to identify the spatial properties of the scene. Integrate and interpret data from all provided images to obtain the answer.", 47 | "Determine the spatial properties of the scene using the images. You need to combine and scrutinize information from all provided images to ascertain the answer.", 48 | "Identify the spatial properties of the scene by examining the images. Collate and review data from all provided images to deduce the answer.", 49 | "Find the spatial properties of the scene using the images. Merge and assess information from all provided images to get the answer.", 50 | "Analyze the images to determine the spatial properties of the scene. You must gather and interpret data from all provided images to conclude the answer.", 51 | "Using the images, identify the spatial properties of the scene. Combine and evaluate information from all provided images to derive the answer.", 52 | "Examine the images to find the spatial properties of the scene. You need to synthesize and review data from all provided images to find the answer.", 53 | "Determine the spatial properties of the scene based on the images. Integrate and scrutinize information from all provided images to reach the answer.", 54 | "Identify the spatial properties of the scene using the images. Collate and interpret data from all provided images to obtain the answer.", 55 | "Find the spatial properties of the scene by analyzing the images. You must merge and assess information from all provided images to ascertain the answer." 56 | ] 57 | 58 | QUESTION_TEMPLATES = [ 59 | "What is the {dimension} (in millimeters) of the {object_category} itself commonly visible in these images?", 60 | "Calculate the {dimension} (in millimeters) of the {object_category} that is commonly visible in these images.", 61 | "Determine the {dimension} (in millimeters) of the {object_category} which is commonly visible in these images.", 62 | "Find the {dimension} (in millimeters) of the {object_category} that is commonly visible in these images.", 63 | "Estimate the {dimension} (in millimeters) of the {object_category} itself commonly visible in these images.", 64 | "Measure the {dimension} (in millimeters) of the {object_category} that is commonly visible in these images.", 65 | "Could you tell me the {dimension} (in millimeters) of the {object_category} which is commonly visible in these images?", 66 | "Please compute the {dimension} (in millimeters) of the {object_category} itself commonly visible in these images.", 67 | "What is the approximate {dimension} (in millimeters) of the {object_category} that is commonly visible in these images?", 68 | "Give the {dimension} (in millimeters) of the {object_category} which is commonly visible in these images.", 69 | "What is the {dimension} (in millimeters) of the {object_category} that is commonly visible across these images?", 70 | "Calculate the {dimension} (in millimeters) of the {object_category} itself commonly visible across these images.", 71 | "Determine the {dimension} (in millimeters) of the {object_category} which is commonly visible across these images.", 72 | "Find the {dimension} (in millimeters) of the {object_category} that is commonly visible across these images.", 73 | "Estimate the {dimension} (in millimeters) of the {object_category} itself commonly visible across these images.", 74 | "Measure the {dimension} (in millimeters) of the {object_category} that is commonly visible across these images.", 75 | "Could you tell me the {dimension} (in millimeters) of the {object_category} which is commonly visible across these images?", 76 | "Please compute the {dimension} (in millimeters) of the {object_category} itself commonly visible across these images.", 77 | "What is the approximate {dimension} (in millimeters) of the {object_category} that is commonly visible across these images?", 78 | "Give the {dimension} (in millimeters) of the {object_category} which is commonly visible across these images.", 79 | "What is the {dimension} (in millimeters) of the {object_category} itself that is commonly visible in these images?", 80 | "Calculate the {dimension} (in millimeters) of the {object_category} which is commonly visible in these images.", 81 | "Determine the {dimension} (in millimeters) of the {object_category} itself that is commonly visible in these images.", 82 | "Find the {dimension} (in millimeters) of the {object_category} which is commonly visible in these images.", 83 | "Estimate the {dimension} (in millimeters) of the {object_category} that is commonly visible in these images.", 84 | "Measure the {dimension} (in millimeters) of the {object_category} itself that is commonly visible in these images.", 85 | "Could you tell me the {dimension} (in millimeters) of the {object_category} that is commonly visible in these images?", 86 | "Please compute the {dimension} (in millimeters) of the {object_category} which is commonly visible in these images.", 87 | "What is the approximate {dimension} (in millimeters) of the {object_category} itself that is commonly visible in these images?", 88 | "Give the {dimension} (in millimeters) of the {object_category} that is commonly visible in these images." 89 | ] 90 | 91 | ANSWER_TEMPLATES = [ 92 | "The {dimension} is approximately `{value_mm}` millimeters.", 93 | "It measures about `{value_mm}` millimeters in {dimension}.", 94 | "I estimate the {dimension} to be around `{value_mm}` millimeters.", 95 | "The {object_category}'s {dimension} is roughly `{value_mm}` millimeters.", 96 | "Based on the images, the {dimension} is near `{value_mm}` millimeters.", 97 | "It appears that the {dimension} is `{value_mm}` millimeters.", 98 | "From my estimation, the {dimension} is `{value_mm}` millimeters.", 99 | "The {dimension} seems to be around `{value_mm}` millimeters.", 100 | "Approximately, the {dimension} is `{value_mm}` millimeters.", 101 | "I would say the {dimension} is `{value_mm}` millimeters.", 102 | "The {dimension} is estimated to be `{value_mm}` millimeters.", 103 | "In my view, the {dimension} is about `{value_mm}` millimeters.", 104 | "The {dimension} is likely around `{value_mm}` millimeters.", 105 | "Judging by the images, the {dimension} is approximately `{value_mm}` millimeters.", 106 | "The {dimension} is calculated to be `{value_mm}` millimeters.", 107 | "It looks like the {dimension} is `{value_mm}` millimeters.", 108 | "The {dimension} is assessed to be `{value_mm}` millimeters.", 109 | "The {dimension} is gauged at `{value_mm}` millimeters.", 110 | "The {dimension} is reckoned to be `{value_mm}` millimeters.", 111 | "The {dimension} is figured to be `{value_mm}` millimeters.", 112 | "The {dimension} is computed to be `{value_mm}` millimeters.", 113 | "The {dimension} is deduced to be `{value_mm}` millimeters.", 114 | "The {dimension} is inferred to be `{value_mm}` millimeters.", 115 | "The {dimension} is surmised to be `{value_mm}` millimeters.", 116 | "The {dimension} is supposed to be `{value_mm}` millimeters.", 117 | "The {dimension} is thought to be `{value_mm}` millimeters.", 118 | "The {dimension} is understood to be `{value_mm}` millimeters.", 119 | "The {dimension} is viewed as `{value_mm}` millimeters.", 120 | "The {dimension} is approximated to be `{value_mm}` millimeters based on the data.", 121 | "After analyzing the images, the {dimension} is concluded to be `{value_mm}` millimeters." 122 | ] 123 | 124 | def convert_train_sample_to_eval_sample(train_sample): 125 | conversation = train_sample.pop("conversations") 126 | train_sample["text"] = conversation[0]["value"] 127 | return train_sample 128 | 129 | def build_lwh_qa_samples(scene_info_handler, dimension_info_path, dimension_name, split, output_dir, max_k=6, max_samples=-1): 130 | """ 131 | Construct QA samples based on merged info files, and write to different jsonl files according to combination size K. 132 | 133 | Example file structure (info file, dict saved with pickle): 134 | { 135 | scene_id: { 136 | object_id: { 137 | "1": [ [img1], [img3], ... ], 138 | "2": [ [img2, img5], ... ], 139 | ... 140 | }, 141 | ... 142 | }, 143 | ... 144 | } 145 | 146 | For each combination sample, construct a training sample: 147 | { 148 | "id": "___", 149 | "image": [ "scene_id/img1.jpg", "scene_id/img2.jpg", ... ], 150 | "conversations": [ 151 | {"from": "human", "value": "\n\n"}, 152 | {"from": "gpt", "value": ""} 153 | ], 154 | "height_list": [scene_info_handler.image_height, ...], // repeated len(combo) times 155 | "width_list": [scene_info_handler.image_width, ...], 156 | "question_type": "{dimension}_estimation", 157 | "gt_value": 158 | } 159 | """ 160 | print(f"Processing dimension: {dimension_name}, split: {split}") 161 | with open(dimension_info_path, "rb") as f: 162 | dim_info = pickle.load(f) 163 | os.makedirs(output_dir, exist_ok=True) 164 | 165 | samples_by_k = {k: [] for k in range(1, max_k+1)} 166 | 167 | for scene_id, obj_dict in tqdm(dim_info.items(), desc=f"Processing {dimension_name} info"): 168 | for object_id, k_dict in obj_dict.items(): 169 | if dimension_name == "height": 170 | val_m = scene_info_handler.get_object_height(scene_id, object_id) 171 | elif dimension_name == "length": 172 | val_m = scene_info_handler.get_object_length(scene_id, object_id) 173 | elif dimension_name == "width": 174 | val_m = scene_info_handler.get_object_width(scene_id, object_id) 175 | else: 176 | val_m = 0.0 177 | val_mm = int(round(val_m * 1000)) 178 | object_category = scene_info_handler.get_object_raw_category(scene_id, object_id) 179 | for k_str, combos in k_dict.items(): 180 | try: 181 | k_val = int(k_str) 182 | except: 183 | continue 184 | if k_val < 1 or k_val > max_k: 185 | continue 186 | for combo_idx, combo in enumerate(combos): 187 | if not combo: 188 | continue 189 | combo = list(combo) 190 | random.shuffle(combo) 191 | prefix_lines = [f"Image-{i}: " for i in range(1, len(combo)+1)] 192 | prefix = "\n".join(prefix_lines) 193 | task_line = random.choice(TASK_DESCRIPTION) 194 | q_template = random.choice(QUESTION_TEMPLATES) 195 | question = q_template.format(dimension=dimension_name, object_category=object_category) 196 | full_question = f"{prefix}\n{task_line}\n{question}" 197 | a_template = random.choice(ANSWER_TEMPLATES) 198 | answer = a_template.format(dimension=dimension_name, value_mm=val_mm, object_category=object_category) 199 | conversation = [ 200 | {"from": "human", "value": full_question}, 201 | {"from": "gpt", "value": answer} 202 | ] 203 | sample = { 204 | "id": f"{scene_id}_{object_id}_{k_val}_{combo_idx}", 205 | "image": [f"{scene_id}/{img}.jpg" for img in combo], 206 | "conversations": conversation, 207 | "height_list": [scene_info_handler.image_height] * len(combo), 208 | "width_list": [scene_info_handler.image_width] * len(combo), 209 | "question_type": f"object_perception_{dimension_name}_estimation", 210 | "gt_value": val_mm 211 | } 212 | samples_by_k[k_val].append(sample) 213 | 214 | for k in range(1, max_k+1): 215 | # only consider those have at least one sample 216 | if len(samples_by_k[k]) == 0: 217 | continue 218 | if max_samples > 0 and len(samples_by_k[k]) > max_samples: 219 | samples_by_k[k] = random.sample(samples_by_k[k], max_samples) 220 | fname = f"object_perception_{dimension_name}_k{k}_{split}_{max_samples}.jsonl" 221 | fpath = os.path.join(output_dir, fname) 222 | with open(fpath, "w", encoding="utf-8") as f: 223 | for sample in samples_by_k[k]: 224 | f.write(json.dumps(sample) + "\n") 225 | print(f"Written K={k} {len(samples_by_k[k])} samples to {fpath}") 226 | 227 | print(f"Finished building QA samples for {dimension_name}.") 228 | 229 | def build_train_and_val_datasets(): 230 | scene_info_path = "data/scannet/scannet_instance_data/scenes_train_val_info_i_D5.pkl" 231 | 232 | train_height_info = "training_data/object_perception/merged_train_object_coverage_height.pkl" 233 | train_length_info = "training_data/object_perception/merged_train_object_coverage_length.pkl" 234 | train_width_info = "training_data/object_perception/merged_train_object_coverage_width.pkl" 235 | 236 | val_height_info = "evaluation_data/object_perception/merged_val_object_coverage_height.pkl" 237 | val_length_info = "evaluation_data/object_perception/merged_val_object_coverage_length.pkl" 238 | val_width_info = "evaluation_data/object_perception/merged_val_object_coverage_width.pkl" 239 | 240 | train_output_dir = "training_data/object_perception" 241 | val_output_dir = "evaluation_data/object_perception" 242 | os.makedirs(train_output_dir, exist_ok=True) 243 | os.makedirs(val_output_dir, exist_ok=True) 244 | 245 | scene_info_handler = SceneInfoHandler(scene_info_path) 246 | 247 | print("\nBuilding TRAIN samples ...") 248 | build_lwh_qa_samples(scene_info_handler, train_height_info, "height", "train", train_output_dir, max_k=6, max_samples=max_train_samples) # will generate K=1, 281335 samples, K=2 457409 samples 249 | build_lwh_qa_samples(scene_info_handler, train_length_info, "length", "train", train_output_dir, max_k=6, max_samples=max_train_samples) # will generate K=1, 274175 samples, K=2 667299 samples 250 | build_lwh_qa_samples(scene_info_handler, train_width_info, "width", "train", train_output_dir, max_k=6, max_samples=max_train_samples) # will generate K=1 256017 samples, K=2 425229 samples 251 | 252 | temp_val_dir = os.path.join(val_output_dir, "temp") 253 | os.makedirs(temp_val_dir, exist_ok=True) 254 | print("\nBuilding VAL samples (train format) ...") 255 | build_lwh_qa_samples(scene_info_handler, val_height_info, "height", "val", temp_val_dir, max_k=6, max_samples=val_max_samples) 256 | build_lwh_qa_samples(scene_info_handler, val_length_info, "length", "val", temp_val_dir, max_k=6, max_samples=val_max_samples) 257 | build_lwh_qa_samples(scene_info_handler, val_width_info, "width", "val", temp_val_dir, max_k=6, max_samples=val_max_samples) 258 | for fname in os.listdir(temp_val_dir): 259 | temp_path = os.path.join(temp_val_dir, fname) 260 | output_path = os.path.join(val_output_dir, fname) 261 | with open(temp_path, "r", encoding="utf-8") as fin, open(output_path, "w", encoding="utf-8") as fout: 262 | for line in fin: 263 | sample = json.loads(line) 264 | sample = convert_train_sample_to_eval_sample(sample) 265 | fout.write(json.dumps(sample) + "\n") 266 | 267 | import shutil; shutil.rmtree(temp_val_dir) 268 | print("Finished building both TRAIN and VAL datasets.") 269 | 270 | def main(): 271 | build_train_and_val_datasets() 272 | 273 | if __name__ == "__main__": 274 | main() -------------------------------------------------------------------------------- /spatial_engine/utils/scannet_utils/README.md: -------------------------------------------------------------------------------- 1 | #### Notes 2 | 1. This doc assume the current directory is the root directory of the project (Multi-SpatialMLLM). 3 | 2. Create a directory `data/scannet` to store the downloaded data. 4 | ```bash 5 | cd Multi-SpatialMLLM 6 | mkdir -p data/scannet 7 | ``` 8 | 9 | #### Download ScanNet data 10 | 1. Follow the official website of [ScanNet](https://github.com/ScanNet/ScanNet) to get the download file `scannet_download.py`. 11 | 2. Specify the data types we need to download: 12 | ```python 13 | FILETYPES = [ 14 | ".sens", 15 | ".aggregation.json", 16 | "_vh_clean.ply", 17 | "_vh_clean_2.0.010000.segs.json", 18 | "_vh_clean_2.ply", 19 | "_vh_clean.segs.json", 20 | "_vh_clean.aggregation.json", 21 | "_vh_clean_2.labels.ply", 22 | "_2d-instance.zip", 23 | "_2d-instance-filt.zip", 24 | "_2d-label.zip", 25 | "_2d-label-filt.zip", 26 | ] 27 | ``` 28 | 3. Run the script to download the data: 29 | ```bash 30 | python spatial_engine/utils/scannet_utils/scannet_download.py -o data/scannet 31 | ``` 32 | 4. After downloading, you should have a folder structure like: 33 | ``` 34 | Multi-SpatialMLLM 35 | ├── data 36 | │ └── scannet 37 | │ ├── scans 38 | │ │ ├── scene0000_00 39 | │ │ ├── ... 40 | │ └── ... 41 | ``` 42 | 43 | #### Process ScanNet data 44 | 1. Load point clouds, object point clouds, bounding boxes, etc. 45 | ```bash 46 | python spatial_engine/utils/scannet_utils/batch_load_scannet_data.py 47 | ``` 48 | This will generate a file called `scenes_train_info.pkl`, which contains train and val splits. For convenience, we can split this info file to train and val (after update with image info if needed). 49 | 50 | ```python 51 | import mmengine 52 | ori_train_info = mmengine.load("data/scannet/scannet_instance_data/scenes_train_val_info.pkl") 53 | 54 | # * load train scene id, read a txt, one line with one id, mmengine does not support txt format 55 | train_scene_ids = mmengine.list_from_file("data/scannet/meta_data/scannetv2_train.txt") 56 | train_scene_ids.sort() 57 | val_scene_ids = mmengine.list_from_file("data/scannet/meta_data/scannetv2_val.txt") 58 | val_scene_ids.sort() 59 | 60 | # * ori_train_info is a dict, key is scene_id 61 | # * according to the above train/val ids, split the info file to two separate files and save it. 62 | # * remember to check train + val = ori_train_info.keys() 63 | train_info = {k: ori_train_info[k] for k in train_scene_ids} 64 | val_info = {k: ori_train_info[k] for k in val_scene_ids} 65 | 66 | # * check 67 | assert len(train_info) + len(val_info) == len(ori_train_info) 68 | 69 | # * save 70 | mmengine.dump(train_info, "data/scannet/scannet_instance_data/scenes_train_info.pkl") 71 | mmengine.dump(val_info, "data/scannet/scannet_instance_data/scenes_val_info.pkl") 72 | ``` 73 | 74 | 2. Load posed images. 75 | 76 | By default, we extract every image, if you want to export every nth image, you can set `frame_skip=n` like below. 77 | ```bash 78 | python spatial_engine/utils/scannet_utils/extract_posed_images.py --frame_skip 1 79 | ``` 80 | Note that no matter what `frame_skip` is, the extracted images will still be saved with the ID like `00000.jpg`, `00001.jpg`, etc. It also exports depth images, use imageio to read them and should divide 1000 to get the real depth value in meters. 81 | 82 | 3. Update info_file with posed images information, this costs about 40 mins for if extracting all images. In this project, we only use every 5th image, so you need to set `frame_skip=5` in the script to skip images. Note that the you still need to set `frame_skip=1` in the `extract_posed_images.py` to make sure the image ids are consistent with our setting. 83 | ```bash 84 | python spatial_engine/utils/scannet_utils/update_info_file_with_images.py 85 | ``` 86 | 87 | The images_info contains information like: 88 | ```python 89 | # Update the image data dictionary with this image's information 90 | image_data[image_id] = { 91 | "image_path": image_path, 92 | "depth_image_path": depth_image_path, 93 | "extrinsic_matrix": extrinsic_matrix 94 | } 95 | 96 | # Update the scene_info dictionary for the current scene_id 97 | scene_info[scene_id].update({ 98 | "num_posed_images": num_posed_images, 99 | "images": image_data, 100 | "intrinsic_matrix": intrinsic_matrix 101 | }) 102 | ``` 103 | 104 | #### Final Results 105 | After processing, you should have a folder structure like: 106 | ``` 107 | Multi-SpatialMLLM 108 | ├── data 109 | │ └── scannet 110 | │ ├── meta_data 111 | │ ├── scans 112 | │ │ ├── scene0000_00 113 | │ │ ├── ... 114 | │ ├── posed_images 115 | │ │ ├── scene0000_00 116 | │ │ ├── ... 117 | │ └── scannet_instance_data 118 | │ ├── scene0000_00 119 | │ ├── ... 120 | │ ├── scenes_train_info.pkl 121 | │ └── scenes_val_info.pkl 122 | ``` 123 | The `scannet_instance_data` contains the point clouds of the whole scene and each instance both with axis-aligned and not axis-aligned world coordinate. 124 | 125 | -------------------------------------------------------------------------------- /spatial_engine/utils/scannet_utils/batch_load_scannet_data.py: -------------------------------------------------------------------------------- 1 | # Modified from 2 | # https://github.com/facebookresearch/votenet/blob/master/scannet/batch_load_scannet_data.py 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Batch mode in loading Scannet scenes with vertices and ground truth labels 8 | for semantic and instance segmentations. 9 | 10 | Usage example: python ./batch_load_scannet_data.py 11 | """ 12 | 13 | import argparse 14 | import os 15 | import pickle 16 | from multiprocessing import Pool 17 | from os import path as osp 18 | 19 | import numpy as np 20 | import scannet_utils 21 | 22 | DONOTCARE_CLASS_IDS = np.array([]) 23 | # OBJ_CLASS_IDS = np.array( 24 | # [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39]) 25 | # OBJ_CLASS_IDS = np.array([]) 26 | 27 | 28 | def export(mesh_file, agg_file, seg_file, meta_file, label_map_file, test_mode=False): 29 | """Export original files to vert, ins_label, sem_label and bbox file. 30 | 31 | Args: 32 | mesh_file (str): Path of the mesh_file. 33 | agg_file (str): Path of the agg_file. 34 | seg_file (str): Path of the seg_file. 35 | meta_file (str): Path of the meta_file. 36 | label_map_file (str): Path of the label_map_file. 37 | test_mode (bool): Whether is generating test data without labels. 38 | Default: False. 39 | 40 | It returns a tuple, which contains the the following things: 41 | np.ndarray: Vertices of points data. 42 | np.ndarray: Indexes of label. 43 | np.ndarray: Indexes of instance. 44 | np.ndarray: Instance bboxes. 45 | dict: Map from object_id to label_id. 46 | """ 47 | 48 | label_map = scannet_utils.read_label_mapping( 49 | label_map_file, label_from="raw_category", label_to="nyu40id" 50 | ) 51 | mesh_vertices = scannet_utils.read_mesh_vertices_rgb(mesh_file) 52 | 53 | # Load scene axis alignment matrix 54 | lines = open(meta_file).readlines() 55 | # test set data doesn't have align_matrix 56 | axis_align_matrix = np.eye(4) 57 | for line in lines: 58 | if "axisAlignment" in line: 59 | axis_align_matrix = [ 60 | float(x) for x in line.rstrip().strip("axisAlignment = ").split(" ") 61 | ] 62 | break 63 | axis_align_matrix = np.array(axis_align_matrix).reshape((4, 4)) 64 | 65 | # perform global alignment of mesh vertices 66 | pts = np.ones((mesh_vertices.shape[0], 4)) 67 | pts[:, 0:3] = mesh_vertices[:, 0:3] 68 | pts = np.dot(pts, axis_align_matrix.transpose()) # Nx4 69 | aligned_mesh_vertices = np.concatenate([pts[:, 0:3], mesh_vertices[:, 3:]], axis=1) 70 | 71 | # Load semantic and instance labels 72 | if not test_mode: 73 | object_id_to_segs, label_to_segs = scannet_utils.read_aggregation( 74 | agg_file 75 | ) # * return dicts with id(int) or label(str) to lists of seg ids, object ids are 1-indexed 76 | seg_to_verts, num_verts = scannet_utils.read_segmentation(seg_file) 77 | label_ids = np.zeros(shape=(num_verts), dtype=np.uint32) 78 | raw_categories = np.array([None] * num_verts) # Array to store raw categories 79 | 80 | object_id_to_label_id = {} 81 | object_id_to_raw_category = {} 82 | for raw_category, segs in label_to_segs.items(): 83 | label_id = label_map[raw_category] 84 | for seg in segs: 85 | verts = seg_to_verts[seg] 86 | label_ids[verts] = label_id 87 | raw_categories[verts] = raw_category # Assign raw category 88 | 89 | instance_ids = np.zeros(shape=(num_verts), dtype=np.uint32) # 0: unannotated 90 | for object_id, segs in object_id_to_segs.items(): 91 | for seg in segs: 92 | verts = seg_to_verts[seg] 93 | instance_ids[verts] = object_id 94 | if object_id not in object_id_to_label_id: 95 | object_id_to_label_id[object_id] = label_ids[verts][ 96 | 0 97 | ] # * obj_id: int 98 | if object_id not in object_id_to_raw_category: 99 | object_id_to_raw_category[object_id] = raw_categories[verts][ 100 | 0 101 | ] # * obj_id: str, note, the obj_id is 1-indexed 102 | unaligned_bboxes, unaligned_obj_point_clouds = scannet_utils.extract_bbox( 103 | mesh_vertices, object_id_to_segs, object_id_to_label_id, instance_ids 104 | ) 105 | aligned_bboxes, aligned_obj_point_clouds = scannet_utils.extract_bbox( 106 | aligned_mesh_vertices, 107 | object_id_to_segs, 108 | object_id_to_label_id, 109 | instance_ids, 110 | ) 111 | else: 112 | label_ids = None 113 | raw_categories = None 114 | instance_ids = None 115 | unaligned_bboxes = None 116 | aligned_bboxes = None 117 | object_id_to_label_id = None 118 | aligned_obj_point_clouds = None 119 | unaligned_obj_point_clouds = None 120 | object_id_to_raw_category = None 121 | 122 | return ( 123 | mesh_vertices, 124 | aligned_mesh_vertices, 125 | label_ids, 126 | raw_categories, 127 | instance_ids, 128 | unaligned_bboxes, 129 | aligned_bboxes, 130 | unaligned_obj_point_clouds, 131 | aligned_obj_point_clouds, 132 | object_id_to_raw_category, 133 | object_id_to_label_id, 134 | axis_align_matrix, 135 | ) 136 | 137 | 138 | def export_one_scan( 139 | scan_name, 140 | output_filename_prefix, 141 | max_num_point, 142 | label_map_file, 143 | scannet_dir, 144 | test_mode=False, 145 | ): 146 | if not osp.exists(output_filename_prefix): 147 | os.makedirs(output_filename_prefix) 148 | 149 | mesh_file = osp.join(scannet_dir, scan_name, scan_name + "_vh_clean_2.ply") 150 | agg_file = osp.join(scannet_dir, scan_name, scan_name + ".aggregation.json") 151 | seg_file = osp.join( 152 | scannet_dir, scan_name, scan_name + "_vh_clean_2.0.010000.segs.json" 153 | ) 154 | # includes axisAlignment info for the train set scans. 155 | meta_file = osp.join(scannet_dir, scan_name, f"{scan_name}.txt") 156 | ( 157 | mesh_vertices, 158 | aligned_mesh_vertices, 159 | semantic_labels, 160 | raw_categories, 161 | instance_labels, 162 | unaligned_bboxes, 163 | aligned_bboxes, 164 | unaligned_obj_point_clouds, 165 | aligned_obj_point_clouds, 166 | object_id_to_raw_category, 167 | object_id_to_label_id, 168 | axis_align_matrix, 169 | ) = export(mesh_file, agg_file, seg_file, meta_file, label_map_file, test_mode) 170 | 171 | if not test_mode: 172 | # mask = np.logical_not(np.in1d(semantic_labels, DONOTCARE_CLASS_IDS)) 173 | # mesh_vertices = mesh_vertices[mask, :] 174 | # semantic_labels = semantic_labels[mask] 175 | # instance_labels = instance_labels[mask] 176 | # raw_categories = raw_categories[mask] 177 | 178 | num_instances = len(np.unique(instance_labels)) 179 | print(f"Num of instances: {num_instances - 1}") 180 | 181 | # bbox_mask = np.in1d(unaligned_bboxes[:, -1], OBJ_CLASS_IDS) # * keep all instances 182 | # unaligned_bboxes = unaligned_bboxes[bbox_mask, :] 183 | # bbox_mask = np.in1d(aligned_bboxes[:, -1], OBJ_CLASS_IDS) 184 | # aligned_bboxes = aligned_bboxes[bbox_mask, :] 185 | assert unaligned_bboxes.shape[0] == aligned_bboxes.shape[0] 186 | print(f"Num of care instances: {unaligned_bboxes.shape[0]}") 187 | 188 | if max_num_point is not None: 189 | max_num_point = int(max_num_point) 190 | N = mesh_vertices.shape[0] 191 | if N > max_num_point: 192 | choices = np.random.choice(N, max_num_point, replace=False) 193 | mesh_vertices = mesh_vertices[choices, :] 194 | if not test_mode: 195 | semantic_labels = semantic_labels[choices] 196 | instance_labels = instance_labels[choices] 197 | raw_categories = raw_categories[choices] 198 | 199 | # Save points, semantic_labels, instance_labels as .npy files 200 | np.save(f"{output_filename_prefix}/unaligned_points.npy", mesh_vertices) 201 | np.save(f"{output_filename_prefix}/aligned_points.npy", aligned_mesh_vertices) 202 | scene_info = {} # Dictionary to hold scene information 203 | 204 | if not test_mode: 205 | np.save(f"{output_filename_prefix}/semantic_mask.npy", semantic_labels) 206 | np.save(f"{output_filename_prefix}/instance_mask.npy", instance_labels) 207 | np.save(f"{output_filename_prefix}/raw_category_mask.npy", raw_categories) 208 | 209 | # * assert these four npy have the same length 210 | assert ( 211 | len(semantic_labels) 212 | == len(instance_labels) 213 | == len(raw_categories) 214 | == len(mesh_vertices) 215 | ), "Lengths of semantic_labels, instance_labels, raw_categories, and mesh_vertices are not equal." 216 | 217 | # Save bounding boxes and raw category names in a dict 218 | for obj_id, (aligned_bbox, unaligned_bbox) in enumerate( 219 | zip(aligned_bboxes, unaligned_bboxes) 220 | ): 221 | raw_category_name = object_id_to_raw_category.get( 222 | obj_id + 1, "None" 223 | ) # * object_id_to_raw_category is 1 indexed 224 | if raw_category_name == "None": 225 | print( 226 | f"Something wrong for the raw category name of object {obj_id} in scan {scan_name}." 227 | ) 228 | exit(0) 229 | scene_info[obj_id] = { 230 | "aligned_bbox": aligned_bbox, 231 | "unaligned_bbox": unaligned_bbox, 232 | "raw_category": raw_category_name, 233 | } 234 | 235 | # * save aligned and unaligned points 236 | # * first check if the two types of points have the same shape 237 | 238 | np.save( 239 | f"{output_filename_prefix}/object_{obj_id}_aligned_points.npy", 240 | aligned_obj_point_clouds[obj_id], 241 | ) 242 | np.save( 243 | f"{output_filename_prefix}/object_{obj_id}_unaligned_points.npy", 244 | unaligned_obj_point_clouds[obj_id], 245 | ) 246 | 247 | scene_info["axis_align_matrix"] = axis_align_matrix 248 | # * store the object number 249 | scene_info["num_objects"] = len(aligned_bboxes) 250 | 251 | return {scan_name: scene_info} 252 | 253 | 254 | def worker(args): 255 | ( 256 | scan_name, 257 | output_filename_prefix, 258 | max_num_point, 259 | label_map_file, 260 | scannet_dir, 261 | test_mode, 262 | ) = args 263 | print("-" * 20 + f"begin for {scan_name}.") 264 | return export_one_scan( 265 | scan_name, 266 | output_filename_prefix, 267 | max_num_point, 268 | label_map_file, 269 | scannet_dir, 270 | test_mode, 271 | ) 272 | 273 | 274 | def batch_export( 275 | max_num_point, 276 | output_folder, 277 | scan_names_file, 278 | label_map_file, 279 | scannet_dir, 280 | test_mode=False, 281 | num_workers=20, 282 | ): 283 | if test_mode and not os.path.exists(scannet_dir): 284 | return 285 | if not os.path.exists(output_folder): 286 | os.makedirs(output_folder) 287 | 288 | scan_names = [line.rstrip() for line in open(scan_names_file)] 289 | # * sort scan_names 290 | scan_names.sort() 291 | args = [ 292 | ( 293 | scan_name, 294 | osp.join(output_folder, scan_name), 295 | max_num_point, 296 | label_map_file, 297 | scannet_dir, 298 | test_mode, 299 | ) 300 | for scan_name in scan_names 301 | ] 302 | 303 | all_scene_info = {} 304 | with Pool(num_workers) as p: 305 | results = p.map(worker, args) 306 | for result in results: 307 | all_scene_info.update(result) 308 | 309 | # Save the combined scene information 310 | if test_mode: 311 | file_name = "scenes_test_info.pkl" 312 | else: 313 | file_name = "scenes_train_val_info.pkl" 314 | with open(osp.join(output_folder, file_name), "wb") as f: 315 | pickle.dump(all_scene_info, f) 316 | 317 | 318 | def main(): 319 | parser = argparse.ArgumentParser() 320 | parser.add_argument( 321 | "--max_num_point", default=None, help="The maximum number of the points." 322 | ) 323 | parser.add_argument( 324 | "--output_folder", 325 | default="data/scannet/scannet_instance_data", 326 | help="output folder of the result.", 327 | ) 328 | parser.add_argument( 329 | "--train_scannet_dir", default="scans", help="scannet data directory." 330 | ) 331 | parser.add_argument( 332 | "--test_scannet_dir", default="scans_test", help="scannet data directory." 333 | ) 334 | parser.add_argument( 335 | "--label_map_file", 336 | default="data/scannet/meta_data/scannetv2-labels.combined.tsv", 337 | help="The path of label map file.", 338 | ) 339 | parser.add_argument( 340 | "--train_scan_names_file", 341 | default="data/scannet/meta_data/scannet_train.txt", 342 | help="The path of the file that stores the scan names.", 343 | ) 344 | parser.add_argument( 345 | "--test_scan_names_file", 346 | default="data/scannet/meta_data/scannetv2_test.txt", 347 | help="The path of the file that stores the scan names.", 348 | ) 349 | args = parser.parse_args() 350 | batch_export( 351 | args.max_num_point, 352 | args.output_folder, 353 | args.train_scan_names_file, 354 | args.label_map_file, 355 | args.train_scannet_dir, 356 | test_mode=False, 357 | ) 358 | # * change output folder for test 359 | args.output_folder = args.output_folder.replace("scannet", "scannet_test") 360 | batch_export( 361 | args.max_num_point, 362 | args.output_folder, 363 | args.test_scan_names_file, 364 | args.label_map_file, 365 | args.test_scannet_dir, 366 | test_mode=True, 367 | ) 368 | 369 | 370 | if __name__ == "__main__": 371 | main() 372 | -------------------------------------------------------------------------------- /spatial_engine/utils/scannet_utils/extract_posed_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import struct 9 | import time 10 | import zlib 11 | from argparse import ArgumentParser 12 | from functools import partial 13 | 14 | import imageio.v2 as imageio # * to surpress warning 15 | import mmengine 16 | import numpy as np 17 | 18 | COMPRESSION_TYPE_COLOR = {-1: "unknown", 0: "raw", 1: "png", 2: "jpeg"} 19 | 20 | COMPRESSION_TYPE_DEPTH = { 21 | -1: "unknown", 22 | 0: "raw_ushort", 23 | 1: "zlib_ushort", 24 | 2: "occi_ushort", 25 | } 26 | 27 | 28 | class RGBDFrame: 29 | """Class for single ScanNet RGB-D image processing.""" 30 | 31 | def load(self, file_handle): 32 | self.camera_to_world = np.asarray( 33 | struct.unpack("f" * 16, file_handle.read(16 * 4)), dtype=np.float32 34 | ).reshape(4, 4) 35 | self.timestamp_color = struct.unpack("Q", file_handle.read(8))[0] 36 | self.timestamp_depth = struct.unpack("Q", file_handle.read(8))[0] 37 | self.color_size_bytes = struct.unpack("Q", file_handle.read(8))[0] 38 | self.depth_size_bytes = struct.unpack("Q", file_handle.read(8))[0] 39 | self.color_data = b"".join( 40 | struct.unpack( 41 | "c" * self.color_size_bytes, file_handle.read(self.color_size_bytes) 42 | ) 43 | ) 44 | self.depth_data = b"".join( 45 | struct.unpack( 46 | "c" * self.depth_size_bytes, file_handle.read(self.depth_size_bytes) 47 | ) 48 | ) 49 | 50 | def decompress_depth(self, compression_type): 51 | assert compression_type == "zlib_ushort" 52 | return zlib.decompress(self.depth_data) 53 | 54 | def decompress_color(self, compression_type): 55 | assert compression_type == "jpeg" 56 | return imageio.imread(self.color_data) 57 | 58 | 59 | class SensorData: 60 | """Class for single ScanNet scene processing. 61 | 62 | Single scene file contains multiple RGB-D images. 63 | """ 64 | 65 | def __init__(self, filename, frame_skip): 66 | self.version = 4 67 | self.load(filename, frame_skip) 68 | 69 | def load(self, filename, frame_skip): 70 | with open(filename, "rb") as f: 71 | version = struct.unpack("I", f.read(4))[0] 72 | assert self.version == version 73 | strlen = struct.unpack("Q", f.read(8))[0] 74 | self.sensor_name = b"".join(struct.unpack("c" * strlen, f.read(strlen))) 75 | self.intrinsic_color = np.asarray( 76 | struct.unpack("f" * 16, f.read(16 * 4)), dtype=np.float32 77 | ).reshape(4, 4) 78 | self.extrinsic_color = np.asarray( 79 | struct.unpack("f" * 16, f.read(16 * 4)), dtype=np.float32 80 | ).reshape(4, 4) 81 | self.intrinsic_depth = np.asarray( 82 | struct.unpack("f" * 16, f.read(16 * 4)), dtype=np.float32 83 | ).reshape(4, 4) 84 | self.extrinsic_depth = np.asarray( 85 | struct.unpack("f" * 16, f.read(16 * 4)), dtype=np.float32 86 | ).reshape(4, 4) 87 | self.color_compression_type = COMPRESSION_TYPE_COLOR[ 88 | struct.unpack("i", f.read(4))[0] 89 | ] 90 | self.depth_compression_type = COMPRESSION_TYPE_DEPTH[ 91 | struct.unpack("i", f.read(4))[0] 92 | ] 93 | self.color_width = struct.unpack("I", f.read(4))[0] 94 | self.color_height = struct.unpack("I", f.read(4))[0] 95 | self.depth_width = struct.unpack("I", f.read(4))[0] 96 | self.depth_height = struct.unpack("I", f.read(4))[0] 97 | self.depth_shift = struct.unpack("f", f.read(4))[0] 98 | num_frames = struct.unpack("Q", f.read(8))[0] 99 | print(f"Number of total frames: {num_frames}") 100 | self.frames = [] 101 | 102 | # * use frame_skip to get index 103 | index = list(range(0, num_frames, frame_skip)) 104 | for i in range(num_frames): 105 | frame = RGBDFrame() 106 | frame.load(f) # should iterate to get the next frame 107 | if i in index: 108 | self.frames.append(frame) 109 | 110 | assert len(index) == len(self.frames), "Number of frames mismatch." 111 | print(f"Exported {len(index)} frames. Frame skip is {frame_skip}.") 112 | 113 | def export_depth_images(self, output_path): 114 | if not os.path.exists(output_path): 115 | os.makedirs(output_path) 116 | for f in range(len(self.frames)): 117 | depth_data = self.frames[f].decompress_depth(self.depth_compression_type) 118 | depth = np.fromstring(depth_data, dtype=np.uint16).reshape( 119 | self.depth_height, self.depth_width 120 | ) 121 | imageio.imwrite( 122 | os.path.join(output_path, self.index_to_str(f) + ".png"), depth 123 | ) 124 | 125 | def export_color_images(self, output_path): 126 | if not os.path.exists(output_path): 127 | os.makedirs(output_path) 128 | for f in range(len(self.frames)): 129 | color = self.frames[f].decompress_color(self.color_compression_type) 130 | imageio.imwrite( 131 | os.path.join(output_path, self.index_to_str(f) + ".jpg"), color 132 | ) 133 | 134 | @staticmethod 135 | def index_to_str(index): 136 | return str(index).zfill(5) 137 | 138 | @staticmethod 139 | def save_mat_to_file(matrix, filename): 140 | with open(filename, "w") as f: 141 | for line in matrix: 142 | np.savetxt(f, line[np.newaxis], fmt="%f") 143 | 144 | def export_poses(self, output_path): 145 | if not os.path.exists(output_path): 146 | os.makedirs(output_path) 147 | for f in range(len(self.frames)): 148 | self.save_mat_to_file( 149 | self.frames[f].camera_to_world, 150 | os.path.join(output_path, self.index_to_str(f) + ".txt"), 151 | ) 152 | 153 | def export_intrinsics(self, output_path): 154 | if not os.path.exists(output_path): 155 | os.makedirs(output_path) 156 | self.save_mat_to_file( 157 | self.intrinsic_color, os.path.join(output_path, "intrinsic.txt") 158 | ) 159 | 160 | 161 | def process_scene(path, frame_skip, idx): 162 | """Process single ScanNet scene. 163 | 164 | Extract RGB images, poses and camera intrinsics. 165 | """ 166 | print(f"Processing {idx}.") 167 | t1 = time.time() 168 | output_path = os.path.join("posed_images", idx) 169 | if mmengine.exists(output_path): 170 | print(f"{output_path} already exists. Skip.") 171 | return 172 | data = SensorData(os.path.join(path, idx, f"{idx}.sens"), frame_skip) 173 | data.export_color_images(output_path) 174 | data.export_intrinsics(output_path) 175 | data.export_poses(output_path) 176 | data.export_depth_images(output_path) 177 | print(f"Finish processing {idx}. Using {time.time() - t1}s.") 178 | 179 | 180 | def process_directory(path, frame_skip, nproc): 181 | print(f"processing {path}") 182 | scan_ids = os.listdir(path) 183 | # debug 184 | mmengine.track_parallel_progress( 185 | func=partial(process_scene, path, frame_skip), 186 | tasks=scan_ids, 187 | nproc=nproc, 188 | ) 189 | 190 | 191 | if __name__ == "__main__": 192 | parser = ArgumentParser() 193 | parser.add_argument( 194 | "--frame_skip", type=int, default=1, help="export every nth frame" 195 | ) # * use this or --max-images-per-scene 196 | parser.add_argument("--nproc", type=int, default=20) 197 | args = parser.parse_args() 198 | 199 | # process train and val scenes 200 | if os.path.exists("scans"): 201 | process_directory( 202 | "scans", args.frame_skip, args.nproc 203 | ) 204 | # process test scenes 205 | if os.path.exists("scans_test"): 206 | process_directory( 207 | "scans_test", args.frame_skip, args.nproc 208 | ) 209 | -------------------------------------------------------------------------------- /spatial_engine/utils/scannet_utils/handler/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List 8 | 9 | import cv2 10 | import numpy as np 11 | import open3d as o3d 12 | 13 | 14 | def filter_images_laplacian(image_paths, threshold=100): 15 | res = [] 16 | for i in range(len(image_paths)): 17 | if calculate_image_sharpness(image_paths[i]) >= threshold: 18 | res.append(image_paths[i]) 19 | return res 20 | 21 | 22 | def calculate_image_sharpness(image_path): 23 | """ 24 | calculate the sharpness of image 25 | """ 26 | image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) 27 | 28 | if image is None: 29 | raise ValueError(f"Unable to read image from: {image_path}") 30 | 31 | # calculate laplacion 32 | laplacian = cv2.Laplacian(image, cv2.CV_64F) 33 | sharpness_score = laplacian.var() 34 | return sharpness_score 35 | 36 | 37 | def convert_to_corners(bounding_boxes: List[np.ndarray]) -> List[np.ndarray]: 38 | """ 39 | Convert bounding boxes to eight corners 40 | Args: 41 | bounding_boxes: List of bounding boxes with format [cx, cy, cz, dx, dy, dz, ...] 42 | Returns: 43 | List of eight corners for each bounding box with format [[x1, y1, z1], [x2, y2, z2], ...] 44 | """ 45 | corners = [] 46 | for bbox in bounding_boxes: 47 | corners.append( 48 | np.array( 49 | [ 50 | [ 51 | bbox[0] - bbox[3] / 2, 52 | bbox[1] - bbox[4] / 2, 53 | bbox[2] - bbox[5] / 2, 54 | ], 55 | [ 56 | bbox[0] + bbox[3] / 2, 57 | bbox[1] - bbox[4] / 2, 58 | bbox[2] - bbox[5] / 2, 59 | ], 60 | [ 61 | bbox[0] - bbox[3] / 2, 62 | bbox[1] + bbox[4] / 2, 63 | bbox[2] - bbox[5] / 2, 64 | ], 65 | [ 66 | bbox[0] + bbox[3] / 2, 67 | bbox[1] + bbox[4] / 2, 68 | bbox[2] - bbox[5] / 2, 69 | ], 70 | [ 71 | bbox[0] - bbox[3] / 2, 72 | bbox[1] - bbox[4] / 2, 73 | bbox[2] + bbox[5] / 2, 74 | ], 75 | [ 76 | bbox[0] + bbox[3] / 2, 77 | bbox[1] - bbox[4] / 2, 78 | bbox[2] + bbox[5] / 2, 79 | ], 80 | [ 81 | bbox[0] - bbox[3] / 2, 82 | bbox[1] + bbox[4] / 2, 83 | bbox[2] + bbox[5] / 2, 84 | ], 85 | [ 86 | bbox[0] + bbox[3] / 2, 87 | bbox[1] + bbox[4] / 2, 88 | bbox[2] + bbox[5] / 2, 89 | ], 90 | ], 91 | dtype=np.float32, 92 | ) 93 | ) 94 | return corners 95 | 96 | 97 | def calculate_iou_2d(mask1, mask2): 98 | """ 99 | Calculate 2D Intersection over Union (IoU) of two masks, both of which are H * W binary numpy arrays. 100 | """ 101 | # Calculate the intersection of the two masks 102 | intersection = np.logical_and(mask1, mask2) 103 | 104 | # Calculate the union of the two masks 105 | union = np.logical_or(mask1, mask2) 106 | 107 | # Calculate the IoU 108 | # handle zero 109 | iou = np.sum(intersection) / np.sum(union) if np.sum(union) != 0 else 0.0 110 | 111 | return iou 112 | 113 | 114 | def calculate_iou_3d(box1, box2): 115 | """ 116 | Calculate 3D Intersection over Union (IoU) of two 3D boxes, both of which are N * 6 117 | Boxes are defined by numpy arrays with [x, y, z, dx, dy, dz], 118 | where (x, y, z) is the center of the box, and (dx, dy, dz) are the size of the box along each axis. 119 | """ 120 | # Calculate the coordinates of the intersections points 121 | inter_min = np.maximum(box1[:3] - box1[3:] / 2, box2[:3] - box2[3:] / 2) 122 | inter_max = np.minimum(box1[:3] + box1[3:] / 2, box2[:3] + box2[3:] / 2) 123 | 124 | # Calculate intersection volume 125 | inter_dim = inter_max - inter_min 126 | inter_volume = np.prod(inter_dim) if np.all(inter_dim > 0) else 0 127 | 128 | # Calculate the volume of each box 129 | box1_volume = np.prod(box1[3:]) 130 | box2_volume = np.prod(box2[3:]) 131 | 132 | # Calculate IoU 133 | iou = inter_volume / (box1_volume + box2_volume - inter_volume) 134 | 135 | return iou 136 | 137 | 138 | def remove_statistical_outliers(point_cloud_data, nb_neighbors=20, std_ratio=1.0): 139 | """ 140 | Removes statistical outliers from a point cloud data and retains all original dimensions. 141 | 142 | Args: 143 | point_cloud_data (numpy.ndarray): The input point cloud data as a NxC numpy array where C >= 3. [N, C] 144 | nb_neighbors (int): Number of nearest neighbors to consider for calculating average distance. 145 | std_ratio (float): Standard deviation ratio; points beyond this many standard deviations are considered outliers. 146 | 147 | Returns: 148 | numpy.ndarray: Filtered point cloud data with outliers removed, including all original dimensions. [n, C] 149 | """ 150 | # Convert numpy array to Open3D point cloud for XYZ coordinates 151 | pcd = o3d.geometry.PointCloud() 152 | pcd.points = o3d.utility.Vector3dVector(point_cloud_data[:, :3]) 153 | 154 | # Perform statistical outlier removal 155 | clean_pcd, ind = pcd.remove_statistical_outlier(nb_neighbors, std_ratio) 156 | 157 | # Use indices to filter the original full-dimension data 158 | inlier_data = point_cloud_data[ind, :] 159 | 160 | return inlier_data 161 | 162 | 163 | def remove_truncated_outliers(point_cloud_data, tx: float, ty: float, tz: float): 164 | """ 165 | Removes statistical outliers from a point cloud data and retains all original dimensions. 166 | 167 | Args: 168 | point_cloud_data (numpy.ndarray): The input point cloud data as a NxC numpy array where C >= 3. [N, C] 169 | tx: Ratio of points to remove from the beginning and end of the sorted x values 170 | ty: Ratio of points to remove from the beginning and end of the sorted y values 171 | tz: Ratio of points to remove from the beginning and end of the sorted z values 172 | 173 | Returns: 174 | numpy.ndarray: Filtered point cloud data with outliers removed, including all original dimensions. [n, C] 175 | """ 176 | 177 | # assert tx, ty, tz all < 0.5 178 | assert tx < 0.5 and ty < 0.5 and tz < 0.5, "tx, ty, tz must be less than 0.5." 179 | 180 | n_points = len(point_cloud_data) 181 | # Calculate the number of points to remove based on the given percentages 182 | if tx == 0 and ty == 0 and tz == 0: 183 | return point_cloud_data 184 | 185 | nx = int(tx * n_points) 186 | ny = int(ty * n_points) 187 | nz = int(tz * n_points) 188 | 189 | # Process x-axis 190 | x_sorted_indices = np.argsort(point_cloud_data[:, 0]) 191 | valid_x_indices = x_sorted_indices[nx:-nx] if 2 * nx < n_points else np.array([]) 192 | 193 | # Process y-axis 194 | y_sorted_indices = np.argsort(point_cloud_data[:, 1]) 195 | valid_y_indices = y_sorted_indices[ny:-ny] if 2 * ny < n_points else np.array([]) 196 | 197 | # Process z-axis 198 | z_sorted_indices = np.argsort(point_cloud_data[:, 2]) 199 | valid_z_indices = z_sorted_indices[nz:-nz] if 2 * nz < n_points else np.array([]) 200 | 201 | # Find the intersection of valid indices across all axes 202 | valid_indices = np.intersect1d(valid_x_indices, valid_y_indices) 203 | valid_indices = np.intersect1d(valid_indices, valid_z_indices) 204 | 205 | # Filter the original full-dimension data 206 | inlier_data = point_cloud_data[valid_indices] 207 | 208 | return inlier_data 209 | 210 | 211 | def calculate_aabb(point_cloud_data): 212 | """ 213 | Calculates the axis-aligned bounding box (AABB) of a point cloud. 214 | 215 | Args: 216 | point_cloud_data (numpy.ndarray): The input point cloud data as a NxC numpy array where C >= 3. [N, C] 217 | 218 | Returns: 219 | tuple: Contains the center of the AABB (numpy.ndarray) and the dimensions of the AABB (numpy.ndarray). [x, y, z, dx, dy, dz] 220 | """ 221 | # Calculate the min and max along each column (x, y, z) 222 | min_corner = np.min(point_cloud_data[:, :3], axis=0) 223 | max_corner = np.max(point_cloud_data[:, :3], axis=0) 224 | 225 | # Calculate center and dimensions 226 | center = (max_corner + min_corner) / 2 227 | dimensions = max_corner - min_corner 228 | 229 | # Combine center and dimensions into a single array 230 | result = np.concatenate([center, dimensions]) 231 | 232 | return result 233 | 234 | 235 | def project_mask_to_3d( 236 | depth_image, 237 | intrinsic_matrix, 238 | extrinsic_matrix, 239 | mask=None, 240 | world_to_axis_align_matrix=None, 241 | color_image=None, 242 | ): 243 | """ 244 | Projects a mask to 3D space using the provided depth map and camera parameters. 245 | Optionally appends RGB values from a color image to the 3D points. (RGB order with 0-255 range) 246 | 247 | Parameters: 248 | - depth_image (str or ndarray): Path to the depth image or a numpy array of depth values. h, w 249 | - intrinsic_matrix (ndarray): The camera's intrinsic matrix. 4 * 4 250 | - extrinsic_matrix (ndarray): The camera's extrinsic matrix. 4 * 4 251 | - mask (ndarray): A binary mask (zero, non-zero array) where True values indicate pixels to project, which has the same shape with color_image. H, W. Could be None, where all pixels are projected. 252 | - world_to_axis_align_matrix (ndarray, optional): Matrix to align the world coordinates. 4 * 4 253 | - color_image (str or ndarray, optional): Path to the color image or a numpy array of color values. H, W, 3 254 | 255 | Returns: 256 | - ndarray: Array of 3D coordinates, optionally with RGB values appended. All False mask will give `array([], shape=(0, C), dtype=float64)` 257 | """ 258 | 259 | # Load depth image from path if it's a string 260 | if isinstance(depth_image, str): 261 | depth_image = cv2.imread(depth_image, -1) 262 | 263 | # Load color image from path if it's a string 264 | if isinstance(color_image, str): 265 | color_image = cv2.imread(color_image) 266 | color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB) 267 | 268 | if mask is None: 269 | mask = np.ones(color_image.shape[:2], dtype=bool) 270 | 271 | # Calculate scaling factors 272 | scale_y = depth_image.shape[0] / mask.shape[0] 273 | scale_x = depth_image.shape[1] / mask.shape[1] 274 | 275 | # Get coordinates of True values in mask 276 | mask_indices = np.where(mask) 277 | mask_y = mask_indices[0] 278 | mask_x = mask_indices[1] 279 | 280 | # Scale coordinates to match the depth image size 281 | # depth_y = (mask_y * scale_y).astype(int) 282 | # depth_x = (mask_x * scale_x).astype(int) 283 | 284 | # # scale and use round 285 | depth_y = np.round(mask_y * scale_y).astype(int) 286 | depth_x = np.round(mask_x * scale_x).astype(int) 287 | 288 | # Clip scaled coordinates to ensure they are within the image boundary 289 | depth_y = np.clip(depth_y, 0, depth_image.shape[0] - 1) 290 | depth_x = np.clip(depth_x, 0, depth_image.shape[1] - 1) 291 | 292 | # Extract depth values 293 | depth_values = ( 294 | depth_image[depth_y, depth_x] * 0.001 295 | ) # Assume depth is in millimeters 296 | 297 | # Filter out zero depth values 298 | valid = depth_values > 0 299 | depth_values = depth_values[valid] 300 | mask_x = mask_x[valid] 301 | mask_y = mask_y[valid] 302 | 303 | # Construct normalized pixel coordinates 304 | normalized_pixels = np.vstack( 305 | ( 306 | mask_x * depth_values, 307 | mask_y * depth_values, 308 | depth_values, 309 | np.ones_like(depth_values), 310 | ) 311 | ) 312 | 313 | # Compute points in camera coordinate system 314 | cam_coords = np.dot(np.linalg.inv(intrinsic_matrix), normalized_pixels) 315 | 316 | # Transform to world coordinates 317 | world_coords = np.dot(extrinsic_matrix, cam_coords) 318 | 319 | # Apply world-to-axis alignment if provided 320 | if world_to_axis_align_matrix is not None: 321 | world_coords = np.dot(world_to_axis_align_matrix, world_coords) 322 | 323 | # Append color information if color image is provided 324 | if color_image is not None: 325 | # Scale mask coordinates for the color image 326 | rgb_values = color_image[mask_y, mask_x] 327 | return np.hstack((world_coords[:3].T, rgb_values)) 328 | 329 | return world_coords[:3].T 330 | -------------------------------------------------------------------------------- /spatial_engine/utils/scannet_utils/make_visibility_info.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import os 9 | import mmengine 10 | from multiprocessing import Pool 11 | from tqdm import tqdm 12 | from mmengine.utils.dl_utils import TimeCounter 13 | from spatial_engine.utils.scannet_utils.handler.info_handler import SceneInfoHandler 14 | 15 | 16 | """ 17 | { 18 | scene_id: { 19 | "image_to_points": { 20 | image_id: [point_index, ...], 21 | ... 22 | }, 23 | "point_to_images": { 24 | point_index: [image_id, ...], 25 | ... 26 | } 27 | }, 28 | ... 29 | } 30 | 31 | After generating, will convert to parquet file. 32 | Note, at the begnining, the file is saved as pkl, so I use convert_pkl_to_parquet to convert. 33 | 34 | """ 35 | 36 | DEBUG = False 37 | 38 | def convert_pkl_to_parquet(pkl_file): 39 | """ 40 | Converts a previously saved PKL file to a Parquet file. 41 | 42 | Args: 43 | pkl_file (str): Path to the input PKL file. 44 | """ 45 | # Load the PKL file 46 | import json 47 | import pickle 48 | import pandas as pd 49 | with open(pkl_file, 'rb') as f: 50 | scene_visibility_dict = pickle.load(f) 51 | print(f"Loaded pkl file from {pkl_file}.") 52 | 53 | parquet_file = pkl_file.replace(".pkl", ".parquet") 54 | # Convert to a format suitable for Parquet 55 | data = [] 56 | for scene_id, visibility_info in scene_visibility_dict.items(): 57 | # Process image_to_points 58 | for image_id, points in visibility_info["image_to_points"].items(): 59 | key = f"{scene_id}:image_to_points:{image_id}" 60 | data.append((key, json.dumps(points))) # Convert list to JSON string 61 | 62 | # Process point_to_images 63 | for point_idx, images in visibility_info["point_to_images"].items(): 64 | key = f"{scene_id}:point_to_images:{point_idx}" 65 | data.append((key, json.dumps(images))) # Convert list to JSON string 66 | 67 | # Create a DataFrame 68 | df = pd.DataFrame(data, columns=["key", "values"]) 69 | 70 | # Save as a Parquet file 71 | df.to_parquet(parquet_file, index=False) 72 | 73 | print(f"Converted {pkl_file} to {parquet_file}. The file has {len(data)} items in total.") 74 | 75 | def process_scene(scene_id, scene_infos, warning_file): 76 | """ 77 | For one scene: 78 | 1) Gather which points each image sees. 79 | 2) Invert that mapping to know which images see a given point. 80 | 3) Return scene_id and the final dict. 81 | """ 82 | print(f"[process_scene] Start: {scene_id}") 83 | image_ids = scene_infos.get_all_extrinsic_valid_image_ids(scene_id) 84 | 85 | # Get all points in the scene (use only first 3 coords) 86 | scene_points = scene_infos.get_scene_points_align(scene_id)[:, :3] 87 | num_points = scene_points.shape[0] 88 | 89 | image_to_points = {} 90 | # Initially keep track of points -> images in a list of sets (index = point_id, value = set of image_ids) 91 | point_to_images_sets = [set() for _ in range(num_points)] 92 | 93 | for image_id in image_ids: 94 | E = scene_infos.get_extrinsic_matrix_align(scene_id, image_id) 95 | scene_points_2d, scene_points_depth = scene_infos.project_3d_point_to_image( 96 | scene_id, image_id, scene_points 97 | ) 98 | in_bounds_mask = scene_infos.check_point_visibility( 99 | scene_id, image_id, scene_points_2d, scene_points_depth 100 | ) 101 | 102 | # Record which points are visible for this image 103 | visible_point_indices = np.where(in_bounds_mask)[0] 104 | image_to_points[image_id] = visible_point_indices.tolist() 105 | 106 | # Also record the inverse mapping 107 | for idx in visible_point_indices: 108 | point_to_images_sets[idx].add(image_id) 109 | 110 | # Optional: warning if no visible points 111 | if len(visible_point_indices) == 0: 112 | with open(warning_file, 'a') as f: 113 | f.write(f"[Warning] {scene_id}: {image_id} has no in-bound points.\n") 114 | 115 | # Convert from list of sets -> dict 116 | point_to_images = { 117 | idx: sorted(list(img_set)) for idx, img_set in enumerate(point_to_images_sets) # * if one point is not observed in any images, the value is an empty list 118 | } 119 | 120 | result_dict = { 121 | "image_to_points": image_to_points, 122 | "point_to_images": point_to_images 123 | } 124 | print(f"[process_scene] Done: {scene_id}") 125 | return scene_id, result_dict 126 | 127 | @TimeCounter() 128 | def run_split(scene_info_path, output_file, warning_file, num_workers=8): 129 | """ 130 | 1. Loads the SceneInfoHandler for the given split. 131 | 2. Processes each scene in parallel to get visibility info. 132 | 3. Accumulates them into a top-level dict -> { scene_id: {...} } 133 | 4. Saves everything into a pickle (or any other format) for easy reload. 134 | """ 135 | scene_infos = SceneInfoHandler(scene_info_path) 136 | all_scene_ids = scene_infos.get_all_scene_ids() 137 | mmengine.mkdir_or_exist(os.path.dirname(output_file)) 138 | 139 | if DEBUG and len(all_scene_ids) > 1: 140 | all_scene_ids = all_scene_ids[:1] 141 | print("[run_split] DEBUG mode. Only processing first scene.") 142 | 143 | print(f"[run_split] Found {len(all_scene_ids)} scenes in {scene_info_path}") 144 | print(f"[run_split] Output will be saved to {output_file}") 145 | 146 | scene_visibility_dict = {} 147 | 148 | # Prepare pool args 149 | args = [(scene_id, scene_infos, warning_file) for scene_id in all_scene_ids] 150 | 151 | with Pool(num_workers) as pool: 152 | results = [pool.apply_async(process_scene, arg) for arg in args] 153 | 154 | for r in tqdm(results, desc=f"Processing scenes in {scene_info_path}"): 155 | scene_id, visibility_info = r.get() 156 | scene_visibility_dict[scene_id] = visibility_info 157 | 158 | # Convert scene_visibility_dict to a DataFrame 159 | data = [] 160 | for scene_id, visibility_info in scene_visibility_dict.items(): 161 | # Process image_to_points 162 | for image_id, points in visibility_info["image_to_points"].items(): 163 | key = f"{scene_id},image_to_points,{image_id}" 164 | data.append((key, points)) 165 | 166 | # Process point_to_images 167 | for point_idx, images in visibility_info["point_to_images"].items(): 168 | key = f"{scene_id},point_to_images,{point_idx}" 169 | data.append((key, images)) 170 | 171 | # Create DataFrame 172 | df = pd.DataFrame(data, columns=["key", "values"]) 173 | 174 | # Save to Parquet file 175 | df.to_parquet(output_file, index=False) 176 | 177 | print(f"[run_split] Done. Wrote {len(df)} entries to {output_file}") 178 | 179 | def main(): 180 | # Adjust these as needed 181 | train_info_path = "data/scannet/scannet_instance_data/scenes_train_info_i_D5.pkl" 182 | val_info_path = "data/scannet/scannet_instance_data/scenes_val_info_i_D5.pkl" 183 | 184 | train_output_dir = "data/scannet/scannet_instance_data" 185 | val_output_dir = "data/scannet/scannet_instance_data" 186 | mmengine.mkdir_or_exist(train_output_dir) 187 | mmengine.mkdir_or_exist(val_output_dir) 188 | 189 | # Warnings 190 | train_warning_file = os.path.join(train_output_dir, "make_visibility_train_warning.txt") 191 | val_warning_file = os.path.join(val_output_dir, "make_visibility_val_warning.txt") 192 | 193 | # Output pickle files 194 | train_output_file = os.path.join(train_output_dir, "train_visibility_info_D5.parquet") 195 | val_output_file = os.path.join(val_output_dir, "val_visibility_info_D5.parquet") 196 | 197 | # If DEBUG, tweak output to avoid overwriting real data 198 | global DEBUG 199 | if DEBUG: 200 | train_warning_file = train_warning_file.replace(".txt", "_debug.txt") 201 | val_warning_file = val_warning_file.replace(".txt", "_debug.txt") 202 | train_output_file = train_output_file.replace(".parquet", "_debug.parquet") 203 | val_output_file = val_output_file.replace(".parquet", "_debug.parquet") 204 | 205 | # Number of processes to run in parallel 206 | num_workers = 25 207 | 208 | print("[main] DEBUG =", DEBUG) 209 | 210 | print(f"[main] Generating val visibility -> {val_output_file}") 211 | run_split(val_info_path, val_output_file, val_warning_file, num_workers=num_workers) # * costs 47 mins, the info file is 6G as pkl and 3.7G as parquet 212 | 213 | print(f"[main] Generating train visibility -> {train_output_file}") 214 | run_split(train_info_path, train_output_file, train_warning_file, num_workers=num_workers) # * costs 3 hours, the info file is 21G as pkl and 13G as parquet 215 | 216 | if __name__ == "__main__": 217 | main() -------------------------------------------------------------------------------- /spatial_engine/utils/scannet_utils/scannet_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from 2 | # https://github.com/facebookresearch/votenet/blob/master/scannet/scannet_utils.py 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import csv 10 | import json 11 | import os 12 | 13 | import numpy as np 14 | from plyfile import PlyData 15 | 16 | 17 | def read_aggregation(filename): 18 | assert os.path.isfile(filename) 19 | object_id_to_segs = {} 20 | label_to_segs = {} 21 | with open(filename) as f: 22 | data = json.load(f) 23 | num_objects = len(data["segGroups"]) 24 | for i in range(num_objects): 25 | object_id = ( 26 | data["segGroups"][i]["objectId"] + 1 27 | ) # instance ids should be 1-indexed 28 | label = data["segGroups"][i]["label"] 29 | segs = data["segGroups"][i]["segments"] 30 | object_id_to_segs[object_id] = segs # * segs are is list 31 | if label in label_to_segs: 32 | label_to_segs[label].extend(segs) # * for semantic segmentation 33 | else: 34 | label_to_segs[label] = segs 35 | return object_id_to_segs, label_to_segs 36 | 37 | 38 | def read_segmentation(filename): 39 | assert os.path.isfile(filename) 40 | seg_to_verts = {} 41 | with open(filename) as f: 42 | data = json.load(f) 43 | num_verts = len(data["segIndices"]) 44 | for i in range(num_verts): 45 | seg_id = data["segIndices"][i] 46 | if seg_id in seg_to_verts: 47 | seg_to_verts[seg_id].append(i) 48 | else: 49 | seg_to_verts[seg_id] = [i] 50 | return seg_to_verts, num_verts 51 | 52 | 53 | def extract_bbox(mesh_vertices, object_id_to_segs, object_id_to_label_id, instance_ids): 54 | """ 55 | Extracts bounding boxes and point clouds for each instance. 56 | 57 | Parameters: 58 | mesh_vertices (numpy.ndarray): The mesh vertices. 59 | object_id_to_segs (dict): Mapping of object IDs to segments. 60 | object_id_to_label_id (dict): Mapping of object IDs to label IDs. 61 | instance_ids (numpy.ndarray): Array of instance IDs. 62 | 63 | Returns: 64 | numpy.ndarray: An array containing bounding boxes for each instance. 65 | The ID (1-index) of each bounding box is its index in the returned array plus 1. 66 | list of numpy.ndarray: A list containing point clouds for each instance, 67 | corresponding to the bounding boxes. 'None' for instances without points. 68 | """ 69 | num_instances = len(np.unique(list(object_id_to_segs.keys()))) 70 | instance_bboxes = np.zeros((num_instances, 7)) 71 | instance_pcs = [None] * num_instances # Initialize list with None 72 | 73 | for obj_id in object_id_to_segs: 74 | label_id = object_id_to_label_id[obj_id] 75 | obj_pc = mesh_vertices[instance_ids == obj_id, 0:3] 76 | obj_pc_rgb = mesh_vertices[instance_ids == obj_id, :] 77 | if len(obj_pc) == 0: 78 | print( 79 | f"WARNING: object id {obj_id} does not have points. Corresponding entry is set to None." 80 | ) 81 | continue 82 | 83 | xyz_min = np.min(obj_pc, axis=0) 84 | xyz_max = np.max(obj_pc, axis=0) 85 | bbox = np.concatenate( 86 | [(xyz_min + xyz_max) / 2.0, xyz_max - xyz_min, np.array([label_id])] 87 | ) 88 | 89 | instance_bboxes[obj_id - 1, :] = bbox 90 | instance_pcs[obj_id - 1] = ( 91 | obj_pc_rgb # Store the point cloud at the appropriate index 92 | ) 93 | 94 | return instance_bboxes, instance_pcs 95 | 96 | 97 | def represents_int(s): 98 | """Judge whether string s represents an int. 99 | 100 | Args: 101 | s(str): The input string to be judged. 102 | 103 | Returns: 104 | bool: Whether s represents int or not. 105 | """ 106 | try: 107 | int(s) 108 | return True 109 | except ValueError: 110 | return False 111 | 112 | 113 | def read_label_mapping(filename, label_from="raw_category", label_to="nyu40id"): 114 | assert os.path.isfile(filename) 115 | mapping = dict() 116 | with open(filename) as csvfile: 117 | reader = csv.DictReader(csvfile, delimiter="\t") 118 | for row in reader: 119 | mapping[row[label_from]] = int(row[label_to]) 120 | if represents_int(list(mapping.keys())[0]): 121 | mapping = {int(k): v for k, v in mapping.items()} 122 | return mapping 123 | 124 | 125 | def read_mesh_vertices(filename): 126 | """Read XYZ for each vertex. 127 | 128 | Args: 129 | filename(str): The name of the mesh vertices file. 130 | 131 | Returns: 132 | ndarray: Vertices. 133 | """ 134 | assert os.path.isfile(filename) 135 | with open(filename, "rb") as f: 136 | plydata = PlyData.read(f) 137 | num_verts = plydata["vertex"].count 138 | vertices = np.zeros(shape=[num_verts, 3], dtype=np.float32) 139 | vertices[:, 0] = plydata["vertex"].data["x"] 140 | vertices[:, 1] = plydata["vertex"].data["y"] 141 | vertices[:, 2] = plydata["vertex"].data["z"] 142 | return vertices 143 | 144 | 145 | def read_mesh_vertices_rgb(filename): 146 | """Read XYZ and RGB for each vertex. 147 | 148 | Args: 149 | filename(str): The name of the mesh vertices file. 150 | 151 | Returns: 152 | Vertices. Note that RGB values are in 0-255. 153 | """ 154 | assert os.path.isfile(filename) 155 | with open(filename, "rb") as f: 156 | plydata = PlyData.read(f) 157 | num_verts = plydata["vertex"].count 158 | vertices = np.zeros(shape=[num_verts, 6], dtype=np.float32) 159 | vertices[:, 0] = plydata["vertex"].data["x"] 160 | vertices[:, 1] = plydata["vertex"].data["y"] 161 | vertices[:, 2] = plydata["vertex"].data["z"] 162 | vertices[:, 3] = plydata["vertex"].data["red"] 163 | vertices[:, 4] = plydata["vertex"].data["green"] 164 | vertices[:, 5] = plydata["vertex"].data["blue"] 165 | return vertices 166 | -------------------------------------------------------------------------------- /spatial_engine/utils/scannet_utils/update_info_file_with_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | 9 | import mmengine 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | # Base directory where the scene_id folders are located 14 | base_dir = "data/scannet/posed_images" 15 | scene_infos_file = "data/scannet/scannet_instance_data/scenes_train_val_info.pkl" 16 | frame_skip = 5 # store one image from every frame_skip images 17 | scene_infos = mmengine.load(scene_infos_file) 18 | 19 | # Iterate through each scene_id in the scene_info dict 20 | for scene_id in tqdm(scene_infos.keys()): 21 | # Construct the path to the current scene_id folder 22 | scene_path = os.path.join(base_dir, scene_id) 23 | 24 | # Initialize the number of posed images to 0 25 | num_posed_images = 0 26 | 27 | # Initialize a dictionary to hold image data 28 | image_data = {} 29 | 30 | # Read the intrinsic matrix 31 | intrinsic_path = os.path.join(scene_path, "intrinsic.txt") 32 | with open(intrinsic_path, "r") as f: 33 | intrinsic_matrix = np.array( 34 | [list(map(float, line.split())) for line in f.readlines()] 35 | ) 36 | 37 | # Iterate through each file in the scene_id directory 38 | all_files = os.listdir(scene_path) 39 | all_jpg_files = [f for f in all_files if f.endswith(".jpg")] 40 | all_jpg_files.sort() 41 | for i, filename in enumerate(all_jpg_files): 42 | if i % frame_skip == 0: # Only process every frame_skip-th image 43 | # Extract the image_id from the filename (e.g., "00000.jpg" -> "00000") 44 | image_id = filename.split(".")[0] 45 | # Construct paths to the image, depth image, and extrinsic matrix file 46 | image_path = f"posed_images/{scene_id}/{filename}" 47 | depth_image_path = f"posed_images/{scene_id}/{image_id}.png" 48 | extrinsic_path = os.path.join(scene_path, f"{image_id}.txt") 49 | # Read the extrinsic matrix from the file 50 | with open(extrinsic_path, "r") as f: 51 | extrinsic_matrix = np.array( 52 | [list(map(float, line.split())) for line in f.readlines()] 53 | ) 54 | # Update the image data dictionary with this image's information 55 | image_data[image_id] = { 56 | "image_path": image_path, 57 | "depth_image_path": depth_image_path, 58 | "extrinsic_matrix": extrinsic_matrix, 59 | } 60 | # Increment the count of posed images 61 | num_posed_images += 1 62 | 63 | # Update the scene_info dictionary for the current scene_id 64 | scene_infos[scene_id].update( 65 | { 66 | "num_posed_images": num_posed_images, 67 | "images_info": image_data, 68 | "intrinsic_matrix": intrinsic_matrix, 69 | } 70 | ) 71 | 72 | mmengine.dump(scene_infos, scene_infos_file.replace(".pkl", f"_i_D{frame_skip}.pkl")) 73 | --------------------------------------------------------------------------------