├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── assets
└── teaser.png
├── requirements
└── data_engine.yaml
└── spatial_engine
├── camera_movement
├── TEMPLATES.py
├── calculate_frames_relations.py
└── camera_movement_engine_train_val.py
├── depth_perception
├── depth_comparison_coor_engine.py
├── depth_comparison_dot_engine.py
├── depth_estimation_coor_engine.py
└── depth_estimation_dot_engine.py
├── object_movement
├── single_object_movement_engine_coord.py
└── single_object_movement_engine_dot.py
├── object_perception
├── compute_object_visibility.py
├── find_object_coverage.sh
├── merge_object_coverage.py
├── single_object_coverage_finder.py
└── single_object_perception_engine.py
├── utils
└── scannet_utils
│ ├── README.md
│ ├── batch_load_scannet_data.py
│ ├── extract_posed_images.py
│ ├── handler
│ ├── info_handler.py
│ └── ops.py
│ ├── make_visibility_info.py
│ ├── scannet_utils.py
│ └── update_info_file_with_images.py
└── visual_correspondence
├── visual_correspondence_qa_engine_coor_2_coor.py
└── visual_correspondence_qa_engine_dot_2_multichoice.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data
2 | __pycache__
3 | *.pkl
4 | *.parquet
5 |
6 | evaluation_data
7 | training_data
8 | tmp
9 | visualization_results
10 | Book2.csv
11 | group_acc_results.csv
12 |
13 | training_data
14 | evaluation_data
15 |
16 | .vscode
17 | .env
18 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to make participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | This Code of Conduct also applies outside the project spaces when there is a
56 | reasonable belief that an individual's behavior may have a negative impact on
57 | the project or its community.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported by contacting the project team at . All
63 | complaints will be reviewed and investigated and will result in a response that
64 | is deemed necessary and appropriate to the circumstances. The project team is
65 | obligated to maintain confidentiality with regard to the reporter of an incident.
66 | Further details of specific enforcement policies may be posted separately.
67 |
68 | Project maintainers who do not follow or enforce the Code of Conduct in good
69 | faith may face temporary or permanent repercussions as determined by other
70 | members of the project's leadership.
71 |
72 | ## Attribution
73 |
74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76 |
77 | [homepage]: https://www.contributor-covenant.org
78 |
79 | For answers to common questions about this code of conduct, see
80 | https://www.contributor-covenant.org/faq
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to Multi-SpatialMLLM
2 | We want to make contributing to this project as easy and transparent as
3 | possible.
4 |
5 | ## Pull Requests
6 | We actively welcome your pull requests.
7 |
8 | 1. Fork the repo and create your branch from `main`.
9 | 2. If you've added code that should be tested, add tests.
10 | 3. If you've changed APIs, update the documentation.
11 | 4. Ensure the test suite passes.
12 | 5. Make sure your code lints.
13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14 |
15 | ## Contributor License Agreement ("CLA")
16 | In order to accept your pull request, we need you to submit a CLA. You only need
17 | to do this once to work on any of Meta's open source projects.
18 |
19 | Complete your CLA here:
20 |
21 | ## Issues
22 | We use GitHub issues to track public bugs. Please ensure your description is
23 | clear and has sufficient instructions to be able to reproduce the issue.
24 |
25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26 | disclosure of security bugs. In those cases, please go through the process
27 | outlined on that page and do not file a public issue.
28 |
29 | ## License
30 | By contributing to Multi-SpatialMLLM, you agree that your contributions will be licensed
31 | under the LICENSE file in the root directory of this source tree.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial 4.0 International Public
58 | License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial 4.0 International Public License ("Public
63 | License"). To the extent this Public License may be interpreted as a
64 | contract, You are granted the Licensed Rights in consideration of Your
65 | acceptance of these terms and conditions, and the Licensor grants You
66 | such rights in consideration of benefits the Licensor receives from
67 | making the Licensed Material available under these terms and
68 | conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. Copyright and Similar Rights means copyright and/or similar rights
88 | closely related to copyright including, without limitation,
89 | performance, broadcast, sound recording, and Sui Generis Database
90 | Rights, without regard to how the rights are labeled or
91 | categorized. For purposes of this Public License, the rights
92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
93 | Rights.
94 | d. Effective Technological Measures means those measures that, in the
95 | absence of proper authority, may not be circumvented under laws
96 | fulfilling obligations under Article 11 of the WIPO Copyright
97 | Treaty adopted on December 20, 1996, and/or similar international
98 | agreements.
99 |
100 | e. Exceptions and Limitations means fair use, fair dealing, and/or
101 | any other exception or limitation to Copyright and Similar Rights
102 | that applies to Your use of the Licensed Material.
103 |
104 | f. Licensed Material means the artistic or literary work, database,
105 | or other material to which the Licensor applied this Public
106 | License.
107 |
108 | g. Licensed Rights means the rights granted to You subject to the
109 | terms and conditions of this Public License, which are limited to
110 | all Copyright and Similar Rights that apply to Your use of the
111 | Licensed Material and that the Licensor has authority to license.
112 |
113 | h. Licensor means the individual(s) or entity(ies) granting rights
114 | under this Public License.
115 |
116 | i. NonCommercial means not primarily intended for or directed towards
117 | commercial advantage or monetary compensation. For purposes of
118 | this Public License, the exchange of the Licensed Material for
119 | other material subject to Copyright and Similar Rights by digital
120 | file-sharing or similar means is NonCommercial provided there is
121 | no payment of monetary compensation in connection with the
122 | exchange.
123 |
124 | j. Share means to provide material to the public by any means or
125 | process that requires permission under the Licensed Rights, such
126 | as reproduction, public display, public performance, distribution,
127 | dissemination, communication, or importation, and to make material
128 | available to the public including in ways that members of the
129 | public may access the material from a place and at a time
130 | individually chosen by them.
131 |
132 | k. Sui Generis Database Rights means rights other than copyright
133 | resulting from Directive 96/9/EC of the European Parliament and of
134 | the Council of 11 March 1996 on the legal protection of databases,
135 | as amended and/or succeeded, as well as other essentially
136 | equivalent rights anywhere in the world.
137 |
138 | l. You means the individual or entity exercising the Licensed Rights
139 | under this Public License. Your has a corresponding meaning.
140 |
141 |
142 | Section 2 -- Scope.
143 |
144 | a. License grant.
145 |
146 | 1. Subject to the terms and conditions of this Public License,
147 | the Licensor hereby grants You a worldwide, royalty-free,
148 | non-sublicensable, non-exclusive, irrevocable license to
149 | exercise the Licensed Rights in the Licensed Material to:
150 |
151 | a. reproduce and Share the Licensed Material, in whole or
152 | in part, for NonCommercial purposes only; and
153 |
154 | b. produce, reproduce, and Share Adapted Material for
155 | NonCommercial purposes only.
156 |
157 | 2. Exceptions and Limitations. For the avoidance of doubt, where
158 | Exceptions and Limitations apply to Your use, this Public
159 | License does not apply, and You do not need to comply with
160 | its terms and conditions.
161 |
162 | 3. Term. The term of this Public License is specified in Section
163 | 6(a).
164 |
165 | 4. Media and formats; technical modifications allowed. The
166 | Licensor authorizes You to exercise the Licensed Rights in
167 | all media and formats whether now known or hereafter created,
168 | and to make technical modifications necessary to do so. The
169 | Licensor waives and/or agrees not to assert any right or
170 | authority to forbid You from making technical modifications
171 | necessary to exercise the Licensed Rights, including
172 | technical modifications necessary to circumvent Effective
173 | Technological Measures. For purposes of this Public License,
174 | simply making modifications authorized by this Section 2(a)
175 | (4) never produces Adapted Material.
176 |
177 | 5. Downstream recipients.
178 |
179 | a. Offer from the Licensor -- Licensed Material. Every
180 | recipient of the Licensed Material automatically
181 | receives an offer from the Licensor to exercise the
182 | Licensed Rights under the terms and conditions of this
183 | Public License.
184 |
185 | b. No downstream restrictions. You may not offer or impose
186 | any additional or different terms or conditions on, or
187 | apply any Effective Technological Measures to, the
188 | Licensed Material if doing so restricts exercise of the
189 | Licensed Rights by any recipient of the Licensed
190 | Material.
191 |
192 | 6. No endorsement. Nothing in this Public License constitutes or
193 | may be construed as permission to assert or imply that You
194 | are, or that Your use of the Licensed Material is, connected
195 | with, or sponsored, endorsed, or granted official status by,
196 | the Licensor or others designated to receive attribution as
197 | provided in Section 3(a)(1)(A)(i).
198 |
199 | b. Other rights.
200 |
201 | 1. Moral rights, such as the right of integrity, are not
202 | licensed under this Public License, nor are publicity,
203 | privacy, and/or other similar personality rights; however, to
204 | the extent possible, the Licensor waives and/or agrees not to
205 | assert any such rights held by the Licensor to the limited
206 | extent necessary to allow You to exercise the Licensed
207 | Rights, but not otherwise.
208 |
209 | 2. Patent and trademark rights are not licensed under this
210 | Public License.
211 |
212 | 3. To the extent possible, the Licensor waives any right to
213 | collect royalties from You for the exercise of the Licensed
214 | Rights, whether directly or through a collecting society
215 | under any voluntary or waivable statutory or compulsory
216 | licensing scheme. In all other cases the Licensor expressly
217 | reserves any right to collect such royalties, including when
218 | the Licensed Material is used other than for NonCommercial
219 | purposes.
220 |
221 |
222 | Section 3 -- License Conditions.
223 |
224 | Your exercise of the Licensed Rights is expressly made subject to the
225 | following conditions.
226 |
227 | a. Attribution.
228 |
229 | 1. If You Share the Licensed Material (including in modified
230 | form), You must:
231 |
232 | a. retain the following if it is supplied by the Licensor
233 | with the Licensed Material:
234 |
235 | i. identification of the creator(s) of the Licensed
236 | Material and any others designated to receive
237 | attribution, in any reasonable manner requested by
238 | the Licensor (including by pseudonym if
239 | designated);
240 |
241 | ii. a copyright notice;
242 |
243 | iii. a notice that refers to this Public License;
244 |
245 | iv. a notice that refers to the disclaimer of
246 | warranties;
247 |
248 | v. a URI or hyperlink to the Licensed Material to the
249 | extent reasonably practicable;
250 |
251 | b. indicate if You modified the Licensed Material and
252 | retain an indication of any previous modifications; and
253 |
254 | c. indicate the Licensed Material is licensed under this
255 | Public License, and include the text of, or the URI or
256 | hyperlink to, this Public License.
257 |
258 | 2. You may satisfy the conditions in Section 3(a)(1) in any
259 | reasonable manner based on the medium, means, and context in
260 | which You Share the Licensed Material. For example, it may be
261 | reasonable to satisfy the conditions by providing a URI or
262 | hyperlink to a resource that includes the required
263 | information.
264 |
265 | 3. If requested by the Licensor, You must remove any of the
266 | information required by Section 3(a)(1)(A) to the extent
267 | reasonably practicable.
268 |
269 | 4. If You Share Adapted Material You produce, the Adapter's
270 | License You apply must not prevent recipients of the Adapted
271 | Material from complying with this Public License.
272 |
273 |
274 | Section 4 -- Sui Generis Database Rights.
275 |
276 | Where the Licensed Rights include Sui Generis Database Rights that
277 | apply to Your use of the Licensed Material:
278 |
279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280 | to extract, reuse, reproduce, and Share all or a substantial
281 | portion of the contents of the database for NonCommercial purposes
282 | only;
283 |
284 | b. if You include all or a substantial portion of the database
285 | contents in a database in which You have Sui Generis Database
286 | Rights, then the database in which You have Sui Generis Database
287 | Rights (but not its individual contents) is Adapted Material; and
288 |
289 | c. You must comply with the conditions in Section 3(a) if You Share
290 | all or a substantial portion of the contents of the database.
291 |
292 | For the avoidance of doubt, this Section 4 supplements and does not
293 | replace Your obligations under this Public License where the Licensed
294 | Rights include other Copyright and Similar Rights.
295 |
296 |
297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298 |
299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309 |
310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319 |
320 | c. The disclaimer of warranties and limitation of liability provided
321 | above shall be interpreted in a manner that, to the extent
322 | possible, most closely approximates an absolute disclaimer and
323 | waiver of all liability.
324 |
325 |
326 | Section 6 -- Term and Termination.
327 |
328 | a. This Public License applies for the term of the Copyright and
329 | Similar Rights licensed here. However, if You fail to comply with
330 | this Public License, then Your rights under this Public License
331 | terminate automatically.
332 |
333 | b. Where Your right to use the Licensed Material has terminated under
334 | Section 6(a), it reinstates:
335 |
336 | 1. automatically as of the date the violation is cured, provided
337 | it is cured within 30 days of Your discovery of the
338 | violation; or
339 |
340 | 2. upon express reinstatement by the Licensor.
341 |
342 | For the avoidance of doubt, this Section 6(b) does not affect any
343 | right the Licensor may have to seek remedies for Your violations
344 | of this Public License.
345 |
346 | c. For the avoidance of doubt, the Licensor may also offer the
347 | Licensed Material under separate terms or conditions or stop
348 | distributing the Licensed Material at any time; however, doing so
349 | will not terminate this Public License.
350 |
351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352 | License.
353 |
354 |
355 | Section 7 -- Other Terms and Conditions.
356 |
357 | a. The Licensor shall not be bound by any additional or different
358 | terms or conditions communicated by You unless expressly agreed.
359 |
360 | b. Any arrangements, understandings, or agreements regarding the
361 | Licensed Material not stated herein are separate from and
362 | independent of the terms and conditions of this Public License.
363 |
364 |
365 | Section 8 -- Interpretation.
366 |
367 | a. For the avoidance of doubt, this Public License does not, and
368 | shall not be interpreted to, reduce, limit, restrict, or impose
369 | conditions on any use of the Licensed Material that could lawfully
370 | be made without permission under this Public License.
371 |
372 | b. To the extent possible, if any provision of this Public License is
373 | deemed unenforceable, it shall be automatically reformed to the
374 | minimum extent necessary to make it enforceable. If the provision
375 | cannot be reformed, it shall be severed from this Public License
376 | without affecting the enforceability of the remaining terms and
377 | conditions.
378 |
379 | c. No term or condition of this Public License will be waived and no
380 | failure to comply consented to unless expressly agreed to by the
381 | Licensor.
382 |
383 | d. Nothing in this Public License constitutes or may be interpreted
384 | as a limitation upon, or waiver of, any privileges and immunities
385 | that apply to the Licensor or You, including from the legal
386 | processes of any jurisdiction or authority.
387 |
388 | =======================================================================
389 |
390 | Creative Commons is not a party to its public
391 | licenses. Notwithstanding, Creative Commons may elect to apply one of
392 | its public licenses to material it publishes and in those instances
393 | will be considered the "Licensor." The text of the Creative Commons
394 | public licenses is dedicated to the public domain under the CC0 Public
395 | Domain Dedication. Except for the limited purpose of indicating that
396 | material is shared under a Creative Commons public license or as
397 | otherwise permitted by the Creative Commons policies published at
398 | creativecommons.org/policies, Creative Commons does not authorize the
399 | use of the trademark "Creative Commons" or any other trademark or logo
400 | of Creative Commons without its prior written consent including,
401 | without limitation, in connection with any unauthorized modifications
402 | to any of its public licenses or any other arrangements,
403 | understandings, or agreements concerning use of licensed material. For
404 | the avoidance of doubt, this paragraph does not form part of the
405 | public licenses.
406 |
407 | Creative Commons may be contacted at creativecommons.org.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Multi-SpatialMLLM: Multi-Frame Spatial Understanding with Multi-Modal Large Language Models
3 |
4 | Runsen Xu
5 | Weiyao Wang
6 | Hao Tang
7 | Xingyu Chen
8 | Xiaodong Wang
9 | Fu-Jen Chu
10 |
11 | Dahua Lin
12 | Matt Feiszli
13 | Kevin J. Liang
14 |
15 | FAIR, Meta The Chinese University of Hong Kong
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 | ## 🏠 About
35 |
36 |
37 |

38 |
39 | We present Multi-SpatialMLLM to equip MLLMs with robust multi-frame spatial understanding by integrating depth perception, visual correspondence, and dynamic perception. Central to our approach is the MultiSPA dataset, a novel, large-scale collection of more than 27 million samples spanning diverse 3D and 4D scenes. Alongside MultiSPA, we introduce a comprehensive benchmark that tests a wide spectrum of spatial tasks under uniform metrics. Our model achieves significant gains over baselines and proprietary systems, demonstrating scalable, generalizable multi-frame reasoning. We further observe multi-task benefits and early indications of emergent capabilities in challenging scenarios, and showcase our model can serve as a multi-frame reward annotator for robotics.
40 |
41 |
42 | ## 🔥 News
43 | - [2025-05-22] We release the [paper](https://arxiv.org/abs/2505.17015) of Multi-SpatialMLLM and the codes of our data engine. 🎉
44 |
45 | ## 📋 Contents
46 | - [📦 Data Engine](#-data-engine)
47 | - [🏋️♂️ Model Training](#-model-training)
48 | - [🔗 Citation](#-citation)
49 | - [📄 License](#-license)
50 | - [👥 Contributing](#-contributing)
51 |
52 | ## Data Engine
53 | ### Environment
54 | To set up the Conda environment for our data engine, please follow these steps:
55 |
56 | 1. Ensure you have [Anaconda](https://www.anaconda.com/products/distribution) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html) installed.
57 | 2. Clone this repository.
58 | ```bash
59 | git clone https://github.com/facebookresearch/Multi-SpatialMLLM.git
60 | cd Multi-SpatialMLLM
61 | ```
62 | 3. Create the Conda environment using the provided YAML file:
63 | ```bash
64 | conda env create -f requirements/data_engine.yaml
65 | ```
66 | 4. Activate the newly created environment:
67 | ```bash
68 | conda activate data_engine
69 | ```
70 |
71 | ### Data Preparation
72 | #### ScanNet
73 | Please follow [spatial_engine/utils/scannet_utils/README.md](spatial_engine/utils/scannet_utils/README.md) to download and process the ScanNet data.
74 |
75 | #### TAPVid-3D
76 | Follow [TAPVid-3D](https://github.com/google-deepmind/tapnet/tree/main/tapnet/tapvid3d) to download the data. We only use the ADT and PStudio subsets. You need to download the version with camera extrinsics annotation according to [this](https://github.com/google-deepmind/tapnet/issues/115).
77 |
78 | Here are some notes for your reference, but simply following the official script is enough.
79 |
80 | - Change the download url in `tapnet/tapnet/tapvid3d/annotation_generation/gcs_utils.py` from `https://storage.googleapis.com/dm-tapnet/tapvid3d/release_files/rc4` to `"https://storage.googleapis.com/dm-tapnet/tapvid3d/release_files/rc5"`. Also, modify `tapnet/tapnet/tapvid3d/annotation_generation/adt_utils.py` to store the extrinsics_w2c in the npz file like below.
81 | ```python
82 | # also add this for warning
83 | sequence_path = os.path.join(input_adt_path, adt_v2_name)
84 | # if the sequence_path does not exist, write to a warning file and exit
85 | if not os.path.exists(sequence_path):
86 | with open(f"adt_warning.txt", "a") as f:
87 | f.write(f"Sequence {seq_name} does not exist.")
88 | return
89 | ...
90 | ...
91 |
92 | queries_xyt = in_npz["queries_xyt"]
93 | trajectories = in_npz["tracks_XYZ"]
94 | visibilities = in_npz["visibility"]
95 | extrinsics_w2c = in_npz["extrinsics_w2c"] # add this
96 |
97 | # Verify video means.
98 | video_means = np.stack([np.mean(x, axis=(0, 1)) for x in rgb_ims], axis=0)
99 | assert np.allclose(video_means, in_npz["video_means"], atol=1e-3)
100 |
101 | example = {
102 | "images_jpeg_bytes": rgb_jpegs,
103 | "queries_xyt": queries_xyt,
104 | "tracks_XYZ": trajectories,
105 | "visibility": visibilities,
106 | "fx_fy_cx_cy": np.array(
107 | [FOCAL_LENGTH, FOCAL_LENGTH, WIDTH / 2, HEIGHT / 2]
108 | ),
109 | "extrinsics_w2c": extrinsics_w2c # add this
110 | }
111 | ```
112 | - Each npz file from TAPVid-3D contains the following keys:
113 | ```python
114 | """
115 | Each *.npz file contains:
116 |
117 | images_jpeg_bytes: tensor of shape [# of frames, height, width, 3], each frame stored as JPEG bytes that must be decoded
118 | intrinsics: (fx, fy, cx, cy) camera intrinsics of the video
119 | tracks_xyz: tensor of shape (# of frames, # of point tracks, 3), representing the 3D point trajectories and the last dimension is the (x, y, z) point position in meters. They are in camera coordinates.
120 | visibility: tensor of shape (# of frames, # of point tracks), representing the visibility of each point along its trajectory
121 | queries_xyt: tensor of shape (# of point tracks, 3), representing the query point used in the benchmark as the initial given point to track. The last dimension is given in (x, y, t), where x,y are the pixel location of the query point and t is the query frame.
122 | extrinsics_w2c: tensor of shape (#, 4, 4)
123 | """
124 | ```
125 | - For Pstudio, after running the official script, you will have a `tmp` folder inside, which is used to store the original video (image) from the pstudio dataset. You can just omit this folder.
126 | - For ADT
127 | - Need to download the original files from the Project Aria website and place the data in `data/projectaria_tools_adt_data`.
128 | ```
129 | pip install projectaria-tools'[all]'
130 |
131 | # get a adt_download_urls.json from the official website
132 | mkdir data/projectaria_tools_adt_data
133 | mv adt_download_urls.json data/projectaria_tools_adt_data
134 |
135 | # download the data with all the types, it costs 1.4 T in total.
136 | aria_dataset_downloader -c data/projectaria_tools_adt_data/adt_download_urls.json -o data/projectaria_tools_adt_data/ -l all
137 | ```
138 | - Then run the official script to download the query points and postprocess them to store image info inside the npz file.
139 | ```
140 | cd tapnet
141 | ADT_OUTPUT_DIRECTORY="tapvid3d_dataset/adt/"\nmkdir -p $ADT_OUTPUT_DIRECTORY
142 | PYTHON_DEBUG="False"
143 | conda activate projectaria # if applicable, use a new env
144 | python3 -m tapnet.tapvid3d.annotation_generation.generate_adt --output_dir=$ADT_OUTPUT_DIRECTORY --debug=$PYTHON_DEBUG --split=all --adt_base_path data/projectaria_tools_adt_data
145 | ```
146 | Specifically, in `tapnet/tapnet/tapvid3d/annotation_generation/generate_adt.py`:
147 | ```
148 | gcs_utils.download_tapvid3d_files(tmp_adt_dir, _SPLIT.value, "adt", _DEBUG.value)
149 | ```
150 | This function will download the npz files from the given url to `tmp` (which costs about 11G), and the following function will merge the images/videos from `adt_base_path` to the npz file.
151 | ```
152 | generate_adt_npz(_ADT_BASE_PATH.value, tmp_adt_dir, _OUTPUT_DIR.value)
153 | ```
154 |
155 |
156 |
157 | Finally, we assume the structure of the data is like:
158 | ```bash
159 | data/tapvid3d_dataset
160 | ├── adt
161 | │ ├── "scene_id".npz
162 | ├── pstudio
163 | │ ├── "scene_id".npz
164 | ```
165 |
166 | ### Data Generation
167 | We generate the data based on the conversation format of [InternVL](https://internvl.readthedocs.io/en/latest/get_started/chat_data_format.html#multi-image-data). You could easily change the generated jsonl file to the format of your own.
168 | #### Camera Movement
169 | 1. Run `python spatial_engine/camera_movement/calculate_frames_relations.py` to calculate the spatial relations between frames, e.g. their overlap ratios. After running this script, a parquet containing this spatial information will be generated in `training_data/camera_movement` and `evaluation_data/camera_movement`.
170 |
171 | 2. Then run `python spatial_engine/camera_movement/camera_movement_engine_train_val.py` to generate the training and evaluation data.
172 |
173 | #### Depth Perception
174 | 1. Run `python spatial_engine/utils/scannet_utils/make_visibility_info.py` to compute the visibility information for each frame. Note that when using this information, loading the file takes a long time, about several minutes.
175 | 2. Run `python spatial_engine/depth_perception/depth_estimation_dot_engine.py` to generate the training and evaluation data for visual-based depth estimation.
176 | 3. Run `python spatial_engine/depth_perception/depth_estimation_coor_engine.py` to generate the training and evaluation data for coordinate-based depth estimation.
177 | 4. Run `python spatial_engine/depth_perception/depth_comparison_dot_engine.py` to generate the training and evaluation data for visual-based depth comparison.
178 | 5. Run `python spatial_engine/depth_perception/depth_comparison_coor_engine.py` to generate the training and evaluation data for coordinate-based depth comparison.
179 |
180 | #### Visual Correspondence
181 | 1. Run `python spatial_engine/visual_correspondence/visual_correspondence_qa_engine_dot_2_multichoice.py` to generate the training and evaluation data for visual correspondence in dot-based multichoice format.
182 | 2. Run `python spatial_engine/visual_correspondence/visual_correspondence_qa_engine_coor_2_coor.py` to generate the training and evaluation data for visual correspondence in coordinate-based format.
183 |
184 | #### Object Perception
185 | 1. Run `python spatial_engine/object_perception/compute_object_visibility.py` to compute the visibility information for each object. After running this script, a pkl file containing this visibility information will be saved to `training_data/object_perception` and `evaluation_data/object_perception`.
186 | 2. Run `bash find_object_coverage.sh` to compute the coverage information for each object in each scene. You could check `spatial_engine/object_perception/single_object_coverage_finder.py` to modify the parameters and run it with several processes.
187 | 3. After generating all the coverage information for each scene, run `python spatial_engine/object_perception/merge_object_coverage.py` to merge the coverage information.
188 | 4. Run `python spatial_engine/object_perception/single_object_perception_engine.py` to generate the training and evaluation data for object perception.
189 |
190 | #### Object Movement
191 | 1. Run `python spatial_engine/object_movement/single_object_movement_engine_coord.py` to generate the training and evaluation data for object movement in coordinate-based format. After running this script, images will be extracted from the npz file and saved to `data/my_tapvid3d_images`.
192 | 2. Run `python spatial_engine/object_movement/single_object_movement_engine_dot.py` to generate the training and evaluation data for object movement in dot-based format.
193 |
194 | ## 🏋️♂️ Model Training
195 |
196 | We use the [InternVL-2](https://internvl.readthedocs.io/en/latest/internvl2.0/finetune.html#start-2nd-fine-tuning) models for experiments in our paper. You could follow their official instructions to easily fine-tune the models with the generated data and reproduce our results. Other VLMs can also be used. Below are some training details used in our experiments, and more can be found in our paper.
197 | - All images should be resized to `H*W=1296*968` for training.
198 | - Different from the original InternVL setting of dynamically allocating 12 image tiles to all images, we make sure each image can use up to 6 image tiles for training and evaluation. Please change this [line](https://github.com/OpenGVLab/InternVL/blob/dd3635206874c92386185d586fffeda1026d3a76/internvl_chat/internvl/train/internvl_chat_finetune.py#L488) to `max_num=self.max_dynamic_patch`. Pay attention to GPU OOM issues, and you may change the `--max_seq_length` to `8192`.
199 | - The training config used for our main paper is in `data/configs/mix3M.json`. Note that this config only uses 3M training samples, and we use LoRA training for research efficiency. You could use more data and fully fine-tune the whole model to get much better performance.
200 | - To preserve the original ability of the model, some general instruction-following data should be added to the training data.
201 |
202 | ## 🔗 Citation
203 |
204 | If you find our work and this codebase helpful, please consider starring this repo 🌟 and cite:
205 |
206 | ```bibtex
207 | @article{xu2025multi,
208 | title={Multi-SpatialMLLM: Multi-Frame Spatial Understanding with Multi-Modal Large Language Models},
209 | author={Xu, Runsen and Wang, Weiyao and Tang, Hao and Chen, Xingyu and Wang, Xiaodong and Chu, Fu-Jen and Lin, Dahua and Feiszli, Matt and Liang, Kevin J.},
210 | journal={arXiv preprint arXiv:2505.17015},
211 | year={2025}
212 | }
213 | ```
214 |
215 | ## 📄 License
216 |
217 | Shield: [![CC BY-NC 4.0][cc-by-nc-shield]][cc-by-nc]
218 |
219 | This work is licensed under a
220 | [Creative Commons Attribution-NonCommercial 4.0 International License][cc-by-nc].
221 |
222 | [![CC BY-NC 4.0][cc-by-nc-image]][cc-by-nc]
223 |
224 | [cc-by-nc]: https://creativecommons.org/licenses/by-nc/4.0/
225 | [cc-by-nc-image]: https://licensebuttons.net/l/by-nc/4.0/88x31.png
226 | [cc-by-nc-shield]: https://img.shields.io/badge/License-CC%20BY--NC%204.0-lightgrey.svg
227 |
228 | ## 👥 Contributing
229 |
230 | See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
231 |
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/Multi-SpatialMLLM/2f2eda329f9aa047428413d77e925271db5d4f33/assets/teaser.png
--------------------------------------------------------------------------------
/requirements/data_engine.yaml:
--------------------------------------------------------------------------------
1 | name: data_engine
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=conda_forge
7 | - _openmp_mutex=4.5=2_gnu
8 | - asttokens=3.0.0=pyhd8ed1ab_1
9 | - bzip2=1.0.8=h5eee18b_6
10 | - ca-certificates=2024.12.31=h06a4308_0
11 | - comm=0.2.2=pyhd8ed1ab_1
12 | - debugpy=1.8.12=py311hfdbb021_0
13 | - decorator=5.1.1=pyhd8ed1ab_1
14 | - exceptiongroup=1.2.2=pyhd8ed1ab_1
15 | - executing=2.1.0=pyhd8ed1ab_1
16 | - importlib-metadata=8.6.1=pyha770c72_0
17 | - ipykernel=6.29.5=pyh3099207_0
18 | - ipython=8.31.0=pyh707e725_0
19 | - jedi=0.19.2=pyhd8ed1ab_1
20 | - jupyter_client=8.6.3=pyhd8ed1ab_1
21 | - jupyter_core=5.7.2=pyh31011fe_1
22 | - krb5=1.21.3=h143b758_0
23 | - ld_impl_linux-64=2.40=h12ee557_0
24 | - libedit=3.1.20230828=h5eee18b_0
25 | - libffi=3.4.4=h6a678d5_1
26 | - libgcc=14.2.0=h77fa898_1
27 | - libgcc-ng=14.2.0=h69a702a_1
28 | - libgomp=14.2.0=h77fa898_1
29 | - libsodium=1.0.20=h4ab18f5_0
30 | - libstdcxx=14.2.0=hc0a3c3a_1
31 | - libstdcxx-ng=11.2.0=h1234567_1
32 | - libuuid=1.41.5=h5eee18b_0
33 | - matplotlib-inline=0.1.7=pyhd8ed1ab_1
34 | - ncurses=6.4=h6a678d5_0
35 | - nest-asyncio=1.6.0=pyhd8ed1ab_1
36 | - openssl=3.4.0=h7b32b05_1
37 | - packaging=24.2=pyhd8ed1ab_2
38 | - parso=0.8.4=pyhd8ed1ab_1
39 | - pexpect=4.9.0=pyhd8ed1ab_1
40 | - pickleshare=0.7.5=pyhd8ed1ab_1004
41 | - pip=24.2=py311h06a4308_0
42 | - platformdirs=4.3.6=pyhd8ed1ab_1
43 | - prompt-toolkit=3.0.50=pyha770c72_0
44 | - psutil=6.1.1=py311h9ecbd09_0
45 | - ptyprocess=0.7.0=pyhd8ed1ab_1
46 | - pure_eval=0.2.3=pyhd8ed1ab_1
47 | - pygments=2.19.1=pyhd8ed1ab_0
48 | - python=3.11.11=he870216_0
49 | - python-dateutil=2.9.0.post0=pyhff2d567_1
50 | - python_abi=3.11=2_cp311
51 | - pyzmq=26.2.0=py311h7deb3e3_3
52 | - readline=8.2=h5eee18b_0
53 | - setuptools=75.1.0=py311h06a4308_0
54 | - six=1.17.0=pyhd8ed1ab_0
55 | - sqlite=3.45.3=h5eee18b_0
56 | - stack_data=0.6.3=pyhd8ed1ab_1
57 | - tk=8.6.14=h39e8969_0
58 | - tornado=6.4.2=py311h9ecbd09_0
59 | - traitlets=5.14.3=pyhd8ed1ab_1
60 | - typing_extensions=4.12.2=pyha770c72_1
61 | - wcwidth=0.2.13=pyhd8ed1ab_1
62 | - wheel=0.44.0=py311h06a4308_0
63 | - xz=5.4.6=h5eee18b_1
64 | - zeromq=4.3.5=h3b0a872_7
65 | - zipp=3.21.0=pyhd8ed1ab_1
66 | - zlib=1.2.13=h5eee18b_1
67 | - pip:
68 | - absl-py==2.1.0
69 | - accelerate==0.34.2
70 | - addict==2.4.0
71 | - aiofiles==24.1.0
72 | - aiohappyeyeballs==2.4.4
73 | - aiohttp==3.11.11
74 | - aiosignal==1.3.2
75 | - altair==5.5.0
76 | - annotated-types==0.7.0
77 | - anyio==4.8.0
78 | - attrs==24.3.0
79 | - beautifulsoup4==4.12.3
80 | - bitsandbytes==0.41.0
81 | - blinker==1.9.0
82 | - cachetools==5.5.1
83 | - certifi==2024.12.14
84 | - charset-normalizer==3.4.1
85 | - chex==0.1.88
86 | - click==8.1.8
87 | - colorama==0.4.6
88 | - contourpy==1.3.1
89 | - cycler==0.12.1
90 | - decord==0.6.0
91 | - deepspeed==0.13.5
92 | - defusedxml==0.7.1
93 | - distro==1.9.0
94 | - docker-pycreds==0.4.0
95 | - einops==0.6.1
96 | - einops-exts==0.0.4
97 | - einshape==1.0
98 | - et-xmlfile==2.0.0
99 | - fastapi==0.115.7
100 | - ffmpy==0.5.0
101 | - filelock==3.17.0
102 | - flow-vis==0.1
103 | - fonttools==4.55.5
104 | - frozenlist==1.5.0
105 | - fsspec==2024.12.0
106 | - future==1.0.0
107 | - gdown==5.2.0
108 | - gitdb==4.0.12
109 | - gitpython==3.1.44
110 | - gradio==3.35.2
111 | - gradio-client==0.2.9
112 | - greenlet==3.1.1
113 | - grpcio==1.70.0
114 | - h11==0.14.0
115 | - h5py==3.13.0
116 | - hjson==3.1.0
117 | - httpcore==0.17.3
118 | - httpx==0.24.0
119 | - huggingface-hub==0.27.1
120 | - idna==3.10
121 | - imageio==2.37.0
122 | - ipdb==0.13.13
123 | - ipywidgets==8.1.5
124 | - jax==0.5.0
125 | - jaxlib==0.5.0
126 | - jinja2==3.1.5
127 | - jiter==0.8.2
128 | - joblib==1.4.2
129 | - jsonlines==4.0.0
130 | - jsonpatch==1.33
131 | - jsonpointer==3.0.0
132 | - jsonschema==4.23.0
133 | - jsonschema-specifications==2024.10.1
134 | - jupyterlab-widgets==3.0.13
135 | - kiwisolver==1.4.8
136 | - langchain==0.3.0
137 | - langchain-core==0.3.6
138 | - langchain-openai==0.2.0
139 | - langchain-text-splitters==0.3.0
140 | - langsmith==0.1.129
141 | - latex2mathml==3.77.0
142 | - linkify-it-py==2.0.3
143 | - markdown==3.7
144 | - markdown-it-py==2.2.0
145 | - markdown2==2.5.2
146 | - markupsafe==3.0.2
147 | - matplotlib==3.10.0
148 | - mdit-py-plugins==0.3.3
149 | - mdurl==0.1.2
150 | - mediapy==1.2.2
151 | - ml-dtypes==0.5.1
152 | - mmcls==0.25.0
153 | - mmcv-full==1.6.2
154 | - mmengine==0.10.6
155 | - mmsegmentation==0.30.0
156 | - model-index==0.1.11
157 | - mpmath==1.3.0
158 | - multidict==6.1.0
159 | - narwhals==1.23.0
160 | - networkx==3.4.2
161 | - ninja==1.11.1.3
162 | - numpy==1.26.4
163 | - nvidia-cublas-cu11==11.11.3.6
164 | - nvidia-cublas-cu12==12.4.5.8
165 | - nvidia-cuda-cupti-cu11==11.8.87
166 | - nvidia-cuda-cupti-cu12==12.4.127
167 | - nvidia-cuda-nvrtc-cu11==11.8.89
168 | - nvidia-cuda-nvrtc-cu12==12.4.127
169 | - nvidia-cuda-runtime-cu11==11.8.89
170 | - nvidia-cuda-runtime-cu12==12.4.127
171 | - nvidia-cudnn-cu11==8.7.0.84
172 | - nvidia-cudnn-cu12==9.1.0.70
173 | - nvidia-cufft-cu11==10.9.0.58
174 | - nvidia-cufft-cu12==11.2.1.3
175 | - nvidia-curand-cu11==10.3.0.86
176 | - nvidia-curand-cu12==10.3.5.147
177 | - nvidia-cusolver-cu11==11.4.1.48
178 | - nvidia-cusolver-cu12==11.6.1.9
179 | - nvidia-cusparse-cu11==11.7.5.86
180 | - nvidia-cusparse-cu12==12.3.1.170
181 | - nvidia-ml-py==12.560.30
182 | - nvidia-nccl-cu11==2.20.5
183 | - nvidia-nccl-cu12==2.21.5
184 | - nvidia-nvjitlink-cu12==12.4.127
185 | - nvidia-nvtx-cu11==11.8.86
186 | - nvidia-nvtx-cu12==12.4.127
187 | - openai==1.60.1
188 | - opencv-python==4.11.0.86
189 | - opendatalab==0.0.10
190 | - openmim==0.3.9
191 | - openpyxl==3.1.5
192 | - openxlab==0.0.11
193 | - opt-einsum==3.4.0
194 | - ordered-set==4.1.0
195 | - orjson==3.10.15
196 | - pandas==2.2.3
197 | - peft==0.12.0
198 | - pillow==11.1.0
199 | - prettytable==3.12.0
200 | - propcache==0.2.1
201 | - protobuf==5.29.3
202 | - py-cpuinfo==9.0.0
203 | - pyarrow==19.0.0
204 | - pycocoevalcap==1.2
205 | - pycocotools==2.0.8
206 | - pycryptodome==3.21.0
207 | - pydantic==2.10.6
208 | - pydantic-core==2.27.2
209 | - pydeck==0.9.1
210 | - pydub==0.25.1
211 | - pynvml==12.0.0
212 | - pyparsing==3.2.1
213 | - pysocks==1.7.1
214 | - python-dotenv==1.0.1
215 | - python-multipart==0.0.20
216 | - pytz==2024.2
217 | - pyyaml==6.0.2
218 | - referencing==0.36.1
219 | - regex==2024.11.6
220 | - requests==2.32.3
221 | - requests-toolbelt==1.0.0
222 | - rich==13.9.4
223 | - rpds-py==0.22.3
224 | - safetensors==0.5.2
225 | - scenepic==1.1.1
226 | - scikit-learn==1.6.1
227 | - scipy==1.15.1
228 | - seaborn==0.13.2
229 | - semantic-version==2.10.0
230 | - sentencepiece==0.1.99
231 | - sentry-sdk==2.20.0
232 | - setproctitle==1.3.4
233 | - shortuuid==1.0.13
234 | - smmap==5.0.2
235 | - sniffio==1.3.1
236 | - soupsieve==2.6
237 | - sqlalchemy==2.0.37
238 | - starlette==0.45.2
239 | - streamlit==1.41.1
240 | - streamlit-image-select==0.6.0
241 | - supervision==0.25.1
242 | - svgwrite==1.4.3
243 | - sympy==1.13.1
244 | - tabulate==0.9.0
245 | - tenacity==8.5.0
246 | - tensorboard==2.18.0
247 | - tensorboard-data-server==0.7.2
248 | - tensorboardx==2.6.2.2
249 | - termcolor==2.5.0
250 | - threadpoolctl==3.5.0
251 | - tiktoken==0.8.0
252 | - timm==0.9.12
253 | - tokenizers==0.15.1
254 | - toml==0.10.2
255 | - tomli==2.2.1
256 | - toolz==1.0.0
257 | - torch==2.3.1+cu118
258 | - torchaudio==2.3.1+cu118
259 | - torchvision==0.18.1+cu118
260 | - tqdm==4.67.1
261 | - transformers==4.37.2
262 | - triton==2.3.1
263 | - tzdata==2025.1
264 | - uc-micro-py==1.0.3
265 | - ultralytics==8.3.67
266 | - ultralytics-thop==2.0.14
267 | - urllib3==2.3.0
268 | - uvicorn==0.34.0
269 | - wandb==0.19.4
270 | - watchdog==6.0.0
271 | - wavedrom==2.0.3.post3
272 | - websockets==14.2
273 | - werkzeug==3.1.3
274 | - widgetsnbextension==4.0.13
275 | - yacs==0.1.8
276 | - yapf==0.40.1
277 | - yarl==1.18.3
278 | - zstandard==0.23.0
279 | prefix: /home/runsenxu/miniconda3/envs/data_engine
280 |
--------------------------------------------------------------------------------
/spatial_engine/camera_movement/calculate_frames_relations.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | In total, with D5 scene info of ScanNet
9 | for train, we have 82M records, with 35M nonzero records. (82654914 and 35063709)
10 | for val, we have 24M records, with 10M nonzero records. (24117039 and 10156707)
11 | """
12 |
13 | import numpy as np
14 | from tqdm import tqdm
15 | import torch
16 | import mmengine
17 |
18 | from spatial_engine.utils.scannet_utils.handler.info_handler import SceneInfoHandler
19 | from multiprocessing import Pool
20 | import pandas as pd
21 | import os
22 |
23 | # set random seed
24 | np.random.seed(0)
25 |
26 | DEBUG = False # Set to True for debug mode.
27 |
28 | def save_overlap_info(overlap_info, parquet_file):
29 | """
30 | Convert a nested dictionary into a DataFrame and save it as Parquet.
31 | The nested dictionary is expected in the form:
32 |
33 | overlap_info[scene_id][(image_id1, image_id2)] = {
34 | 'overlap': ...,
35 | 'distance': ...,
36 | 'yaw': ...,
37 | 'pitch': ...
38 | }
39 | """
40 | rows = []
41 | for scene_id, pair_dict in overlap_info.items():
42 | for (img1, img2), vals in pair_dict.items():
43 | rows.append({
44 | 'scene_id': scene_id,
45 | 'image_id1': img1,
46 | 'image_id2': img2,
47 | 'overlap': vals['overlap'],
48 | 'distance': vals['distance'],
49 | 'yaw': vals['yaw'],
50 | 'pitch': vals['pitch']
51 | })
52 | if not rows:
53 | print(f"[save_overlap_info] Nothing to save to {parquet_file}.")
54 | return
55 | df = pd.DataFrame(rows)
56 | df.to_parquet(parquet_file, index=False)
57 | print(f"[save_overlap_info] Saved {len(df)} records to {parquet_file}.")
58 |
59 | def save_overlap_info_nonzero(overlap_info, parquet_file_nonzero):
60 | """
61 | Similar to save_overlap_info, but filter out any rows where overlap == 0.
62 | """
63 | rows = []
64 | for scene_id, pair_dict in overlap_info.items():
65 | for (img1, img2), vals in pair_dict.items():
66 | # Only collect if overlap != 0
67 | if vals['overlap'] != 0.0:
68 | rows.append({
69 | 'scene_id': scene_id,
70 | 'image_id1': img1,
71 | 'image_id2': img2,
72 | 'overlap': vals['overlap'],
73 | 'distance': vals['distance'],
74 | 'yaw': vals['yaw'],
75 | 'pitch': vals['pitch']
76 | })
77 |
78 | if not rows:
79 | print(f"[save_overlap_info_nonzero] No nonzero-overlap pairs to save.")
80 | return
81 |
82 | df = pd.DataFrame(rows)
83 | df.to_parquet(parquet_file_nonzero, index=False)
84 | print(f"[save_overlap_info_nonzero] Saved {len(df)} records to {parquet_file_nonzero}.")
85 |
86 | def extract_yaw_pitch(R):
87 | """
88 | Extract the yaw and pitch angles from a rotation matrix.
89 |
90 | :param R: A 4x4 or 3x3 matrix. If 4x4, only the top-left 3x3 is used.
91 | :return: A tuple (yaw, pitch) in degrees.
92 | """
93 | R3 = R[:3, :3] if R.shape == (4, 4) else R
94 | rotated_z_axis = R3[:, 2]
95 |
96 | # Yaw: arctan2 of (y, x)
97 | yaw = np.degrees(np.arctan2(rotated_z_axis[1], rotated_z_axis[0]))
98 | # Pitch: arcsin of z component
99 | pitch = np.degrees(np.arcsin(rotated_z_axis[2] / np.linalg.norm(rotated_z_axis)))
100 | return yaw, pitch
101 |
102 | def calculate_camera_overlap(in_bounds_dict, image_id1, image_id2, use_cuda=False):
103 | """
104 | Calculate the percentage of overlap in the field of view of two cameras using precomputed in-bounds points.
105 |
106 | :param in_bounds_dict: Dictionary containing in-bounds information for each image ID
107 | :param image_id1: Image ID of the first camera
108 | :param image_id2: Image ID of the second camera
109 | :return: Percentage of overlap
110 | """
111 | in_bounds1 = in_bounds_dict[image_id1]
112 | in_bounds2 = in_bounds_dict[image_id2]
113 |
114 | if torch.cuda.is_available() and use_cuda:
115 | # Move data to GPU
116 | in_bounds1 = torch.from_numpy(in_bounds1).to('cuda')
117 | in_bounds2 = torch.from_numpy(in_bounds2).to('cuda')
118 |
119 | # Points that are visible in at least one of the cameras
120 | visible_points_union = torch.logical_or(in_bounds1, in_bounds2)
121 |
122 | # Points that are visible in both cameras
123 | overlap_points = torch.logical_and(in_bounds1, in_bounds2)
124 |
125 | # Calculate the overlap percentage
126 | overlap_percentage = torch.sum(overlap_points).float() / torch.sum(visible_points_union).float() * 100
127 | return overlap_percentage.item() # Move result back to CPU and convert to Python float
128 | else:
129 | # Points that are visible in at least one of the cameras
130 | visible_points_union = np.logical_or(in_bounds1, in_bounds2)
131 |
132 | # Points that are visible in both cameras
133 | overlap_points = np.logical_and(in_bounds1, in_bounds2)
134 |
135 | # Calculate the overlap percentage
136 | overlap_percentage = np.sum(overlap_points) / np.sum(visible_points_union) * 100
137 | return overlap_percentage
138 |
139 | def process_scene(scene_id, scene_infos: SceneInfoHandler, warning_file):
140 | print(f"Start processing {scene_id}.")
141 | image_ids = scene_infos.get_all_extrinsic_valid_image_ids(scene_id)
142 |
143 | # get all points in the scene
144 | scene_points = scene_infos.get_scene_points_align(scene_id)
145 | scene_points = scene_points[:, :3]
146 |
147 | # Precompute in-bounds points for each image
148 | in_bounds_dict = {}
149 | yaw_dict = {}
150 | pitch_dict = {}
151 | positions_dict = {}
152 | for image_id in image_ids:
153 | E = scene_infos.get_extrinsic_matrix_align(scene_id, image_id)
154 |
155 | # project points to camera
156 | scene_points_2d, scene_points_depth = scene_infos.project_3d_point_to_image(scene_id, image_id, scene_points)
157 | in_bounds_points = scene_infos.check_point_visibility(scene_id, image_id, scene_points_2d, scene_points_depth) # * a True or False mask
158 |
159 | if np.sum(in_bounds_points) == 0:
160 | with open(warning_file, 'a') as f:
161 | f.write(f"{scene_id}: {image_id} has no in bound points\n")
162 |
163 | in_bounds_dict[image_id] = in_bounds_points
164 |
165 | # get yaw and pitch
166 | yaw, pitch = extract_yaw_pitch(E)
167 |
168 | yaw_dict[image_id] = yaw
169 | pitch_dict[image_id] = pitch
170 |
171 | # positions
172 | positions_dict[image_id] = E[:3, 3]
173 |
174 | scene_overlap_info = {}
175 |
176 | for i, image_id1 in enumerate(image_ids):
177 | for j in range(i + 1, len(image_ids)):
178 | image_id2 = image_ids[j]
179 | overlap_percentage = calculate_camera_overlap(in_bounds_dict, image_id1, image_id2)
180 |
181 | yaw_diff = yaw_dict[image_id2] - yaw_dict[image_id1]
182 | pitch_diff = pitch_dict[image_id2] - pitch_dict[image_id1]
183 | distance = np.linalg.norm(positions_dict[image_id2] - positions_dict[image_id1])
184 |
185 | scene_overlap_info[(image_id1, image_id2)] = {}
186 | scene_overlap_info[(image_id1, image_id2)]['overlap'] = overlap_percentage
187 | scene_overlap_info[(image_id1, image_id2)]['distance'] = distance
188 | scene_overlap_info[(image_id1, image_id2)]['yaw'] = yaw_diff
189 | scene_overlap_info[(image_id1, image_id2)]['pitch'] = pitch_diff
190 |
191 | if np.any(np.isnan(list(scene_overlap_info[(image_id1, image_id2)].values()))) or \
192 | np.any(np.isinf(list(scene_overlap_info[(image_id1, image_id2)].values()))):
193 | with open(warning_file, 'a') as f:
194 | f.write(f"{scene_id}: {(image_id1, image_id2)} has something wrong {list(scene_overlap_info[(image_id1, image_id2)].values())}. \n")
195 |
196 | print(f"Finished scene {scene_id}.")
197 | return scene_id, scene_overlap_info
198 |
199 |
200 | def run_split(scene_info_path, output_parquet, warning_file, num_workers=15, save_interval=20):
201 | """
202 | 1. Instantiate SceneInfoHandler for the given `scene_info_path`.
203 | 2. Load existing overlap info from `output_parquet`.
204 | 3. Process each scene in parallel and accumulate results in a dictionary.
205 | 4. Periodically save to parquet + nonzero parquet, then final save at the end.
206 | """
207 | scene_infos = SceneInfoHandler(scene_info_path)
208 | overlap_info = {}
209 |
210 | # Gather all scenes
211 | all_scene_ids = scene_infos.get_all_scene_ids()
212 | print(f"[run_split] Found {len(all_scene_ids)} scenes in {scene_info_path}.")
213 |
214 | # If we're debugging, only process the first scene
215 | if DEBUG and len(all_scene_ids) > 1:
216 | all_scene_ids = all_scene_ids[:1]
217 | print("[run_split] DEBUG mode: processing only the first scene.")
218 |
219 | # Prepare arguments
220 | args = [(scene_id, scene_infos, warning_file) for scene_id in all_scene_ids]
221 |
222 | with Pool(num_workers) as pool:
223 | results = []
224 | for arg in args:
225 | results.append(pool.apply_async(process_scene, arg))
226 |
227 | for count, r in enumerate(tqdm(results, desc=f"Processing {scene_info_path}")):
228 | scene_id, scene_overlap_info = r.get()
229 | overlap_info[scene_id] = scene_overlap_info
230 |
231 | # Save partial results every 'save_interval' scenes
232 | if (count + 1) % save_interval == 0:
233 | save_overlap_info(overlap_info, output_parquet)
234 |
235 | # Also save nonzero
236 | nonzero_parquet = output_parquet.replace(".parquet", "_nonzero.parquet")
237 | save_overlap_info_nonzero(overlap_info, nonzero_parquet)
238 |
239 | print(f"[run_split] Saved partial results for {count + 1} scenes to {output_parquet}")
240 |
241 | # Final save
242 | save_overlap_info(overlap_info, output_parquet)
243 | nonzero_parquet = output_parquet.replace(".parquet", "_nonzero.parquet")
244 | save_overlap_info_nonzero(overlap_info, nonzero_parquet)
245 | print(f"[run_split] Final save to {output_parquet} complete.")
246 | # print total number of records
247 | total_records = sum(len(v) for v in overlap_info.values())
248 | print(f"[run_split] Total number of records: {total_records}")
249 |
250 | print(f"[run_split] Nonzero overlap also saved to {nonzero_parquet}.")
251 | # need to iterate over the overlap_info to get the number of nonzero records
252 | nonzero_records = sum(1 for scene in overlap_info.values() for pair in scene.values() if pair['overlap'] != 0.0)
253 | print(f"[run_split] Total number of nonzero records: {nonzero_records}")
254 |
255 | def main():
256 | # Adjust these paths if needed
257 | train_info_path = "data/scannet/scannet_instance_data/scenes_train_info_i_D5.pkl"
258 | val_info_path = "data/scannet/scannet_instance_data/scenes_val_info_i_D5.pkl"
259 |
260 | train_output_dir = "training_data/camera_movement"
261 | val_output_dir = "evaluation_data/camera_movement"
262 |
263 | mmengine.mkdir_or_exist(train_output_dir)
264 | mmengine.mkdir_or_exist(val_output_dir)
265 |
266 | train_output_file = os.path.join(train_output_dir, "train_camera_info_D5.parquet")
267 | val_output_file = os.path.join(val_output_dir, "val_camera_info_D5.parquet")
268 |
269 | # Warnings
270 | train_warning_file = os.path.join(train_output_dir, "train_warning_D5.txt")
271 | val_warning_file = os.path.join(val_output_dir, "val_warning_D5.txt")
272 |
273 | # If DEBUG is true, add a suffix to file names
274 | if DEBUG:
275 | train_output_file = train_output_file.replace(".parquet", "_debug.parquet")
276 | val_output_file = val_output_file.replace(".parquet", "_debug.parquet")
277 | train_warning_file = train_warning_file.replace(".txt", "_debug.txt")
278 | val_warning_file = val_warning_file.replace(".txt", "_debug.txt")
279 |
280 | num_workers = 25
281 |
282 | print(f"[main] DEBUG mode: {DEBUG}")
283 | print(f"[main] Processing train split -> {train_output_file}")
284 | run_split(train_info_path, train_output_file, train_warning_file, num_workers=num_workers, save_interval=20)
285 |
286 | print(f"[main] Processing val split -> {val_output_file}")
287 | run_split(val_info_path, val_output_file, val_warning_file, num_workers=num_workers, save_interval=20)
288 |
289 | if __name__ == "__main__":
290 | main()
291 |
--------------------------------------------------------------------------------
/spatial_engine/camera_movement/camera_movement_engine_train_val.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Multi-round QA about qualitative, quantitative, and vector of movements (distance, and two angles).
9 |
10 | For questions, we have 30 templates. For answer and task description, we have 10 templates.
11 | """
12 |
13 | import pandas as pd
14 | from tqdm import tqdm
15 | import numpy as np
16 | import random
17 | random.seed(0)
18 | np.random.seed(0)
19 | import json
20 | from spatial_engine.utils.scannet_utils.handler.info_handler import SceneInfoHandler
21 | import os
22 | import mmengine
23 | import sys
24 |
25 | sys.path.append("spatial_engine/camera_movement")
26 | from TEMPLATES import QUESTION_TEMPLATES, ANSWER_TEMPLATES, TASK_DESCRIPTION
27 |
28 |
29 | def sample_dataframe(df, all_overlap_samples, non_overlap_samples,
30 | overlap_min=0, overlap_max=100, interval=1):
31 | """
32 | Sample from the input DataFrame, aiming to collect a total of all_overlap_samples
33 | from bins where (overlap != 0). If a bin doesn't have enough samples,
34 | the remaining quota will be passed to the next bin, and so on.
35 |
36 | :param df: Input DataFrame, must contain an 'overlap' column
37 | :param all_overlap_samples: Total desired samples from overlap>0 bins
38 | :param non_overlap_samples: Number of samples to take where overlap == 0
39 | :param overlap_min: Minimum overlap value, default is 0
40 | :param overlap_max: Maximum overlap value, default is 100
41 | :param interval: Bin interval step size, default is 1
42 | :return: Sampled DataFrame
43 | """
44 |
45 | # -----------------------------------------------------------
46 | # 1) overlap == 0: Sample these separately first
47 | # -----------------------------------------------------------
48 | non_overlap_df = df[df["overlap"] == 0].copy()
49 | if len(non_overlap_df) <= non_overlap_samples:
50 | sampled_non_overlap_df = non_overlap_df
51 | else:
52 | sampled_non_overlap_df = non_overlap_df.sample(n=non_overlap_samples)
53 |
54 | # Remaining data (overlap != 0)
55 | remaining_df = df[df["overlap"] != 0].copy()
56 |
57 | # -----------------------------------------------------------
58 | # 2) Group by overlap into bins
59 | # -----------------------------------------------------------
60 | bins = np.arange(overlap_min, overlap_max + interval, interval)
61 | remaining_df["overlap_group"] = pd.cut(
62 | remaining_df["overlap"],
63 | bins=bins,
64 | include_lowest=True
65 | )
66 |
67 | # Exclude rows not assigned to any bin (NaN)
68 | remaining_df = remaining_df.dropna(subset=["overlap_group"])
69 |
70 | # Collect data for each bin
71 | bin_dfs = []
72 | for interval_bin, group_df in remaining_df.groupby("overlap_group"):
73 | bin_dfs.append((interval_bin, group_df))
74 |
75 | if len(bin_dfs) == 0:
76 | # If no bins exist, return only the overlap==0 samples
77 | final_sampled_df = sampled_non_overlap_df
78 | if "overlap_group" in final_sampled_df.columns:
79 | final_sampled_df.drop(columns=["overlap_group"], inplace=True)
80 | return final_sampled_df
81 |
82 | # -----------------------------------------------------------
83 | # 3) Distribute all_overlap_samples evenly across bins
84 | # and handle the remainder
85 | # -----------------------------------------------------------
86 | N = len(bin_dfs)
87 | base_quota = all_overlap_samples // N
88 | remainder = all_overlap_samples % N
89 |
90 | # bin_quotas[i] = base_quota, with first remainder bins getting +1
91 | bin_quotas = [base_quota] * N
92 | for i in range(remainder):
93 | bin_quotas[i] += 1
94 |
95 | # -----------------------------------------------------------
96 | # 4) Sort bins by data volume (small to large), not by overlap interval
97 | # -----------------------------------------------------------
98 | # Key for sorting bin_dfs: len(group_df)
99 | # bin_dfs[i] = (interval_bin, group_df)
100 | # bin_quotas[i] = base_quota +/- 1
101 | # To maintain one-to-one correspondence, we zip bin_dfs and bin_quotas together and sort
102 | bin_data = []
103 | for i, (interval_bin, group_df) in enumerate(bin_dfs):
104 | bin_data.append({
105 | "interval_bin": interval_bin,
106 | "group_df": group_df,
107 | "quota": bin_quotas[i],
108 | "size": len(group_df) # for sorting
109 | })
110 |
111 | # Sort by size in ascending order
112 | bin_data.sort(key=lambda x: x["size"])
113 |
114 | # -----------------------------------------------------------
115 | # 5) Process each bin, using leftover_quota mechanism
116 | # -----------------------------------------------------------
117 | sampled_df = pd.DataFrame()
118 | leftover_quota = 0
119 |
120 | for bin_info in bin_data:
121 | group_df = bin_info["group_df"]
122 | bin_quota = bin_info["quota"]
123 |
124 | current_quota = bin_quota + leftover_quota
125 | print(f"[sample_dataframe] current_quota v.s. group_df: {current_quota} v.s. {len(group_df)}")
126 |
127 | if len(group_df) <= current_quota:
128 | # Not enough for quota, take all
129 | sampled_df = pd.concat([sampled_df, group_df], ignore_index=True)
130 | # leftover = current_quota - actual amount taken
131 | leftover_quota = current_quota - len(group_df)
132 | else:
133 | # More than quota, randomly sample current_quota
134 | sampled_part = group_df.sample(n=current_quota)
135 | sampled_df = pd.concat([sampled_df, sampled_part], ignore_index=True)
136 | leftover_quota = 0
137 |
138 | # If leftover_quota > 0, data is insufficient
139 | if leftover_quota > 0:
140 | print(f"[sample_dataframe] Warning: bins not enough to reach {all_overlap_samples}; leftover {leftover_quota}")
141 |
142 | # -----------------------------------------------------------
143 | # 6) Merge with overlap==0 samples
144 | # -----------------------------------------------------------
145 | final_sampled_df = pd.concat([sampled_df, sampled_non_overlap_df], ignore_index=True)
146 |
147 | # Cleanup
148 | if "overlap_group" in final_sampled_df.columns:
149 | final_sampled_df.drop(columns=["overlap_group"], inplace=True, errors="ignore")
150 |
151 | return final_sampled_df
152 |
153 | def build_training_sample(scene_infos: SceneInfoHandler, row, idx: int, question_type: str):
154 | scene_id = row["scene_id"]
155 | image1 = row["image_id1"]
156 | image2 = row["image_id2"]
157 |
158 | overlap = float(row["overlap"])
159 | yaw_angle = float(row["yaw"])
160 | pitch_angle = float(row["pitch"])
161 |
162 | # randomly terminate if to swap image1 and image2
163 | if random.random() < 0.5:
164 | yaw_angle = -yaw_angle
165 | pitch_angle = -pitch_angle
166 | image1, image2 = image2, image1
167 |
168 | if abs(yaw_angle) > 180:
169 | if yaw_angle > 0:
170 | yaw_angle = yaw_angle - 360
171 | else:
172 | yaw_angle = yaw_angle + 360
173 |
174 |
175 | images = [f"{scene_id}/{image1}.jpg", f"{scene_id}/{image2}.jpg"]
176 |
177 | E1 = scene_infos.get_extrinsic_matrix_align(scene_id, image1) # camera to world
178 | E2 = scene_infos.get_extrinsic_matrix_align(scene_id, image2)
179 |
180 | # assert no nan
181 | assert not np.isnan(E1).any(), f"E1 is nan for {scene_id} {image1}"
182 | assert not np.isnan(E2).any(), f"E2 is nan for {scene_id} {image2}"
183 |
184 | # Transform E2 into the coordinate system of E1
185 | E1_inv = np.linalg.inv(E1)
186 | E2_relative = E1_inv @ E2
187 |
188 | # calculate the displacement vector in the first frame's coordinate system
189 | displacement_vector = E2_relative[:3, 3]
190 | distance = np.linalg.norm(displacement_vector)
191 |
192 | # should close to the distance from df
193 | assert abs(distance - row['distance']) < 0.1, f"distance is not close to the distance from df for {scene_id} {image1} {image2}."
194 |
195 | # the output format should be one item:
196 | # {"id": 1358431, "image": ["scene0006_01/01550.jpg", "scene0006_01/01245.jpg"], "conversations": [{"from": "human", "value": "Image-1: \nImage-2: \nAssume the scene remains unchanged. Your task is to perceive the spatial information of the scene based on the captured images. Calculate the distance (in mm) separating the cameras of images and ."}, {"from": "gpt", "value": "The difference is about `1150`."}], "height_list": [968, 968], "width_list": [1296, 1296]}
197 | task_description = random.choice(TASK_DESCRIPTION)
198 |
199 | if overlap < 0.1:
200 | # random select from q1, q2, q3
201 | raise NotImplementedError("overlap < 0.1 is not supported yet.")
202 | else:
203 | question = random.choice(QUESTION_TEMPLATES[question_type])
204 | answer_template = random.choice(ANSWER_TEMPLATES[question_type])
205 |
206 | # replace the placeholder with the actual values
207 | # for movement, need to use > 0 to get left/right, 'forward'/'backward', 'up'/'down'
208 | # x -> left/right, y -> up/down, z -> forward/backward
209 | answer_values = {
210 | "x_movement": "right" if displacement_vector[0] > 0 else "left",
211 | "y_movement": "down" if displacement_vector[1] > 0 else "up",
212 | "z_movement": "forward" if displacement_vector[2] > 0 else "backward",
213 | "yaw_movement": "left" if yaw_angle > 0 else "right",
214 | "pitch_movement": "up" if pitch_angle > 0 else "down",
215 | "x_distance": int(abs(displacement_vector[0]) * 1000),
216 | "y_distance": int(abs(displacement_vector[1]) * 1000),
217 | "z_distance": int(abs(displacement_vector[2]) * 1000),
218 | "yaw_angle": int(abs(yaw_angle)),
219 | "pitch_angle": int(abs(pitch_angle)),
220 | "x_value": int(displacement_vector[0] * 1000),
221 | "y_value": int(displacement_vector[1] * 1000),
222 | "z_value": int(displacement_vector[2] * 1000),
223 | "total_distance": int(np.linalg.norm(displacement_vector) * 1000),
224 | "displacement_vector": displacement_vector.tolist(),
225 | }
226 | # map
227 | answer_text = answer_template.format(**answer_values)
228 |
229 | conversation = [
230 | {"from": "human", "value": f"{task_description}\n{question}"},
231 | {"from": "gpt", "value": answer_text},
232 | ]
233 |
234 | train_sample = {
235 | "id": idx,
236 | "image": images,
237 | "conversations": conversation,
238 | "height_list": [scene_infos.get_image_shape(scene_id, image1)[0]] * len(images),
239 | "width_list": [scene_infos.get_image_shape(scene_id, image1)[1]] * len(images),
240 | "answer_values": answer_values,
241 | "question_type": question_type,
242 | "gt_value": answer_values[question_type],
243 | }
244 |
245 | return train_sample
246 |
247 | def convert_train_sample_to_eval_sample(train_sample):
248 | # a train sample looks like this:
249 | # {
250 | # "id": f"{scene_id}_{object_id}_{pair_id}",
251 | # "image": [f"{scene_id}/{image1}.jpg", f"{scene_id}/{image2}.jpg"],
252 | # "conversations": conversation,
253 | # "height_list": [image_height] * 2,
254 | # "width_list": [image_width] * 2,
255 | # "answer_values": answer_values,
256 | # "question_type": "visual_correspondence"
257 | # }
258 | # we need a eval sample like:
259 | # {
260 | # "id": f"{scene_id}_{object_id}_{pair_id}",
261 | # "image": [f"{scene_id}/{image1}.jpg", f"{scene_id}/{image2}.jpg"],
262 | # "text": question,
263 | # "gt_value": value,
264 | # "question_type": attr,
265 | # }
266 | conversation = train_sample.pop("conversations")
267 | train_sample['text'] = conversation[0]['value']
268 |
269 | return train_sample
270 |
271 | def build_train_dataset(
272 | parquet_path,
273 | output_dir,
274 | scene_infos,
275 | qtype,
276 | desired_count,
277 | overlap_min,
278 | overlap_max,
279 | interval
280 | ):
281 | df = pd.read_parquet(parquet_path)
282 | print(f"[Train: {qtype}] Loaded DF with {len(df)} rows from {parquet_path}")
283 |
284 | print(f"[Train: {qtype}] sampling {desired_count} samples in overlap=[{overlap_min}..{overlap_max}]")
285 |
286 | df_sampled = sample_dataframe(
287 | df,
288 | all_overlap_samples=desired_count,
289 | non_overlap_samples=0, # or your chosen number
290 | overlap_min=overlap_min,
291 | overlap_max=overlap_max,
292 | interval=interval
293 | )
294 | print(f"[Train: {qtype}] got {len(df_sampled)} sampled rows")
295 |
296 | # build samples
297 | out_samples = []
298 | for idx in tqdm(range(len(df_sampled)), desc=f"{qtype}"):
299 | row = df_sampled.iloc[idx]
300 | s = build_training_sample(scene_infos, row, idx, qtype)
301 | out_samples.append(s)
302 |
303 | random.shuffle(out_samples)
304 | out_file = os.path.join(output_dir, f"{qtype}_train.jsonl")
305 | print(f"[Train: {qtype}] writing {len(out_samples)} items to {out_file}")
306 | with open(out_file, "w") as f:
307 | for item in out_samples:
308 | f.write(json.dumps(item)+"\n")
309 |
310 | ########################################################################
311 | # Build val dataset for one question type
312 | ########################################################################
313 |
314 | def build_val_dataset(
315 | parquet_path,
316 | output_dir,
317 | scene_infos,
318 | qtype,
319 | desired_count,
320 | overlap_min,
321 | overlap_max,
322 | interval
323 | ):
324 | df = pd.read_parquet(parquet_path)
325 | print(f"[Val: {qtype}] Loaded DF with {len(df)} rows from {parquet_path}")
326 |
327 | # same sampling logic
328 | print(f"[Val: {qtype}] sampling {desired_count} samples in overlap=[{overlap_min}..{overlap_max}]")
329 |
330 | df_sampled = sample_dataframe(
331 | df,
332 | all_overlap_samples=desired_count,
333 | non_overlap_samples=0,
334 | overlap_min=overlap_min,
335 | overlap_max=overlap_max,
336 | interval=interval
337 | )
338 | print(f"[Val: {qtype}] got {len(df_sampled)} sampled rows")
339 |
340 | # build as "train" but then convert
341 | out_samples = []
342 | for idx in tqdm(range(len(df_sampled)), desc=f"{qtype}_val"):
343 | row = df_sampled.iloc[idx]
344 | s_train = build_training_sample(scene_infos, row, idx, qtype)
345 | s_eval = convert_train_sample_to_eval_sample(s_train)
346 | out_samples.append(s_eval)
347 |
348 | random.shuffle(out_samples)
349 | out_file = os.path.join(output_dir, f"{qtype}_val.jsonl")
350 | print(f"[Val: {qtype}] writing {len(out_samples)} items to {out_file}")
351 | with open(out_file, "w") as f:
352 | for item in out_samples:
353 | f.write(json.dumps(item)+"\n")
354 |
355 | ########################################################################
356 | # Main
357 | ########################################################################
358 | DEBUG = False
359 |
360 | def main():
361 | info_path = "data/scannet/scannet_instance_data/scenes_train_val_info_i_D5.pkl"
362 | overlap_min = 6
363 | overlap_max = 35
364 | interval = 1
365 |
366 | version = "v1_0"
367 |
368 | # Desired sample counts
369 | train_question_samples = {
370 | "x_movement": 1000000,
371 | "y_movement": 1000000,
372 | "z_movement": 1000000,
373 | "yaw_movement": 1000000,
374 | "pitch_movement": 1000000,
375 | "yaw_angle": 1000000,
376 | "pitch_angle": 1000000,
377 | "total_distance": 3000000,
378 | "displacement_vector": 3000000,
379 | }
380 | val_question_samples = {
381 | "x_movement": 300,
382 | "y_movement": 300,
383 | "z_movement": 300,
384 | "yaw_movement": 300,
385 | "pitch_movement": 300,
386 | "total_distance": 300,
387 | "yaw_angle": 300,
388 | "pitch_angle": 300,
389 | "displacement_vector": 300,
390 | }
391 |
392 | # Debug overrides
393 | global DEBUG
394 | if DEBUG:
395 | train_parquet_path = "training_data/camera_movement/train_camera_info_D5_debug_nonzero.parquet"
396 | val_parquet_path = "evaluation_data/camera_movement/val_camera_info_D5_debug_nonzero.parquet"
397 | for k in train_question_samples:
398 | train_question_samples[k] = 100
399 | for k in val_question_samples:
400 | val_question_samples[k] = 100
401 | version += "_debug"
402 | else:
403 | train_parquet_path = "training_data/camera_movement/train_camera_info_D5.parquet"
404 | val_parquet_path = "evaluation_data/camera_movement/val_camera_info_D5.parquet"
405 |
406 | train_output_dir = f"training_data/camera_movement/{version}"
407 | val_output_dir = f"evaluation_data/camera_movement/{version}"
408 |
409 | mmengine.mkdir_or_exist(train_output_dir)
410 | mmengine.mkdir_or_exist(val_output_dir)
411 |
412 | # Initialize SceneInfoHandler
413 | scene_infos = SceneInfoHandler(info_path)
414 |
415 | # Build train & val for each question type
416 | for qtype in train_question_samples.keys():
417 | print(f"\n=== Processing question type: {qtype} ===")
418 | # Each type costs about 4 mins to generate 1M samples
419 |
420 | # Build val
421 | build_val_dataset(
422 | parquet_path=val_parquet_path,
423 | output_dir=val_output_dir,
424 | scene_infos=scene_infos,
425 | qtype=qtype,
426 | desired_count=val_question_samples[qtype],
427 | overlap_min=overlap_min,
428 | overlap_max=overlap_max,
429 | interval=interval
430 | )
431 |
432 | # Build train
433 | build_train_dataset(
434 | parquet_path=train_parquet_path,
435 | output_dir=train_output_dir,
436 | scene_infos=scene_infos,
437 | qtype=qtype,
438 | desired_count=train_question_samples[qtype],
439 | overlap_min=overlap_min,
440 | overlap_max=overlap_max,
441 | interval=interval
442 | )
443 |
444 | print("All question types processed. Done.")
445 |
446 | if __name__ == "__main__":
447 | main()
--------------------------------------------------------------------------------
/spatial_engine/depth_perception/depth_estimation_dot_engine.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import json
8 | import random
9 | import os
10 | from tqdm import tqdm
11 | import mmengine
12 | import cv2
13 | import numpy
14 | # set seed
15 | numpy.random.seed(5)
16 | random.seed(5)
17 |
18 | from spatial_engine.utils.scannet_utils.handler.info_handler import SceneInfoHandler, VisibilityInfoHandler
19 |
20 | import argparse
21 |
22 | def generate_distinct_colors(n, max_retries=10):
23 | colors = []
24 | retries = 0
25 | while len(colors) < n and retries < max_retries:
26 | color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
27 | if all(sum(abs(c1 - c2) for c1, c2 in zip(color, existing_color)) > 300 for existing_color in colors):
28 | colors.append(color)
29 | retries += 1
30 | if len(colors) < n:
31 | predefined_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (0, 0, 0), (255, 255, 255)] # Red, Green, Blue, Black, White
32 | colors += random.sample(predefined_colors, n - len(colors))
33 | return colors
34 |
35 | class DepthEstimationDotQAEngine:
36 | def __init__(self, scene_info_path,
37 | version_num="v1_0",
38 | all_max_samples=-1,
39 | image_output_dir=None,
40 | visibility_info_path=None,
41 | max_n_points_per_image=1,
42 | warning_file=None):
43 | self.scene_info = SceneInfoHandler(scene_info_path)
44 | self.version_num = version_num
45 | self.image_output_dir = image_output_dir
46 | self.all_max_samples = all_max_samples
47 | self.task_name = "depth_estimation_dot"
48 | # * note, in v1_0, max_n_points_per_image is 1
49 | # * and even if it is set to be larger than 1, it will generate different QA pairs, i.e., single-round QA and the total number of samples is max_n_points_per_image * num_images (all_max_samples * max_n_points_per_image)
50 | self.max_n_points_per_image = max_n_points_per_image
51 | self.warning_file = warning_file
52 | # read visibility infos
53 | self.visibility_info = VisibilityInfoHandler(visibility_info_path)
54 |
55 | self.task_description = [
56 | "\nGiven an image with an annotated point, complete the question-answer task.",
57 | "\nFor an image with an annotated point, answer the depth-related questions.",
58 | "\nUsing the provided image with an annotated point, complete the QA task.",
59 | "\nGiven an image with a specific annotated point, perform the question-answer process.",
60 | "\nWork with the image that has an annotated point to answer the related questions.",
61 | "\nAnalyze the image with an annotated point and complete the QA task.",
62 | "\nGiven an image where a point is annotated, proceed with the question-answer task.",
63 | "\nFor the image with an annotated point, determine the depth-related answers.",
64 | "\nUsing the image with an annotated point, perform the QA task.",
65 | "\nGiven an image with a marked point, complete the question-answer process.",
66 | "\nWork with the image containing an annotated point to answer the questions.",
67 | "\nGiven an image with a highlighted point, complete the QA task.",
68 | "\nFor an image with a marked point, answer the depth-related questions.",
69 | "\nUsing the image with a highlighted point, perform the QA task.",
70 | "\nGiven an image with a designated point, complete the question-answer process.",
71 | "\nWork with the image that has a marked point to answer the related questions.",
72 | "\nAnalyze the image with a designated point and complete the QA task.",
73 | "\nGiven an image where a point is highlighted, proceed with the question-answer task.",
74 | "\nFor the image with a designated point, determine the depth-related answers.",
75 | "\nUsing the image with a marked point, perform the QA task.",
76 | "\nGiven an image with a pinpointed location, engage in the question-answer task.",
77 | "\nFor an image with a specified point, resolve the depth-related queries.",
78 | "\nUtilize the image with a pinpointed spot to complete the QA task.",
79 | "\nGiven an image with a noted point, carry out the question-answer process.",
80 | "\nWork with the image that has a pinpointed point to answer the related questions.",
81 | "\nExamine the image with a noted point and complete the QA task.",
82 | "\nGiven an image where a point is pinpointed, proceed with the question-answer task.",
83 | "\nFor the image with a noted point, ascertain the depth-related answers.",
84 | "\nUsing the image with a pinpointed point, perform the QA task.",
85 | "\nGiven an image with a specified point, complete the question-answer process."
86 | ]
87 |
88 | self.templates = {
89 | "questions": [
90 | "What is the depth of the annotated point in the image (in mm)?",
91 | "How far is the annotated point from the camera in millimeters?",
92 | "Determine the depth value of the annotated point in the given image (mm).",
93 | "Find the distance from the observer to the annotated point in the image, in mm.",
94 | "What is the measured depth of the annotated point in mm?",
95 | "How far away is the annotated point from the viewer in the image (mm)?",
96 | "Identify the depth value for the annotated point in millimeters.",
97 | "Given the annotated point, what is its depth in the image (mm)?",
98 | "What is the distance between the camera and the annotated point in mm?",
99 | "How deep is the annotated point in the given image (in millimeters)?",
100 | "What is the distance to the annotated point in the image (in mm)?",
101 | "How far is the annotated point located from the camera in mm?",
102 | "Determine the distance of the annotated point from the observer in mm.",
103 | "What is the depth measurement of the annotated point in the image (mm)?",
104 | "How far is the annotated point from the camera lens in millimeters?",
105 | "What is the depth of the highlighted point in the image (in mm)?",
106 | "How far is the highlighted point from the camera in millimeters?",
107 | "Determine the depth value of the highlighted point in the given image (mm).",
108 | "Find the distance from the observer to the highlighted point in the image, in mm.",
109 | "What is the measured depth of the highlighted point in mm?",
110 | "How far away is the highlighted point from the viewer in the image (mm)?",
111 | "Identify the depth value for the highlighted point in millimeters.",
112 | "Given the highlighted point, what is its depth in the image (mm)?",
113 | "What is the distance between the camera and the highlighted point in mm?",
114 | "How deep is the highlighted point in the given image (in millimeters)?",
115 | "What is the distance to the highlighted point in the image (in mm)?",
116 | "How far is the highlighted point located from the camera in mm?",
117 | "Determine the distance of the highlighted point from the observer in mm.",
118 | "What is the depth measurement of the highlighted point in the image (mm)?",
119 | "How far is the highlighted point from the camera lens in millimeters?"
120 | ],
121 |
122 | "answers": [
123 | "The depth of the annotated point is `{depth}` mm.",
124 | "It is `{depth}` mm away from the camera.",
125 | "The depth value of the annotated point is `{depth}` mm.",
126 | "The distance to the annotated point is `{depth}` mm.",
127 | "Measured depth of the annotated point is `{depth}` mm.",
128 | "The annotated point is located `{depth}` mm away from the viewer.",
129 | "The depth of the annotated point is `{depth}` mm.",
130 | "Its depth in the image is `{depth}` mm.",
131 | "The distance from the camera to the annotated point is `{depth}` mm.",
132 | "The annotated point is `{depth}` mm deep in the image.",
133 | "The depth of the highlighted point is `{depth}` mm.",
134 | "It is `{depth}` mm away from the camera.",
135 | "The depth value of the highlighted point is `{depth}` mm.",
136 | "The distance to the highlighted point is `{depth}` mm.",
137 | "Measured depth of the highlighted point is `{depth}` mm.",
138 | "The highlighted point is located `{depth}` mm away from the viewer.",
139 | "The depth of the highlighted point is `{depth}` mm.",
140 | "Its depth in the image is `{depth}` mm.",
141 | "The distance from the camera to the highlighted point is `{depth}` mm.",
142 | "The highlighted point is `{depth}` mm deep in the image.",
143 | "The depth of the marked point is `{depth}` mm.",
144 | "It is `{depth}` mm away from the camera.",
145 | "The depth value of the marked point is `{depth}` mm.",
146 | "The distance to the marked point is `{depth}` mm.",
147 | "Measured depth of the marked point is `{depth}` mm.",
148 | "The marked point is located `{depth}` mm away from the viewer.",
149 | "The depth of the marked point is `{depth}` mm.",
150 | "Its depth in the image is `{depth}` mm.",
151 | "The distance from the camera to the marked point is `{depth}` mm.",
152 | "The marked point is `{depth}` mm deep in the image."
153 | ]
154 | }
155 |
156 | def generate_random_point(self, height, width, margin=50):
157 | """Generate random point within image boundaries with margin"""
158 | x = random.randint(margin, width - margin)
159 | y = random.randint(margin, height - margin)
160 | return (x, y)
161 |
162 | def annotate_image(self, image, point):
163 | """Annotate image with a single point"""
164 | annotated_img = image.copy()
165 | x, y = point
166 |
167 | # Generate a random distinct color
168 | color = generate_distinct_colors(1)[0]
169 |
170 | # Draw circle
171 | cv2.circle(annotated_img, (x, y), 10, color, -1)
172 |
173 | return annotated_img
174 |
175 | def generate_qa_training_single_scene(self, scene_id):
176 | # Get all valid image IDs for the scene
177 | image_ids = self.scene_info.get_all_extrinsic_valid_image_ids(scene_id)
178 | scene_image_height, scene_image_width = self.scene_info.get_image_shape(scene_id)
179 |
180 | # Calculate how many images to sample from this scene
181 | if self.max_samples > 0:
182 | n_images = min(self.max_samples, len(image_ids))
183 | else:
184 | n_images = len(image_ids)
185 |
186 | # Randomly sample images
187 | sampled_image_ids = random.sample(image_ids, n_images)
188 |
189 | all_samples = []
190 | for image_id in sampled_image_ids:
191 | # * Load all the visible points in this image
192 | visible_points = self.visibility_info.get_image_to_points_info(scene_id, image_id) # [point_index, ...]
193 |
194 | # * Randomly sample a set of points from the visible points in the image
195 | # * need to consider the number of available points, if not enough, use put back sampling
196 | if len(visible_points) < self.max_n_points_per_image:
197 | sampled_points = random.choices(visible_points, k=self.max_n_points_per_image)
198 | else:
199 | sampled_points = random.sample(visible_points, self.max_n_points_per_image)
200 |
201 | for point in sampled_points:
202 | # Calculate the 2D coordinates of the point in the image
203 | point_2d, point_depth = self.scene_info.get_point_2d_coordinates_in_image(
204 | scene_id, image_id, point, align=True, check_visible=True, return_depth=True
205 | )
206 |
207 | if len(point_2d) == 0:
208 | # If the point is not visible in the image, print a warning and skip it
209 | message = f"Warning: Point-Id {point} is not visible in image {image_id} in scene {scene_id}.\n"
210 | print(message.strip())
211 | with open(self.warning_file, 'a') as wf:
212 | wf.write(message.strip())
213 | continue
214 |
215 | # Convert the normalized coordinates to scaled values (0-1000 range)
216 | x = round((point_2d[0][0] / scene_image_width) * 1000)
217 | y = round((point_2d[0][1] / scene_image_height) * 1000)
218 | depth = round(point_depth[0] * 1000) # depth is in meters
219 |
220 | # read and annotate the image
221 | img_path = self.scene_info.get_image_path(scene_id, image_id)
222 | img = cv2.imread(img_path)
223 | annotated_img = self.annotate_image(img, (int(point_2d[0][0]), int(point_2d[0][1])))
224 |
225 | # save the annotated image
226 | save_dir = os.path.join(self.image_output_dir, scene_id)
227 | mmengine.mkdir_or_exist(save_dir)
228 | save_path = os.path.join(save_dir, f"{image_id}_p{point}_annotated.jpg")
229 | cv2.imwrite(save_path, annotated_img)
230 |
231 | # Fill the question and answer templates with the coordinates and depth
232 | question_template = random.choice(self.templates["questions"])
233 | question = question_template
234 |
235 | answer_template = random.choice(self.templates["answers"])
236 | answer = answer_template.format(x1=x, y1=y, depth=depth)
237 |
238 | task_description = random.choice(self.task_description)
239 |
240 | # If it's the first round, add the task description
241 | conversation = [
242 | {
243 | "from": "human",
244 | "value": f"{task_description}\n{question}"
245 | },
246 | {
247 | "from": "gpt",
248 | "value": answer
249 | }
250 | ]
251 |
252 | # Complete the training sample for the current image
253 | training_sample = {
254 | "id": f"{scene_id}_{image_id}_point{point}",
255 | "image": [f"{scene_id}/{image_id}_p{point}_annotated.jpg"],
256 | "conversations": conversation,
257 | "height_list": [scene_image_height],
258 | "width_list": [scene_image_width],
259 | "question_type": "depth_estimation_dot",
260 | "gt_value": depth,
261 | "ori_coordinates": [int(point_2d[0][0]), int(point_2d[0][1])],
262 | }
263 | all_samples.append(training_sample)
264 |
265 | return all_samples
266 |
267 | def generate_qa_training_data(self, output_dir, save_file=True):
268 | scene_ids = self.scene_info.get_sorted_keys()
269 |
270 | # if max_samples is not -1, then we need to sample the scenes
271 | if self.all_max_samples > 0:
272 | # need to calculate how many samples for each scene
273 | self.max_samples = max(self.all_max_samples // len(scene_ids) + 1, 1)
274 | if self.max_samples == 1:
275 | scene_ids = random.sample(scene_ids, self.all_max_samples)
276 | else:
277 | self.max_samples = -1
278 | self.num_used_scenes = len(scene_ids)
279 |
280 | train_data = []
281 | for scene_id in tqdm(scene_ids, desc="Generating QA Training Data"):
282 | train_data.extend(self.generate_qa_training_single_scene(scene_id))
283 |
284 | if len(train_data) > self.all_max_samples:
285 | train_data = random.sample(train_data, self.all_max_samples)
286 |
287 | random.shuffle(train_data)
288 |
289 | if save_file:
290 | output_jsonl_filepath = f"{output_dir}/{self.task_name}.jsonl"
291 |
292 | mmengine.mkdir_or_exist(output_dir)
293 | with open(output_jsonl_filepath, 'w') as f:
294 | for entry in train_data:
295 | f.write(json.dumps(entry) + '\n')
296 | print(f"[Train] Training data saved to {output_jsonl_filepath}. Generated {len(train_data)} samples in total.")
297 | else:
298 | return train_data
299 |
300 | def convert_train_sample_to_eval_sample(self, train_sample):
301 | conversation = train_sample["conversations"]
302 | train_sample["text"] = conversation[0]["value"]
303 | return train_sample
304 |
305 | def generate_qa_eval_data(self, output_dir):
306 | assert self.max_n_points_per_image == 1, "max_n_points_per_image should be 1 for evaluation"
307 | train_data = self.generate_qa_training_data(output_dir, save_file=False)
308 | all_data = [self.convert_train_sample_to_eval_sample(train_sample) for train_sample in train_data]
309 |
310 | output_jsonl_filepath = f"{output_dir}/{self.task_name}.jsonl"
311 |
312 | mmengine.mkdir_or_exist(output_dir)
313 | with open(output_jsonl_filepath, 'w') as f:
314 | for entry in all_data:
315 | f.write(json.dumps(entry) + '\n')
316 |
317 | print(f"[Eval] Evaluation data saved to {output_jsonl_filepath}. Generated {len(all_data)} samples in total.")
318 |
319 | if __name__ == "__main__":
320 | parser = argparse.ArgumentParser()
321 | parser.add_argument("--output_suffix", type=str, default="")
322 |
323 | parser.add_argument("--train_scene_info_path", type=str,
324 | default=f"data/scannet/scannet_instance_data/scenes_train_info_i_D5.pkl")
325 | parser.add_argument("--val_scene_info_path", type=str,
326 | default=f"data/scannet/scannet_instance_data/scenes_val_info_i_D5.pkl")
327 | parser.add_argument("--train_all_max_samples", type=int, default=500000)
328 | parser.add_argument("--val_all_max_samples", type=int, default=300)
329 |
330 | parser.add_argument("--output_dir_train", type=str,
331 | default=f"training_data/depth_estimation_dot")
332 | parser.add_argument("--output_dir_val", type=str,
333 | default=f"evaluation_data/depth_estimation_dot")
334 |
335 | parser.add_argument("--version_num", type=str, default="v1_0")
336 | args = parser.parse_args()
337 |
338 | args.output_dir_train = os.path.join(args.output_dir_train, args.version_num, args.output_suffix.replace('_', ''))
339 | args.output_dir_val = os.path.join(args.output_dir_val, args.version_num, args.output_suffix.replace('_', ''))
340 | args.image_output_dir_train = os.path.join(args.output_dir_train, "images")
341 | args.image_output_dir_val = os.path.join(args.output_dir_val, "images")
342 |
343 | # read the visibility info files
344 | train_visibility_info_path = f"data/scannet/scannet_instance_data/train_visibility_info_D5.parquet"
345 | val_visibility_info_path = f"data/scannet/scannet_instance_data/val_visibility_info_D5.parquet"
346 |
347 | val_warning_file = f"{args.output_dir_val}/val_warning.txt"
348 | train_warning_file = f"{args.output_dir_train}/train_warning.txt"
349 |
350 | print("Generating evaluation data...") # cost 3s to generate
351 | qa_engine_eval = DepthEstimationDotQAEngine(
352 | scene_info_path=args.val_scene_info_path,
353 | version_num=args.version_num,
354 | all_max_samples=args.val_all_max_samples,
355 | image_output_dir=args.image_output_dir_val,
356 | visibility_info_path=val_visibility_info_path,
357 | warning_file=val_warning_file
358 | )
359 | qa_engine_eval.generate_qa_eval_data(args.output_dir_val)
360 |
361 | print("Generating training data...") # cost 1.5 hours to generate, will generate 337523 samples
362 | qa_engine_train = DepthEstimationDotQAEngine(
363 | scene_info_path=args.train_scene_info_path,
364 | version_num=args.version_num,
365 | all_max_samples=args.train_all_max_samples,
366 | image_output_dir=args.image_output_dir_train,
367 | visibility_info_path=train_visibility_info_path,
368 | warning_file=train_warning_file
369 | )
370 | qa_engine_train.generate_qa_training_data(args.output_dir_train)
--------------------------------------------------------------------------------
/spatial_engine/object_perception/compute_object_visibility.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | This script uses the saved visibility parquet file along with SceneInfoHandler to
9 | compute the visibility of each object (excluding categories in NONINFORMATIVE_DESC)
10 | in all valid images for each scene.
11 |
12 | This script needs to be run after generating the point-level visibility parquet file and before running the object frame coverage script.
13 |
14 | The final saved result structure is:
15 | {
16 | scene_id: {
17 | "object_to_images": {
18 | object_id: [
19 | {
20 | "image_id": image_id,
21 | "intersection_count": number of intersecting points,
22 | "visibility": visibility percentage
23 | },
24 | ...
25 | ]
26 | },
27 | "image_to_objects": {
28 | image_id: [
29 | {
30 | "object_id": object_id,
31 | "intersection_count": number of intersecting points,
32 | "visibility": visibility percentage
33 | },
34 | ...
35 | ]
36 | }
37 | }
38 | }
39 |
40 | The result saves both:
41 | 1. object_to_images: Maps each object to a list of images that see it,
42 | each entry contains image_id, intersection_count and visibility percentage.
43 | 2. image_to_objects: Maps each image to a list of objects it can see,
44 | also saving intersection_count and visibility.
45 | For skipped cases, warnings are printed and written to warning.txt in the output directory.
46 | """
47 |
48 | import os
49 | import pickle
50 | import json
51 | import numpy as np
52 | import pandas as pd
53 | from tqdm import tqdm
54 |
55 | from spatial_engine.utils.scannet_utils.handler.info_handler import SceneInfoHandler
56 |
57 | # Non-informative categories (keep original strings, don't call lower())
58 | NONINFORMATIVE_DESC = {"wall", "object", "floor", "ceiling", "window"}
59 |
60 | def load_visibility_dict(parquet_file):
61 | """
62 | Load parquet file and convert to dict with structure:
63 | key: "scene_id:image_to_points:image_id"
64 | value: parsed point index list
65 | """
66 | df = pd.read_parquet(parquet_file)
67 | keys = df["key"].tolist()
68 | values = df["values"].tolist()
69 |
70 | return dict(zip(keys, values))
71 |
72 | def process_scene(scene_id, scene_info_handler, visibility_dict):
73 | """
74 | Process a single scene:
75 | - Iterate through each object in the scene,
76 | Skip and print warning if object's raw category is in NONINFORMATIVE_DESC;
77 | Skip and print warning if object has no point indices.
78 | - For remaining objects, call get_object_point_index to get object point indices,
79 | Calculate threshold (5% of object points, minimum 1).
80 | - Iterate through all valid images (scene_info.get_all_extrinsic_valid_image_ids),
81 | Look up visible points for that image in visibility_dict,
82 | Calculate intersection with object points, record the image if intersection count meets threshold,
83 | Save intersection count and visibility percentage.
84 | - Build two mappings:
85 | object_to_images: { object_id: [ { "image_id": ..., "intersection_count": n, "visibility": v }, ... ] }
86 | image_to_objects: { image_id: [ { "object_id": ..., "intersection_count": n, "visibility": v }, ... ] }
87 | Returns: (scene_id, result, warnings)
88 | where result is dict of above two mappings, warnings is list of all warning messages for this scene.
89 | """
90 | print(f"Processing scene {scene_id}.")
91 | warnings_list = []
92 | result = {
93 | "object_to_images": {},
94 | "image_to_objects": {}
95 | }
96 |
97 | # Use the scene_info_handler
98 | if scene_id not in scene_info_handler.infos:
99 | msg = f"[Warning] Scene {scene_id} not found in scene_info."
100 | warnings_list.append(msg)
101 | print(msg)
102 | return scene_id, result, warnings_list
103 |
104 | num_objects = scene_info_handler.get_num_objects(scene_id)
105 | for object_id in range(num_objects):
106 | raw_category = scene_info_handler.get_object_raw_category(scene_id, object_id)
107 | if raw_category in NONINFORMATIVE_DESC:
108 | continue
109 |
110 | object_points = scene_info_handler.get_object_point_index(scene_id, object_id)
111 | if len(object_points) == 0:
112 | msg = f"[Warning] Scene {scene_id}, object {object_id} has no point indices, skipping."
113 | warnings_list.append(msg)
114 | print(msg)
115 | continue
116 |
117 | if isinstance(object_points, np.ndarray):
118 | object_points_set = set(object_points.tolist())
119 | else:
120 | object_points_set = set(object_points)
121 | total_points = len(object_points_set)
122 | threshold = max(1, int(0.05 * total_points))
123 |
124 | valid_image_ids = scene_info_handler.get_all_extrinsic_valid_image_ids(scene_id)
125 | for image_id in valid_image_ids:
126 | key = f"{scene_id}:image_to_points:{image_id}"
127 | if key not in visibility_dict:
128 | msg = f"[Warning] Scene {scene_id}, image {image_id} not found in visibility dict."
129 | warnings_list.append(msg)
130 | print(msg)
131 | continue
132 |
133 | visible_points = set(json.loads(visibility_dict[key]))
134 | intersection_count = len(visible_points & object_points_set)
135 | if intersection_count >= threshold:
136 | visibility_percent = (intersection_count / total_points) * 100.0
137 | if object_id not in result["object_to_images"]:
138 | result["object_to_images"][object_id] = []
139 | result["object_to_images"][object_id].append({
140 | "image_id": image_id,
141 | "intersection_count": intersection_count,
142 | "visibility": visibility_percent
143 | })
144 | if image_id not in result["image_to_objects"]:
145 | result["image_to_objects"][image_id] = []
146 | result["image_to_objects"][image_id].append({
147 | "object_id": object_id,
148 | "intersection_count": intersection_count,
149 | "visibility": visibility_percent
150 | })
151 |
152 | return scene_id, result, warnings_list
153 |
154 | def process_split(split_name, scene_info_path, visibility_parquet_file, output_dir):
155 | """
156 | Process one split (train or val):
157 | - Load visibility parquet file and convert to dict;
158 | - Load scene_info to get all scene_ids;
159 | - Process each scene sequentially (call process_scene);
160 | - Save combined results to pkl file and write all warnings to warning.txt in output_dir.
161 | """
162 | if not os.path.exists(output_dir):
163 | os.makedirs(output_dir)
164 |
165 | output_pkl_file = os.path.join(output_dir, "object_visibility.pkl")
166 | warning_file = os.path.join(output_dir, "warning.txt")
167 | with open(warning_file, "w") as wf:
168 | wf.write("")
169 |
170 | # Get scene_ids in the main process
171 | scene_info_handler = SceneInfoHandler(scene_info_path)
172 | scene_ids = scene_info_handler.get_all_scene_ids()
173 |
174 | results = {}
175 | all_warnings = []
176 |
177 | scene_info_handler = SceneInfoHandler(scene_info_path)
178 | print(f"Loading visibility dict from {visibility_parquet_file}.")
179 | visibility_dict = load_visibility_dict(visibility_parquet_file)
180 | print(f"Loaded visibility dict.")
181 |
182 | for scene_id in tqdm(scene_ids, desc=f"Processing {split_name} scenes"):
183 | scene_id, scene_result, warnings = process_scene(scene_id, scene_info_handler, visibility_dict)
184 | results[scene_id] = scene_result
185 | all_warnings.extend(warnings)
186 |
187 | with open(warning_file, "a") as wf:
188 | for w in all_warnings:
189 | wf.write(w + "\n")
190 |
191 | with open(output_pkl_file, "wb") as f:
192 | pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)
193 |
194 | print(f"Finished processing split '{split_name}'.")
195 | print(f"Result saved to {output_pkl_file}")
196 | print(f"Warnings saved to {warning_file}")
197 |
198 |
199 | def main():
200 | # Configure parameters for train and val
201 | splits = {
202 | "val": { # * take 15 mins
203 | "scene_info_path": "data/scannet/scannet_instance_data/scenes_val_info_i_D5.pkl",
204 | "visibility_parquet_file": "data/scannet/scannet_instance_data/val_visibility_info_D5.parquet",
205 | "output_dir": "evaluation_data/object_perception"
206 | },
207 | "train": { # * take 1h 46 mins to generate, 94M size
208 | "scene_info_path": "data/scannet/scannet_instance_data/scenes_train_info_i_D5.pkl",
209 | "visibility_parquet_file": "data/scannet/scannet_instance_data/train_visibility_info_D5.parquet",
210 | "output_dir": "training_data/object_perception"
211 | }
212 | }
213 |
214 | for split_name, config in splits.items():
215 | process_split(split_name,
216 | config["scene_info_path"],
217 | config["visibility_parquet_file"],
218 | config["output_dir"])
219 |
220 |
221 | if __name__ == "__main__":
222 | main()
--------------------------------------------------------------------------------
/spatial_engine/object_perception/find_object_coverage.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | #!/bin/bash
8 | # run_tasks.sh
9 | # Usage: ./run_tasks.sh [start] [chunk_size]
10 | # This script runs object coverage finding for both train and val splits
11 |
12 | # Set default values
13 | DEFAULT_START=0
14 | DEFAULT_CHUNK_SIZE=10
15 |
16 | # Get command line parameters or use default values
17 | START=${1:-$DEFAULT_START}
18 | CHUNK_SIZE=${2:-$DEFAULT_CHUNK_SIZE}
19 |
20 | echo "Launching tasks for both train and val splits from scene $START with chunk size $CHUNK_SIZE"
21 |
22 | # Process train split
23 | echo "Processing train split..."
24 | for ((current_start=START; ; current_start+=CHUNK_SIZE)); do
25 | current_end=$((current_start+CHUNK_SIZE))
26 | echo " Starting task for train: scenes $current_start to $current_end"
27 | # Launch task in background, output files will include start and end suffixes
28 | python spatial_engine/object_perception/single_object_coverage_finder.py --split "train" --start "$current_start" --end "$current_end" &
29 |
30 | # Add a limit to prevent infinite loop - can be adjusted as needed
31 | if [ $current_start -ge 1000 ]; then
32 | break
33 | fi
34 | done
35 |
36 | # Process val split
37 | echo "Processing val split..."
38 | for ((current_start=START; ; current_start+=CHUNK_SIZE)); do
39 | current_end=$((current_start+CHUNK_SIZE))
40 | echo " Starting task for val: scenes $current_start to $current_end"
41 | # Launch task in background, output files will include start and end suffixes
42 | python spatial_engine/object_perception/single_object_coverage_finder.py --split "val" --start "$current_start" --end "$current_end" &
43 |
44 | # Add a limit to prevent infinite loop - can be adjusted as needed
45 | if [ $current_start -ge 1000 ]; then
46 | break
47 | fi
48 | done
49 |
50 | # Wait for all background tasks to complete
51 | wait
52 | echo "All tasks completed for both train and val splits."
--------------------------------------------------------------------------------
/spatial_engine/object_perception/merge_object_coverage.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | #!/usr/bin/env python3
8 | """
9 | This script should be run after find_object_coverage.sh. It's used to merge the object coverage results of different scenes.
10 | """
11 |
12 | import os
13 | import pickle
14 | import glob
15 | import re
16 |
17 | def merge_dimension(split, base_dir, dimension):
18 | """
19 | For the specified split ("train" or "val"), base_dir, and dimension,
20 | traverse all subdirectories starting with split_ under base_dir,
21 | load pkl files with filenames matching f"{split}_object_coverage_{dimension}__.pkl",
22 | merge them and return the combined dictionary.
23 | """
24 | merged = {}
25 | # List all subdirectories starting with split_
26 | subdirs = [d for d in os.listdir(base_dir)
27 | if os.path.isdir(os.path.join(base_dir, d)) and d.startswith(f"{split}_")]
28 | if not subdirs:
29 | print(f"No subdirectories starting with {split}_ found in {base_dir}.")
30 | return merged
31 |
32 | # Use regex to parse directory names (e.g., "train_0_10")
33 | pattern = re.compile(fr"{split}_(\d+)_(\d+)")
34 | dir_ranges = []
35 | for d in subdirs:
36 | m = pattern.match(d)
37 | if m:
38 | start = int(m.group(1))
39 | end = int(m.group(2))
40 | dir_ranges.append((d, start, end))
41 | else:
42 | print(f"Directory name format does not match requirements, skipping: {d}")
43 | # Sort by start value
44 | dir_ranges.sort(key=lambda x: x[1])
45 | print(f"Found {len(dir_ranges)} {split} subdirectories:")
46 | for d, s, e in dir_ranges:
47 | print(f" {d}: {s} ~ {e}")
48 |
49 | # Traverse each subdirectory, looking for pkl files corresponding to the dimension
50 | for d, s, e in dir_ranges:
51 | subdir_path = os.path.join(base_dir, d)
52 | # Filename format: "{split}_object_coverage_{dimension}_{s}_{e}.pkl"
53 | pattern_file = os.path.join(subdir_path, f"{split}_object_coverage_{dimension}_*_*.pkl")
54 | files = glob.glob(pattern_file)
55 | if not files:
56 | print(f"No {dimension} files found in subdirectory {d}, skipping.")
57 | continue
58 | for file in files:
59 | print(f"Loading file: {file}")
60 | with open(file, "rb") as f:
61 | data = pickle.load(f)
62 | # Assume each file contains a dict with scene_id as keys
63 | merged.update(data)
64 | return merged
65 |
66 | def merge_split(split, base_dir, output_dir):
67 | """
68 | For the specified split and base_dir, merge files for height, length, width dimensions separately,
69 | and save to output_dir.
70 | """
71 | dims = ["height", "length", "width"]
72 | merged_dict = {}
73 | for d in dims:
74 | print(f"\nMerging {split} {d} files ...")
75 | merged = merge_dimension(split, base_dir, d)
76 | merged_dict[d] = merged
77 | num_scene = len(merged)
78 | print(f"After merging {split} {d}, there are {num_scene} scene_ids in total.")
79 | # Save results
80 | output_file = os.path.join(output_dir, f"merged_{split}_object_coverage_{d}.pkl")
81 | with open(output_file, "wb") as f:
82 | pickle.dump(merged, f)
83 | print(f"Saved merged results to {output_file}")
84 | return merged_dict
85 |
86 | def main():
87 | # Set base_dir for training and validation sets (please modify according to actual paths)
88 | train_base = "training_data/object_perception"
89 | val_base = "evaluation_data/object_perception"
90 |
91 | # Output directory can be the same as base_dir or set separately, here we use the respective base_dir directly
92 | # Note: Output files will be saved to the corresponding folders
93 | print("Starting to merge training set files ...")
94 | merge_split("train", train_base, train_base)
95 |
96 | print("\nStarting to merge validation set files ...")
97 | merge_split("val", val_base, val_base)
98 |
99 | if __name__ == "__main__":
100 | main()
--------------------------------------------------------------------------------
/spatial_engine/object_perception/single_object_coverage_finder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | #!/usr/bin/env python
8 | """
9 | This script processes object coverage in three dimensions (height, length, width)
10 | based on the precomputed object visibility information and SceneInfoHandler.
11 | For each object (excluding those with categories in NONINFORMATIVE_DESC), we aim to find
12 | a minimal combination of images such that the union of the visible 3D points in that
13 | dimension exactly (within a specified tolerance) covers the object’s target size.
14 | A combination is considered minimal if removing any image from it would result in incomplete coverage.
15 | The output for each object is stored as a dict keyed by the number of images in the combination,
16 | e.g.:
17 | {
18 | "height": { 1: [ [img1], [img3], ... ], 2: [ [img2, img5], ... ], ... },
19 | "length": { ... },
20 | "width": { ... }
21 | }
22 | Note: Theoretically, one could compute visible images on the fly by projecting the object's 3D points
23 | into each image and checking visibility; however, precomputing and saving the object visibility data
24 | (such as in our previous step) can save time in subsequent processing.
25 | """
26 |
27 | import os
28 | import pickle
29 | import json
30 | import numpy as np
31 | import pandas as pd
32 | from tqdm import tqdm
33 | import mmengine
34 | import random
35 | import argparse
36 | random.seed(0)
37 | # Tolerance for dimension coverage (10% of target)
38 | TOLERANCE = 0.1
39 |
40 | from spatial_engine.utils.scannet_utils.handler.info_handler import SceneInfoHandler
41 |
42 | DEBUG = False
43 |
44 |
45 | def load_visibility_dict(parquet_file):
46 | """
47 | Load parquet file and convert to dict with structure:
48 | key: "scene_id:image_to_points:image_id"
49 | value: parsed point index list (stored as JSON string)
50 | """
51 | df = pd.read_parquet(parquet_file)
52 | keys = df["key"].tolist()
53 | values = df["values"].tolist()
54 | return dict(zip(keys, values))
55 |
56 |
57 | def compute_coverage(scene_pts, indices_bool_mask, axis):
58 | """
59 | Given a set of indices (subset of scene_pts) and the 3D points (scene_pts),
60 | compute the coverage along the specified axis.
61 | """
62 | if not indices_bool_mask.any():
63 | return None
64 | coords = scene_pts[indices_bool_mask][:, axis]
65 | return max(coords) - min(coords)
66 |
67 |
68 | def covers_dimension(coverage, target, tolerance):
69 | """
70 | Check if the computed coverage is within tolerance of the target dimension.
71 | """
72 | if coverage is None:
73 | return False
74 | return abs(coverage - target) <= tolerance * target
75 |
76 | def find_minimal_combinations(
77 | scene_id,
78 | scene_pts,
79 | object_points_indices,
80 | visible_images,
81 | images_to_visible_points_dict,
82 | axis,
83 | target_dim,
84 | tolerance,
85 | max_images=5
86 | ):
87 | """
88 | Use two-phase breadth-first search (BFS) to find "all minimal combinations" layer by layer:
89 | 1. Phase A: Check coverage, record minimal solutions, and add uncovered combinations to expansion list;
90 | 2. Phase B: Save minimal solutions to global set for pruning, then expand the list to the next layer.
91 | Once a combination covers the target, it stops expanding; but this doesn't affect other combinations
92 | continuing to search, thus finding all mutually exclusive minimal solutions.
93 |
94 | Returns a dictionary: {k: [list of minimal combinations of size k]}.
95 | """
96 | # ----------- Preprocessing -----------
97 | # 1) Build boolean mask for object points
98 | object_points_indices_mask = np.zeros(len(scene_pts), dtype=bool)
99 | object_points_indices_mask[object_points_indices] = True
100 |
101 | # 2) Extract boolean mask for each image (only keeping object points)
102 | image_bool_masks = {}
103 | for img in visible_images:
104 | key = f"{scene_id}:image_to_points:{img}"
105 | if key not in images_to_visible_points_dict:
106 | print(f"[Warning] Scene {scene_id}, image {img} not found in visibility dict. Skip this combination.")
107 | continue # skip
108 | bool_mask = np.zeros(len(scene_pts), dtype=bool)
109 | bool_mask[json.loads(images_to_visible_points_dict[key])] = True
110 | bool_mask = np.logical_and(bool_mask, object_points_indices_mask)
111 | image_bool_masks[img] = bool_mask
112 |
113 | valid_images = list(image_bool_masks.keys())
114 | # Sort by number of covered points in descending order, optional (sometimes finds smaller combinations faster)
115 | # valid_images.sort(key=lambda x: np.sum(image_bool_masks[x]), reverse=True)
116 |
117 | # only get 25 images
118 | if len(valid_images) > 25:
119 | valid_images = random.sample(valid_images, 25)
120 |
121 | # Pre-compute cumulative remaining coverage, the union mask of all images from index i to the end
122 | cumulative_union = [None] * len(valid_images)
123 | if valid_images:
124 | cumulative_union[-1] = image_bool_masks[valid_images[-1]].copy()
125 | for i in range(len(valid_images) - 2, -1, -1):
126 | cumulative_union[i] = np.logical_or(image_bool_masks[valid_images[i]], cumulative_union[i + 1])
127 |
128 | # Initialize a list to store the minimal sets as boolean masks
129 | found_minimal_sets = []
130 |
131 | def is_superset_of_any_minimal_bit(comb_bitmask):
132 | # Check if the combination is a superset of any known minimal set using numpy operations
133 | if not found_minimal_sets:
134 | return False
135 | # Perform a logical AND between the combination bitmask and each minimal set
136 | # Then check if any row in the result matches the minimal set itself
137 | for minimal_set in found_minimal_sets:
138 | if np.array_equal(np.logical_and(minimal_set, comb_bitmask), minimal_set):
139 | return True
140 | return False
141 |
142 | # Helper function: Check coverage
143 | def can_cover(union_mask):
144 | cov = compute_coverage(scene_pts, union_mask, axis)
145 | return covers_dimension(cov, target_dim, tolerance)
146 |
147 | # ----------- BFS Initialization -----------
148 | # current_level stores (comb, union_mask, last_idx), representing all size=k combinations at the current level
149 | current_level = []
150 | for i, img in enumerate(valid_images):
151 | comb_bitmask = np.zeros(len(valid_images), dtype=bool)
152 | comb_bitmask[i] = True
153 | current_level.append(([img], image_bool_masks[img], i, comb_bitmask))
154 |
155 | # Store results: dictionary {combination size k: [list of minimal combinations of size k]}
156 | minimal_solutions = {}
157 | first_layer_to_expand_list = []
158 |
159 | k = 1
160 | # Enter loop while k is within limit and current_level is not empty
161 | while k <= max_images and current_level:
162 | # ---------- Phase A: Check Coverage -----------
163 | to_expand = [] # Combinations to expand to the next level
164 | new_minimal_sets = []
165 |
166 | # Traverse the current level
167 | for comb, union_mask, last_idx, comb_bitmask in current_level:
168 | # for comb, union_mask, last_idx, comb_bitmask in tqdm(current_level, desc=f"Processing layer {k}"):
169 | # Prune: Skip if the combination is a superset of any known minimal set
170 | if is_superset_of_any_minimal_bit(comb_bitmask):
171 | continue
172 |
173 | # If not a superset, check coverage
174 | if can_cover(union_mask):
175 | # This is a new minimal combination
176 | new_minimal_sets.append(comb_bitmask)
177 | # Save to minimal_solutions
178 | minimal_solutions.setdefault(k, [])
179 | minimal_solutions[k].append(tuple(comb))
180 | else:
181 | # Not yet covered, add to the list to expand
182 | # early prune: if the cumulative union is not covered, skip
183 | if last_idx < len(valid_images) - 1:
184 | possible_union_mask = np.logical_or(cumulative_union[last_idx], union_mask)
185 | if not can_cover(possible_union_mask):
186 | continue
187 | to_expand.append((comb, union_mask, last_idx, comb_bitmask))
188 | if k == 1:
189 | first_layer_to_expand_list.append((comb, union_mask, last_idx, comb_bitmask))
190 |
191 | # Update known minimal sets
192 | if new_minimal_sets:
193 | found_minimal_sets.extend(new_minimal_sets)
194 |
195 | # ---------- Phase B: Expand to Next Level -----------
196 | # Generate next level combinations (k+1), only expand "not covered" combinations
197 |
198 | next_level = []
199 | if k < max_images: # If further expansion is possible
200 | for comb, union_mask, last_idx, comb_bitmask in to_expand:
201 | # Only use those in first_layer_to_expand_list to expand
202 | # Need to find the first larger image index
203 | for potential_comb, potential_union_mask, potential_last_idx, potential_comb_bitmask in first_layer_to_expand_list:
204 | if potential_last_idx > last_idx:
205 | # Make a new combination
206 | assert len(potential_comb) == 1, f"Number of images in potential_comb should be 1. {potential_comb}."
207 | new_comb = comb + potential_comb
208 | new_union_mask = np.logical_or(union_mask, potential_union_mask)
209 | new_comb_bitmask = np.logical_or(potential_comb_bitmask, comb_bitmask)
210 | next_level.append((new_comb, new_union_mask, potential_last_idx, new_comb_bitmask))
211 |
212 | if len(next_level) > 5000:
213 | # random sample 5000
214 | next_level = random.sample(next_level, 5000)
215 |
216 | # Update current_level to the next level
217 | current_level = next_level
218 | k += 1
219 |
220 | return minimal_solutions
221 |
222 | def process_object(scene_id, object_id, scene_info_handler: SceneInfoHandler, visible_images, images_to_visible_points_dict):
223 | """
224 | For a given object in a scene:
225 | - Get its point indices and corresponding 3D coordinates (aligned).
226 | - Use the precomputed object visibility (from the visibility file) to obtain the set of images
227 | that see this object.
228 | - For each dimension, compute the target value and corresponding axis:
229 | Height: target = get_object_height, axis = 2.
230 | For length and width, use get_object_width_axis_aligned to determine:
231 | if width_axis == 0 then width is along x (axis=0) and length along y (axis=1);
232 | else width is along y (axis=1) and length along x (axis=0).
233 | - Use find_minimal_combinations to obtain minimal image combinations covering the target.
234 | - Group the combinations by the number of images.
235 | Returns a dict: { "height": {n: [...], ...}, "length": {n: [...], ...}, "width": {n: [...], ...} }.
236 |
237 | Note: Theoretically, one could compute the visible images on the fly by projecting the object's 3D points
238 | into each image and checking visibility; however, precomputing and saving the object visibility data
239 | (e.g., in object_visibility.pkl) can save time in subsequent processing.
240 | """
241 | # Get full aligned scene points (first 3 coords)
242 | scene_pts = scene_info_handler.get_scene_points_align(scene_id)[:, :3]
243 | object_points_indices = scene_info_handler.get_object_point_index(scene_id, object_id)
244 |
245 | # Height: target from get_object_height, axis=2.
246 | height_target = scene_info_handler.get_object_height(scene_id, object_id)
247 | height_axis = 2
248 |
249 | # For length and width, determine axis using get_object_width_axis_aligned.
250 | width_axis = scene_info_handler.get_object_width_axis_aligned(scene_id, object_id) # 0 or 1
251 | length_axis = 1 if width_axis == 0 else 0
252 | length_target = scene_info_handler.get_object_length(scene_id, object_id)
253 | width_target = scene_info_handler.get_object_width(scene_id, object_id)
254 |
255 | height_combs = find_minimal_combinations(scene_id, scene_pts, object_points_indices, visible_images, images_to_visible_points_dict, height_axis, height_target, TOLERANCE)
256 | # for height, length, width, we do not consider those less than 0.02 m
257 | length_combs = find_minimal_combinations(scene_id, scene_pts, object_points_indices, visible_images, images_to_visible_points_dict, length_axis, length_target, TOLERANCE)
258 | width_combs = find_minimal_combinations(scene_id, scene_pts, object_points_indices, visible_images, images_to_visible_points_dict, width_axis, width_target, TOLERANCE)
259 |
260 | return {
261 | "height": height_combs,
262 | "length": length_combs,
263 | "width": width_combs
264 | }
265 |
266 | def process_scene_for_coverage(scene_id, scene_info_handler, images_to_visible_points_dict, object_visibility_dict):
267 | """
268 | Process one scene:
269 | For each object (excluding non-informative ones), compute the minimal image combinations
270 | that cover the object's height, length, and width.
271 | Returns: (scene_id, scene_result) where scene_result maps object_id to
272 | { "height": {...}, "length": {...}, "width": {...} }.
273 | """
274 | print(f"Processing scene {scene_id} for object coverage.")
275 | scene_result = {}
276 | # iterate through object_visibility_dict
277 | scene_object_visibility = object_visibility_dict[scene_id]["object_to_images"]
278 | for object_id, visibility_list in tqdm(scene_object_visibility.items(), desc=f"Processing scene {scene_id} for object coverage"):
279 | # * process each object
280 | visible_images = [img["image_id"] for img in visibility_list]
281 | res = process_object(scene_id, object_id, scene_info_handler, visible_images, images_to_visible_points_dict)
282 | if res is not None:
283 | scene_result[object_id] = res
284 | return scene_id, scene_result
285 |
286 |
287 | def process_split_objects(split_name, scene_info_path, visibility_parquet_file, object_visibility_file, output_dir):
288 | """
289 | Process one split (train or val) for object coverage.
290 | For each scene, compute per-object minimal image combinations for height, length, and width.
291 | Save three separate result files in output_dir.
292 | """
293 | if not os.path.exists(output_dir):
294 | os.makedirs(output_dir)
295 | output_height_file = os.path.join(output_dir, f"{split_name}_object_coverage_height.pkl")
296 | output_length_file = os.path.join(output_dir, f"{split_name}_object_coverage_length.pkl")
297 | output_width_file = os.path.join(output_dir, f"{split_name}_object_coverage_width.pkl")
298 | warning_file = os.path.join(output_dir, "coverage_finding_warning_objects.txt")
299 | with open(warning_file, "w") as wf:
300 | wf.write("")
301 |
302 | # Load necessary data once
303 | scene_info_handler = SceneInfoHandler(scene_info_path)
304 | images_to_visible_points_dict = load_visibility_dict(visibility_parquet_file)
305 | object_visibility_dict = mmengine.load(object_visibility_file)
306 |
307 | # Get all scene_ids from SceneInfoHandler in main process
308 | scene_ids = scene_info_handler.get_all_scene_ids()
309 |
310 | results_height = {}
311 | results_length = {}
312 | results_width = {}
313 |
314 | if DEBUG:
315 | scene_ids = ["scene0011_00"]
316 |
317 | for scene_id in tqdm(scene_ids, desc=f"Processing {split_name} scenes for object coverage"):
318 | scene_id, scene_result = process_scene_for_coverage(scene_id, scene_info_handler, images_to_visible_points_dict, object_visibility_dict)
319 | if scene_result:
320 | results_height[scene_id] = {}
321 | results_length[scene_id] = {}
322 | results_width[scene_id] = {}
323 | for object_id, res in scene_result.items():
324 | results_height[scene_id][object_id] = res["height"] # possible to be empty
325 | results_length[scene_id][object_id] = res["length"]
326 | results_width[scene_id][object_id] = res["width"]
327 |
328 | with open(output_height_file, "wb") as f:
329 | pickle.dump(results_height, f)
330 | with open(output_length_file, "wb") as f:
331 | pickle.dump(results_length, f)
332 | with open(output_width_file, "wb") as f:
333 | pickle.dump(results_width, f)
334 |
335 | print(f"Finished processing split '{split_name}' for object coverage.")
336 | print(f"Height coverage saved to {output_height_file}")
337 | print(f"Length coverage saved to {output_length_file}")
338 | print(f"Width coverage saved to {output_width_file}")
339 | print(f"Warnings saved to {warning_file}")
340 |
341 |
342 | def main():
343 | parser = argparse.ArgumentParser(description="Process object coverage for a given split and scene index range.")
344 | parser.add_argument("--split", type=str, required=True, help="Specify the split: train or val")
345 | parser.add_argument("--start", type=int, required=True, help="Start scene index (inclusive)")
346 | parser.add_argument("--end", type=int, default=None, help="End scene index (exclusive)")
347 | args = parser.parse_args()
348 |
349 | split = args.split
350 | start_index = args.start
351 | end_index = args.end
352 |
353 | # 配置各个split的参数(根据你实际情况修改路径)
354 | splits = {
355 | "val": {
356 | "scene_info_path": "data/scannet/scannet_instance_data/scenes_val_info_i_D5.pkl",
357 | "visibility_parquet_file": "data/scannet/scannet_instance_data/val_visibility_info_D5.parquet",
358 | "object_visibility_file": "evaluation_data/object_perception/object_visibility.pkl",
359 | "output_dir": "evaluation_data/object_perception"
360 | },
361 | "train": {
362 | "scene_info_path": "data/scannet/scannet_instance_data/scenes_train_info_i_D5.pkl",
363 | "visibility_parquet_file": "data/scannet/scannet_instance_data/train_visibility_info_D5.parquet",
364 | "object_visibility_file": "training_data/object_perception/object_visibility.pkl",
365 | "output_dir": "training_data/object_perception"
366 | }
367 | }
368 |
369 | if DEBUG:
370 | split = "val"
371 | start_index = 0
372 | end_index = 1
373 | # output dir should have a subdir named "debug"
374 | splits["val"]["output_dir"] = os.path.join(splits["val"]["output_dir"], "debug")
375 | splits["val"]["visibility_parquet_file"] = "data/scannet/scannet_instance_data/val_visibility_info_D5_debug.parquet"
376 |
377 | if split not in splits:
378 | raise ValueError("Invalid split. Choose train or val.")
379 |
380 | config = splits[split]
381 | # 在输出目录后加上 start_end 后缀
382 | config["output_dir"] = os.path.join(config["output_dir"], f"{split}_{start_index}_{end_index}")
383 | if not os.path.exists(config["output_dir"]):
384 | os.makedirs(config["output_dir"])
385 |
386 | scene_info_handler = SceneInfoHandler(config["scene_info_path"])
387 | images_to_visible_points_dict = load_visibility_dict(config["visibility_parquet_file"])
388 | object_visibility_dict = mmengine.load(config["object_visibility_file"])
389 |
390 | all_scene_ids = scene_info_handler.get_all_scene_ids()
391 | selected_scene_ids = all_scene_ids[start_index:end_index]
392 | print(f"Processing scenes from index {start_index} to {end_index} (total {len(selected_scene_ids)}) in split {split}.")
393 |
394 | results_height = {}
395 | results_length = {}
396 | results_width = {}
397 |
398 | for scene_id in tqdm(selected_scene_ids, desc=f"Processing {split} scenes"):
399 | scene_id, scene_result = process_scene_for_coverage(scene_id, scene_info_handler, images_to_visible_points_dict, object_visibility_dict)
400 | if scene_result:
401 | results_height[scene_id] = {}
402 | results_length[scene_id] = {}
403 | results_width[scene_id] = {}
404 | for object_id, res in scene_result.items():
405 | results_height[scene_id][object_id] = res["height"]
406 | results_length[scene_id][object_id] = res["length"]
407 | results_width[scene_id][object_id] = res["width"]
408 |
409 | output_height_file = os.path.join(config["output_dir"], f"{split}_object_coverage_height_{start_index}_{end_index}.pkl")
410 | output_length_file = os.path.join(config["output_dir"], f"{split}_object_coverage_length_{start_index}_{end_index}.pkl")
411 | output_width_file = os.path.join(config["output_dir"], f"{split}_object_coverage_width_{start_index}_{end_index}.pkl")
412 |
413 | with open(output_height_file, "wb") as f:
414 | pickle.dump(results_height, f)
415 | with open(output_length_file, "wb") as f:
416 | pickle.dump(results_length, f)
417 | with open(output_width_file, "wb") as f:
418 | pickle.dump(results_width, f)
419 |
420 | print(f"Finished processing split '{split}' for scenes {start_index} to {end_index}.")
421 | print(f"Height coverage saved to {output_height_file}")
422 | print(f"Length coverage saved to {output_length_file}")
423 | print(f"Width coverage saved to {output_width_file}")
424 |
425 | if __name__ == "__main__":
426 | main()
--------------------------------------------------------------------------------
/spatial_engine/object_perception/single_object_perception_engine.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import json
9 | import pickle
10 | import random
11 | from tqdm import tqdm
12 | import numpy as np
13 | import random
14 | random.seed(1)
15 | np.random.seed(1) # * use a different seed from cam_cam
16 | import json
17 | from spatial_engine.utils.scannet_utils.handler.info_handler import SceneInfoHandler
18 |
19 |
20 | # ==================== Global Variables ====================
21 | max_train_samples = -1 # -1 means no downsampling for training data; if positive, downsample to this number
22 | val_max_samples = 3000 # Validation data will be randomly sampled to 300 samples
23 |
24 |
25 | ASK_DESCRIPTION = [
26 | "Assume the scene remains unchanged. Your task is to determine the spatial properties based on the images. You need to integrate and analyze information from all provided images to get the answer.",
27 | "Given the static scene, determine the spatial properties using the images. Synthesize and evaluate information from all provided images to derive the answer.",
28 | "Analyze the images to determine the spatial properties of the scene. You must combine and interpret data from all provided images to find the answer.",
29 | "Using the images, identify the spatial properties of the scene. Collate and assess information from all provided images to reach the answer.",
30 | "Examine the images to find the spatial properties of the scene. You need to merge and scrutinize information from all provided images to obtain the answer.",
31 | "Determine the spatial properties of the scene based on the images. Integrate and review information from all provided images to conclude the answer.",
32 | "Identify the spatial properties of the scene using the images. You must gather and analyze data from all provided images to ascertain the answer.",
33 | "Find the spatial properties of the scene by analyzing the images. Combine and evaluate information from all provided images to deduce the answer.",
34 | "Use the images to determine the spatial properties of the scene. You need to synthesize and interpret information from all provided images to get the answer.",
35 | "Analyze the provided images to identify the spatial properties of the scene. Collate and assess data from all provided images to derive the answer.",
36 | "Determine the spatial properties of the scene using the images. You must integrate and scrutinize information from all provided images to find the answer.",
37 | "Identify the spatial properties of the scene by examining the images. Merge and evaluate data from all provided images to reach the answer.",
38 | "Find the spatial properties of the scene using the images. You need to gather and interpret information from all provided images to obtain the answer.",
39 | "Analyze the images to determine the spatial properties of the scene. Combine and review data from all provided images to ascertain the answer.",
40 | "Using the images, identify the spatial properties of the scene. Synthesize and scrutinize information from all provided images to deduce the answer.",
41 | "Examine the images to find the spatial properties of the scene. You must integrate and assess information from all provided images to get the answer.",
42 | "Determine the spatial properties of the scene based on the images. Collate and interpret data from all provided images to conclude the answer.",
43 | "Identify the spatial properties of the scene using the images. You need to merge and evaluate information from all provided images to derive the answer.",
44 | "Find the spatial properties of the scene by analyzing the images. Gather and review data from all provided images to find the answer.",
45 | "Use the images to determine the spatial properties of the scene. You must synthesize and assess information from all provided images to reach the answer.",
46 | "Analyze the provided images to identify the spatial properties of the scene. Integrate and interpret data from all provided images to obtain the answer.",
47 | "Determine the spatial properties of the scene using the images. You need to combine and scrutinize information from all provided images to ascertain the answer.",
48 | "Identify the spatial properties of the scene by examining the images. Collate and review data from all provided images to deduce the answer.",
49 | "Find the spatial properties of the scene using the images. Merge and assess information from all provided images to get the answer.",
50 | "Analyze the images to determine the spatial properties of the scene. You must gather and interpret data from all provided images to conclude the answer.",
51 | "Using the images, identify the spatial properties of the scene. Combine and evaluate information from all provided images to derive the answer.",
52 | "Examine the images to find the spatial properties of the scene. You need to synthesize and review data from all provided images to find the answer.",
53 | "Determine the spatial properties of the scene based on the images. Integrate and scrutinize information from all provided images to reach the answer.",
54 | "Identify the spatial properties of the scene using the images. Collate and interpret data from all provided images to obtain the answer.",
55 | "Find the spatial properties of the scene by analyzing the images. You must merge and assess information from all provided images to ascertain the answer."
56 | ]
57 |
58 | QUESTION_TEMPLATES = [
59 | "What is the {dimension} (in millimeters) of the {object_category} itself commonly visible in these images?",
60 | "Calculate the {dimension} (in millimeters) of the {object_category} that is commonly visible in these images.",
61 | "Determine the {dimension} (in millimeters) of the {object_category} which is commonly visible in these images.",
62 | "Find the {dimension} (in millimeters) of the {object_category} that is commonly visible in these images.",
63 | "Estimate the {dimension} (in millimeters) of the {object_category} itself commonly visible in these images.",
64 | "Measure the {dimension} (in millimeters) of the {object_category} that is commonly visible in these images.",
65 | "Could you tell me the {dimension} (in millimeters) of the {object_category} which is commonly visible in these images?",
66 | "Please compute the {dimension} (in millimeters) of the {object_category} itself commonly visible in these images.",
67 | "What is the approximate {dimension} (in millimeters) of the {object_category} that is commonly visible in these images?",
68 | "Give the {dimension} (in millimeters) of the {object_category} which is commonly visible in these images.",
69 | "What is the {dimension} (in millimeters) of the {object_category} that is commonly visible across these images?",
70 | "Calculate the {dimension} (in millimeters) of the {object_category} itself commonly visible across these images.",
71 | "Determine the {dimension} (in millimeters) of the {object_category} which is commonly visible across these images.",
72 | "Find the {dimension} (in millimeters) of the {object_category} that is commonly visible across these images.",
73 | "Estimate the {dimension} (in millimeters) of the {object_category} itself commonly visible across these images.",
74 | "Measure the {dimension} (in millimeters) of the {object_category} that is commonly visible across these images.",
75 | "Could you tell me the {dimension} (in millimeters) of the {object_category} which is commonly visible across these images?",
76 | "Please compute the {dimension} (in millimeters) of the {object_category} itself commonly visible across these images.",
77 | "What is the approximate {dimension} (in millimeters) of the {object_category} that is commonly visible across these images?",
78 | "Give the {dimension} (in millimeters) of the {object_category} which is commonly visible across these images.",
79 | "What is the {dimension} (in millimeters) of the {object_category} itself that is commonly visible in these images?",
80 | "Calculate the {dimension} (in millimeters) of the {object_category} which is commonly visible in these images.",
81 | "Determine the {dimension} (in millimeters) of the {object_category} itself that is commonly visible in these images.",
82 | "Find the {dimension} (in millimeters) of the {object_category} which is commonly visible in these images.",
83 | "Estimate the {dimension} (in millimeters) of the {object_category} that is commonly visible in these images.",
84 | "Measure the {dimension} (in millimeters) of the {object_category} itself that is commonly visible in these images.",
85 | "Could you tell me the {dimension} (in millimeters) of the {object_category} that is commonly visible in these images?",
86 | "Please compute the {dimension} (in millimeters) of the {object_category} which is commonly visible in these images.",
87 | "What is the approximate {dimension} (in millimeters) of the {object_category} itself that is commonly visible in these images?",
88 | "Give the {dimension} (in millimeters) of the {object_category} that is commonly visible in these images."
89 | ]
90 |
91 | ANSWER_TEMPLATES = [
92 | "The {dimension} is approximately `{value_mm}` millimeters.",
93 | "It measures about `{value_mm}` millimeters in {dimension}.",
94 | "I estimate the {dimension} to be around `{value_mm}` millimeters.",
95 | "The {object_category}'s {dimension} is roughly `{value_mm}` millimeters.",
96 | "Based on the images, the {dimension} is near `{value_mm}` millimeters.",
97 | "It appears that the {dimension} is `{value_mm}` millimeters.",
98 | "From my estimation, the {dimension} is `{value_mm}` millimeters.",
99 | "The {dimension} seems to be around `{value_mm}` millimeters.",
100 | "Approximately, the {dimension} is `{value_mm}` millimeters.",
101 | "I would say the {dimension} is `{value_mm}` millimeters.",
102 | "The {dimension} is estimated to be `{value_mm}` millimeters.",
103 | "In my view, the {dimension} is about `{value_mm}` millimeters.",
104 | "The {dimension} is likely around `{value_mm}` millimeters.",
105 | "Judging by the images, the {dimension} is approximately `{value_mm}` millimeters.",
106 | "The {dimension} is calculated to be `{value_mm}` millimeters.",
107 | "It looks like the {dimension} is `{value_mm}` millimeters.",
108 | "The {dimension} is assessed to be `{value_mm}` millimeters.",
109 | "The {dimension} is gauged at `{value_mm}` millimeters.",
110 | "The {dimension} is reckoned to be `{value_mm}` millimeters.",
111 | "The {dimension} is figured to be `{value_mm}` millimeters.",
112 | "The {dimension} is computed to be `{value_mm}` millimeters.",
113 | "The {dimension} is deduced to be `{value_mm}` millimeters.",
114 | "The {dimension} is inferred to be `{value_mm}` millimeters.",
115 | "The {dimension} is surmised to be `{value_mm}` millimeters.",
116 | "The {dimension} is supposed to be `{value_mm}` millimeters.",
117 | "The {dimension} is thought to be `{value_mm}` millimeters.",
118 | "The {dimension} is understood to be `{value_mm}` millimeters.",
119 | "The {dimension} is viewed as `{value_mm}` millimeters.",
120 | "The {dimension} is approximated to be `{value_mm}` millimeters based on the data.",
121 | "After analyzing the images, the {dimension} is concluded to be `{value_mm}` millimeters."
122 | ]
123 |
124 | def convert_train_sample_to_eval_sample(train_sample):
125 | conversation = train_sample.pop("conversations")
126 | train_sample["text"] = conversation[0]["value"]
127 | return train_sample
128 |
129 | def build_lwh_qa_samples(scene_info_handler, dimension_info_path, dimension_name, split, output_dir, max_k=6, max_samples=-1):
130 | """
131 | Construct QA samples based on merged info files, and write to different jsonl files according to combination size K.
132 |
133 | Example file structure (info file, dict saved with pickle):
134 | {
135 | scene_id: {
136 | object_id: {
137 | "1": [ [img1], [img3], ... ],
138 | "2": [ [img2, img5], ... ],
139 | ...
140 | },
141 | ...
142 | },
143 | ...
144 | }
145 |
146 | For each combination sample, construct a training sample:
147 | {
148 | "id": "___",
149 | "image": [ "scene_id/img1.jpg", "scene_id/img2.jpg", ... ],
150 | "conversations": [
151 | {"from": "human", "value": "\n\n"},
152 | {"from": "gpt", "value": ""}
153 | ],
154 | "height_list": [scene_info_handler.image_height, ...], // repeated len(combo) times
155 | "width_list": [scene_info_handler.image_width, ...],
156 | "question_type": "{dimension}_estimation",
157 | "gt_value":
158 | }
159 | """
160 | print(f"Processing dimension: {dimension_name}, split: {split}")
161 | with open(dimension_info_path, "rb") as f:
162 | dim_info = pickle.load(f)
163 | os.makedirs(output_dir, exist_ok=True)
164 |
165 | samples_by_k = {k: [] for k in range(1, max_k+1)}
166 |
167 | for scene_id, obj_dict in tqdm(dim_info.items(), desc=f"Processing {dimension_name} info"):
168 | for object_id, k_dict in obj_dict.items():
169 | if dimension_name == "height":
170 | val_m = scene_info_handler.get_object_height(scene_id, object_id)
171 | elif dimension_name == "length":
172 | val_m = scene_info_handler.get_object_length(scene_id, object_id)
173 | elif dimension_name == "width":
174 | val_m = scene_info_handler.get_object_width(scene_id, object_id)
175 | else:
176 | val_m = 0.0
177 | val_mm = int(round(val_m * 1000))
178 | object_category = scene_info_handler.get_object_raw_category(scene_id, object_id)
179 | for k_str, combos in k_dict.items():
180 | try:
181 | k_val = int(k_str)
182 | except:
183 | continue
184 | if k_val < 1 or k_val > max_k:
185 | continue
186 | for combo_idx, combo in enumerate(combos):
187 | if not combo:
188 | continue
189 | combo = list(combo)
190 | random.shuffle(combo)
191 | prefix_lines = [f"Image-{i}: " for i in range(1, len(combo)+1)]
192 | prefix = "\n".join(prefix_lines)
193 | task_line = random.choice(TASK_DESCRIPTION)
194 | q_template = random.choice(QUESTION_TEMPLATES)
195 | question = q_template.format(dimension=dimension_name, object_category=object_category)
196 | full_question = f"{prefix}\n{task_line}\n{question}"
197 | a_template = random.choice(ANSWER_TEMPLATES)
198 | answer = a_template.format(dimension=dimension_name, value_mm=val_mm, object_category=object_category)
199 | conversation = [
200 | {"from": "human", "value": full_question},
201 | {"from": "gpt", "value": answer}
202 | ]
203 | sample = {
204 | "id": f"{scene_id}_{object_id}_{k_val}_{combo_idx}",
205 | "image": [f"{scene_id}/{img}.jpg" for img in combo],
206 | "conversations": conversation,
207 | "height_list": [scene_info_handler.image_height] * len(combo),
208 | "width_list": [scene_info_handler.image_width] * len(combo),
209 | "question_type": f"object_perception_{dimension_name}_estimation",
210 | "gt_value": val_mm
211 | }
212 | samples_by_k[k_val].append(sample)
213 |
214 | for k in range(1, max_k+1):
215 | # only consider those have at least one sample
216 | if len(samples_by_k[k]) == 0:
217 | continue
218 | if max_samples > 0 and len(samples_by_k[k]) > max_samples:
219 | samples_by_k[k] = random.sample(samples_by_k[k], max_samples)
220 | fname = f"object_perception_{dimension_name}_k{k}_{split}_{max_samples}.jsonl"
221 | fpath = os.path.join(output_dir, fname)
222 | with open(fpath, "w", encoding="utf-8") as f:
223 | for sample in samples_by_k[k]:
224 | f.write(json.dumps(sample) + "\n")
225 | print(f"Written K={k} {len(samples_by_k[k])} samples to {fpath}")
226 |
227 | print(f"Finished building QA samples for {dimension_name}.")
228 |
229 | def build_train_and_val_datasets():
230 | scene_info_path = "data/scannet/scannet_instance_data/scenes_train_val_info_i_D5.pkl"
231 |
232 | train_height_info = "training_data/object_perception/merged_train_object_coverage_height.pkl"
233 | train_length_info = "training_data/object_perception/merged_train_object_coverage_length.pkl"
234 | train_width_info = "training_data/object_perception/merged_train_object_coverage_width.pkl"
235 |
236 | val_height_info = "evaluation_data/object_perception/merged_val_object_coverage_height.pkl"
237 | val_length_info = "evaluation_data/object_perception/merged_val_object_coverage_length.pkl"
238 | val_width_info = "evaluation_data/object_perception/merged_val_object_coverage_width.pkl"
239 |
240 | train_output_dir = "training_data/object_perception"
241 | val_output_dir = "evaluation_data/object_perception"
242 | os.makedirs(train_output_dir, exist_ok=True)
243 | os.makedirs(val_output_dir, exist_ok=True)
244 |
245 | scene_info_handler = SceneInfoHandler(scene_info_path)
246 |
247 | print("\nBuilding TRAIN samples ...")
248 | build_lwh_qa_samples(scene_info_handler, train_height_info, "height", "train", train_output_dir, max_k=6, max_samples=max_train_samples) # will generate K=1, 281335 samples, K=2 457409 samples
249 | build_lwh_qa_samples(scene_info_handler, train_length_info, "length", "train", train_output_dir, max_k=6, max_samples=max_train_samples) # will generate K=1, 274175 samples, K=2 667299 samples
250 | build_lwh_qa_samples(scene_info_handler, train_width_info, "width", "train", train_output_dir, max_k=6, max_samples=max_train_samples) # will generate K=1 256017 samples, K=2 425229 samples
251 |
252 | temp_val_dir = os.path.join(val_output_dir, "temp")
253 | os.makedirs(temp_val_dir, exist_ok=True)
254 | print("\nBuilding VAL samples (train format) ...")
255 | build_lwh_qa_samples(scene_info_handler, val_height_info, "height", "val", temp_val_dir, max_k=6, max_samples=val_max_samples)
256 | build_lwh_qa_samples(scene_info_handler, val_length_info, "length", "val", temp_val_dir, max_k=6, max_samples=val_max_samples)
257 | build_lwh_qa_samples(scene_info_handler, val_width_info, "width", "val", temp_val_dir, max_k=6, max_samples=val_max_samples)
258 | for fname in os.listdir(temp_val_dir):
259 | temp_path = os.path.join(temp_val_dir, fname)
260 | output_path = os.path.join(val_output_dir, fname)
261 | with open(temp_path, "r", encoding="utf-8") as fin, open(output_path, "w", encoding="utf-8") as fout:
262 | for line in fin:
263 | sample = json.loads(line)
264 | sample = convert_train_sample_to_eval_sample(sample)
265 | fout.write(json.dumps(sample) + "\n")
266 |
267 | import shutil; shutil.rmtree(temp_val_dir)
268 | print("Finished building both TRAIN and VAL datasets.")
269 |
270 | def main():
271 | build_train_and_val_datasets()
272 |
273 | if __name__ == "__main__":
274 | main()
--------------------------------------------------------------------------------
/spatial_engine/utils/scannet_utils/README.md:
--------------------------------------------------------------------------------
1 | #### Notes
2 | 1. This doc assume the current directory is the root directory of the project (Multi-SpatialMLLM).
3 | 2. Create a directory `data/scannet` to store the downloaded data.
4 | ```bash
5 | cd Multi-SpatialMLLM
6 | mkdir -p data/scannet
7 | ```
8 |
9 | #### Download ScanNet data
10 | 1. Follow the official website of [ScanNet](https://github.com/ScanNet/ScanNet) to get the download file `scannet_download.py`.
11 | 2. Specify the data types we need to download:
12 | ```python
13 | FILETYPES = [
14 | ".sens",
15 | ".aggregation.json",
16 | "_vh_clean.ply",
17 | "_vh_clean_2.0.010000.segs.json",
18 | "_vh_clean_2.ply",
19 | "_vh_clean.segs.json",
20 | "_vh_clean.aggregation.json",
21 | "_vh_clean_2.labels.ply",
22 | "_2d-instance.zip",
23 | "_2d-instance-filt.zip",
24 | "_2d-label.zip",
25 | "_2d-label-filt.zip",
26 | ]
27 | ```
28 | 3. Run the script to download the data:
29 | ```bash
30 | python spatial_engine/utils/scannet_utils/scannet_download.py -o data/scannet
31 | ```
32 | 4. After downloading, you should have a folder structure like:
33 | ```
34 | Multi-SpatialMLLM
35 | ├── data
36 | │ └── scannet
37 | │ ├── scans
38 | │ │ ├── scene0000_00
39 | │ │ ├── ...
40 | │ └── ...
41 | ```
42 |
43 | #### Process ScanNet data
44 | 1. Load point clouds, object point clouds, bounding boxes, etc.
45 | ```bash
46 | python spatial_engine/utils/scannet_utils/batch_load_scannet_data.py
47 | ```
48 | This will generate a file called `scenes_train_info.pkl`, which contains train and val splits. For convenience, we can split this info file to train and val (after update with image info if needed).
49 |
50 | ```python
51 | import mmengine
52 | ori_train_info = mmengine.load("data/scannet/scannet_instance_data/scenes_train_val_info.pkl")
53 |
54 | # * load train scene id, read a txt, one line with one id, mmengine does not support txt format
55 | train_scene_ids = mmengine.list_from_file("data/scannet/meta_data/scannetv2_train.txt")
56 | train_scene_ids.sort()
57 | val_scene_ids = mmengine.list_from_file("data/scannet/meta_data/scannetv2_val.txt")
58 | val_scene_ids.sort()
59 |
60 | # * ori_train_info is a dict, key is scene_id
61 | # * according to the above train/val ids, split the info file to two separate files and save it.
62 | # * remember to check train + val = ori_train_info.keys()
63 | train_info = {k: ori_train_info[k] for k in train_scene_ids}
64 | val_info = {k: ori_train_info[k] for k in val_scene_ids}
65 |
66 | # * check
67 | assert len(train_info) + len(val_info) == len(ori_train_info)
68 |
69 | # * save
70 | mmengine.dump(train_info, "data/scannet/scannet_instance_data/scenes_train_info.pkl")
71 | mmengine.dump(val_info, "data/scannet/scannet_instance_data/scenes_val_info.pkl")
72 | ```
73 |
74 | 2. Load posed images.
75 |
76 | By default, we extract every image, if you want to export every nth image, you can set `frame_skip=n` like below.
77 | ```bash
78 | python spatial_engine/utils/scannet_utils/extract_posed_images.py --frame_skip 1
79 | ```
80 | Note that no matter what `frame_skip` is, the extracted images will still be saved with the ID like `00000.jpg`, `00001.jpg`, etc. It also exports depth images, use imageio to read them and should divide 1000 to get the real depth value in meters.
81 |
82 | 3. Update info_file with posed images information, this costs about 40 mins for if extracting all images. In this project, we only use every 5th image, so you need to set `frame_skip=5` in the script to skip images. Note that the you still need to set `frame_skip=1` in the `extract_posed_images.py` to make sure the image ids are consistent with our setting.
83 | ```bash
84 | python spatial_engine/utils/scannet_utils/update_info_file_with_images.py
85 | ```
86 |
87 | The images_info contains information like:
88 | ```python
89 | # Update the image data dictionary with this image's information
90 | image_data[image_id] = {
91 | "image_path": image_path,
92 | "depth_image_path": depth_image_path,
93 | "extrinsic_matrix": extrinsic_matrix
94 | }
95 |
96 | # Update the scene_info dictionary for the current scene_id
97 | scene_info[scene_id].update({
98 | "num_posed_images": num_posed_images,
99 | "images": image_data,
100 | "intrinsic_matrix": intrinsic_matrix
101 | })
102 | ```
103 |
104 | #### Final Results
105 | After processing, you should have a folder structure like:
106 | ```
107 | Multi-SpatialMLLM
108 | ├── data
109 | │ └── scannet
110 | │ ├── meta_data
111 | │ ├── scans
112 | │ │ ├── scene0000_00
113 | │ │ ├── ...
114 | │ ├── posed_images
115 | │ │ ├── scene0000_00
116 | │ │ ├── ...
117 | │ └── scannet_instance_data
118 | │ ├── scene0000_00
119 | │ ├── ...
120 | │ ├── scenes_train_info.pkl
121 | │ └── scenes_val_info.pkl
122 | ```
123 | The `scannet_instance_data` contains the point clouds of the whole scene and each instance both with axis-aligned and not axis-aligned world coordinate.
124 |
125 |
--------------------------------------------------------------------------------
/spatial_engine/utils/scannet_utils/batch_load_scannet_data.py:
--------------------------------------------------------------------------------
1 | # Modified from
2 | # https://github.com/facebookresearch/votenet/blob/master/scannet/batch_load_scannet_data.py
3 | # Copyright (c) Meta Platforms, Inc. and affiliates.
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """Batch mode in loading Scannet scenes with vertices and ground truth labels
8 | for semantic and instance segmentations.
9 |
10 | Usage example: python ./batch_load_scannet_data.py
11 | """
12 |
13 | import argparse
14 | import os
15 | import pickle
16 | from multiprocessing import Pool
17 | from os import path as osp
18 |
19 | import numpy as np
20 | import scannet_utils
21 |
22 | DONOTCARE_CLASS_IDS = np.array([])
23 | # OBJ_CLASS_IDS = np.array(
24 | # [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39])
25 | # OBJ_CLASS_IDS = np.array([])
26 |
27 |
28 | def export(mesh_file, agg_file, seg_file, meta_file, label_map_file, test_mode=False):
29 | """Export original files to vert, ins_label, sem_label and bbox file.
30 |
31 | Args:
32 | mesh_file (str): Path of the mesh_file.
33 | agg_file (str): Path of the agg_file.
34 | seg_file (str): Path of the seg_file.
35 | meta_file (str): Path of the meta_file.
36 | label_map_file (str): Path of the label_map_file.
37 | test_mode (bool): Whether is generating test data without labels.
38 | Default: False.
39 |
40 | It returns a tuple, which contains the the following things:
41 | np.ndarray: Vertices of points data.
42 | np.ndarray: Indexes of label.
43 | np.ndarray: Indexes of instance.
44 | np.ndarray: Instance bboxes.
45 | dict: Map from object_id to label_id.
46 | """
47 |
48 | label_map = scannet_utils.read_label_mapping(
49 | label_map_file, label_from="raw_category", label_to="nyu40id"
50 | )
51 | mesh_vertices = scannet_utils.read_mesh_vertices_rgb(mesh_file)
52 |
53 | # Load scene axis alignment matrix
54 | lines = open(meta_file).readlines()
55 | # test set data doesn't have align_matrix
56 | axis_align_matrix = np.eye(4)
57 | for line in lines:
58 | if "axisAlignment" in line:
59 | axis_align_matrix = [
60 | float(x) for x in line.rstrip().strip("axisAlignment = ").split(" ")
61 | ]
62 | break
63 | axis_align_matrix = np.array(axis_align_matrix).reshape((4, 4))
64 |
65 | # perform global alignment of mesh vertices
66 | pts = np.ones((mesh_vertices.shape[0], 4))
67 | pts[:, 0:3] = mesh_vertices[:, 0:3]
68 | pts = np.dot(pts, axis_align_matrix.transpose()) # Nx4
69 | aligned_mesh_vertices = np.concatenate([pts[:, 0:3], mesh_vertices[:, 3:]], axis=1)
70 |
71 | # Load semantic and instance labels
72 | if not test_mode:
73 | object_id_to_segs, label_to_segs = scannet_utils.read_aggregation(
74 | agg_file
75 | ) # * return dicts with id(int) or label(str) to lists of seg ids, object ids are 1-indexed
76 | seg_to_verts, num_verts = scannet_utils.read_segmentation(seg_file)
77 | label_ids = np.zeros(shape=(num_verts), dtype=np.uint32)
78 | raw_categories = np.array([None] * num_verts) # Array to store raw categories
79 |
80 | object_id_to_label_id = {}
81 | object_id_to_raw_category = {}
82 | for raw_category, segs in label_to_segs.items():
83 | label_id = label_map[raw_category]
84 | for seg in segs:
85 | verts = seg_to_verts[seg]
86 | label_ids[verts] = label_id
87 | raw_categories[verts] = raw_category # Assign raw category
88 |
89 | instance_ids = np.zeros(shape=(num_verts), dtype=np.uint32) # 0: unannotated
90 | for object_id, segs in object_id_to_segs.items():
91 | for seg in segs:
92 | verts = seg_to_verts[seg]
93 | instance_ids[verts] = object_id
94 | if object_id not in object_id_to_label_id:
95 | object_id_to_label_id[object_id] = label_ids[verts][
96 | 0
97 | ] # * obj_id: int
98 | if object_id not in object_id_to_raw_category:
99 | object_id_to_raw_category[object_id] = raw_categories[verts][
100 | 0
101 | ] # * obj_id: str, note, the obj_id is 1-indexed
102 | unaligned_bboxes, unaligned_obj_point_clouds = scannet_utils.extract_bbox(
103 | mesh_vertices, object_id_to_segs, object_id_to_label_id, instance_ids
104 | )
105 | aligned_bboxes, aligned_obj_point_clouds = scannet_utils.extract_bbox(
106 | aligned_mesh_vertices,
107 | object_id_to_segs,
108 | object_id_to_label_id,
109 | instance_ids,
110 | )
111 | else:
112 | label_ids = None
113 | raw_categories = None
114 | instance_ids = None
115 | unaligned_bboxes = None
116 | aligned_bboxes = None
117 | object_id_to_label_id = None
118 | aligned_obj_point_clouds = None
119 | unaligned_obj_point_clouds = None
120 | object_id_to_raw_category = None
121 |
122 | return (
123 | mesh_vertices,
124 | aligned_mesh_vertices,
125 | label_ids,
126 | raw_categories,
127 | instance_ids,
128 | unaligned_bboxes,
129 | aligned_bboxes,
130 | unaligned_obj_point_clouds,
131 | aligned_obj_point_clouds,
132 | object_id_to_raw_category,
133 | object_id_to_label_id,
134 | axis_align_matrix,
135 | )
136 |
137 |
138 | def export_one_scan(
139 | scan_name,
140 | output_filename_prefix,
141 | max_num_point,
142 | label_map_file,
143 | scannet_dir,
144 | test_mode=False,
145 | ):
146 | if not osp.exists(output_filename_prefix):
147 | os.makedirs(output_filename_prefix)
148 |
149 | mesh_file = osp.join(scannet_dir, scan_name, scan_name + "_vh_clean_2.ply")
150 | agg_file = osp.join(scannet_dir, scan_name, scan_name + ".aggregation.json")
151 | seg_file = osp.join(
152 | scannet_dir, scan_name, scan_name + "_vh_clean_2.0.010000.segs.json"
153 | )
154 | # includes axisAlignment info for the train set scans.
155 | meta_file = osp.join(scannet_dir, scan_name, f"{scan_name}.txt")
156 | (
157 | mesh_vertices,
158 | aligned_mesh_vertices,
159 | semantic_labels,
160 | raw_categories,
161 | instance_labels,
162 | unaligned_bboxes,
163 | aligned_bboxes,
164 | unaligned_obj_point_clouds,
165 | aligned_obj_point_clouds,
166 | object_id_to_raw_category,
167 | object_id_to_label_id,
168 | axis_align_matrix,
169 | ) = export(mesh_file, agg_file, seg_file, meta_file, label_map_file, test_mode)
170 |
171 | if not test_mode:
172 | # mask = np.logical_not(np.in1d(semantic_labels, DONOTCARE_CLASS_IDS))
173 | # mesh_vertices = mesh_vertices[mask, :]
174 | # semantic_labels = semantic_labels[mask]
175 | # instance_labels = instance_labels[mask]
176 | # raw_categories = raw_categories[mask]
177 |
178 | num_instances = len(np.unique(instance_labels))
179 | print(f"Num of instances: {num_instances - 1}")
180 |
181 | # bbox_mask = np.in1d(unaligned_bboxes[:, -1], OBJ_CLASS_IDS) # * keep all instances
182 | # unaligned_bboxes = unaligned_bboxes[bbox_mask, :]
183 | # bbox_mask = np.in1d(aligned_bboxes[:, -1], OBJ_CLASS_IDS)
184 | # aligned_bboxes = aligned_bboxes[bbox_mask, :]
185 | assert unaligned_bboxes.shape[0] == aligned_bboxes.shape[0]
186 | print(f"Num of care instances: {unaligned_bboxes.shape[0]}")
187 |
188 | if max_num_point is not None:
189 | max_num_point = int(max_num_point)
190 | N = mesh_vertices.shape[0]
191 | if N > max_num_point:
192 | choices = np.random.choice(N, max_num_point, replace=False)
193 | mesh_vertices = mesh_vertices[choices, :]
194 | if not test_mode:
195 | semantic_labels = semantic_labels[choices]
196 | instance_labels = instance_labels[choices]
197 | raw_categories = raw_categories[choices]
198 |
199 | # Save points, semantic_labels, instance_labels as .npy files
200 | np.save(f"{output_filename_prefix}/unaligned_points.npy", mesh_vertices)
201 | np.save(f"{output_filename_prefix}/aligned_points.npy", aligned_mesh_vertices)
202 | scene_info = {} # Dictionary to hold scene information
203 |
204 | if not test_mode:
205 | np.save(f"{output_filename_prefix}/semantic_mask.npy", semantic_labels)
206 | np.save(f"{output_filename_prefix}/instance_mask.npy", instance_labels)
207 | np.save(f"{output_filename_prefix}/raw_category_mask.npy", raw_categories)
208 |
209 | # * assert these four npy have the same length
210 | assert (
211 | len(semantic_labels)
212 | == len(instance_labels)
213 | == len(raw_categories)
214 | == len(mesh_vertices)
215 | ), "Lengths of semantic_labels, instance_labels, raw_categories, and mesh_vertices are not equal."
216 |
217 | # Save bounding boxes and raw category names in a dict
218 | for obj_id, (aligned_bbox, unaligned_bbox) in enumerate(
219 | zip(aligned_bboxes, unaligned_bboxes)
220 | ):
221 | raw_category_name = object_id_to_raw_category.get(
222 | obj_id + 1, "None"
223 | ) # * object_id_to_raw_category is 1 indexed
224 | if raw_category_name == "None":
225 | print(
226 | f"Something wrong for the raw category name of object {obj_id} in scan {scan_name}."
227 | )
228 | exit(0)
229 | scene_info[obj_id] = {
230 | "aligned_bbox": aligned_bbox,
231 | "unaligned_bbox": unaligned_bbox,
232 | "raw_category": raw_category_name,
233 | }
234 |
235 | # * save aligned and unaligned points
236 | # * first check if the two types of points have the same shape
237 |
238 | np.save(
239 | f"{output_filename_prefix}/object_{obj_id}_aligned_points.npy",
240 | aligned_obj_point_clouds[obj_id],
241 | )
242 | np.save(
243 | f"{output_filename_prefix}/object_{obj_id}_unaligned_points.npy",
244 | unaligned_obj_point_clouds[obj_id],
245 | )
246 |
247 | scene_info["axis_align_matrix"] = axis_align_matrix
248 | # * store the object number
249 | scene_info["num_objects"] = len(aligned_bboxes)
250 |
251 | return {scan_name: scene_info}
252 |
253 |
254 | def worker(args):
255 | (
256 | scan_name,
257 | output_filename_prefix,
258 | max_num_point,
259 | label_map_file,
260 | scannet_dir,
261 | test_mode,
262 | ) = args
263 | print("-" * 20 + f"begin for {scan_name}.")
264 | return export_one_scan(
265 | scan_name,
266 | output_filename_prefix,
267 | max_num_point,
268 | label_map_file,
269 | scannet_dir,
270 | test_mode,
271 | )
272 |
273 |
274 | def batch_export(
275 | max_num_point,
276 | output_folder,
277 | scan_names_file,
278 | label_map_file,
279 | scannet_dir,
280 | test_mode=False,
281 | num_workers=20,
282 | ):
283 | if test_mode and not os.path.exists(scannet_dir):
284 | return
285 | if not os.path.exists(output_folder):
286 | os.makedirs(output_folder)
287 |
288 | scan_names = [line.rstrip() for line in open(scan_names_file)]
289 | # * sort scan_names
290 | scan_names.sort()
291 | args = [
292 | (
293 | scan_name,
294 | osp.join(output_folder, scan_name),
295 | max_num_point,
296 | label_map_file,
297 | scannet_dir,
298 | test_mode,
299 | )
300 | for scan_name in scan_names
301 | ]
302 |
303 | all_scene_info = {}
304 | with Pool(num_workers) as p:
305 | results = p.map(worker, args)
306 | for result in results:
307 | all_scene_info.update(result)
308 |
309 | # Save the combined scene information
310 | if test_mode:
311 | file_name = "scenes_test_info.pkl"
312 | else:
313 | file_name = "scenes_train_val_info.pkl"
314 | with open(osp.join(output_folder, file_name), "wb") as f:
315 | pickle.dump(all_scene_info, f)
316 |
317 |
318 | def main():
319 | parser = argparse.ArgumentParser()
320 | parser.add_argument(
321 | "--max_num_point", default=None, help="The maximum number of the points."
322 | )
323 | parser.add_argument(
324 | "--output_folder",
325 | default="data/scannet/scannet_instance_data",
326 | help="output folder of the result.",
327 | )
328 | parser.add_argument(
329 | "--train_scannet_dir", default="scans", help="scannet data directory."
330 | )
331 | parser.add_argument(
332 | "--test_scannet_dir", default="scans_test", help="scannet data directory."
333 | )
334 | parser.add_argument(
335 | "--label_map_file",
336 | default="data/scannet/meta_data/scannetv2-labels.combined.tsv",
337 | help="The path of label map file.",
338 | )
339 | parser.add_argument(
340 | "--train_scan_names_file",
341 | default="data/scannet/meta_data/scannet_train.txt",
342 | help="The path of the file that stores the scan names.",
343 | )
344 | parser.add_argument(
345 | "--test_scan_names_file",
346 | default="data/scannet/meta_data/scannetv2_test.txt",
347 | help="The path of the file that stores the scan names.",
348 | )
349 | args = parser.parse_args()
350 | batch_export(
351 | args.max_num_point,
352 | args.output_folder,
353 | args.train_scan_names_file,
354 | args.label_map_file,
355 | args.train_scannet_dir,
356 | test_mode=False,
357 | )
358 | # * change output folder for test
359 | args.output_folder = args.output_folder.replace("scannet", "scannet_test")
360 | batch_export(
361 | args.max_num_point,
362 | args.output_folder,
363 | args.test_scan_names_file,
364 | args.label_map_file,
365 | args.test_scannet_dir,
366 | test_mode=True,
367 | )
368 |
369 |
370 | if __name__ == "__main__":
371 | main()
372 |
--------------------------------------------------------------------------------
/spatial_engine/utils/scannet_utils/extract_posed_images.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import struct
9 | import time
10 | import zlib
11 | from argparse import ArgumentParser
12 | from functools import partial
13 |
14 | import imageio.v2 as imageio # * to surpress warning
15 | import mmengine
16 | import numpy as np
17 |
18 | COMPRESSION_TYPE_COLOR = {-1: "unknown", 0: "raw", 1: "png", 2: "jpeg"}
19 |
20 | COMPRESSION_TYPE_DEPTH = {
21 | -1: "unknown",
22 | 0: "raw_ushort",
23 | 1: "zlib_ushort",
24 | 2: "occi_ushort",
25 | }
26 |
27 |
28 | class RGBDFrame:
29 | """Class for single ScanNet RGB-D image processing."""
30 |
31 | def load(self, file_handle):
32 | self.camera_to_world = np.asarray(
33 | struct.unpack("f" * 16, file_handle.read(16 * 4)), dtype=np.float32
34 | ).reshape(4, 4)
35 | self.timestamp_color = struct.unpack("Q", file_handle.read(8))[0]
36 | self.timestamp_depth = struct.unpack("Q", file_handle.read(8))[0]
37 | self.color_size_bytes = struct.unpack("Q", file_handle.read(8))[0]
38 | self.depth_size_bytes = struct.unpack("Q", file_handle.read(8))[0]
39 | self.color_data = b"".join(
40 | struct.unpack(
41 | "c" * self.color_size_bytes, file_handle.read(self.color_size_bytes)
42 | )
43 | )
44 | self.depth_data = b"".join(
45 | struct.unpack(
46 | "c" * self.depth_size_bytes, file_handle.read(self.depth_size_bytes)
47 | )
48 | )
49 |
50 | def decompress_depth(self, compression_type):
51 | assert compression_type == "zlib_ushort"
52 | return zlib.decompress(self.depth_data)
53 |
54 | def decompress_color(self, compression_type):
55 | assert compression_type == "jpeg"
56 | return imageio.imread(self.color_data)
57 |
58 |
59 | class SensorData:
60 | """Class for single ScanNet scene processing.
61 |
62 | Single scene file contains multiple RGB-D images.
63 | """
64 |
65 | def __init__(self, filename, frame_skip):
66 | self.version = 4
67 | self.load(filename, frame_skip)
68 |
69 | def load(self, filename, frame_skip):
70 | with open(filename, "rb") as f:
71 | version = struct.unpack("I", f.read(4))[0]
72 | assert self.version == version
73 | strlen = struct.unpack("Q", f.read(8))[0]
74 | self.sensor_name = b"".join(struct.unpack("c" * strlen, f.read(strlen)))
75 | self.intrinsic_color = np.asarray(
76 | struct.unpack("f" * 16, f.read(16 * 4)), dtype=np.float32
77 | ).reshape(4, 4)
78 | self.extrinsic_color = np.asarray(
79 | struct.unpack("f" * 16, f.read(16 * 4)), dtype=np.float32
80 | ).reshape(4, 4)
81 | self.intrinsic_depth = np.asarray(
82 | struct.unpack("f" * 16, f.read(16 * 4)), dtype=np.float32
83 | ).reshape(4, 4)
84 | self.extrinsic_depth = np.asarray(
85 | struct.unpack("f" * 16, f.read(16 * 4)), dtype=np.float32
86 | ).reshape(4, 4)
87 | self.color_compression_type = COMPRESSION_TYPE_COLOR[
88 | struct.unpack("i", f.read(4))[0]
89 | ]
90 | self.depth_compression_type = COMPRESSION_TYPE_DEPTH[
91 | struct.unpack("i", f.read(4))[0]
92 | ]
93 | self.color_width = struct.unpack("I", f.read(4))[0]
94 | self.color_height = struct.unpack("I", f.read(4))[0]
95 | self.depth_width = struct.unpack("I", f.read(4))[0]
96 | self.depth_height = struct.unpack("I", f.read(4))[0]
97 | self.depth_shift = struct.unpack("f", f.read(4))[0]
98 | num_frames = struct.unpack("Q", f.read(8))[0]
99 | print(f"Number of total frames: {num_frames}")
100 | self.frames = []
101 |
102 | # * use frame_skip to get index
103 | index = list(range(0, num_frames, frame_skip))
104 | for i in range(num_frames):
105 | frame = RGBDFrame()
106 | frame.load(f) # should iterate to get the next frame
107 | if i in index:
108 | self.frames.append(frame)
109 |
110 | assert len(index) == len(self.frames), "Number of frames mismatch."
111 | print(f"Exported {len(index)} frames. Frame skip is {frame_skip}.")
112 |
113 | def export_depth_images(self, output_path):
114 | if not os.path.exists(output_path):
115 | os.makedirs(output_path)
116 | for f in range(len(self.frames)):
117 | depth_data = self.frames[f].decompress_depth(self.depth_compression_type)
118 | depth = np.fromstring(depth_data, dtype=np.uint16).reshape(
119 | self.depth_height, self.depth_width
120 | )
121 | imageio.imwrite(
122 | os.path.join(output_path, self.index_to_str(f) + ".png"), depth
123 | )
124 |
125 | def export_color_images(self, output_path):
126 | if not os.path.exists(output_path):
127 | os.makedirs(output_path)
128 | for f in range(len(self.frames)):
129 | color = self.frames[f].decompress_color(self.color_compression_type)
130 | imageio.imwrite(
131 | os.path.join(output_path, self.index_to_str(f) + ".jpg"), color
132 | )
133 |
134 | @staticmethod
135 | def index_to_str(index):
136 | return str(index).zfill(5)
137 |
138 | @staticmethod
139 | def save_mat_to_file(matrix, filename):
140 | with open(filename, "w") as f:
141 | for line in matrix:
142 | np.savetxt(f, line[np.newaxis], fmt="%f")
143 |
144 | def export_poses(self, output_path):
145 | if not os.path.exists(output_path):
146 | os.makedirs(output_path)
147 | for f in range(len(self.frames)):
148 | self.save_mat_to_file(
149 | self.frames[f].camera_to_world,
150 | os.path.join(output_path, self.index_to_str(f) + ".txt"),
151 | )
152 |
153 | def export_intrinsics(self, output_path):
154 | if not os.path.exists(output_path):
155 | os.makedirs(output_path)
156 | self.save_mat_to_file(
157 | self.intrinsic_color, os.path.join(output_path, "intrinsic.txt")
158 | )
159 |
160 |
161 | def process_scene(path, frame_skip, idx):
162 | """Process single ScanNet scene.
163 |
164 | Extract RGB images, poses and camera intrinsics.
165 | """
166 | print(f"Processing {idx}.")
167 | t1 = time.time()
168 | output_path = os.path.join("posed_images", idx)
169 | if mmengine.exists(output_path):
170 | print(f"{output_path} already exists. Skip.")
171 | return
172 | data = SensorData(os.path.join(path, idx, f"{idx}.sens"), frame_skip)
173 | data.export_color_images(output_path)
174 | data.export_intrinsics(output_path)
175 | data.export_poses(output_path)
176 | data.export_depth_images(output_path)
177 | print(f"Finish processing {idx}. Using {time.time() - t1}s.")
178 |
179 |
180 | def process_directory(path, frame_skip, nproc):
181 | print(f"processing {path}")
182 | scan_ids = os.listdir(path)
183 | # debug
184 | mmengine.track_parallel_progress(
185 | func=partial(process_scene, path, frame_skip),
186 | tasks=scan_ids,
187 | nproc=nproc,
188 | )
189 |
190 |
191 | if __name__ == "__main__":
192 | parser = ArgumentParser()
193 | parser.add_argument(
194 | "--frame_skip", type=int, default=1, help="export every nth frame"
195 | ) # * use this or --max-images-per-scene
196 | parser.add_argument("--nproc", type=int, default=20)
197 | args = parser.parse_args()
198 |
199 | # process train and val scenes
200 | if os.path.exists("scans"):
201 | process_directory(
202 | "scans", args.frame_skip, args.nproc
203 | )
204 | # process test scenes
205 | if os.path.exists("scans_test"):
206 | process_directory(
207 | "scans_test", args.frame_skip, args.nproc
208 | )
209 |
--------------------------------------------------------------------------------
/spatial_engine/utils/scannet_utils/handler/ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import List
8 |
9 | import cv2
10 | import numpy as np
11 | import open3d as o3d
12 |
13 |
14 | def filter_images_laplacian(image_paths, threshold=100):
15 | res = []
16 | for i in range(len(image_paths)):
17 | if calculate_image_sharpness(image_paths[i]) >= threshold:
18 | res.append(image_paths[i])
19 | return res
20 |
21 |
22 | def calculate_image_sharpness(image_path):
23 | """
24 | calculate the sharpness of image
25 | """
26 | image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
27 |
28 | if image is None:
29 | raise ValueError(f"Unable to read image from: {image_path}")
30 |
31 | # calculate laplacion
32 | laplacian = cv2.Laplacian(image, cv2.CV_64F)
33 | sharpness_score = laplacian.var()
34 | return sharpness_score
35 |
36 |
37 | def convert_to_corners(bounding_boxes: List[np.ndarray]) -> List[np.ndarray]:
38 | """
39 | Convert bounding boxes to eight corners
40 | Args:
41 | bounding_boxes: List of bounding boxes with format [cx, cy, cz, dx, dy, dz, ...]
42 | Returns:
43 | List of eight corners for each bounding box with format [[x1, y1, z1], [x2, y2, z2], ...]
44 | """
45 | corners = []
46 | for bbox in bounding_boxes:
47 | corners.append(
48 | np.array(
49 | [
50 | [
51 | bbox[0] - bbox[3] / 2,
52 | bbox[1] - bbox[4] / 2,
53 | bbox[2] - bbox[5] / 2,
54 | ],
55 | [
56 | bbox[0] + bbox[3] / 2,
57 | bbox[1] - bbox[4] / 2,
58 | bbox[2] - bbox[5] / 2,
59 | ],
60 | [
61 | bbox[0] - bbox[3] / 2,
62 | bbox[1] + bbox[4] / 2,
63 | bbox[2] - bbox[5] / 2,
64 | ],
65 | [
66 | bbox[0] + bbox[3] / 2,
67 | bbox[1] + bbox[4] / 2,
68 | bbox[2] - bbox[5] / 2,
69 | ],
70 | [
71 | bbox[0] - bbox[3] / 2,
72 | bbox[1] - bbox[4] / 2,
73 | bbox[2] + bbox[5] / 2,
74 | ],
75 | [
76 | bbox[0] + bbox[3] / 2,
77 | bbox[1] - bbox[4] / 2,
78 | bbox[2] + bbox[5] / 2,
79 | ],
80 | [
81 | bbox[0] - bbox[3] / 2,
82 | bbox[1] + bbox[4] / 2,
83 | bbox[2] + bbox[5] / 2,
84 | ],
85 | [
86 | bbox[0] + bbox[3] / 2,
87 | bbox[1] + bbox[4] / 2,
88 | bbox[2] + bbox[5] / 2,
89 | ],
90 | ],
91 | dtype=np.float32,
92 | )
93 | )
94 | return corners
95 |
96 |
97 | def calculate_iou_2d(mask1, mask2):
98 | """
99 | Calculate 2D Intersection over Union (IoU) of two masks, both of which are H * W binary numpy arrays.
100 | """
101 | # Calculate the intersection of the two masks
102 | intersection = np.logical_and(mask1, mask2)
103 |
104 | # Calculate the union of the two masks
105 | union = np.logical_or(mask1, mask2)
106 |
107 | # Calculate the IoU
108 | # handle zero
109 | iou = np.sum(intersection) / np.sum(union) if np.sum(union) != 0 else 0.0
110 |
111 | return iou
112 |
113 |
114 | def calculate_iou_3d(box1, box2):
115 | """
116 | Calculate 3D Intersection over Union (IoU) of two 3D boxes, both of which are N * 6
117 | Boxes are defined by numpy arrays with [x, y, z, dx, dy, dz],
118 | where (x, y, z) is the center of the box, and (dx, dy, dz) are the size of the box along each axis.
119 | """
120 | # Calculate the coordinates of the intersections points
121 | inter_min = np.maximum(box1[:3] - box1[3:] / 2, box2[:3] - box2[3:] / 2)
122 | inter_max = np.minimum(box1[:3] + box1[3:] / 2, box2[:3] + box2[3:] / 2)
123 |
124 | # Calculate intersection volume
125 | inter_dim = inter_max - inter_min
126 | inter_volume = np.prod(inter_dim) if np.all(inter_dim > 0) else 0
127 |
128 | # Calculate the volume of each box
129 | box1_volume = np.prod(box1[3:])
130 | box2_volume = np.prod(box2[3:])
131 |
132 | # Calculate IoU
133 | iou = inter_volume / (box1_volume + box2_volume - inter_volume)
134 |
135 | return iou
136 |
137 |
138 | def remove_statistical_outliers(point_cloud_data, nb_neighbors=20, std_ratio=1.0):
139 | """
140 | Removes statistical outliers from a point cloud data and retains all original dimensions.
141 |
142 | Args:
143 | point_cloud_data (numpy.ndarray): The input point cloud data as a NxC numpy array where C >= 3. [N, C]
144 | nb_neighbors (int): Number of nearest neighbors to consider for calculating average distance.
145 | std_ratio (float): Standard deviation ratio; points beyond this many standard deviations are considered outliers.
146 |
147 | Returns:
148 | numpy.ndarray: Filtered point cloud data with outliers removed, including all original dimensions. [n, C]
149 | """
150 | # Convert numpy array to Open3D point cloud for XYZ coordinates
151 | pcd = o3d.geometry.PointCloud()
152 | pcd.points = o3d.utility.Vector3dVector(point_cloud_data[:, :3])
153 |
154 | # Perform statistical outlier removal
155 | clean_pcd, ind = pcd.remove_statistical_outlier(nb_neighbors, std_ratio)
156 |
157 | # Use indices to filter the original full-dimension data
158 | inlier_data = point_cloud_data[ind, :]
159 |
160 | return inlier_data
161 |
162 |
163 | def remove_truncated_outliers(point_cloud_data, tx: float, ty: float, tz: float):
164 | """
165 | Removes statistical outliers from a point cloud data and retains all original dimensions.
166 |
167 | Args:
168 | point_cloud_data (numpy.ndarray): The input point cloud data as a NxC numpy array where C >= 3. [N, C]
169 | tx: Ratio of points to remove from the beginning and end of the sorted x values
170 | ty: Ratio of points to remove from the beginning and end of the sorted y values
171 | tz: Ratio of points to remove from the beginning and end of the sorted z values
172 |
173 | Returns:
174 | numpy.ndarray: Filtered point cloud data with outliers removed, including all original dimensions. [n, C]
175 | """
176 |
177 | # assert tx, ty, tz all < 0.5
178 | assert tx < 0.5 and ty < 0.5 and tz < 0.5, "tx, ty, tz must be less than 0.5."
179 |
180 | n_points = len(point_cloud_data)
181 | # Calculate the number of points to remove based on the given percentages
182 | if tx == 0 and ty == 0 and tz == 0:
183 | return point_cloud_data
184 |
185 | nx = int(tx * n_points)
186 | ny = int(ty * n_points)
187 | nz = int(tz * n_points)
188 |
189 | # Process x-axis
190 | x_sorted_indices = np.argsort(point_cloud_data[:, 0])
191 | valid_x_indices = x_sorted_indices[nx:-nx] if 2 * nx < n_points else np.array([])
192 |
193 | # Process y-axis
194 | y_sorted_indices = np.argsort(point_cloud_data[:, 1])
195 | valid_y_indices = y_sorted_indices[ny:-ny] if 2 * ny < n_points else np.array([])
196 |
197 | # Process z-axis
198 | z_sorted_indices = np.argsort(point_cloud_data[:, 2])
199 | valid_z_indices = z_sorted_indices[nz:-nz] if 2 * nz < n_points else np.array([])
200 |
201 | # Find the intersection of valid indices across all axes
202 | valid_indices = np.intersect1d(valid_x_indices, valid_y_indices)
203 | valid_indices = np.intersect1d(valid_indices, valid_z_indices)
204 |
205 | # Filter the original full-dimension data
206 | inlier_data = point_cloud_data[valid_indices]
207 |
208 | return inlier_data
209 |
210 |
211 | def calculate_aabb(point_cloud_data):
212 | """
213 | Calculates the axis-aligned bounding box (AABB) of a point cloud.
214 |
215 | Args:
216 | point_cloud_data (numpy.ndarray): The input point cloud data as a NxC numpy array where C >= 3. [N, C]
217 |
218 | Returns:
219 | tuple: Contains the center of the AABB (numpy.ndarray) and the dimensions of the AABB (numpy.ndarray). [x, y, z, dx, dy, dz]
220 | """
221 | # Calculate the min and max along each column (x, y, z)
222 | min_corner = np.min(point_cloud_data[:, :3], axis=0)
223 | max_corner = np.max(point_cloud_data[:, :3], axis=0)
224 |
225 | # Calculate center and dimensions
226 | center = (max_corner + min_corner) / 2
227 | dimensions = max_corner - min_corner
228 |
229 | # Combine center and dimensions into a single array
230 | result = np.concatenate([center, dimensions])
231 |
232 | return result
233 |
234 |
235 | def project_mask_to_3d(
236 | depth_image,
237 | intrinsic_matrix,
238 | extrinsic_matrix,
239 | mask=None,
240 | world_to_axis_align_matrix=None,
241 | color_image=None,
242 | ):
243 | """
244 | Projects a mask to 3D space using the provided depth map and camera parameters.
245 | Optionally appends RGB values from a color image to the 3D points. (RGB order with 0-255 range)
246 |
247 | Parameters:
248 | - depth_image (str or ndarray): Path to the depth image or a numpy array of depth values. h, w
249 | - intrinsic_matrix (ndarray): The camera's intrinsic matrix. 4 * 4
250 | - extrinsic_matrix (ndarray): The camera's extrinsic matrix. 4 * 4
251 | - mask (ndarray): A binary mask (zero, non-zero array) where True values indicate pixels to project, which has the same shape with color_image. H, W. Could be None, where all pixels are projected.
252 | - world_to_axis_align_matrix (ndarray, optional): Matrix to align the world coordinates. 4 * 4
253 | - color_image (str or ndarray, optional): Path to the color image or a numpy array of color values. H, W, 3
254 |
255 | Returns:
256 | - ndarray: Array of 3D coordinates, optionally with RGB values appended. All False mask will give `array([], shape=(0, C), dtype=float64)`
257 | """
258 |
259 | # Load depth image from path if it's a string
260 | if isinstance(depth_image, str):
261 | depth_image = cv2.imread(depth_image, -1)
262 |
263 | # Load color image from path if it's a string
264 | if isinstance(color_image, str):
265 | color_image = cv2.imread(color_image)
266 | color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
267 |
268 | if mask is None:
269 | mask = np.ones(color_image.shape[:2], dtype=bool)
270 |
271 | # Calculate scaling factors
272 | scale_y = depth_image.shape[0] / mask.shape[0]
273 | scale_x = depth_image.shape[1] / mask.shape[1]
274 |
275 | # Get coordinates of True values in mask
276 | mask_indices = np.where(mask)
277 | mask_y = mask_indices[0]
278 | mask_x = mask_indices[1]
279 |
280 | # Scale coordinates to match the depth image size
281 | # depth_y = (mask_y * scale_y).astype(int)
282 | # depth_x = (mask_x * scale_x).astype(int)
283 |
284 | # # scale and use round
285 | depth_y = np.round(mask_y * scale_y).astype(int)
286 | depth_x = np.round(mask_x * scale_x).astype(int)
287 |
288 | # Clip scaled coordinates to ensure they are within the image boundary
289 | depth_y = np.clip(depth_y, 0, depth_image.shape[0] - 1)
290 | depth_x = np.clip(depth_x, 0, depth_image.shape[1] - 1)
291 |
292 | # Extract depth values
293 | depth_values = (
294 | depth_image[depth_y, depth_x] * 0.001
295 | ) # Assume depth is in millimeters
296 |
297 | # Filter out zero depth values
298 | valid = depth_values > 0
299 | depth_values = depth_values[valid]
300 | mask_x = mask_x[valid]
301 | mask_y = mask_y[valid]
302 |
303 | # Construct normalized pixel coordinates
304 | normalized_pixels = np.vstack(
305 | (
306 | mask_x * depth_values,
307 | mask_y * depth_values,
308 | depth_values,
309 | np.ones_like(depth_values),
310 | )
311 | )
312 |
313 | # Compute points in camera coordinate system
314 | cam_coords = np.dot(np.linalg.inv(intrinsic_matrix), normalized_pixels)
315 |
316 | # Transform to world coordinates
317 | world_coords = np.dot(extrinsic_matrix, cam_coords)
318 |
319 | # Apply world-to-axis alignment if provided
320 | if world_to_axis_align_matrix is not None:
321 | world_coords = np.dot(world_to_axis_align_matrix, world_coords)
322 |
323 | # Append color information if color image is provided
324 | if color_image is not None:
325 | # Scale mask coordinates for the color image
326 | rgb_values = color_image[mask_y, mask_x]
327 | return np.hstack((world_coords[:3].T, rgb_values))
328 |
329 | return world_coords[:3].T
330 |
--------------------------------------------------------------------------------
/spatial_engine/utils/scannet_utils/make_visibility_info.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import os
9 | import mmengine
10 | from multiprocessing import Pool
11 | from tqdm import tqdm
12 | from mmengine.utils.dl_utils import TimeCounter
13 | from spatial_engine.utils.scannet_utils.handler.info_handler import SceneInfoHandler
14 |
15 |
16 | """
17 | {
18 | scene_id: {
19 | "image_to_points": {
20 | image_id: [point_index, ...],
21 | ...
22 | },
23 | "point_to_images": {
24 | point_index: [image_id, ...],
25 | ...
26 | }
27 | },
28 | ...
29 | }
30 |
31 | After generating, will convert to parquet file.
32 | Note, at the begnining, the file is saved as pkl, so I use convert_pkl_to_parquet to convert.
33 |
34 | """
35 |
36 | DEBUG = False
37 |
38 | def convert_pkl_to_parquet(pkl_file):
39 | """
40 | Converts a previously saved PKL file to a Parquet file.
41 |
42 | Args:
43 | pkl_file (str): Path to the input PKL file.
44 | """
45 | # Load the PKL file
46 | import json
47 | import pickle
48 | import pandas as pd
49 | with open(pkl_file, 'rb') as f:
50 | scene_visibility_dict = pickle.load(f)
51 | print(f"Loaded pkl file from {pkl_file}.")
52 |
53 | parquet_file = pkl_file.replace(".pkl", ".parquet")
54 | # Convert to a format suitable for Parquet
55 | data = []
56 | for scene_id, visibility_info in scene_visibility_dict.items():
57 | # Process image_to_points
58 | for image_id, points in visibility_info["image_to_points"].items():
59 | key = f"{scene_id}:image_to_points:{image_id}"
60 | data.append((key, json.dumps(points))) # Convert list to JSON string
61 |
62 | # Process point_to_images
63 | for point_idx, images in visibility_info["point_to_images"].items():
64 | key = f"{scene_id}:point_to_images:{point_idx}"
65 | data.append((key, json.dumps(images))) # Convert list to JSON string
66 |
67 | # Create a DataFrame
68 | df = pd.DataFrame(data, columns=["key", "values"])
69 |
70 | # Save as a Parquet file
71 | df.to_parquet(parquet_file, index=False)
72 |
73 | print(f"Converted {pkl_file} to {parquet_file}. The file has {len(data)} items in total.")
74 |
75 | def process_scene(scene_id, scene_infos, warning_file):
76 | """
77 | For one scene:
78 | 1) Gather which points each image sees.
79 | 2) Invert that mapping to know which images see a given point.
80 | 3) Return scene_id and the final dict.
81 | """
82 | print(f"[process_scene] Start: {scene_id}")
83 | image_ids = scene_infos.get_all_extrinsic_valid_image_ids(scene_id)
84 |
85 | # Get all points in the scene (use only first 3 coords)
86 | scene_points = scene_infos.get_scene_points_align(scene_id)[:, :3]
87 | num_points = scene_points.shape[0]
88 |
89 | image_to_points = {}
90 | # Initially keep track of points -> images in a list of sets (index = point_id, value = set of image_ids)
91 | point_to_images_sets = [set() for _ in range(num_points)]
92 |
93 | for image_id in image_ids:
94 | E = scene_infos.get_extrinsic_matrix_align(scene_id, image_id)
95 | scene_points_2d, scene_points_depth = scene_infos.project_3d_point_to_image(
96 | scene_id, image_id, scene_points
97 | )
98 | in_bounds_mask = scene_infos.check_point_visibility(
99 | scene_id, image_id, scene_points_2d, scene_points_depth
100 | )
101 |
102 | # Record which points are visible for this image
103 | visible_point_indices = np.where(in_bounds_mask)[0]
104 | image_to_points[image_id] = visible_point_indices.tolist()
105 |
106 | # Also record the inverse mapping
107 | for idx in visible_point_indices:
108 | point_to_images_sets[idx].add(image_id)
109 |
110 | # Optional: warning if no visible points
111 | if len(visible_point_indices) == 0:
112 | with open(warning_file, 'a') as f:
113 | f.write(f"[Warning] {scene_id}: {image_id} has no in-bound points.\n")
114 |
115 | # Convert from list of sets -> dict
116 | point_to_images = {
117 | idx: sorted(list(img_set)) for idx, img_set in enumerate(point_to_images_sets) # * if one point is not observed in any images, the value is an empty list
118 | }
119 |
120 | result_dict = {
121 | "image_to_points": image_to_points,
122 | "point_to_images": point_to_images
123 | }
124 | print(f"[process_scene] Done: {scene_id}")
125 | return scene_id, result_dict
126 |
127 | @TimeCounter()
128 | def run_split(scene_info_path, output_file, warning_file, num_workers=8):
129 | """
130 | 1. Loads the SceneInfoHandler for the given split.
131 | 2. Processes each scene in parallel to get visibility info.
132 | 3. Accumulates them into a top-level dict -> { scene_id: {...} }
133 | 4. Saves everything into a pickle (or any other format) for easy reload.
134 | """
135 | scene_infos = SceneInfoHandler(scene_info_path)
136 | all_scene_ids = scene_infos.get_all_scene_ids()
137 | mmengine.mkdir_or_exist(os.path.dirname(output_file))
138 |
139 | if DEBUG and len(all_scene_ids) > 1:
140 | all_scene_ids = all_scene_ids[:1]
141 | print("[run_split] DEBUG mode. Only processing first scene.")
142 |
143 | print(f"[run_split] Found {len(all_scene_ids)} scenes in {scene_info_path}")
144 | print(f"[run_split] Output will be saved to {output_file}")
145 |
146 | scene_visibility_dict = {}
147 |
148 | # Prepare pool args
149 | args = [(scene_id, scene_infos, warning_file) for scene_id in all_scene_ids]
150 |
151 | with Pool(num_workers) as pool:
152 | results = [pool.apply_async(process_scene, arg) for arg in args]
153 |
154 | for r in tqdm(results, desc=f"Processing scenes in {scene_info_path}"):
155 | scene_id, visibility_info = r.get()
156 | scene_visibility_dict[scene_id] = visibility_info
157 |
158 | # Convert scene_visibility_dict to a DataFrame
159 | data = []
160 | for scene_id, visibility_info in scene_visibility_dict.items():
161 | # Process image_to_points
162 | for image_id, points in visibility_info["image_to_points"].items():
163 | key = f"{scene_id},image_to_points,{image_id}"
164 | data.append((key, points))
165 |
166 | # Process point_to_images
167 | for point_idx, images in visibility_info["point_to_images"].items():
168 | key = f"{scene_id},point_to_images,{point_idx}"
169 | data.append((key, images))
170 |
171 | # Create DataFrame
172 | df = pd.DataFrame(data, columns=["key", "values"])
173 |
174 | # Save to Parquet file
175 | df.to_parquet(output_file, index=False)
176 |
177 | print(f"[run_split] Done. Wrote {len(df)} entries to {output_file}")
178 |
179 | def main():
180 | # Adjust these as needed
181 | train_info_path = "data/scannet/scannet_instance_data/scenes_train_info_i_D5.pkl"
182 | val_info_path = "data/scannet/scannet_instance_data/scenes_val_info_i_D5.pkl"
183 |
184 | train_output_dir = "data/scannet/scannet_instance_data"
185 | val_output_dir = "data/scannet/scannet_instance_data"
186 | mmengine.mkdir_or_exist(train_output_dir)
187 | mmengine.mkdir_or_exist(val_output_dir)
188 |
189 | # Warnings
190 | train_warning_file = os.path.join(train_output_dir, "make_visibility_train_warning.txt")
191 | val_warning_file = os.path.join(val_output_dir, "make_visibility_val_warning.txt")
192 |
193 | # Output pickle files
194 | train_output_file = os.path.join(train_output_dir, "train_visibility_info_D5.parquet")
195 | val_output_file = os.path.join(val_output_dir, "val_visibility_info_D5.parquet")
196 |
197 | # If DEBUG, tweak output to avoid overwriting real data
198 | global DEBUG
199 | if DEBUG:
200 | train_warning_file = train_warning_file.replace(".txt", "_debug.txt")
201 | val_warning_file = val_warning_file.replace(".txt", "_debug.txt")
202 | train_output_file = train_output_file.replace(".parquet", "_debug.parquet")
203 | val_output_file = val_output_file.replace(".parquet", "_debug.parquet")
204 |
205 | # Number of processes to run in parallel
206 | num_workers = 25
207 |
208 | print("[main] DEBUG =", DEBUG)
209 |
210 | print(f"[main] Generating val visibility -> {val_output_file}")
211 | run_split(val_info_path, val_output_file, val_warning_file, num_workers=num_workers) # * costs 47 mins, the info file is 6G as pkl and 3.7G as parquet
212 |
213 | print(f"[main] Generating train visibility -> {train_output_file}")
214 | run_split(train_info_path, train_output_file, train_warning_file, num_workers=num_workers) # * costs 3 hours, the info file is 21G as pkl and 13G as parquet
215 |
216 | if __name__ == "__main__":
217 | main()
--------------------------------------------------------------------------------
/spatial_engine/utils/scannet_utils/scannet_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from
2 | # https://github.com/facebookresearch/votenet/blob/master/scannet/scannet_utils.py
3 | # Copyright (c) Meta Platforms, Inc. and affiliates.
4 | # All rights reserved.
5 |
6 | # This source code is licensed under the license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | import csv
10 | import json
11 | import os
12 |
13 | import numpy as np
14 | from plyfile import PlyData
15 |
16 |
17 | def read_aggregation(filename):
18 | assert os.path.isfile(filename)
19 | object_id_to_segs = {}
20 | label_to_segs = {}
21 | with open(filename) as f:
22 | data = json.load(f)
23 | num_objects = len(data["segGroups"])
24 | for i in range(num_objects):
25 | object_id = (
26 | data["segGroups"][i]["objectId"] + 1
27 | ) # instance ids should be 1-indexed
28 | label = data["segGroups"][i]["label"]
29 | segs = data["segGroups"][i]["segments"]
30 | object_id_to_segs[object_id] = segs # * segs are is list
31 | if label in label_to_segs:
32 | label_to_segs[label].extend(segs) # * for semantic segmentation
33 | else:
34 | label_to_segs[label] = segs
35 | return object_id_to_segs, label_to_segs
36 |
37 |
38 | def read_segmentation(filename):
39 | assert os.path.isfile(filename)
40 | seg_to_verts = {}
41 | with open(filename) as f:
42 | data = json.load(f)
43 | num_verts = len(data["segIndices"])
44 | for i in range(num_verts):
45 | seg_id = data["segIndices"][i]
46 | if seg_id in seg_to_verts:
47 | seg_to_verts[seg_id].append(i)
48 | else:
49 | seg_to_verts[seg_id] = [i]
50 | return seg_to_verts, num_verts
51 |
52 |
53 | def extract_bbox(mesh_vertices, object_id_to_segs, object_id_to_label_id, instance_ids):
54 | """
55 | Extracts bounding boxes and point clouds for each instance.
56 |
57 | Parameters:
58 | mesh_vertices (numpy.ndarray): The mesh vertices.
59 | object_id_to_segs (dict): Mapping of object IDs to segments.
60 | object_id_to_label_id (dict): Mapping of object IDs to label IDs.
61 | instance_ids (numpy.ndarray): Array of instance IDs.
62 |
63 | Returns:
64 | numpy.ndarray: An array containing bounding boxes for each instance.
65 | The ID (1-index) of each bounding box is its index in the returned array plus 1.
66 | list of numpy.ndarray: A list containing point clouds for each instance,
67 | corresponding to the bounding boxes. 'None' for instances without points.
68 | """
69 | num_instances = len(np.unique(list(object_id_to_segs.keys())))
70 | instance_bboxes = np.zeros((num_instances, 7))
71 | instance_pcs = [None] * num_instances # Initialize list with None
72 |
73 | for obj_id in object_id_to_segs:
74 | label_id = object_id_to_label_id[obj_id]
75 | obj_pc = mesh_vertices[instance_ids == obj_id, 0:3]
76 | obj_pc_rgb = mesh_vertices[instance_ids == obj_id, :]
77 | if len(obj_pc) == 0:
78 | print(
79 | f"WARNING: object id {obj_id} does not have points. Corresponding entry is set to None."
80 | )
81 | continue
82 |
83 | xyz_min = np.min(obj_pc, axis=0)
84 | xyz_max = np.max(obj_pc, axis=0)
85 | bbox = np.concatenate(
86 | [(xyz_min + xyz_max) / 2.0, xyz_max - xyz_min, np.array([label_id])]
87 | )
88 |
89 | instance_bboxes[obj_id - 1, :] = bbox
90 | instance_pcs[obj_id - 1] = (
91 | obj_pc_rgb # Store the point cloud at the appropriate index
92 | )
93 |
94 | return instance_bboxes, instance_pcs
95 |
96 |
97 | def represents_int(s):
98 | """Judge whether string s represents an int.
99 |
100 | Args:
101 | s(str): The input string to be judged.
102 |
103 | Returns:
104 | bool: Whether s represents int or not.
105 | """
106 | try:
107 | int(s)
108 | return True
109 | except ValueError:
110 | return False
111 |
112 |
113 | def read_label_mapping(filename, label_from="raw_category", label_to="nyu40id"):
114 | assert os.path.isfile(filename)
115 | mapping = dict()
116 | with open(filename) as csvfile:
117 | reader = csv.DictReader(csvfile, delimiter="\t")
118 | for row in reader:
119 | mapping[row[label_from]] = int(row[label_to])
120 | if represents_int(list(mapping.keys())[0]):
121 | mapping = {int(k): v for k, v in mapping.items()}
122 | return mapping
123 |
124 |
125 | def read_mesh_vertices(filename):
126 | """Read XYZ for each vertex.
127 |
128 | Args:
129 | filename(str): The name of the mesh vertices file.
130 |
131 | Returns:
132 | ndarray: Vertices.
133 | """
134 | assert os.path.isfile(filename)
135 | with open(filename, "rb") as f:
136 | plydata = PlyData.read(f)
137 | num_verts = plydata["vertex"].count
138 | vertices = np.zeros(shape=[num_verts, 3], dtype=np.float32)
139 | vertices[:, 0] = plydata["vertex"].data["x"]
140 | vertices[:, 1] = plydata["vertex"].data["y"]
141 | vertices[:, 2] = plydata["vertex"].data["z"]
142 | return vertices
143 |
144 |
145 | def read_mesh_vertices_rgb(filename):
146 | """Read XYZ and RGB for each vertex.
147 |
148 | Args:
149 | filename(str): The name of the mesh vertices file.
150 |
151 | Returns:
152 | Vertices. Note that RGB values are in 0-255.
153 | """
154 | assert os.path.isfile(filename)
155 | with open(filename, "rb") as f:
156 | plydata = PlyData.read(f)
157 | num_verts = plydata["vertex"].count
158 | vertices = np.zeros(shape=[num_verts, 6], dtype=np.float32)
159 | vertices[:, 0] = plydata["vertex"].data["x"]
160 | vertices[:, 1] = plydata["vertex"].data["y"]
161 | vertices[:, 2] = plydata["vertex"].data["z"]
162 | vertices[:, 3] = plydata["vertex"].data["red"]
163 | vertices[:, 4] = plydata["vertex"].data["green"]
164 | vertices[:, 5] = plydata["vertex"].data["blue"]
165 | return vertices
166 |
--------------------------------------------------------------------------------
/spatial_engine/utils/scannet_utils/update_info_file_with_images.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 |
9 | import mmengine
10 | import numpy as np
11 | from tqdm import tqdm
12 |
13 | # Base directory where the scene_id folders are located
14 | base_dir = "data/scannet/posed_images"
15 | scene_infos_file = "data/scannet/scannet_instance_data/scenes_train_val_info.pkl"
16 | frame_skip = 5 # store one image from every frame_skip images
17 | scene_infos = mmengine.load(scene_infos_file)
18 |
19 | # Iterate through each scene_id in the scene_info dict
20 | for scene_id in tqdm(scene_infos.keys()):
21 | # Construct the path to the current scene_id folder
22 | scene_path = os.path.join(base_dir, scene_id)
23 |
24 | # Initialize the number of posed images to 0
25 | num_posed_images = 0
26 |
27 | # Initialize a dictionary to hold image data
28 | image_data = {}
29 |
30 | # Read the intrinsic matrix
31 | intrinsic_path = os.path.join(scene_path, "intrinsic.txt")
32 | with open(intrinsic_path, "r") as f:
33 | intrinsic_matrix = np.array(
34 | [list(map(float, line.split())) for line in f.readlines()]
35 | )
36 |
37 | # Iterate through each file in the scene_id directory
38 | all_files = os.listdir(scene_path)
39 | all_jpg_files = [f for f in all_files if f.endswith(".jpg")]
40 | all_jpg_files.sort()
41 | for i, filename in enumerate(all_jpg_files):
42 | if i % frame_skip == 0: # Only process every frame_skip-th image
43 | # Extract the image_id from the filename (e.g., "00000.jpg" -> "00000")
44 | image_id = filename.split(".")[0]
45 | # Construct paths to the image, depth image, and extrinsic matrix file
46 | image_path = f"posed_images/{scene_id}/{filename}"
47 | depth_image_path = f"posed_images/{scene_id}/{image_id}.png"
48 | extrinsic_path = os.path.join(scene_path, f"{image_id}.txt")
49 | # Read the extrinsic matrix from the file
50 | with open(extrinsic_path, "r") as f:
51 | extrinsic_matrix = np.array(
52 | [list(map(float, line.split())) for line in f.readlines()]
53 | )
54 | # Update the image data dictionary with this image's information
55 | image_data[image_id] = {
56 | "image_path": image_path,
57 | "depth_image_path": depth_image_path,
58 | "extrinsic_matrix": extrinsic_matrix,
59 | }
60 | # Increment the count of posed images
61 | num_posed_images += 1
62 |
63 | # Update the scene_info dictionary for the current scene_id
64 | scene_infos[scene_id].update(
65 | {
66 | "num_posed_images": num_posed_images,
67 | "images_info": image_data,
68 | "intrinsic_matrix": intrinsic_matrix,
69 | }
70 | )
71 |
72 | mmengine.dump(scene_infos, scene_infos_file.replace(".pkl", f"_i_D{frame_skip}.pkl"))
73 |
--------------------------------------------------------------------------------