├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── README_TRAINING.md ├── cat_spec_training.py ├── data ├── _init_.py ├── distillation.py ├── generic_img_mask_loader.py └── warehouse3d.py ├── data_utils.py ├── demo.ipynb ├── demo_data ├── mask_0.png ├── mask_1.png ├── mask_2.png ├── mask_3.png ├── mask_4.png ├── rgb_0.png ├── rgb_1.png ├── rgb_2.png ├── rgb_3.png └── rgb_4.png ├── env.yaml ├── hydra_config ├── _init_.py └── config.py ├── model.py ├── synth_pretraining.py ├── unified_distillation.py └── volumetric_render.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.obj 2 | data/*__pycache__/* 3 | __pycache__/* 4 | */__pycache__/* 5 | *.DS_Store 6 | *.o 7 | *.a 8 | *.so 9 | *jobs.sh 10 | .ipynb_checkpoints 11 | *.pyc 12 | external 13 | cachedir 14 | outputs 15 | job_logs 16 | ipyNb 17 | *.brf 18 | *.log 19 | *.key 20 | *.out 21 | *.gz 22 | *.blg 23 | *.aux 24 | *.bbl 25 | *.fls 26 | *.pdf 27 | *.texnicle 28 | *.fdb_latexmk 29 | #*.ipynb 30 | volumetric_networks.egg-info 31 | job_outputs/* -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to ss3d 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to ss3d, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. 401 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | ======================================================================= 2 | Siren Model implementation's MIT license 3 | ======================================================================= 4 | Copyright (c) 2020 Vincent Sitzmann 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Pre-train, Self-train, Distill: A simple recipe for Supersizing 3D Reconstruction](https://shubhtuls.github.io/ss3d/) 2 | Kalyan Vasudev Alwala, Abhinav Gupta, Shubham Tulsiani 3 | 4 | [[Paper](https://arxiv.org/abs/2204.03642)] [[Project Page](https://shubhtuls.github.io/ss3d/)] 5 | 6 | 7 | 8 | ## Setup 9 | Download the final distilled model from [here](https://dl.fbaipublicfiles.com/ss3d/distilled_model.torch). 10 | 11 | Install the following pre-requisites: 12 | * Python >=3.6 13 | * PyTorch tested with `1.10.0` 14 | * TorchVision tested with `0.11.1` 15 | * Trimesh 16 | * pymcubes 17 | 18 | ## 3D Reconstruction Interface 19 | 20 | Reconstruct 3D in 3 simple simple steps! Please see the [demo notebook](demo.ipynb) for a working example. 21 | 22 | ```python 23 | 24 | # 1. Load the pre-trained checkpoint 25 | model_3d = VNet() 26 | model_3d.load_state_dict(torch.load("")) 27 | model_3d.eval() 28 | 29 | 30 | # 2. Preprocess an RGB image with associated object mask according to our model's input interface 31 | inp_img = generate_input_img( 32 | img_rgb, 33 | img_mask, 34 | ) 35 | 36 | # 3. Obtain 3D prediction! 37 | out_mesh = extract_trimesh(model_3d, inp_img, "cuda") 38 | # To save the mesh 39 | out_mesh.export("out_mesh_pymcubes.obj") 40 | # To visualize the mesh 41 | out_mesh.show() 42 | ``` 43 | 44 | ## Training and Evaluation 45 | Please to [README_TRAINING.md](https://github.com/facebookresearch/ss3d/blob/main/README_TRAINING.md) for more details. 46 | 47 | 48 | ## Citation 49 | If you find the project useful for your research, please consider citing:- 50 | ``` 51 | @inproceedings{vasudev2022ss3d, 52 | title={Pre-train, Self-train, Distill: A simple recipe for Supersizing 3D Reconstruction}, 53 | author={Vasudev, Kalyan Alwala and Gupta, Abhinav and Tulsiani, Shubham}, 54 | year={2022}, 55 | booktitle={Computer Vision and Pattern Recognition (CVPR)} 56 | } 57 | ``` 58 | 59 | ## Contributing 60 | We welcome your pull requests! Please see [CONTRIBUTING](CONTRIBUTING.md) and [CODE_OF_CONDUCT](CODE_OF_CONDUCT.md) for more information. 61 | 62 | ## License 63 | ss3d is released under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for additional details. However the Sire implementation is additionally licensed under the MIT license (see [NOTICE](NOTICE) for additional details). 64 | -------------------------------------------------------------------------------- /README_TRAINING.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | The basic installation includes dependencies like pytorch, pytorch3d, pytorch-lightning etc. 4 | ``` 5 | git clone https://github.com/facebookresearch/ss3d.git 6 | cd ss3d 7 | conda env create -f env.yaml 8 | conda activate ss3d 9 | ``` 10 | 11 | 12 | # Data Peperation 13 | 14 | For synthetic pre-training we rely from Warehouse3D models as specified in the Shpaenet3D-Core split. For generating pre-rendered images of this dataset, we recommend users to use [this](https://github.com/shubhtuls/snetRenderer) blender rendering tool for Shapenet. 15 | 16 | The generated images should be of the following structure - 17 | ``` 18 | synthetic_rendered_images_root 19 | |-- class_a # In shapenet case, it would be synset id's 20 | |-- ... 21 | |-- ... 22 | |-- class_n 23 | |-- render_0.png # rgb rendered images 24 | |-- ... 25 | |-- render_100.png 26 | |-- depth_0.png # depth images used to extract mask images 27 | |-- ... 28 | |-- depht_100.png 29 | |-- camera_0.mat # camera intrinsics and exterinsics matrices for the rendered images. 30 | |-- ... 31 | |-- camera_100.mat 32 | ``` 33 | 34 | For category-specific training, for any of the datasets you wish to work with, please generate .csv files for training and test phases 35 | respectively of the format, 36 | 37 | ``` 38 | class_name rgb_image_path mask_image_path Bounding_BOX_X_min, Bounding_BOX_Y_min, Bounding_BOX_X_max, Bounding_BOX_Y_max 39 | ``` 40 | 41 | For datasets which already have bounding box cropped images of single object instaces, the following format .csv files is acceptable too for 42 | training and test phases, 43 | 44 | ``` 45 | class_name rgb_image_path mask_image_path 46 | ``` 47 | 48 | # Training 49 | To understand the config overrides more we encourage users to go through the config file located at `hydra_config/config.py`. 50 | 51 | For synthetic pretraining, 52 | ``` 53 | python synth_pretraining.py resources.gpus=8 resources.num_nodes=4 resources.use_cluster=True \ 54 | logging.name=synthetic_pretraining optim.use_scheduler=True 55 | ``` 56 | 57 | For category-specific finetuning, 58 | ``` 59 | python cat_spec_training.py \ 60 | resources.use_cluster=True resources.gpus=8 resources.num_nodes=2 \ 61 | logging.name="" \ 62 | render.cam_num=10 render.num_pre_rend_masks=10 \ 63 | data=generic_img_mask data.bs_train=4 \ 64 | data.train_dataset_file=" 0, "Cam number should be > 0 for no camera" 67 | self.ray_num = self.cfg.render.cam_num * self.cfg.render.ray_num_per_cam 68 | self.shape_reg = None 69 | 70 | def _get_weights_path(self): 71 | # dummy to get veriosn only 72 | dummy_logger = TensorBoardLogger( 73 | osp.join(_base_path, self.cfg.logging.log_dir), name=self.cfg.logging.name 74 | ) 75 | 76 | return osp.join( 77 | _base_path, 78 | self.cfg.logging.log_dir, 79 | self.cfg.logging.name, 80 | "cameras_" + str(dummy_logger.version), 81 | ) 82 | 83 | def get_cameras(self, frame_list, device): 84 | # fame_list: list of stirngs or list of hashes 85 | cameras = [] 86 | for frame in frame_list: 87 | temp_path = os.path.splitext(frame[1:])[0] 88 | temp_path = os.path.join(self.weights_path, temp_path) 89 | temp_file = os.path.join(temp_path, "cameras.npy") 90 | 91 | if os.path.exists(temp_file): 92 | try: 93 | cam = np.load(temp_file, allow_pickle=True) 94 | cam = torch.from_numpy(cam) 95 | load_from_memory_flag = True 96 | except: 97 | print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 98 | print("Error loading cameras for frame {}".format(frame)) 99 | print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 100 | load_from_memory_flag = False 101 | else: 102 | load_from_memory_flag = False 103 | 104 | if not load_from_memory_flag: 105 | cam = [] 106 | azim_init_angles = 10.0 * np.arctanh( 107 | np.linspace(-0.1, 0.1, num=self.cfg.render.cam_num, endpoint=False) 108 | ) 109 | azim_init_angles = azim_init_angles.tolist() 110 | for i in range(self.cfg.render.cam_num): 111 | elev_param = 2 * torch.rand(1) - 1.0 # [-90,90] 112 | # azim_param = 2 * (2 * torch.rand(1) - 1.0) # [-180,180] 113 | azim_param = torch.tensor([azim_init_angles[i]]) 114 | dist = self.dist_range[0] + ( 115 | self.dist_range[1] - self.dist_range[0] 116 | ) * torch.rand(1) 117 | cam.append(torch.cat([elev_param, azim_param, dist], dim=0)) 118 | cam = torch.stack(cam, dim=0) 119 | cameras.append(cam) 120 | return torch.stack(cameras).to(device) 121 | 122 | def _get_cam_weights(self, frame_list, device): 123 | camera_weights = [] 124 | for frame in frame_list: 125 | temp_path = os.path.splitext(frame[1:])[0] 126 | temp_path = os.path.join(self.weights_path, temp_path) 127 | temp_file = os.path.join(temp_path, "cam_weights.npy") 128 | 129 | if os.path.exists(temp_file): 130 | try: 131 | cam_weight = np.load(temp_file, allow_pickle=True) 132 | cam_weight = torch.from_numpy(cam_weight) 133 | load_from_memory_flag = True 134 | except: 135 | print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 136 | print( 137 | "Error loading camer linear weights for frame {}".format(frame) 138 | ) 139 | print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 140 | load_from_memory_flag = False 141 | else: 142 | load_from_memory_flag = False 143 | 144 | if not load_from_memory_flag: 145 | cam_weight = torch.rand(self.cfg.render.cam_num) 146 | camera_weights.append(cam_weight) 147 | 148 | return torch.stack(camera_weights).to(device) 149 | 150 | def _update_cam_weights(self, frame_list, camera_weights): 151 | for i, frame in enumerate(frame_list): 152 | temp_path = os.path.splitext(frame[1:])[0] 153 | temp_path = os.path.join(self.weights_path, temp_path) 154 | if not os.path.isdir(temp_path): 155 | os.makedirs(temp_path, exist_ok=True) 156 | temp_file = os.path.join(temp_path, "cam_weights.npy") 157 | np.save(temp_file, camera_weights[i].detach().cpu().numpy()) 158 | 159 | def _update_cameras(self, frame_list, camera_params): 160 | for i, frame in enumerate(frame_list): 161 | temp_path = os.path.splitext(frame[1:])[0] 162 | temp_path = os.path.join(self.weights_path, temp_path) 163 | if not os.path.isdir(temp_path): 164 | os.makedirs(temp_path, exist_ok=True) 165 | temp_file = os.path.join(temp_path, "cameras.npy") 166 | np.save(temp_file, camera_params[i].detach().cpu().numpy()) 167 | 168 | def _compute_camera_loss( 169 | self, 170 | frame_list, 171 | rgb_imgs, 172 | mask_imgs, 173 | cameras, 174 | decoder, 175 | c_latent, 176 | device, 177 | shape_regularizer=None, 178 | ): 179 | 180 | mask_ray_labels, rgb_ray_labels, rays = get_rays_multiplex( 181 | cameras, rgb_imgs, mask_imgs, self.cfg.render, device 182 | ) 183 | # print("\n") 184 | # print(mask_ray_labels.shape, rgb_ray_labels.shape,rays.shape) 185 | # print(rgb_ray_labels.max(), rgb_ray_labels.min(), mask_ray_labels.max(), mask_ray_labels.min()) 186 | # print(rgb_imgs.shape, mask_imgs.shape) 187 | # print(rgb_imgs.max(), rgb_imgs.min(), mask_imgs.max(), mask_imgs.min()) 188 | # print("\n") 189 | 190 | ray_outs = render_rays( 191 | ray_batch=rays, # [N*num_rays, 8] 192 | c_latent=c_latent, # [N*N_num_rays, c_dim] 193 | decoder=decoder, # nn.Module 194 | N_samples=self.cfg.render.on_ray_num_samples, # int 195 | has_rgb=self.cfg.render.rgb, 196 | has_normal=self.cfg.render.normals, 197 | retraw=shape_regularizer is not None, 198 | ) 199 | mask_ray_outs = ray_outs["acc_map"].to(device) # [N*num_rays] 1-d 200 | 201 | loss_mask = torch.nn.functional.mse_loss( 202 | mask_ray_labels, mask_ray_outs, reduction="none" 203 | ) # [N*num_rays] 1-d 204 | loss_mask = loss_mask.reshape( 205 | len(frame_list), self.cfg.render.cam_num, self.cfg.render.ray_num_per_cam 206 | ) 207 | loss_mask = loss_mask.mean(-1) 208 | 209 | if self.cfg.render.rgb: 210 | rgb_ray_outs = ray_outs["rgb_map"].to(device) 211 | loss_rgb = torch.nn.functional.mse_loss( 212 | rgb_ray_labels, rgb_ray_outs, reduction="none" 213 | ) 214 | loss_rgb = loss_rgb.reshape( 215 | len(frame_list), 216 | self.cfg.render.cam_num, 217 | self.cfg.render.ray_num_per_cam, 218 | 3, 219 | ) 220 | loss_rgb = loss_rgb.mean((-1, -2)) 221 | 222 | if self.cfg.render.rgb: 223 | # loss = (loss_mask + loss_rgb) / 2.0 224 | loss = 0.2 * loss_mask + 0.8 * loss_rgb 225 | else: 226 | loss = loss_mask 227 | 228 | if shape_regularizer is not None: 229 | reg_outs = torch.sigmoid(shape_regularizer(ray_outs["points"])) 230 | reg_outs = reg_outs.squeeze(-1) 231 | instance_outs = torch.sigmoid(ray_outs["raw"][..., 3]) 232 | 233 | reg_loss = torch.nn.functional.mse_loss(reg_outs, instance_outs) 234 | return loss, reg_loss 235 | else: 236 | return loss 237 | 238 | def _softmin_loss( 239 | self, 240 | frame_list, 241 | rgb_imgs, 242 | mask_imgs, 243 | cameras, 244 | decoder, 245 | c_latent_from_encoder, 246 | device, 247 | shape_reg=None, 248 | ): 249 | # compute decoder loss 250 | cameras = cameras.detach() 251 | decoder.zero_grad() 252 | c_latent = ( 253 | c_latent_from_encoder.unsqueeze(1) 254 | .repeat(1, self.ray_num, 1) 255 | .view(-1, c_latent_from_encoder.shape[1]) 256 | ) # [N*N_num_rays, c_dim] 257 | 258 | loss = self._compute_camera_loss( 259 | frame_list, rgb_imgs, mask_imgs, cameras, decoder, c_latent, device 260 | ) 261 | 262 | loss_softmin = torch.nn.functional.softmin( 263 | loss * self.cfg.render.softmin_temp 264 | ).detach() # detatch?? # this will moving average!!! 265 | loss = torch.mul(loss_softmin, loss) 266 | return loss.mean() 267 | 268 | def _softmax_loss( 269 | self, 270 | frame_list, 271 | rgb_imgs, 272 | mask_imgs, 273 | cameras, 274 | decoder, 275 | c_latent_from_encoder, 276 | device, 277 | n_iter=10, # TODO: tune this! 278 | shape_reg=None, 279 | ): 280 | # get weights 281 | camera_weights = self._get_cam_weights(frame_list, device) 282 | camera_weights = torch.nn.Parameter(camera_weights) 283 | 284 | # create optimizer 285 | cam_weight_optimizer = torch.optim.Adam( 286 | [camera_weights], lr=self.optim_lr # TODO: change this! 287 | ) 288 | 289 | cameras = cameras.detach() 290 | c_latent = c_latent_from_encoder.detach() 291 | c_latent = ( 292 | c_latent.unsqueeze(1).repeat(1, self.ray_num, 1).view(-1, c_latent.shape[1]) 293 | ) # [N*N_num_rays, c_dim] 294 | 295 | for i in range(n_iter): 296 | cam_weight_optimizer.zero_grad() 297 | decoder.zero_grad() 298 | 299 | loss = self._compute_camera_loss( 300 | frame_list, rgb_imgs, mask_imgs, cameras, decoder, c_latent, device 301 | ) 302 | 303 | weights_softmax = torch.nn.functional.softmax( 304 | camera_weights 305 | ) # TODO: temperature? 306 | loss = torch.mul(weights_softmax, loss) 307 | loss = loss.mean() 308 | 309 | loss.backward() 310 | cam_weight_optimizer.step() 311 | cam_weight_optimizer.zero_grad() 312 | decoder.zero_grad() 313 | 314 | # update weights 315 | self._update_cam_weights(frame_list, camera_weights) 316 | 317 | # compute model loss! 318 | camera_weights = camera_weights.detach() 319 | c_latent = ( 320 | c_latent_from_encoder.unsqueeze(1) 321 | .repeat(1, self.ray_num, 1) 322 | .view(-1, c_latent_from_encoder.shape[1]) 323 | ) # [N*N_num_rays, c_dim] 324 | 325 | if shape_reg is not None: 326 | loss, reg_loss = self._compute_camera_loss( 327 | frame_list, 328 | rgb_imgs, 329 | mask_imgs, 330 | cameras, 331 | decoder, 332 | c_latent, 333 | device, 334 | shape_regularizer=shape_reg, 335 | ) 336 | else: 337 | loss = self._compute_camera_loss( 338 | frame_list, 339 | rgb_imgs, 340 | mask_imgs, 341 | cameras, 342 | decoder, 343 | c_latent, 344 | device, 345 | ) 346 | weights_softmax = torch.nn.functional.softmax(camera_weights) 347 | loss = torch.mul(weights_softmax, loss) 348 | loss = loss.sum() 349 | 350 | if shape_reg is not None: 351 | return loss, reg_loss 352 | else: 353 | return loss 354 | 355 | def val_get_bet_camera(self, batch_dict, device): 356 | frame_list = batch_dict["label_img_path"] # .detach().cpu() 357 | camera_params = self.get_cameras(frame_list, device) 358 | elev_angles = torch.nn.Parameter(camera_params[..., 0]) 359 | azim_angles = torch.nn.Parameter(camera_params[..., 1]) 360 | dists = torch.nn.Parameter(camera_params[..., 2]) 361 | cameras = torch.empty(len(frame_list), self.cfg.render.cam_num, 3).to(device) 362 | cameras[..., 0] = torch.tanh(0.1 * elev_angles) * 10 * self.elev_range[1] 363 | cameras[..., 1] = torch.tanh(0.1 * azim_angles) * 10 * self.azim_range[1] 364 | cameras[..., 2] = dists 365 | camera_weights = self._get_cam_weights(frame_list, device) 366 | weights_softmax = torch.nn.functional.softmax(camera_weights) 367 | _, inds = torch.max(weights_softmax, dim=1) 368 | 369 | return cameras, weights_softmax, inds 370 | 371 | def optimize_cameras( 372 | self, batch_dict, c_latent_from_encoder, decoder, device, shape_reg=None 373 | ): 374 | # frame_list: (n,) 375 | # c_latent: (n, c_dim) 376 | 377 | frame_list = batch_dict["label_img_path"] # .detach().cpu() 378 | rgb_imgs = batch_dict["label_rgb_img"] 379 | mask_imgs = batch_dict["label_mask_img"] 380 | 381 | c_latent = c_latent_from_encoder.detach() 382 | c_latent = ( 383 | c_latent.unsqueeze(1).repeat(1, self.ray_num, 1).view(-1, c_latent.shape[1]) 384 | ) # [N*N_num_rays, c_dim] 385 | # Define optimizer 386 | camera_params = self.get_cameras(frame_list, device) 387 | elev_angles = torch.nn.Parameter(camera_params[..., 0]) 388 | azim_angles = torch.nn.Parameter(camera_params[..., 1]) 389 | dists = torch.nn.Parameter(camera_params[..., 2]) 390 | 391 | cam_optimizer = torch.optim.Adam( 392 | [elev_angles, azim_angles, dists], lr=self.optim_lr 393 | ) 394 | 395 | for i in range(self.num_iters): 396 | cam_optimizer.zero_grad() 397 | decoder.zero_grad() 398 | 399 | cameras = torch.empty(len(frame_list), self.cfg.render.cam_num, 3).to( 400 | device 401 | ) 402 | cameras[..., 0] = torch.tanh(0.1 * elev_angles) * 10 * self.elev_range[1] 403 | cameras[..., 1] = torch.tanh(0.1 * azim_angles) * 10 * self.azim_range[1] 404 | cameras[..., 2] = dists 405 | 406 | loss = self._compute_camera_loss( 407 | frame_list, rgb_imgs, mask_imgs, cameras, decoder, c_latent, device 408 | ) 409 | loss_camera = loss.mean() 410 | 411 | loss_camera.backward() 412 | cam_optimizer.step() 413 | cam_optimizer.zero_grad() 414 | decoder.zero_grad() 415 | if i == 0: 416 | pre_optim_loss = loss_camera.detach() 417 | 418 | # TODO: Check if camera_params are actually updated!! If not uncomment and fix! 419 | # camera_params = torch.cat([elev_angles, azim_angles, dists], dim=-1) 420 | camera_params = camera_params.detach() 421 | self._update_cameras(frame_list, camera_params) 422 | cam_optimizer = None 423 | 424 | # compute decoder loss 425 | cameras = cameras.detach() 426 | decoder.zero_grad() 427 | reg_loss = 0.0 428 | if self.cfg.render.loss_mode == "softmin": 429 | loss = self._softmin_loss( 430 | frame_list, 431 | rgb_imgs, 432 | mask_imgs, 433 | cameras, 434 | decoder, 435 | c_latent_from_encoder, 436 | device, 437 | shape_reg=shape_reg, 438 | ) 439 | elif self.cfg.render.loss_mode == "softmax": 440 | if self.shape_reg is not None: 441 | loss, reg_loss = self._softmax_loss( 442 | frame_list, 443 | rgb_imgs, 444 | mask_imgs, 445 | cameras, 446 | decoder, 447 | c_latent_from_encoder, 448 | device, 449 | n_iter=1, 450 | shape_reg=shape_reg, 451 | ) 452 | else: 453 | loss = self._softmax_loss( 454 | frame_list, 455 | rgb_imgs, 456 | mask_imgs, 457 | cameras, 458 | decoder, 459 | c_latent_from_encoder, 460 | device, 461 | n_iter=1, 462 | shape_reg=shape_reg, 463 | ) 464 | 465 | ret_dict = { 466 | "pre_optim_loss": pre_optim_loss, 467 | "post_optim_loss": loss_camera, 468 | "decoder_loss": loss, 469 | "regularizer_loss": reg_loss, 470 | } 471 | 472 | return ret_dict 473 | 474 | 475 | class VolumetricNetworkCkpt(LightningModule): 476 | def __init__( 477 | self, 478 | cfg: DictConfig, 479 | ): 480 | super().__init__() 481 | self.cfg = cfg 482 | 483 | # make the e2e (encoder + decoder) model. 484 | self.model = get_model(cfg.model) 485 | 486 | # Save hyperparameters 487 | self.save_hyperparameters(cfg) 488 | 489 | assert self.cfg.render.cam_num > 0, "Cam number should be > 0 for no camera" 490 | self.ray_num = self.cfg.render.cam_num * self.cfg.render.ray_num_per_cam 491 | 492 | 493 | class VolumetricNetwork(LightningModule): 494 | def __init__( 495 | self, 496 | cfg: DictConfig, 497 | ): 498 | super().__init__() 499 | self.cfg = cfg 500 | 501 | # make the e2e (encoder + decoder) model. 502 | self.model = get_model(cfg.model) 503 | 504 | self.shape_reg = None 505 | 506 | # Save hyperparameters 507 | self.save_hyperparameters(cfg) 508 | 509 | assert self.cfg.render.cam_num > 0, "Cam number should be > 0 for no camera" 510 | self.ray_num = self.cfg.render.cam_num * self.cfg.render.ray_num_per_cam 511 | 512 | # Camera pose handler 513 | self.camera_handler = CameraHandler( 514 | self.cfg, 515 | ) 516 | self.temp = torch.rand(2, 2) 517 | 518 | def configure_optimizers(self): 519 | 520 | fine_tune = self.cfg.model.fine_tune 521 | if fine_tune == "none": 522 | params = [self.temp] 523 | elif fine_tune == "all": 524 | params = list(self.model.parameters()) 525 | elif fine_tune == "encoder": 526 | params = list(self.model.encoder.parameters()) 527 | elif fine_tune == "decoder": 528 | params = list(self.model.decoder.parameters()) 529 | 530 | if self.shape_reg is not None: 531 | params += list(self.shape_reg.parameters()) 532 | 533 | return torch.optim.Adam(params, lr=self.cfg.optim.lr) 534 | 535 | def validation_step(self, batch, batch_idx): 536 | """ 537 | This is the method that gets distributed 538 | """ 539 | 540 | with torch.no_grad(): 541 | label_img = batch["label_rgb_img"][-1] 542 | label_img = label_img.reshape(128, 128, 3).permute(2, 0, 1) 543 | 544 | label_mask_img = batch["label_mask_img"][-1] 545 | label_mask_img = label_mask_img.reshape(128, 128) 546 | 547 | inp_imgs = batch["rgb_img"] 548 | inp_imgs = inp_imgs.to(self.device) 549 | 550 | c_latent = self.model.encoder(inp_imgs) 551 | 552 | loss = torch.tensor(float(0.5)).type_as(c_latent) 553 | 554 | # Vol Render output image for logging 555 | render_kwargs = { 556 | "network_query_fn": network_query_fn_validation, 557 | "N_samples": 100, 558 | "decoder": self.model.decoder, 559 | "c_latent": c_latent[-1].reshape(1, -1), 560 | "chunk": 1000, 561 | "device": self.device, 562 | "has_rgb": self.cfg.render.rgb, 563 | "has_normal": False, 564 | } 565 | 566 | cameras, cam_weights, inds = self.camera_handler.val_get_bet_camera( 567 | batch, self.device 568 | ) 569 | 570 | poses = [ 571 | (45.0, 45.0, cameras[-1, inds[-1], 2]), 572 | (0.0, 90.0, cameras[-1, inds[-1], 2]), 573 | (0.0, 0.0, cameras[-1, inds[-1], 2]), 574 | (90.0, 0.0, cameras[-1, inds[-1], 2]), 575 | ] 576 | ref_occ_imgs = [] 577 | ref_rgb_imgs = [] 578 | for p in poses: 579 | elev_angle, azim_angle, dist = p 580 | _, occ_img, rgb_img, _ = render_img( 581 | dist=torch.tensor([dist]), 582 | elev_angle=torch.tensor([elev_angle]), 583 | azim_angle=torch.tensor([azim_angle]), 584 | img_size=self.cfg.render.img_size, 585 | focal=self.cfg.render.focal_length, 586 | render_kwargs=render_kwargs, 587 | ) 588 | ref_occ_imgs.append(occ_img) 589 | ref_rgb_imgs.append(rgb_img) 590 | 591 | ref_occ_imgs = torch.cat(ref_occ_imgs, dim=0) 592 | if self.cfg.render.rgb: 593 | ref_rgb_imgs = (torch.cat(ref_rgb_imgs, dim=0)).permute(2, 0, 1) 594 | else: 595 | ref_rgb_imgs = None 596 | 597 | depth_img, occ_img, rgb_img, normal_img = render_img( 598 | dist=cameras[-1, inds[-1], 2], 599 | elev_angle=cameras[-1, inds[-1], 0], 600 | azim_angle=cameras[-1, inds[-1], 1], 601 | img_size=self.cfg.render.img_size, 602 | focal=self.cfg.render.focal_length, 603 | render_kwargs=render_kwargs, 604 | ) 605 | 606 | if rgb_img is not None: 607 | rgb_img = rgb_img.permute(2, 0, 1) 608 | 609 | if normal_img is not None: 610 | normal_img = normal_img.permute(2, 0, 1) 611 | 612 | return { 613 | "loss": loss, 614 | "inp_img": inp_imgs[-1], 615 | "mask_gt": label_mask_img.unsqueeze(0), 616 | "rgb_gt": None, 617 | "normal_gt": None, 618 | "vol_render": occ_img.unsqueeze(0), 619 | "vol_render_rgb": rgb_img, 620 | "vol_render_normal": normal_img, 621 | "label_rgb_img": label_img, 622 | "ref_rgb_imgs": ref_rgb_imgs, 623 | } 624 | 625 | def validation_epoch_end(self, validation_epoch_outputs): 626 | avg_loss = torch.cat( 627 | [l["loss"].unsqueeze(0) for l in validation_epoch_outputs] 628 | ).mean() 629 | 630 | inp_img = torch.cat([l["inp_img"] for l in validation_epoch_outputs], -1) 631 | 632 | mask_gt = torch.cat([l["mask_gt"] for l in validation_epoch_outputs], -1) 633 | vol_render = torch.cat([l["vol_render"] for l in validation_epoch_outputs], -1) 634 | if self.cfg.render.rgb: 635 | vol_render_rgb = torch.cat( 636 | [l["vol_render_rgb"] for l in validation_epoch_outputs], -1 637 | ) 638 | self.logger.experiment.add_image( 639 | "vol_render_rgb", vol_render_rgb, self.global_step 640 | ) 641 | 642 | ref_rgb_imgs = torch.cat( 643 | [l["ref_rgb_imgs"] for l in validation_epoch_outputs], -1 644 | ) 645 | self.logger.experiment.add_image( 646 | "ref_rgb_imgs", ref_rgb_imgs, self.global_step 647 | ) 648 | 649 | label_img = torch.cat( 650 | [l["label_rgb_img"] for l in validation_epoch_outputs], -1 651 | ) 652 | 653 | self.logger.experiment.add_image( 654 | "label rgb img", label_img, self.global_step 655 | ) 656 | 657 | self.logger.experiment.add_image("val_inp_rgb", inp_img, self.global_step) 658 | self.logger.experiment.add_image("val_vol_render", vol_render, self.global_step) 659 | self.logger.experiment.add_image( 660 | "val_mesh_gt_render", mask_gt, self.global_step 661 | ) 662 | 663 | self.log("val_loss", avg_loss, on_step=False, on_epoch=True, prog_bar=True) 664 | return {"val_loss": avg_loss, "progress_bar": {"global_step": self.global_step}} 665 | 666 | def training_step(self, batch, batch_idx): 667 | """ 668 | This is the method that gets distributed 669 | """ 670 | 671 | inp_imgs = batch["rgb_img"].to(self.device) 672 | c_latent = self.model.encoder(inp_imgs) # [N, c_dim] 673 | out_dict = self.camera_handler.optimize_cameras( 674 | batch, c_latent, self.model.decoder, self.device, self.shape_reg 675 | ) 676 | 677 | self.log( 678 | "train_loss", 679 | out_dict["decoder_loss"], 680 | on_step=True, 681 | on_epoch=True, 682 | prog_bar=True, 683 | ) 684 | self.log( 685 | "camera_post_optim_loss", 686 | out_dict["post_optim_loss"], 687 | on_step=True, 688 | on_epoch=True, 689 | prog_bar=True, 690 | ) 691 | self.log( 692 | "camera_pre_optim_loss", 693 | out_dict["pre_optim_loss"], 694 | on_step=True, 695 | on_epoch=True, 696 | prog_bar=True, 697 | ) 698 | 699 | if self.shape_reg is not None: 700 | self.log( 701 | "train_regularizer_loss", 702 | out_dict["regularizer_loss"], 703 | on_step=True, 704 | on_epoch=True, 705 | prog_bar=True, 706 | ) 707 | loss_out = ( 708 | 0.3 * out_dict["decoder_loss"] + 0.7 * out_dict["regularizer_loss"] 709 | ) 710 | return loss_out 711 | else: 712 | return out_dict["decoder_loss"] 713 | 714 | def forward(self, mode, inputs): 715 | pass 716 | 717 | 718 | def train_model(cfg): 719 | 720 | print(OmegaConf.to_yaml(cfg)) 721 | 722 | data_module = GenericImgMaskModule( 723 | data_cfg=cfg.data, 724 | render_cfg=cfg.render, 725 | num_workers=cfg.resources.num_workers, 726 | ) 727 | 728 | log_dir = osp.join(_base_path, cfg.logging.log_dir, cfg.logging.name) 729 | os.makedirs(log_dir, exist_ok=True) 730 | OmegaConf.save(cfg, osp.join(log_dir, "config.txt")) 731 | 732 | checkpoint_callback = ModelCheckpoint( 733 | save_top_k=-1, 734 | every_n_val_epochs=cfg.optim.save_freq, 735 | filename="checkpoint_{epoch}", 736 | ) 737 | 738 | lr_monitor = LearningRateMonitor(logging_interval="step") 739 | 740 | # Fine tune only cameras 741 | logger = TensorBoardLogger( 742 | osp.join(_base_path, cfg.logging.log_dir), name=cfg.logging.name 743 | ) 744 | 745 | if cfg.optim.stage_one_epochs > 0: 746 | cfg.model.fine_tune = "none" 747 | print(cfg) 748 | model = VolumetricNetwork(cfg=cfg) 749 | if cfg.optim.use_pretrain: 750 | temp_model = VolumetricNetwork.load_from_checkpoint( 751 | cfg.optim.checkpoint_path 752 | ) 753 | model.model.load_state_dict(temp_model.model.state_dict()) 754 | else: 755 | checkpoint = None 756 | 757 | trainer = Trainer( 758 | logger=logger, 759 | gpus=cfg.resources.gpus, 760 | num_nodes=cfg.resources.num_nodes, 761 | val_check_interval=cfg.optim.val_check_interval, 762 | limit_val_batches=cfg.optim.num_val_iter, 763 | # checkpoint_callback=checkpoint_callback, 764 | # resume_from_checkpoint=checkpoint, 765 | resume_from_checkpoint=None, # Only loading weights 766 | max_epochs=cfg.optim.stage_one_epochs, 767 | accelerator=cfg.resources.accelerator 768 | if cfg.resources.accelerator != "none" 769 | else None, 770 | deterministic=False, 771 | # profiler="simple", 772 | callbacks=[lr_monitor], 773 | ) 774 | trainer.fit(model, data_module) 775 | cam_weights_path = model.camera_handler.weights_path 776 | 777 | # Fine tune the entire network + cameras 778 | print("#########################################################################") 779 | print("################### Fine Tuning Phase #########################") 780 | print("#########################################################################") 781 | 782 | logger = TensorBoardLogger( 783 | osp.join(_base_path, cfg.logging.log_dir), name=cfg.logging.name 784 | ) 785 | 786 | cfg.model.fine_tune = "all" # TODO: Make this a config param! 787 | 788 | print(cfg) 789 | model = VolumetricNetwork(cfg=cfg) 790 | if cfg.optim.use_pretrain: 791 | temp_model = VolumetricNetwork.load_from_checkpoint(cfg.optim.checkpoint_path) 792 | model.model.load_state_dict(temp_model.model.state_dict()) 793 | else: 794 | checkpoint = None 795 | 796 | # Point the model to use old paths 797 | if cfg.optim.stage_one_epochs > 0: 798 | model.camera_handler.weights_path = cam_weights_path 799 | 800 | trainer = Trainer( 801 | logger=logger, 802 | gpus=cfg.resources.gpus, 803 | num_nodes=cfg.resources.num_nodes, 804 | val_check_interval=cfg.optim.val_check_interval, 805 | limit_val_batches=cfg.optim.num_val_iter, 806 | callbacks=[checkpoint_callback, lr_monitor], 807 | resume_from_checkpoint=None, 808 | max_epochs=cfg.optim.max_epochs, 809 | accelerator=cfg.resources.accelerator 810 | if cfg.resources.accelerator != "none" 811 | else None, 812 | deterministic=False, 813 | # profiler="simple", 814 | ) 815 | trainer.fit(model, data_module) 816 | 817 | 818 | @hydra.main(config_name="config") 819 | def main(cfg: config.Config) -> None: 820 | # Set the everythin - randon, numpy, torch, torch manula, cuda!! 821 | seed_everything(12) 822 | 823 | # If not cluster launch job locally 824 | if not cfg.resources.use_cluster: 825 | train_model(cfg) 826 | else: 827 | print(OmegaConf.to_yaml(cfg)) 828 | 829 | # dummy to get veriosn only 830 | dummy_logger = TensorBoardLogger( 831 | osp.join(_base_path, cfg.logging.log_dir), name=cfg.logging.name 832 | ) 833 | 834 | submitit_dir = osp.join( 835 | _base_path, 836 | cfg.logging.log_dir, 837 | cfg.logging.name, 838 | "submitit_" + str(dummy_logger.version), 839 | ) 840 | executor = submitit.AutoExecutor(folder=submitit_dir) 841 | 842 | job_kwargs = { 843 | "timeout_min": cfg.resources.time, 844 | "name": cfg.logging.name, 845 | "slurm_partition": cfg.resources.partition, 846 | "gpus_per_node": cfg.resources.gpus, 847 | "tasks_per_node": cfg.resources.gpus, # one task per GPU 848 | "cpus_per_task": 5, 849 | "nodes": cfg.resources.num_nodes, 850 | } 851 | if cfg.resources.max_mem: 852 | job_kwargs["slurm_constraint"] = "volta32gb" 853 | if cfg.resources.partition == "priority": 854 | job_kwargs["slurm_comment"] = cfg.resources.comment 855 | 856 | executor.update_parameters(**job_kwargs) 857 | job = executor.submit(train_model, cfg) 858 | print("Submitit Job ID:", job.job_id) # ID of your job 859 | 860 | 861 | if __name__ == "__main__": 862 | main() 863 | -------------------------------------------------------------------------------- /data/_init_.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ss3d/0efd205f678f8325d9fab3a08d67b69bd11d7dfd/data/_init_.py -------------------------------------------------------------------------------- /data/distillation.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | 5 | import fnmatch 6 | import math 7 | import os 8 | import os.path as osp 9 | import random 10 | 11 | import imageio 12 | import numpy as np 13 | import PIL 14 | import pytorch3d 15 | import pytorch3d.renderer 16 | import scipy 17 | import scipy.io 18 | import scipy.misc 19 | import torch 20 | import torch.nn.functional as torch_F 21 | import torchvision.transforms.functional as F 22 | from data.generic_img_mask_loader import GenericImgMaskModule 23 | from data.warehouse3d import WareHouse3DModule 24 | from hydra_config import config 25 | from model import get_model 26 | from PIL import Image 27 | from pytorch_lightning import LightningDataModule 28 | from pytorch_lightning.core.lightning import LightningModule 29 | from torch.utils.data import Dataset, DistributedSampler 30 | from torchvision import transforms 31 | 32 | 33 | class VolumetricNetworkCkpt(LightningModule): 34 | def __init__( 35 | self, 36 | cfg, 37 | ): 38 | super().__init__() 39 | self.cfg = cfg 40 | 41 | # make the e2e (encoder + decoder) model. 42 | self.model = get_model(cfg.model) 43 | 44 | # Save hyperparameters 45 | self.save_hyperparameters(cfg) 46 | 47 | cam_num = self.cfg.render.cam_num if self.cfg.render.cam_num > 0 else 1 48 | self.ray_num = cam_num * self.cfg.render.ray_num_per_cam 49 | 50 | 51 | def get_pl_datamodule(cfg): 52 | if cfg.data.name == "common_dl": # generic_img_mask 53 | dataLoaderMethod = GenericImgMaskModule 54 | else: 55 | dataLoaderMethod = WareHouse3DModule 56 | 57 | # configure data loader 58 | data_loader = dataLoaderMethod( 59 | data_cfg=cfg.data, 60 | render_cfg=cfg.render, 61 | num_workers=cfg.resources.num_workers, 62 | ) 63 | return data_loader 64 | 65 | 66 | class DistillationDataset(Dataset): 67 | def __init__(self, dataset_list, cfg): 68 | 69 | self.index_to_dataset = [] 70 | self.warehouse3d_probability = cfg.distillation.warehouse3d_prob 71 | 72 | if self.warehouse3d_probability > 0.0: 73 | self.dataset_list = dataset_list[1:] 74 | self.warehouse3d_ds = dataset_list[0] 75 | else: 76 | self.dataset_list = dataset_list 77 | self.warehouse3d_ds = None 78 | 79 | for dataset_id, dataset in enumerate(self.dataset_list): 80 | length = len(dataset) 81 | dataset_list = [(dataset_id, i) for i in range(len(dataset))] 82 | self.index_to_dataset += dataset_list 83 | 84 | render_cfg = cfg.render 85 | 86 | self.dist_range = [ 87 | render_cfg.camera_near_dist, 88 | render_cfg.camera_far_dist, 89 | ] 90 | self.ray_num = render_cfg.ray_num_per_cam * ( 91 | render_cfg.cam_num if render_cfg.cam_num > 0 else 1 92 | ) 93 | self.img_size = render_cfg.img_size 94 | 95 | def __len__(self): 96 | return len(self.index_to_dataset) 97 | 98 | def __getitem__(self, index): 99 | 100 | if self.warehouse3d_probability == 0.0: 101 | dataset_id, data_idx = self.index_to_dataset[index] 102 | data_dict = self.dataset_list[dataset_id][data_idx] 103 | else: 104 | if torch.rand(1).item() < self.warehouse3d_probability: 105 | dataset_id = 0 106 | data_idx = torch.randint(0, len(self.warehouse3d_ds), (1,)).item() 107 | data_dict = self.warehouse3d_ds[data_idx] 108 | data_dict["orig_img_path"] = data_dict["label_img_path"] 109 | else: 110 | dataset_id, data_idx = self.index_to_dataset[index] 111 | data_dict = self.dataset_list[dataset_id][data_idx] 112 | dataset_id += 1 113 | 114 | # Sample random pose 115 | elev_angle = 90 * (2 * torch.rand(1) - 1.0) # [-90,90] 116 | azim_angle = 180 * (2 * torch.rand(1) - 1.0) # [-180,180] 117 | dist = self.dist_range[0] + ( 118 | self.dist_range[1] - self.dist_range[0] 119 | ) * torch.rand(1) 120 | 121 | temp_idx = torch.randperm(self.img_size * self.img_size) 122 | idx = temp_idx[: self.ray_num] # [1, num_points] 123 | 124 | return { 125 | "rgb_img": data_dict["rgb_img"], 126 | "dataset_id": dataset_id, 127 | "data_idx": data_idx, 128 | "label_img_path": data_dict["label_img_path"], 129 | "label_rgb_img": data_dict["label_rgb_img"], 130 | "label_mask_img": data_dict["label_mask_img"], 131 | "orig_img_path": data_dict["orig_img_path"], 132 | "dist": dist, 133 | "elev_angle": elev_angle, 134 | "azim_angle": azim_angle, 135 | "flat_indices": idx, 136 | } 137 | 138 | 139 | def get_paths(root_dir, regex_match, regex_exclude=""): 140 | if regex_exclude != "": 141 | regex_exclude = regex_exclude.split(",") 142 | else: 143 | regex_exclude = [] 144 | 145 | out_paths = [] 146 | for root, d_names, f_names in os.walk(root_dir): 147 | for f in f_names: 148 | 149 | path = os.path.join(root, f) 150 | 151 | in_exclude = False 152 | for ex_re in regex_exclude: 153 | if fnmatch.fnmatch(path, ex_re): 154 | in_exclude = True 155 | 156 | if fnmatch.fnmatch(path, regex_match) and (not in_exclude): 157 | out_paths.append(path) 158 | 159 | return out_paths 160 | 161 | 162 | class DistillationDataModule(LightningDataModule): 163 | def __init__(self, cfg, base_path): 164 | super().__init__() 165 | self.cfg = cfg 166 | self.data_module_list = [] 167 | self.num_workers = cfg.resources.num_workers 168 | assert cfg.distillation.ckpts_root_dir is not None 169 | checkpoint_paths = get_paths( 170 | cfg.distillation.ckpts_root_dir, 171 | cfg.distillation.regex_match, 172 | cfg.distillation.regex_exclude, 173 | ) 174 | checkpoint_paths = [cfg.distillation.warehouse3d_ckpt_path] + checkpoint_paths 175 | print("!!!!!!!!!!!!!Loading checkpoints from root dir for dataloader!!!!!!!!") 176 | print(checkpoint_paths, sep="\n") 177 | for checkpoint_path in checkpoint_paths: 178 | temp_model = VolumetricNetworkCkpt.load_from_checkpoint(checkpoint_path) 179 | self.data_module_list.append(get_pl_datamodule(temp_model.cfg)) 180 | del temp_model 181 | 182 | def train_dataloader(self): 183 | 184 | datasets = [dm.train_dataloader().dataset for dm in self.data_module_list] 185 | distillation_ds = DistillationDataset(datasets, self.cfg) 186 | sampler = DistributedSampler(distillation_ds) 187 | 188 | return torch.utils.data.DataLoader( 189 | distillation_ds, 190 | batch_size=self.cfg.data.bs_train, 191 | num_workers=self.num_workers, 192 | sampler=sampler, 193 | shuffle=sampler is None, 194 | ) 195 | 196 | def val_dataloader(self): 197 | 198 | return self.train_dataloader() 199 | # datasets = [dm.val_dataloader().dataset for dm in self.data_module_list] 200 | # distillation_ds = DistillationDataset(datasets, self.cfg) 201 | # sampler = DistributedSampler(distillation_ds) 202 | # #sampler = None 203 | # return torch.utils.data.DataLoader( 204 | # distillation_ds, 205 | # batch_size=self.cfg.data.bs_val, 206 | # num_workers=self.num_workers, 207 | # sampler=sampler, 208 | # shuffle=sampler is None, 209 | # ) 210 | -------------------------------------------------------------------------------- /data/generic_img_mask_loader.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | 5 | import math 6 | import os.path as osp 7 | import random 8 | 9 | import imageio 10 | import numpy as np 11 | import pandas 12 | import PIL 13 | import torch 14 | import torch.nn.functional as torch_F 15 | import torchvision.transforms.functional as F 16 | from PIL import Image 17 | from pytorch_lightning import LightningDataModule 18 | from torch.utils.data import Dataset, DistributedSampler 19 | from torchvision import transforms 20 | 21 | 22 | class SquarePad: 23 | def __call__(self, image): 24 | w, h = image.size 25 | max_wh = np.max([w, h]) 26 | hp = int((max_wh - w) / 2) 27 | vp = int((max_wh - h) / 2) 28 | padding = (hp, vp, hp, vp) 29 | return F.pad(image, padding, 0, "constant") 30 | 31 | 32 | def get_tight_bbox(mask): 33 | mask_bool = mask > 0.3 34 | row_agg = mask_bool.sum(dim=0) 35 | row_agg = row_agg > 0 36 | col_agg = mask_bool.sum(dim=1) 37 | col_agg = col_agg > 0 38 | 39 | def get_left_right(x): 40 | left = 0 41 | for i in range(len(x)): 42 | if x[i]: 43 | left = i 44 | break 45 | right = len(x) 46 | for i in range(len(x) - 1, 0, -1): 47 | if x[i]: 48 | right = i 49 | right += 1 50 | break 51 | 52 | return left, right 53 | 54 | x1, x2 = get_left_right(col_agg) 55 | y1, y2 = get_left_right(row_agg) 56 | 57 | return np.array([x1, y1, x2, y2]) 58 | 59 | 60 | def read_paths_and_boxes(file_path, data_cfg): 61 | 62 | class_names = data_cfg.class_ids.split(",") 63 | 64 | split_df = pandas.read_csv( 65 | file_path, 66 | header=None, 67 | names=[ 68 | "class_id", 69 | "rgb_path", 70 | "mask_path", 71 | "BoxXMin", 72 | "BoxYMin", 73 | "BoxXMax", 74 | "BoxYMax", 75 | ], 76 | ) 77 | img_list = [] 78 | for allowed_id in class_names: 79 | class_df = split_df.loc[split_df["class_id"] == allowed_id] 80 | 81 | # valid_df = class_df.loc[class_df["truncated"] > iou_th] 82 | num_images = min(len(class_df), data_cfg.max_per_class) 83 | assert num_images > 10, f"Minimum image criterion not met for {class_names}" 84 | class_df = class_df.head(num_images) 85 | 86 | cls_counter = 0 87 | for index, row in class_df.iterrows(): 88 | 89 | rgb_path = None 90 | mask_path = None 91 | if osp.exists( 92 | osp.join(data_cfg.rgb_path_prefix, row["rgb_path"]) 93 | ) and osp.exists(osp.join(data_cfg.mask_path_prefix, row["rgb_path"])): 94 | rgb_path = osp.join(data_cfg.rgb_path_prefix, row["rgb_path"]) 95 | mask_path = osp.join(data_cfg.mask_path_prefix, row["rgb_path"]) 96 | 97 | if (mask_path is None) or (rgb_path is None): 98 | continue 99 | 100 | if math.isnan(row["BoxYMax"]): 101 | bbox = None 102 | else: 103 | bbox = [row["BoxXMin"], row["BoxYMin"], row["BoxXMax"], row["BoxYMax"]] 104 | 105 | img_list.append( 106 | { 107 | "bbox": bbox, 108 | "rgb_path": rgb_path, 109 | "mask_path": mask_path, 110 | } 111 | ) 112 | cls_counter += 1 113 | 114 | if cls_counter > data_cfg.max_per_class: 115 | break 116 | return img_list 117 | 118 | 119 | class GenericImgMaskDataset(Dataset): 120 | def __init__(self, img_list, data_cfg, render_cfg): 121 | 122 | self.data_cfg = data_cfg 123 | 124 | self.render_cfg = render_cfg 125 | self.img_list = img_list 126 | 127 | self.inp_transforms = transforms.Compose( 128 | [ 129 | SquarePad(), # pad to square 130 | transforms.Pad(30, fill=0, padding_mode="constant"), 131 | # functional.crop, 132 | transforms.Resize((224, 224)), # resize 133 | transforms.ToTensor(), 134 | ] 135 | ) 136 | 137 | self.label_transforms = transforms.Compose( 138 | [ 139 | SquarePad(), # pad to square 140 | transforms.Pad(30, fill=0, padding_mode="constant"), 141 | # functional.crop, 142 | transforms.Resize( 143 | (self.render_cfg.img_size, self.render_cfg.img_size) 144 | ), # resize 145 | transforms.ToTensor(), 146 | ] 147 | ) 148 | 149 | def __len__(self): 150 | return len(self.img_list) 151 | 152 | def __getitem__(self, index): 153 | 154 | img_dict = self.img_list[index] 155 | img_path = img_dict["rgb_path"] 156 | mask_path = img_dict["mask_path"] 157 | 158 | with open(img_path, "rb") as f: 159 | raw_rgb_img = Image.open(f) 160 | raw_rgb_img = np.array(raw_rgb_img.convert("RGB")) 161 | 162 | mask_image = imageio.imread(mask_path) 163 | mask_image = (torch.Tensor(mask_image)).float() / 255.0 164 | 165 | if len(mask_image.shape) == 3: 166 | mask_image = mask_image[..., -1] 167 | 168 | mask_image = torch_F.interpolate( 169 | mask_image.unsqueeze(0).unsqueeze(0), 170 | (raw_rgb_img.shape[0], raw_rgb_img.shape[1]), 171 | ) 172 | mask_image = mask_image.squeeze(0).squeeze(0) 173 | 174 | # If Bounding box is not given, get a tight bounding box from mask image 175 | bbox = ( 176 | np.array(img_dict["bbox"]) 177 | if img_dict["bbox"] 178 | else get_tight_bbox(mask_image) 179 | ) 180 | 181 | img_shape = mask_image.shape 182 | # bbox[0] *= img_shape[1] 183 | # bbox[2] *= img_shape[1] 184 | # bbox[1] *= img_shape[0] 185 | # bbox[3] *= img_shape[0] 186 | bbox = bbox.astype(int) 187 | 188 | if len(mask_image.shape) == 3: 189 | mask_image = mask_image[..., -1] 190 | 191 | label_img = mask_image[bbox[1] : bbox[3], bbox[0] : bbox[2]] 192 | 193 | rgb_img = ( 194 | raw_rgb_img[bbox[1] : bbox[3], bbox[0] : bbox[2], :] 195 | * label_img.unsqueeze(-1).numpy() 196 | ) 197 | 198 | rgb_img_proc = self.inp_transforms( 199 | PIL.Image.fromarray(rgb_img.astype(np.uint8)) 200 | ) 201 | 202 | label_rgb_img_proc = self.label_transforms( 203 | PIL.Image.fromarray(rgb_img.astype(np.uint8)) 204 | ) 205 | label_rgb_img_proc = label_rgb_img_proc.permute(1, 2, 0) 206 | label_rgb_img_proc = label_rgb_img_proc.reshape(-1, 3) 207 | 208 | label_img_proc = self.label_transforms( 209 | PIL.Image.fromarray((255.0 * label_img).numpy().astype(np.uint8)) 210 | ) 211 | label_img_proc = label_img_proc.view(-1).float() 212 | 213 | # Simplify mask path to use it as key for hash and storing camera weights. 214 | mask_path = mask_path.replace("/", "") 215 | mask_path = mask_path.replace(".", "") 216 | 217 | return { 218 | "rgb_img": rgb_img_proc.float(), 219 | "label_img_path": mask_path, 220 | "label_rgb_img": label_rgb_img_proc.float(), 221 | "label_mask_img": label_img_proc, 222 | "orig_img_path": img_path, 223 | "mesh_path": "some_placeholder_mesh.off", 224 | } 225 | return ret_dict 226 | 227 | 228 | class DatasetPermutationWrapper(Dataset): 229 | def __init__(self, dset): 230 | self.dset = dset 231 | self._len = len(self.dset) 232 | 233 | def __len__(self): 234 | return self._len 235 | 236 | def __getitem__(self, _): 237 | # TODO(Fix): This random generator behaves same on all gpu's 238 | index = random.randint(0, self._len - 1) 239 | return self.dset[index] 240 | 241 | 242 | class GenericImgMaskModule(LightningDataModule): 243 | def __init__(self, data_cfg, render_cfg, num_workers=0, debug_mode=False, cfg=None): 244 | super().__init__() 245 | self.cfg = cfg 246 | self.data_cfg = data_cfg 247 | self.render_cfg = render_cfg 248 | self.num_workers = num_workers 249 | 250 | self.train_split = read_paths_and_boxes( 251 | self.data_cfg.train_dataset_file, self.data_cfg 252 | ) 253 | self.val_split = read_paths_and_boxes( 254 | self.data_cfg.val_dataset_file, self.data_cfg 255 | ) 256 | 257 | def train_dataloader(self): 258 | 259 | assert self.render_cfg.cam_num > 0, "camera number cannot be 0" 260 | 261 | train_ds = GenericImgMaskDataset( 262 | self.train_split, 263 | self.data_cfg, 264 | self.render_cfg, 265 | ) 266 | 267 | sampler = None 268 | 269 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 270 | sampler = DistributedSampler(train_ds) 271 | 272 | return torch.utils.data.DataLoader( 273 | train_ds, 274 | batch_size=self.data_cfg.bs_train, 275 | num_workers=self.num_workers, 276 | sampler=sampler, 277 | shuffle=sampler is None, 278 | ) 279 | 280 | def val_dataloader(self): 281 | 282 | assert self.render_cfg.cam_num > 0, "camera number cannot be 0" 283 | 284 | val_ds = DatasetPermutationWrapper( 285 | GenericImgMaskDataset( 286 | self.val_split, 287 | self.data_cfg, 288 | self.render_cfg, 289 | ) 290 | ) 291 | return torch.utils.data.DataLoader( 292 | val_ds, batch_size=self.data_cfg.bs_val, num_workers=self.num_workers 293 | ) 294 | -------------------------------------------------------------------------------- /data/warehouse3d.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | 5 | import copy 6 | import os.path as osp 7 | import random 8 | import time 9 | 10 | import imageio 11 | import numpy as np 12 | import scipy.io as sio 13 | import torch 14 | import torchvision.transforms as transforms 15 | from PIL import Image 16 | from pytorch_lightning import LightningDataModule 17 | from torch.utils.data import Dataset, DistributedSampler 18 | from volumetric_render import ( 19 | get_rays_from_angles, 20 | get_transformation, 21 | ) 22 | 23 | 24 | def loadDepth(dFile, minVal=0, maxVal=10): 25 | dMap = imageio.imread(dFile) 26 | dMap = dMap.astype(np.float32) 27 | dMap = dMap * (maxVal - minVal) / (pow(2, 16) - 1) + minVal 28 | return dMap 29 | 30 | 31 | def getSynsetsV1Paths(data_cfg): 32 | synsets = data_cfg.class_ids.split(",") 33 | synsets.sort() 34 | root_dir = data_cfg.rgb_dir 35 | synsetModels = [ 36 | [f for f in os.listdir(osp.join(root_dir, s)) if len(f) > 3] for s in synsets 37 | ] 38 | 39 | paths = [] 40 | for i in range(len(synsets)): 41 | for m in synsetModels[i]: 42 | paths.append([synsets[i], m]) 43 | 44 | return paths 45 | 46 | 47 | def get_rays_multiplex(cameras, rgb_imgs, mask_imgs, render_cfg, device): 48 | len_cameras = len(cameras) 49 | assert len_cameras == len(rgb_imgs) and len_cameras == len( 50 | mask_imgs 51 | ), "incorrect inputs for camera computation" 52 | 53 | all_rays = [] 54 | all_rgb_labels = [] 55 | all_mask_labels = [] 56 | for i, cam in enumerate(cameras): 57 | # compute rays 58 | 59 | assert len(cam) == render_cfg.cam_num, "incorrect per frame camera number" 60 | indices = [] 61 | ind = torch.randperm(render_cfg.img_size * render_cfg.img_size) 62 | for j in range(len(cam)): 63 | indices.append(ind[: render_cfg.ray_num_per_cam]) 64 | all_rgb_labels.append(rgb_imgs[i, indices[-1]]) 65 | all_mask_labels.append(mask_imgs[i, indices[-1]]) 66 | 67 | all_rays.append( 68 | get_rays_from_angles( 69 | H=render_cfg.img_size, 70 | W=render_cfg.img_size, 71 | focal=float(render_cfg.focal_length), 72 | near_plane=render_cfg.near_plane, 73 | far_plane=render_cfg.far_plane, 74 | elev_angles=cam[:, 0], 75 | azim_angles=cam[:, 1], 76 | dists=cam[:, 2], 77 | device=device, 78 | indices=indices, 79 | ) 80 | ) # [(Num_cams_per_frame*Num_rays), 8] #2d 81 | 82 | return ( 83 | torch.cat(all_mask_labels).to( 84 | device 85 | ), # [(N*Num_cams_per_frame*Num_rays_per_cam)] #1d 86 | torch.cat(all_rgb_labels).to( 87 | device 88 | ), # [[(N*Num_cams_per_frame*Num_rays_per_cam), 3] #2d 89 | torch.cat(all_rays).to( 90 | device 91 | ), # [(N*Num_cams_per_frame*Num_rays_per_cam), 8] #2d 92 | ) 93 | 94 | 95 | def extract_data_train(batch_dict, render_cfg, device): 96 | # If using pre-rendered. 97 | assert "mask_label_rays" in batch_dict 98 | 99 | inp_imgs = batch_dict["rgb_img"] 100 | mask_label_rays = batch_dict["mask_label_rays"].view(-1) 101 | rays = batch_dict["rays"].view(-1, 8) 102 | rgb_label_rays = batch_dict["rgb_label_rays"] 103 | rgb_label_rays = rgb_label_rays.reshape(-1, 3) 104 | 105 | return ( 106 | inp_imgs.to(device), # [N, 3, img_size, img_size] 107 | mask_label_rays.to(device), # [(N*Num_rays)] #1d 108 | rgb_label_rays.to(device), # [[(N*Num_rays), 3] #2d 109 | None, 110 | None, 111 | rays.to(device), # [(N*Num_rays), 8] #2d 112 | ) 113 | 114 | 115 | class DatasetPermutationWrapper(Dataset): 116 | def __init__(self, dset): 117 | self.dset = dset 118 | self._len = len(self.dset) 119 | 120 | def __len__(self): 121 | return self._len 122 | 123 | def __getitem__(self, _): 124 | # TODO(Fix): This random generator behaves same on all gpu's 125 | index = random.randint(0, self._len - 1) 126 | return self.dset[index] 127 | 128 | 129 | class WareHouse3DDataset(Dataset): 130 | """Characterizes a dataset for PyTorch""" 131 | 132 | # TODO: Change hardcoded values! 133 | def __init__(self, data_root, paths, render_cfg, encoder_root): 134 | 135 | super(WareHouse3DDataset, self).__init__() 136 | 137 | self.paths = paths 138 | self.render_cfg = render_cfg 139 | self.data_root = data_root 140 | self.n_cams = self.render_cfg.cam_num 141 | self.n_rays_per_cam = self.render_cfg.ray_num_per_cam 142 | self.encoder_root = encoder_root 143 | 144 | self.transform_img = transforms.Compose( 145 | [ 146 | transforms.Resize((224, 224)), 147 | transforms.ToTensor(), 148 | ] 149 | ) 150 | 151 | self.transform_label = transforms.Compose( 152 | [ 153 | transforms.Resize((self.render_cfg.img_size, self.render_cfg.img_size)), 154 | transforms.ToTensor(), 155 | ] 156 | ) 157 | 158 | def __len__(self): 159 | return len(self.paths) 160 | 161 | def __getitem__(self, index): 162 | st_time = time.time() 163 | rel_path = os.path.join(self.paths[index][0], self.paths[index][1]) 164 | data_path = os.path.join(self.data_root, rel_path) 165 | 166 | # sample a random encoder imput rgb image. 167 | 168 | encoder_cam_info = np.load( 169 | osp.join(self.encoder_root, rel_path, "cam_info.npy") 170 | ) 171 | encoder_sample_num = random.randint(0, encoder_cam_info.shape[0] - 1) 172 | # encoder_sample_num = 4 # Only for debug 173 | inp_angles = encoder_cam_info[encoder_sample_num, :] 174 | inp_angles[0] += 90 175 | img_path = os.path.join( 176 | self.encoder_root, rel_path, "render_{}.png".format(encoder_sample_num) 177 | ) 178 | with open(img_path, "rb") as f: 179 | img = Image.open(f) 180 | img = img.convert("RGB") 181 | inp_rgb_size = img.size[0] 182 | inp_focal = (sio.loadmat(os.path.join(data_path, "camera_0.mat")))["K"][0, 0] 183 | 184 | img = self.transform_img(img) 185 | 186 | # Sample random cameras 187 | cam_info = np.load(osp.join(data_path, "cam_info.npy")) 188 | 189 | # Only use a subset of the data 190 | if self.render_cfg.num_pre_rend_masks > 1: 191 | cam_info = cam_info[: self.render_cfg.num_pre_rend_masks, :] 192 | 193 | # TODO(Fix): This random generator behaves same on all data-workers on each gpu 194 | cam_inds = np.random.choice(cam_info.shape[0], self.n_cams) 195 | # cam_inds = [0] # only for debug 196 | cam_info = torch.Tensor(cam_info[cam_inds, :]) 197 | azim_angle, elev_angle, theta, dist = torch.split(cam_info, 1, dim=-1) 198 | azim_angle += ( 199 | 90 # This a known blender offset. Look at the noteboosks for visual test. 200 | ) 201 | # sample rays from cameras and mask images 202 | render_cfg = self.render_cfg 203 | pixel_ids = [] 204 | ray_mask_labels = [] 205 | ray_rgb_labels = [] 206 | 207 | for i, nc in enumerate(cam_inds): 208 | temp_idx = torch.randperm( 209 | self.render_cfg.img_size * self.render_cfg.img_size 210 | ) 211 | idx = temp_idx[: self.n_rays_per_cam] # [1, n_rays_per_cam] 212 | pixel_ids.append(idx) 213 | 214 | # Masks from depth 215 | gt_mask = loadDepth( 216 | osp.join(data_path, "depth_{}.png".format(int(nc))), minVal=0, maxVal=10 217 | ) 218 | empty = gt_mask >= 10.0 219 | notempty = gt_mask < 10.0 220 | gt_mask[empty] = 0 221 | gt_mask[notempty] = 1.0 222 | gt_mask = self.transform_label(Image.fromarray(gt_mask)) 223 | gt_mask = gt_mask.view(-1).float() 224 | ray_mask_labels.append(gt_mask[idx]) 225 | 226 | # RGB Pixels 227 | label_rgb_path = osp.join(data_path, "render_{}.png".format(int(nc))) 228 | with open(label_rgb_path, "rb") as f: 229 | gt_rgb = Image.open(f) 230 | gt_rgb = gt_rgb.convert("RGB") 231 | 232 | gt_rgb = self.transform_label(gt_rgb) 233 | gt_rgb = gt_rgb.permute(1, 2, 0) 234 | gt_rgb = gt_rgb.reshape(-1, 3) 235 | ray_rgb_labels.append(gt_rgb[idx]) 236 | 237 | # inp_angles = cam_info[i, :] # TODO: Test!!! 238 | 239 | # n_cams X n_rays_per_cam 240 | pixel_idx = torch.stack(pixel_ids, dim=0) 241 | mask_label_rays = torch.cat(ray_mask_labels, dim=0) 242 | rgb_label_rays = torch.cat(ray_rgb_labels, dim=0) 243 | 244 | label_focal = inp_focal * float(render_cfg.img_size) / inp_rgb_size 245 | 246 | # Used only in relative case 247 | rays = get_rays_from_angles( 248 | H=render_cfg.img_size, 249 | W=render_cfg.img_size, 250 | focal=label_focal, 251 | near_plane=render_cfg.near_plane, 252 | far_plane=render_cfg.far_plane, 253 | elev_angles=elev_angle[:, 0], 254 | azim_angles=azim_angle[:, 0], 255 | dists=dist[:, 0], 256 | device=torch.device("cpu"), 257 | indices=pixel_idx, 258 | transformation_rel=None, 259 | ) # [(n_cams * n_rays_per_cam), 8] #2d 260 | 261 | return { 262 | "rgb_img": img, 263 | "rays": rays, 264 | "mask_label_rays": mask_label_rays, 265 | "rgb_label_rays": rgb_label_rays, 266 | # info useful for debugging 267 | "elev_angle": torch.tensor([elev_angle[-1, 0]]).float(), 268 | "azim_angle": torch.tensor([azim_angle[-1, 0]]).float(), 269 | "dist": torch.tensor([dist[-1, 0]]).float(), 270 | "rel_path": rel_path, 271 | # Used for no camera pose 272 | "label_img_path": label_rgb_path, 273 | "label_rgb_img": gt_rgb, 274 | "label_mask_img": gt_mask, 275 | # class id's 276 | "class_id": self.paths[index][0], 277 | "orig_img_path": label_rgb_path, 278 | } 279 | 280 | 281 | class WareHouse3DModule(LightningDataModule): 282 | def __init__(self, data_cfg, render_cfg, num_workers=0, debug_mode=False): 283 | super().__init__() 284 | self.data_cfg = data_cfg 285 | self.render_cfg = render_cfg 286 | self.num_workers = num_workers 287 | 288 | paths = getSynsetsV1Paths(self.data_cfg) 289 | 290 | train_size = int(self.data_cfg.train_split * len(paths)) 291 | validation_size = int(self.data_cfg.validation_split * len(paths)) 292 | test_size = len(paths) - train_size - validation_size 293 | ( 294 | self.train_split, 295 | self.validation_split, 296 | self.test_split, 297 | ) = torch.utils.data.random_split( 298 | paths, [train_size, validation_size, test_size] 299 | ) 300 | print( 301 | "Total Number of paths:", 302 | len(paths), 303 | len(self.train_split), 304 | len(self.validation_split), 305 | ) 306 | 307 | def train_dataloader(self): 308 | 309 | assert self.render_cfg.cam_num > 0, "camera number cannot be 0" 310 | train_ds = WareHouse3DDataset( 311 | self.data_cfg.rgb_dir, 312 | self.train_split, 313 | self.render_cfg, 314 | self.data_cfg.encoder_dir, 315 | ) 316 | sampler = None 317 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 318 | sampler = DistributedSampler(train_ds) 319 | 320 | return torch.utils.data.DataLoader( 321 | train_ds, 322 | batch_size=self.data_cfg.bs_train, 323 | num_workers=self.num_workers, 324 | shuffle=sampler is None, 325 | sampler=sampler, 326 | ) 327 | 328 | def val_dataloader(self): 329 | 330 | assert self.render_cfg.cam_num > 0, "camera number cannot be 0" 331 | 332 | val_ds = DatasetPermutationWrapper( 333 | WareHouse3DDataset( 334 | self.data_cfg.rgb_dir, 335 | self.validation_split, 336 | self.render_cfg, 337 | self.data_cfg.encoder_dir, 338 | ) 339 | ) 340 | return torch.utils.data.DataLoader( 341 | val_ds, batch_size=self.data_cfg.bs_val, num_workers=self.num_workers 342 | ) 343 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import imageio 8 | import mcubes 9 | import numpy as np 10 | import PIL.Image 11 | import torch 12 | import torchvision.transforms.functional as F 13 | import trimesh 14 | from PIL import Image 15 | from torchvision import transforms 16 | 17 | 18 | class SquarePad: 19 | def __call__(self, image): 20 | w, h = image.size 21 | max_wh = np.max([w, h]) 22 | hp = int((max_wh - w) / 2) 23 | vp = int((max_wh - h) / 2) 24 | padding = (hp, vp, hp, vp) 25 | return F.pad(image, padding, 0, "constant") 26 | 27 | 28 | def get_tight_bbox(mask): 29 | mask_bool = mask > 0.3 30 | row_agg = mask_bool.sum(dim=0) 31 | row_agg = row_agg > 0 32 | col_agg = mask_bool.sum(dim=1) 33 | col_agg = col_agg > 0 34 | 35 | def get_left_right(x): 36 | left = 0 37 | for i in range(len(x)): 38 | if x[i]: 39 | left = i 40 | break 41 | right = len(x) 42 | for i in range(len(x) - 1, 0, -1): 43 | if x[i]: 44 | right = i 45 | right += 1 46 | break 47 | 48 | return left, right 49 | 50 | x1, x2 = get_left_right(col_agg) 51 | y1, y2 = get_left_right(row_agg) 52 | 53 | return x1, x2, y1, y2 54 | 55 | 56 | def generate_input_img( 57 | rgb_path, 58 | mask_path, 59 | ): 60 | inp_transforms = transforms.Compose( 61 | [ 62 | SquarePad(), # pad to square 63 | transforms.Pad(30, fill=0, padding_mode="constant"), 64 | # functional.crop, 65 | transforms.Resize((224, 224)), # resize 66 | # transforms.Normalize(3 * [0.5], 3 * [0.5]), 67 | transforms.ToTensor(), 68 | ] 69 | ) 70 | 71 | with open(rgb_path, "rb") as f: 72 | raw_rgb_img = Image.open(f) 73 | raw_rgb_img = np.array(raw_rgb_img.convert("RGB")) 74 | 75 | mask_image = imageio.imread(mask_path) 76 | mask_image = (torch.Tensor(mask_image)).float() / 255.0 77 | 78 | # clip based on bbox 79 | bbox = get_tight_bbox(mask_image) 80 | label_img = mask_image[bbox[0] : bbox[1], bbox[2] : bbox[3]] 81 | 82 | rgb_img = ( 83 | raw_rgb_img[bbox[0] : bbox[1], bbox[2] : bbox[3], :] 84 | * label_img.unsqueeze(-1).numpy() 85 | ) 86 | 87 | return ( 88 | inp_transforms(PIL.Image.fromarray(rgb_img.astype(np.uint8))) 89 | .float() 90 | .unsqueeze(0) 91 | ) 92 | 93 | 94 | def extract_trimesh(model, img, device="cuda", threshold=3.0, discretization=100): 95 | 96 | model = model.to(device) 97 | 98 | c_latent = model.encoder(img.to(device)) 99 | assert c_latent.shape[0] == 1, "C should be of shape 1*c_dim for val" 100 | 101 | # Volume during training is contained ot a cube of [-0.5,0.5] 102 | x_l = torch.FloatTensor(np.linspace(-0.5, 0.5, discretization)).to(device) 103 | y_l = torch.FloatTensor(np.linspace(-0.5, 0.5, discretization)).to(device) 104 | z_l = torch.FloatTensor(np.linspace(-0.5, 0.5, discretization)).to(device) 105 | x, y, z = torch.meshgrid(x_l, y_l, z_l) 106 | 107 | points_cords = torch.stack([x, y, z], dim=-1) 108 | with torch.no_grad(): 109 | 110 | # c_inp = torch.cat(points_cords.shape[0] * [c]) 111 | # pred_voxels = network_query_fn_validation(points_cords, c_latent, model.decoder) 112 | 113 | pred_voxels = model.decoder(points_cords, c=c_latent) 114 | pred_voxels = pred_voxels[..., 3] 115 | pred_voxels = pred_voxels.cpu().numpy() 116 | 117 | vertices, triangles = mcubes.marching_cubes(pred_voxels, threshold) 118 | return trimesh.Trimesh(vertices, triangles, vertex_normals=None, process=False) 119 | -------------------------------------------------------------------------------- /demo_data/mask_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ss3d/0efd205f678f8325d9fab3a08d67b69bd11d7dfd/demo_data/mask_0.png -------------------------------------------------------------------------------- /demo_data/mask_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ss3d/0efd205f678f8325d9fab3a08d67b69bd11d7dfd/demo_data/mask_1.png -------------------------------------------------------------------------------- /demo_data/mask_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ss3d/0efd205f678f8325d9fab3a08d67b69bd11d7dfd/demo_data/mask_2.png -------------------------------------------------------------------------------- /demo_data/mask_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ss3d/0efd205f678f8325d9fab3a08d67b69bd11d7dfd/demo_data/mask_3.png -------------------------------------------------------------------------------- /demo_data/mask_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ss3d/0efd205f678f8325d9fab3a08d67b69bd11d7dfd/demo_data/mask_4.png -------------------------------------------------------------------------------- /demo_data/rgb_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ss3d/0efd205f678f8325d9fab3a08d67b69bd11d7dfd/demo_data/rgb_0.png -------------------------------------------------------------------------------- /demo_data/rgb_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ss3d/0efd205f678f8325d9fab3a08d67b69bd11d7dfd/demo_data/rgb_1.png -------------------------------------------------------------------------------- /demo_data/rgb_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ss3d/0efd205f678f8325d9fab3a08d67b69bd11d7dfd/demo_data/rgb_2.png -------------------------------------------------------------------------------- /demo_data/rgb_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ss3d/0efd205f678f8325d9fab3a08d67b69bd11d7dfd/demo_data/rgb_3.png -------------------------------------------------------------------------------- /demo_data/rgb_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ss3d/0efd205f678f8325d9fab3a08d67b69bd11d7dfd/demo_data/rgb_4.png -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: ss3d 2 | channels: 3 | - pytorch 4 | - fvcore 5 | - iopath 6 | - bottler 7 | - pytorch3d 8 | - conda-forge 9 | - defaults 10 | dependencies: 11 | - python=3.9 12 | - pip 13 | - pytorch=1.9.1 14 | - cudatoolkit=10.2 15 | - torchvision 16 | - fvcore 17 | - iopath 18 | - pytorch3d 19 | - pip: 20 | - pytorch-lightning==1.5.10 21 | - submitit 22 | - imageio 23 | - hydra-core 24 | - tqdm 25 | - scipy 26 | - opencv-python 27 | - sklearn 28 | - trimesh 29 | - pandas -------------------------------------------------------------------------------- /hydra_config/_init_.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ss3d/0efd205f678f8325d9fab3a08d67b69bd11d7dfd/hydra_config/_init_.py -------------------------------------------------------------------------------- /hydra_config/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os.path as osp 4 | import pdb 5 | from dataclasses import dataclass, field 6 | from typing import Any, List 7 | 8 | import hydra 9 | from hydra.core.config_store import ConfigStore 10 | from omegaconf import MISSING, OmegaConf 11 | 12 | cs = ConfigStore.instance() 13 | 14 | 15 | defaults = [ 16 | "_self_", 17 | {"logging": "tensorboard"}, 18 | {"model": "vnet"}, 19 | {"optim": "default"}, 20 | {"resources": "default"}, 21 | {"data": "warehouse3d"}, 22 | {"render": "default"}, 23 | {"distillation": "empty"}, 24 | {"distillation/ckpts": []}, 25 | ] 26 | 27 | 28 | @dataclass 29 | class DataConfig: 30 | # name: str = "" 31 | bs_train: int = 1 32 | bs_val: int = 1 33 | num_workers: int = 0 34 | 35 | 36 | @dataclass 37 | class Warehouse3DData(DataConfig): 38 | train_split: float = 0.95 39 | validation_split: float = 0.05 40 | 41 | # If its not seperately rendered, it should be same as rgb_dir 42 | encoder_dir: str = ( 43 | "/private/home/kalyanv/learning_vision3d/datasets/blender/renders_encoder" 44 | ) 45 | rgb_dir: str = "/private/home/kalyanv/learning_vision3d/datasets/blender/renders" 46 | 47 | # class_ids: str = "02691156,03001627,03790512" 48 | class_ids: str = "02858304,02924116,03790512,04468005,\ 49 | 02992529,02843684,02954340,02691156,\ 50 | 02933112,03001627,03636649,04090263,\ 51 | 04379243,04530566,02828884,02958343,\ 52 | 03211117,03691459,04256520,04401088,\ 53 | 02747177,02773838,02801938,02808440,\ 54 | 02818832,02834778,02871439,02876657,\ 55 | 02880940,02942699,02946921,03085013,\ 56 | 03046257,03207941,03261776,03325088,\ 57 | 03337140,03467517,03513137,03593526,\ 58 | 03624134,03642806,03710193,03759954,\ 59 | 03761084,03797390,03928116,03938244,\ 60 | 03948459,03991062,04004475,04074963,\ 61 | 04099429,04225987,04330267,04460130,04554684" 62 | name: str = "warehouse3d" 63 | bs_train: int = 8 64 | bs_val: int = 1 65 | bs_test: int = 1 66 | num_workers: int = 0 67 | 68 | 69 | @dataclass 70 | class GenericImgMaskData(DataConfig): 71 | name: str = "common_dl" 72 | train_split: float = 0.9 73 | validation_split: float = 0.1 74 | train_dataset_file: str = "path_to_class_img_path_maskpath_file.csv" 75 | val_dataset_file: str = "path_to_class_img_path_maskpath_file.csv" 76 | mask_path_prefix: str = "" 77 | rgb_path_prefix: str = "" 78 | 79 | class_ids: str = "class1,class2" 80 | max_per_class: int = 3500 81 | 82 | bs_train: int = 1 83 | bs_val: int = 1 84 | bs_test: int = 1 85 | 86 | 87 | cs.store(group="data", name="default", node=DataConfig) 88 | cs.store(group="data", name="generic_img_mask", node=GenericImgMaskData) 89 | cs.store(group="data", name="warehouse3d", node=Warehouse3DData) 90 | 91 | 92 | @dataclass 93 | class RenderConfig: 94 | img_size: int = 128 95 | focal_length: int = 300 96 | near_plane: float = 0.1 97 | far_plane: float = 2.5 98 | camera_near_dist: float = 1.3 99 | camera_far_dist: float = 1.7 100 | cam_num: int = 5 # if -1, render on fly, dont use prerend 101 | num_pre_rend_masks: int = 50 # -1 corresponds to use all 102 | ray_num_per_cam: int = 340 103 | on_ray_num_samples: int = 80 104 | rgb: bool = True 105 | normals: bool = False 106 | depth: bool = False 107 | 108 | # No camera pose params 109 | softmin_temp: float = 1.0 110 | loss_mode: str = "softmax" # other option is "softmax" 111 | use_momentum: bool = True 112 | 113 | 114 | cs.store(group="render", name="default", node=RenderConfig) 115 | 116 | 117 | @dataclass 118 | class CheckpointConfig: 119 | name: Any = None 120 | version: int = 0 121 | epoch: Any = "last" 122 | pl_module: Any = None 123 | 124 | 125 | def extract_ckpt_path(cfg): 126 | path = osp.join(cfg.name, "version_{}".format(cfg.version)) 127 | if cfg.epoch == "last": 128 | checkpoint_path = osp.join(path, "checkpoints", "last.ckpt") 129 | else: 130 | checkpoint_path = osp.join( 131 | path, "checkpoints", "epoch={}.ckpt".format(cfg.epoch) 132 | ) 133 | return checkpoint_path 134 | 135 | 136 | @dataclass 137 | class OptimizationConfig: 138 | val_check_interval: float = 1 # 300 139 | num_val_iter: int = 20 140 | save_freq: int = 25 141 | max_epochs: int = 475 142 | stage_one_epochs: int = 10 143 | lr: float = 0.00005 144 | use_scheduler: bool = False 145 | use_pretrain: bool = False 146 | checkpoint_path: str = "somepath" 147 | 148 | use_shape_reg: bool = False 149 | 150 | 151 | cs.store(group="optim", name="default", node=OptimizationConfig) 152 | 153 | 154 | @dataclass 155 | class LoggingConfig: 156 | log_dir: str = "job_outputs" 157 | name: str = "temp" 158 | 159 | 160 | cs.store(group="logging", name="tensorboard", node=LoggingConfig) 161 | 162 | 163 | @dataclass 164 | class ModelConfig: 165 | encoder: str = "" 166 | decoder: str = "" 167 | c_dim: int = 0 168 | inp_dim: int = 3 169 | fine_tune: str = "all" # "encoder" or "decoder" or "none" 170 | 171 | 172 | @dataclass 173 | class VNetConfig(ModelConfig): 174 | encoder: str = "resnet34_res_fc" 175 | decoder: str = "siren_rgb" 176 | c_dim: int = 2560 177 | inp_dim: int = 3 178 | fine_tune: str = "all" # "encoder" or "decoder" or "none" 179 | 180 | 181 | cs.store(group="model", name="default", node=ModelConfig) 182 | cs.store(group="model", name="vnet", node=VNetConfig) 183 | 184 | 185 | @dataclass 186 | class ResourceConfig: 187 | gpus: int = 1 188 | num_nodes: int = 1 189 | num_workers: int = 0 190 | accelerator: Any = "ddp" # ddp or dp or none 191 | 192 | # cluster specific config 193 | use_cluster: bool = False 194 | max_mem: bool = True # if true use volta32gb for SLURM jobs. 195 | time: int = 60 * 36 # minutes 196 | partition: str = "dev" 197 | comment: str = "please fill this if using priority partition" 198 | 199 | # TOOD: later remove this 200 | mesh_th: float = 2.0 201 | 202 | 203 | cs.store(group="resources", name="default", node=ResourceConfig) 204 | 205 | 206 | @dataclass 207 | class DistillationConfig: 208 | # name: str = "" 209 | mode: str = "point" 210 | 211 | num_points: int = 1000 212 | sample_bounds: List[float] = field(default_factory=lambda: [-0.6, 0.6]) 213 | 214 | use_encoder_transforms: bool = False 215 | 216 | warehouse3d_prob: float = 0.3 217 | warehouse3d_ckpt_path: str = "" 218 | 219 | ckpts_root_dir: str = "" 220 | regex_exclude: str = "" 221 | regex_match: str = "" 222 | 223 | 224 | cs.store(group="distillation", name="empty", node=DistillationConfig) 225 | 226 | 227 | @dataclass 228 | class Config: 229 | defaults: List[Any] = field(default_factory=lambda: defaults) 230 | 231 | data: DataConfig = MISSING 232 | render: RenderConfig = MISSING 233 | logging: LoggingConfig = MISSING 234 | model: ModelConfig = MISSING 235 | optim: OptimizationConfig = MISSING 236 | resources: ResourceConfig = MISSING 237 | distillation: DistillationConfig = MISSING 238 | 239 | 240 | cs.store(name="config", node=Config) 241 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from __future__ import absolute_import, division, print_function 9 | 10 | import math 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | from torchvision import models 16 | 17 | 18 | def normalize_imagenet(x): 19 | """Normalize input images according to ImageNet standards. 20 | 21 | Args: 22 | x (tensor): input images 23 | """ 24 | x = x.clone() 25 | x[:, 0] = (x[:, 0] - 0.485) / 0.229 26 | x[:, 1] = (x[:, 1] - 0.456) / 0.224 27 | x[:, 2] = (x[:, 2] - 0.406) / 0.225 28 | return x 29 | 30 | 31 | class SineFiLMLayer(nn.Module): 32 | def __init__( 33 | self, in_features, out_features, bias=True, is_first=False, omega_0=30.0 34 | ): 35 | super().__init__() 36 | self.omega_0 = omega_0 37 | self.is_first = is_first 38 | 39 | self.in_features = in_features 40 | self.linear = nn.Linear(in_features, out_features, bias=bias) 41 | 42 | self.init_weights() 43 | 44 | def init_weights(self): 45 | with torch.no_grad(): 46 | if self.is_first: 47 | self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features) 48 | else: 49 | self.linear.weight.uniform_( 50 | -np.sqrt(6 / self.in_features) / self.omega_0, 51 | np.sqrt(6 / self.in_features) / self.omega_0, 52 | ) 53 | 54 | def forward(self, x, frequencies=1.0, shift=0.0): 55 | return torch.sin(frequencies * self.omega_0 * self.linear(x) + shift) 56 | 57 | def forward_with_intermediate(self, x, frequencies=1.0, shift=0.0): 58 | # For visualization of activation distributions 59 | intermediate = frequencies * self.omega_0 * self.linear(x) + shift 60 | return torch.sin(intermediate), intermediate 61 | 62 | 63 | class SirenFiLM(nn.Module): 64 | def __init__( 65 | self, 66 | dim=3, 67 | c_dim=2 * (256 * 5), 68 | hidden_size=256, 69 | hidden_layers=5, 70 | output_ch=4, 71 | outermost_linear=True, 72 | first_omega_0=30.0, 73 | hidden_omega_0=30.0, 74 | ): 75 | super().__init__() 76 | if not outermost_linear: 77 | assert c_dim == 2 * ( 78 | hidden_size * hidden_layers + output_ch 79 | ), "Incorrect c_dim in Siren!!!" 80 | else: 81 | assert c_dim == 2 * ( 82 | hidden_size * hidden_layers 83 | ), "Incorrect c_dim in Siren!!!" 84 | 85 | self.hidden_size = hidden_size 86 | self.output_ch = output_ch 87 | 88 | self.net = nn.ModuleList() 89 | self.net.append( 90 | SineFiLMLayer(dim, hidden_size, is_first=True, omega_0=first_omega_0) 91 | ) 92 | 93 | for i in range(hidden_layers): 94 | self.net.append( 95 | SineFiLMLayer( 96 | hidden_size, hidden_size, is_first=False, omega_0=hidden_omega_0 97 | ) 98 | ) 99 | 100 | if outermost_linear: 101 | final_linear = nn.Linear(hidden_size, output_ch) 102 | 103 | with torch.no_grad(): 104 | final_linear.weight.uniform_( 105 | -np.sqrt(6 / hidden_size) / hidden_omega_0, 106 | np.sqrt(6 / hidden_size) / hidden_omega_0, 107 | ) 108 | 109 | self.net.append(final_linear) 110 | else: 111 | self.net.append( 112 | SineFiLMLayer( 113 | hidden_size, output_ch, is_first=False, omega_0=hidden_omega_0 114 | ) 115 | ) 116 | 117 | def forward(self, p, c, **kwargs): 118 | # shape of c is B * 2*(hidden_size*hidden_layers + output_ch) 119 | output = p # coordinates 120 | split = int(c.shape[1] / 2) 121 | frequencies = c[:, :split] + 1.0 122 | shifts = c[:, split:] 123 | for i, layer in enumerate(self.net): 124 | # initial layer just encodes positions 125 | if i == 0 or (not isinstance(layer, SineFiLMLayer)): 126 | output = layer(output) 127 | else: 128 | f_i = frequencies[ 129 | :, (i - 1) * self.hidden_size : (i) * self.hidden_size 130 | ].unsqueeze(1) 131 | s_i = shifts[ 132 | :, (i - 1) * self.hidden_size : (i) * self.hidden_size 133 | ].unsqueeze(1) 134 | output = layer(output, f_i, s_i) 135 | output = output.squeeze(-1) 136 | return output 137 | 138 | 139 | class ResidualLayer(nn.Module): 140 | def __init__(self, in_features, out_features, bias=True): 141 | super().__init__() 142 | 143 | self.in_features = in_features 144 | self.linear1 = nn.Linear(in_features, out_features, bias=bias) 145 | self.linear2 = nn.Linear(out_features, out_features, bias=bias) 146 | self.relu1 = nn.LeakyReLU(0.01, inplace=False) 147 | self.relu2 = nn.LeakyReLU(0.01, inplace=False) 148 | 149 | def forward(self, x_init): 150 | x = self.relu1(self.linear1(x_init)) 151 | x = x_init + self.linear2(x) 152 | x = self.relu2(x) 153 | return x 154 | 155 | 156 | class Resnet34ResFC(nn.Module): 157 | r"""ResNet-18 encoder network for image input. 158 | Args: 159 | c_dim (int): output dimension of the latent embedding 160 | normalize (bool): whether the input images should be normalized 161 | use_linear (bool): whether a final linear layer should be used 162 | """ 163 | 164 | def __init__(self, c_dim, normalize=True, use_linear=True, linear_dim=512): 165 | super().__init__() 166 | self.normalize = normalize 167 | self.use_linear = use_linear 168 | self.features = models.resnet34(pretrained=True) 169 | self.features.fc = nn.Sequential() 170 | self.fc = nn.Sequential( 171 | nn.Linear(linear_dim, c_dim), 172 | nn.modules.normalization.LayerNorm(c_dim, elementwise_affine=False), 173 | nn.LeakyReLU(0.01, inplace=False), 174 | ResidualLayer(c_dim, c_dim), 175 | ) 176 | 177 | def forward(self, x): 178 | if self.normalize: 179 | x = normalize_imagenet(x) 180 | net = self.features(x) 181 | out = self.fc(net) 182 | return out 183 | 184 | 185 | class VNet(nn.Module): 186 | """Volumetric Network class. 187 | 188 | Args: 189 | decoder (nn.Module): decoder network 190 | encoder (nn.Module): encoder network 191 | encoder_latent (nn.Module): latent encoder network 192 | p0_z (dist): prior distribution for latent code z 193 | device (device): torch device 194 | """ 195 | 196 | def __init__( 197 | self, 198 | decoder=SirenFiLM(hidden_size=256, hidden_layers=5), 199 | encoder=Resnet34ResFC(c_dim=2560), 200 | device="cuda", 201 | ): 202 | super(VNet, self).__init__() 203 | 204 | self.decoder = decoder 205 | self.encoder = encoder 206 | self._device = device 207 | 208 | def forward(self): 209 | pass 210 | 211 | def to(self, device): 212 | """Puts the model to the device. 213 | 214 | Args: 215 | device (device): pytorch device 216 | """ 217 | model = super().to(device) 218 | model._device = device 219 | return model 220 | 221 | 222 | def get_model(model_cfg): 223 | # TODO: Add more functionality to this method to enable building model from cfgs. 224 | return VNet() 225 | -------------------------------------------------------------------------------- /synth_pretraining.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import os.path as osp 5 | import pdb 6 | 7 | import hydra 8 | import numpy as np 9 | import submitit 10 | import torch 11 | from data.warehouse3d import ( 12 | WareHouse3DModule, 13 | extract_data_train, 14 | ) 15 | from hydra_config import config 16 | from model import get_model 17 | from omegaconf import DictConfig, OmegaConf 18 | from pytorch_lightning import Trainer, seed_everything 19 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 20 | from pytorch_lightning.core.lightning import LightningModule 21 | from pytorch_lightning.loggers import TensorBoardLogger 22 | from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler 23 | from torchvision.transforms import ToTensor 24 | from volumetric_render import ( 25 | network_query_fn_validation, 26 | render_img, 27 | render_rays, 28 | ) 29 | 30 | 31 | os.environ["MKL_THREADING_LAYER"] = "GNU" 32 | 33 | _curr_path = osp.dirname(osp.abspath(__file__)) 34 | _base_path = _curr_path # osp.join(_curr_path, "..") 35 | 36 | 37 | class VolumetricNetwork(LightningModule): 38 | def __init__( 39 | self, 40 | cfg: DictConfig, 41 | ): 42 | super().__init__() 43 | self.cfg = cfg 44 | 45 | # make the e2e (encoder + decoder) model. 46 | self.model = get_model(cfg.model) 47 | 48 | # Save hyperparameters 49 | self.save_hyperparameters(cfg) 50 | 51 | cam_num = self.cfg.render.cam_num if self.cfg.render.cam_num > 0 else 1 52 | self.ray_num = cam_num * self.cfg.render.ray_num_per_cam 53 | 54 | # Matplotlib figures 55 | self.iou_fig, self.iou_ax = plt.subplots() 56 | self.IOU_THRESHOLDS = np.linspace(0, 100, 200).tolist() 57 | 58 | self.voxels_fig = plt.figure() 59 | self.voxels_ax = self.voxels_fig.add_subplot(111, projection="3d") 60 | 61 | def configure_optimizers(self): 62 | optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg.optim.lr) 63 | if self.cfg.optim.use_scheduler: 64 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 65 | optimizer, 66 | factor=0.5, 67 | patience=75, 68 | threshold_mode="abs", 69 | threshold=0.005, 70 | ) 71 | scheduler = {"scheduler": scheduler, "monitor": "train_loss"} 72 | return {"optimizer": optimizer, "lr_scheduler": scheduler} 73 | else: 74 | return optimizer 75 | 76 | def extract_batch_input(self, batch): 77 | return extract_data_train(batch, self.cfg.render, self.device) 78 | 79 | def validation_step(self, batch, batch_idx): 80 | """ 81 | This is the method that gets distributed 82 | """ 83 | with torch.no_grad(): 84 | 85 | ( 86 | inp_imgs, 87 | mask_ray_labels, 88 | rgb_label_rays, 89 | normal_ray_labels, 90 | depth_ray_labels, 91 | rays, 92 | ) = self.extract_batch_input(batch) 93 | 94 | c_latent = self.model.encoder(inp_imgs) 95 | c_latent = ( 96 | c_latent.unsqueeze(1) 97 | .repeat(1, self.ray_num, 1) 98 | .view(-1, c_latent.shape[1]) 99 | ) 100 | 101 | ray_outs = render_rays( 102 | ray_batch=rays, 103 | c_latent=c_latent, 104 | decoder=self.model.decoder, 105 | N_samples=self.cfg.render.on_ray_num_samples, 106 | has_rgb=self.cfg.render.rgb, 107 | ) 108 | mask_ray_outs = ray_outs["acc_map"].to(self.device) 109 | 110 | loss = torch.nn.functional.mse_loss( 111 | mask_ray_labels, 112 | mask_ray_outs, # reduction="none" 113 | ) 114 | 115 | # Vol Render output image for logging 116 | render_kwargs = { 117 | "network_query_fn": network_query_fn_validation, 118 | "N_samples": 100, 119 | "decoder": self.model.decoder, 120 | "c_latent": c_latent[-1].reshape(1, -1), 121 | "chunk": 1000, 122 | "device": self.device, 123 | "has_rgb": self.cfg.render.rgb, 124 | "has_normal": False, 125 | } 126 | 127 | poses = [ 128 | (0, 0, 1.5), 129 | (90, 90, 1.5), 130 | (0, 90, 1.5), 131 | (90, 0, 1.5), 132 | ] 133 | occ_imgs = [] 134 | rgb_imgs = [] 135 | for p in poses: 136 | elev_angle, azim_angle, dist = p 137 | _, occ_img, rgb_img, _ = render_img( 138 | dist=torch.tensor([dist]).type_as(batch["dist"][-1]), 139 | elev_angle=torch.tensor([elev_angle]).type_as(batch["dist"][-1]), 140 | azim_angle=torch.tensor([azim_angle]).type_as(batch["dist"][-1]), 141 | img_size=self.cfg.render.img_size, 142 | focal=self.cfg.render.focal_length, 143 | render_kwargs=render_kwargs, 144 | ) 145 | occ_imgs.append(occ_img) 146 | rgb_imgs.append(rgb_img) 147 | 148 | occ_img = torch.cat(occ_imgs, dim=0) 149 | 150 | rgb_img = (torch.cat(rgb_imgs, dim=0)).permute(2, 0, 1) 151 | 152 | return { 153 | "loss": loss, 154 | "inp_img": inp_imgs[-1], 155 | "rgb_gt": None, 156 | "normal_gt": None, 157 | "vol_render": occ_img.unsqueeze(0), 158 | "vol_render_rgb": rgb_img, 159 | "vol_render_normal": None, 160 | } 161 | 162 | def validation_epoch_end(self, validation_epoch_outputs): 163 | 164 | avg_loss = torch.cat( 165 | [l["loss"].unsqueeze(0) for l in validation_epoch_outputs] 166 | ).mean() 167 | 168 | # Input Image 169 | inp_img = torch.cat([l["inp_img"] for l in validation_epoch_outputs], -1) 170 | self.logger.experiment.add_image("val_inp_rgb", inp_img, self.global_step) 171 | 172 | # Mask Rendering 173 | vol_render = torch.cat([l["vol_render"] for l in validation_epoch_outputs], -1) 174 | self.logger.experiment.add_image("val_vol_render", vol_render, self.global_step) 175 | 176 | # RGB Rendering 177 | vol_render_rgb = torch.cat( 178 | [l["vol_render_rgb"] for l in validation_epoch_outputs], -1 179 | ) 180 | self.logger.experiment.add_image( 181 | "vol_render_rgb", vol_render_rgb, self.global_step 182 | ) 183 | 184 | self.log("val_loss", avg_loss, on_step=False, on_epoch=True, prog_bar=True) 185 | return {"val_loss": avg_loss, "progress_bar": {"global_step": self.global_step}} 186 | 187 | def training_step(self, batch, batch_idx): 188 | """ 189 | This is the method that gets distributed 190 | """ 191 | # [N,3, img_size, img_size], 1-d [N*num_rays], [N*num_rays, 8] 192 | ( 193 | inp_imgs, 194 | mask_ray_labels, 195 | rgb_ray_labels, 196 | normal_ray_labels, 197 | depth_ray_labels, 198 | rays, 199 | ) = self.extract_batch_input(batch) 200 | 201 | c_latent = self.model.encoder(inp_imgs) # [N, c_dim] 202 | # For instance, C = [[1,2],[3,4]] and num_rays = 2 203 | # below lines return C = [[1,2],[1,2],[3,4],[3,4]] 204 | c_latent = ( 205 | c_latent.unsqueeze(1).repeat(1, self.ray_num, 1).view(-1, c_latent.shape[1]) 206 | ) # [N*N_num_rays, c_dim] 207 | 208 | ray_outs = render_rays( 209 | ray_batch=rays, # [N*num_rays, 8] 210 | c_latent=c_latent, # [N*N_num_rays, c_dim] 211 | decoder=self.model.decoder, # nn.Module 212 | N_samples=self.cfg.render.on_ray_num_samples, # int 213 | has_rgb=self.cfg.render.rgb, 214 | has_normal=self.cfg.render.normals, 215 | ) 216 | mask_ray_outs = ray_outs["acc_map"].to(self.device) # [N*num_rays] 1-d 217 | 218 | loss = [] 219 | loss_mask = torch.nn.functional.mse_loss( 220 | mask_ray_labels, 221 | mask_ray_outs, # reduction="none" 222 | ) # [N*num_rays] 1-d 223 | loss.append(loss_mask) 224 | 225 | if self.cfg.render.rgb: 226 | rgb_ray_outs = ray_outs["rgb_map"].to(self.device) 227 | loss_rgb = torch.nn.functional.mse_loss( 228 | rgb_ray_labels, 229 | rgb_ray_outs, 230 | ) 231 | loss.append(loss_rgb) 232 | 233 | if self.cfg.render.normals: 234 | normal_ray_outs = ray_outs["normal_map"].to(self.device) 235 | normal_ray_outs = normal_ray_outs.reshape(-1, 3) 236 | normal_ray_labels = normal_ray_labels.reshape(-1, 3) 237 | loss_normal = normal_ray_outs * normal_ray_labels 238 | loss_normal = torch.abs(loss_normal.sum(-1)) 239 | 240 | loss_normal = 1.0 - loss_normal 241 | loss_normal = loss_normal.mean() 242 | loss.append(loss_normal) 243 | 244 | # loss = torch.tensor(loss, requires_grad=True) 245 | if self.cfg.render.rgb: 246 | loss = (loss_mask + loss_rgb) / 2.0 247 | else: 248 | loss = loss_mask 249 | self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True) 250 | return {"loss": loss, "progress_bar": {"global_step": self.global_step}} 251 | 252 | def forward(self, mode, inputs): 253 | pass 254 | 255 | 256 | def train_model(cfg): 257 | print(OmegaConf.to_yaml(cfg)) 258 | 259 | # configure data loader 260 | data_module = WareHouse3DModule( 261 | data_cfg=cfg.data, 262 | render_cfg=cfg.render, 263 | num_workers=cfg.resources.num_workers, 264 | ) 265 | 266 | model = VolumetricNetwork(cfg=cfg) 267 | log_dir = osp.join(_base_path, cfg.logging.log_dir, cfg.logging.name) 268 | os.makedirs(log_dir, exist_ok=True) 269 | OmegaConf.save(cfg, osp.join(log_dir, "config.txt")) 270 | 271 | logger = TensorBoardLogger( 272 | osp.join(_base_path, cfg.logging.log_dir), name=cfg.logging.name 273 | ) 274 | checkpoint_callback = ModelCheckpoint( 275 | save_top_k=-1, 276 | every_n_val_epochs=cfg.optim.save_freq, 277 | filename="checkpoint_{epoch}", 278 | ) 279 | 280 | if cfg.optim.use_pretrain: 281 | temp_model = VolumetricNetwork.load_from_checkpoint(cfg.optim.checkpoint_path) 282 | model.model.load_state_dict(temp_model.model.state_dict()) 283 | else: 284 | checkpoint = None 285 | 286 | lr_monitor = LearningRateMonitor(logging_interval="step") 287 | 288 | # https://pytorch-lightning.readthedocs.io/en/latest/trainer.html 289 | trainer = Trainer( 290 | logger=logger, 291 | gpus=cfg.resources.gpus, 292 | num_nodes=cfg.resources.num_nodes, 293 | val_check_interval=cfg.optim.val_check_interval, 294 | limit_val_batches=cfg.optim.num_val_iter, 295 | callbacks=[checkpoint_callback, lr_monitor], 296 | resume_from_checkpoint=None, # Only loading weights 297 | max_epochs=cfg.optim.max_epochs, 298 | accelerator=cfg.resources.accelerator 299 | if cfg.resources.accelerator != "none" 300 | else None, 301 | deterministic=False, 302 | ) 303 | trainer.fit(model, data_module) 304 | 305 | 306 | @hydra.main(config_name="config") 307 | def main(cfg: config.Config) -> None: 308 | # Set the everythin - randon, numpy, torch, torch manula, cuda!! 309 | seed_everything(12) 310 | 311 | # If not cluster launch job locally 312 | if not cfg.resources.use_cluster: 313 | train_model(cfg) 314 | else: 315 | print(OmegaConf.to_yaml(cfg)) # TODO: Add this to tensorboard logging 316 | 317 | # dummy to get veriosn only 318 | dummy_logger = TensorBoardLogger( 319 | osp.join(_base_path, cfg.logging.log_dir), name=cfg.logging.name 320 | ) 321 | 322 | submitit_dir = osp.join( 323 | _base_path, 324 | cfg.logging.log_dir, 325 | cfg.logging.name, 326 | "submitit_" + str(dummy_logger.version), 327 | ) 328 | executor = submitit.AutoExecutor(folder=submitit_dir) 329 | 330 | job_kwargs = { 331 | "timeout_min": cfg.resources.time, 332 | "name": cfg.logging.name, 333 | "slurm_partition": cfg.resources.partition, 334 | "gpus_per_node": cfg.resources.gpus, 335 | "tasks_per_node": cfg.resources.gpus, # one task per GPU 336 | "cpus_per_task": 5, 337 | "nodes": cfg.resources.num_nodes, 338 | } 339 | if cfg.resources.max_mem: 340 | job_kwargs["slurm_constraint"] = "volta32gb" 341 | if cfg.resources.partition == "priority": 342 | job_kwargs["slurm_comment"] = cfg.resources.comment 343 | 344 | executor.update_parameters(**job_kwargs) 345 | job = executor.submit(train_model, cfg) 346 | print("Submitit Job ID:", job.job_id) # ID of your job 347 | 348 | 349 | if __name__ == "__main__": 350 | main() 351 | -------------------------------------------------------------------------------- /unified_distillation.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import os.path as osp 5 | import pdb 6 | import time 7 | 8 | import hydra 9 | import numpy as np 10 | import submitit 11 | import torch 12 | import torchvision.transforms as T 13 | import torchvision.transforms.functional as T_F 14 | from data.distillation import DistillationDataModule, get_paths 15 | from hydra_config import config 16 | from model import get_model 17 | from omegaconf import DictConfig, OmegaConf 18 | from pytorch_lightning import Trainer, seed_everything 19 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 20 | from pytorch_lightning.core.lightning import LightningModule 21 | from pytorch_lightning.loggers import TensorBoardLogger 22 | from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler 23 | from torchvision.transforms import ToTensor 24 | from volumetric_render import ( 25 | network_query_fn_validation, 26 | network_query_fn_train, 27 | render_img, 28 | render_rays, 29 | get_rays_from_angles, 30 | ) 31 | 32 | 33 | os.environ["MKL_THREADING_LAYER"] = "GNU" 34 | 35 | _curr_path = osp.dirname(osp.abspath(__file__)) 36 | _base_path = _curr_path # osp.join(_curr_path, "..") 37 | 38 | 39 | class VolumetricNetworkCkpt(LightningModule): 40 | def __init__( 41 | self, 42 | cfg: DictConfig, 43 | ): 44 | super().__init__() 45 | self.cfg = cfg 46 | 47 | # make the e2e (encoder + decoder) model. 48 | self.model = get_model(cfg.model) 49 | 50 | # Save hyperparameters 51 | self.save_hyperparameters(cfg) 52 | 53 | cam_num = self.cfg.render.cam_num if self.cfg.render.cam_num > 0 else 1 54 | self.ray_num = cam_num * self.cfg.render.ray_num_per_cam 55 | 56 | 57 | def create_encoder_transforms(): 58 | 59 | transforms = T.Compose( 60 | [ 61 | # T.Pad(padding=), 62 | T.RandomApply( 63 | torch.nn.ModuleList( 64 | [T.ColorJitter(brightness=0.15, hue=0.15, saturation=0.15)] 65 | ), 66 | p=0.5, 67 | ), 68 | T.RandomApply( 69 | torch.nn.ModuleList([T.GaussianBlur(kernel_size=(3, 3))]), 70 | p=0.5, 71 | ), 72 | T.RandomRotation(degrees=(-45, 45)), 73 | # T.RandomAdjustSharpness(sharpness_factor=2), 74 | T.RandomHorizontalFlip(p=0.5), 75 | T.Resize((224, 224)), 76 | # T.RandomVerticalFlip(p=0.5), 77 | ] 78 | ) 79 | return transforms 80 | 81 | 82 | class VolumetricNetwork(LightningModule): 83 | def __init__( 84 | self, 85 | cfg: DictConfig, 86 | ): 87 | super().__init__() 88 | self.cfg = cfg 89 | 90 | # make the e2e (encoder + decoder) model. 91 | self.model = get_model(cfg.model) 92 | 93 | # Save hyperparameters 94 | self.save_hyperparameters(cfg) 95 | 96 | cam_num = self.cfg.render.cam_num if self.cfg.render.cam_num > 0 else 1 97 | self.ray_num = cam_num * self.cfg.render.ray_num_per_cam 98 | 99 | self.load_ckpts_on_cpu() 100 | 101 | self.encoder_transforms = create_encoder_transforms() 102 | self.use_encoder_transforms = self.cfg.distillation.use_encoder_transforms 103 | 104 | def _process_encoder_images(self, inp_images): 105 | if self.use_encoder_transforms: 106 | out_enc_imgs = [] 107 | for img in inp_images: 108 | temp_img = T_F.pad(img, torch.randint(0, 15, (1,)).item()) 109 | temp_img = self.encoder_transforms(temp_img) 110 | out_enc_imgs.append(temp_img) 111 | inp_images = torch.stack(out_enc_imgs, dim=0) 112 | return inp_images 113 | 114 | def load_ckpts_on_cpu(self): 115 | 116 | self.distillation_ckpts = [] 117 | 118 | assert self.cfg.distillation.ckpts_root_dir is not None 119 | 120 | checkpoint_paths = get_paths( 121 | self.cfg.distillation.ckpts_root_dir, 122 | self.cfg.distillation.regex_match, 123 | self.cfg.distillation.regex_exclude, 124 | ) 125 | checkpoint_paths = [ 126 | self.cfg.distillation.warehouse3d_ckpt_path 127 | ] + checkpoint_paths 128 | print("!!!!!!!!!!!!!Loading checkpoints from root dir!!!!!!!!") 129 | print(checkpoint_paths, sep="\n") 130 | ckpt_idx = 0 131 | for checkpoint_path in checkpoint_paths: 132 | temp_model = VolumetricNetworkCkpt.load_from_checkpoint(checkpoint_path) 133 | temp_model = temp_model.to("cpu") 134 | self.distillation_ckpts.append(temp_model) 135 | out_str = f"loading checkpoint {ckpt_idx} onto cpu for distillation" 136 | tot_m, used_m, free_m = map( 137 | int, os.popen("free -t -m").readlines()[-1].split()[1:] 138 | ) 139 | out_str += f" Memory Stats: Total {tot_m}, Used {used_m}, free {free_m}" 140 | ckpt_idx += 1 141 | print(out_str) 142 | 143 | def configure_optimizers(self): 144 | optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg.optim.lr) 145 | if self.cfg.optim.use_scheduler: 146 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 147 | optimizer, 148 | factor=0.5, 149 | patience=75, 150 | threshold_mode="abs", 151 | threshold=0.005, 152 | ) 153 | scheduler = {"scheduler": scheduler, "monitor": "train_loss"} 154 | return {"optimizer": optimizer, "lr_scheduler": scheduler} 155 | else: 156 | return optimizer 157 | 158 | def validation_step(self, batch, batch_idx): 159 | """ 160 | This is the method that gets distributed 161 | """ 162 | with torch.no_grad(): 163 | 164 | if self.use_encoder_transforms: 165 | enc_inp_images = self._process_encoder_images(batch["rgb_img"]) 166 | else: 167 | enc_inp_images = batch["rgb_img"] 168 | 169 | if self.cfg.distillation.mode == "point": 170 | loss = self.compute_point_loss(batch) 171 | else: 172 | loss = self.compute_ray_loss(batch) 173 | 174 | # Vol Render output image for logging 175 | def render_ref_pose_images(model, inp_img): 176 | c_latent = model.encoder(inp_img.unsqueeze(0)) 177 | render_kwargs = { 178 | "network_query_fn": network_query_fn_validation, 179 | "N_samples": 100, 180 | "decoder": model.decoder, 181 | "c_latent": c_latent, 182 | "chunk": 1000, 183 | "device": self.device, 184 | "has_rgb": self.cfg.render.rgb, 185 | "has_normal": False, 186 | } 187 | 188 | poses = [ 189 | (0, 0, 1.5), 190 | (90, 90, 1.5), 191 | (0, 90, 1.5), 192 | (90, 0, 1.5), 193 | ] 194 | occ_imgs = [] 195 | rgb_imgs = [] 196 | for p in poses: 197 | elev_angle, azim_angle, dist = p 198 | _, occ_img, rgb_img, _ = render_img( 199 | dist=torch.tensor([dist]).type_as(batch["dist"][-1]), 200 | elev_angle=torch.tensor([elev_angle]).type_as( 201 | batch["dist"][-1] 202 | ), 203 | azim_angle=torch.tensor([azim_angle]).type_as( 204 | batch["dist"][-1] 205 | ), 206 | img_size=self.cfg.render.img_size, 207 | focal=self.cfg.render.focal_length, 208 | render_kwargs=render_kwargs, 209 | ) 210 | occ_imgs.append(occ_img) 211 | rgb_imgs.append(rgb_img) 212 | 213 | occ_img = torch.cat(occ_imgs, dim=0) 214 | if self.cfg.render.rgb: 215 | rgb_img = (torch.cat(rgb_imgs, dim=0)).permute(2, 0, 1) 216 | else: 217 | rgb_img = None 218 | 219 | return rgb_img, occ_img.unsqueeze(0) 220 | 221 | rgb_img, occ_img = render_ref_pose_images(self.model, batch["rgb_img"][-1]) 222 | 223 | model_id = batch["dataset_id"][-1] 224 | temp_model = self.distillation_ckpts[model_id.item()] 225 | 226 | temp_model = temp_model.to(self.device) 227 | 228 | rgb_img_label, occ_img_label = render_ref_pose_images( 229 | temp_model.model, batch["rgb_img"][-1] 230 | ) 231 | self.distillation_ckpts[model_id.item()] = temp_model.to("cpu") 232 | 233 | return { 234 | "loss": loss, 235 | "inp_img": batch["rgb_img"][-1], 236 | "mask_teacher": occ_img_label, 237 | "rgb_teacher": rgb_img_label, 238 | "vol_render": occ_img, 239 | "vol_render_rgb": rgb_img, 240 | "transformed_encoder_inp_img": enc_inp_images[-1], 241 | } 242 | 243 | def validation_epoch_end(self, validation_epoch_outputs): 244 | 245 | avg_loss = torch.cat( 246 | [l["loss"].unsqueeze(0) for l in validation_epoch_outputs] 247 | ).mean() 248 | 249 | inp_img = torch.cat([l["inp_img"] for l in validation_epoch_outputs], -1) 250 | self.logger.experiment.add_image("val_inp_rgb", inp_img, self.global_step) 251 | 252 | inp_img = torch.cat( 253 | [l["transformed_encoder_inp_img"] for l in validation_epoch_outputs], -1 254 | ) 255 | self.logger.experiment.add_image( 256 | "transformed_encoder_inp_img", inp_img, self.global_step 257 | ) 258 | 259 | vol_render = torch.cat([l["vol_render"] for l in validation_epoch_outputs], -1) 260 | self.logger.experiment.add_image("val_vol_render", vol_render, self.global_step) 261 | 262 | mask_teacher = torch.cat( 263 | [l["mask_teacher"] for l in validation_epoch_outputs], -1 264 | ) 265 | self.logger.experiment.add_image("mask_teacher", mask_teacher, self.global_step) 266 | 267 | if self.cfg.render.rgb: 268 | vol_render_rgb = torch.cat( 269 | [l["vol_render_rgb"] for l in validation_epoch_outputs], -1 270 | ) 271 | self.logger.experiment.add_image( 272 | "vol_render_rgb", vol_render_rgb, self.global_step 273 | ) 274 | 275 | rgb_teacher = torch.cat( 276 | [l["rgb_teacher"] for l in validation_epoch_outputs], -1 277 | ) 278 | self.logger.experiment.add_image( 279 | "rgb_teacher", rgb_teacher, self.global_step 280 | ) 281 | 282 | self.log("val_loss", avg_loss, on_step=False, on_epoch=True, prog_bar=True) 283 | return {"val_loss": avg_loss, "progress_bar": {"global_step": self.global_step}} 284 | 285 | def _extract_point_batch(self, batch): 286 | 287 | inp_imgs = batch["rgb_img"] 288 | model_ids = batch["dataset_id"] 289 | 290 | batch_size = len(model_ids) 291 | num_points = self.cfg.distillation.num_points # TODO: Tune This! 292 | sample_bounds_min = self.cfg.distillation.sample_bounds[0] 293 | sample_bounds_max = self.cfg.distillation.sample_bounds[1] 294 | 295 | with torch.no_grad(): 296 | query_points = torch.rand(batch_size, num_points, 3) 297 | query_points = ( 298 | sample_bounds_min 299 | + (sample_bounds_max - sample_bounds_min) * query_points 300 | ) 301 | query_points = query_points.to(self.device) 302 | # outputs are r,g,b,density 303 | # TODO: Still not clear how to handle density!! 304 | outputs = [] 305 | 306 | for i, model_id in enumerate(model_ids): 307 | 308 | temp_model = self.distillation_ckpts[model_id.item()] 309 | temp_model = temp_model.to(self.device) 310 | temp_c_latent = temp_model.model.encoder(inp_imgs[i].unsqueeze(0)) 311 | 312 | outputs.append( 313 | network_query_fn_validation( 314 | query_points[i].unsqueeze(0), 315 | temp_c_latent, 316 | temp_model.model.decoder, 317 | ) 318 | ) 319 | self.distillation_ckpts[model_id.item()] = temp_model.to("cpu") 320 | 321 | outputs = torch.cat(outputs, dim=0) 322 | return inp_imgs, query_points, outputs 323 | 324 | def compute_point_loss(self, batch): 325 | 326 | inp_imgs, query_points, labels_raw = self._extract_point_batch(batch) 327 | 328 | if self.use_encoder_transforms: 329 | inp_imgs = self._process_encoder_images(inp_imgs) 330 | 331 | self.model.to(self.device) 332 | c_latent = self.model.encoder(inp_imgs) 333 | preds_raw = network_query_fn_train(query_points, c_latent, self.model.decoder) 334 | 335 | colors = preds_raw[..., :3] 336 | colors_labels = labels_raw[..., :3] 337 | loss_rgb = torch.nn.functional.mse_loss(colors, colors_labels) 338 | 339 | density = torch.nn.functional.relu(preds_raw[..., 3]) 340 | density_labels = torch.nn.functional.relu(labels_raw[..., 3]) 341 | loss_density = torch.nn.functional.mse_loss(density, density_labels) 342 | 343 | loss = (loss_density + loss_rgb) / 2.0 344 | 345 | return loss 346 | 347 | def _extract_ray_batch(self, batch): 348 | 349 | inp_imgs = batch["rgb_img"] 350 | model_ids = batch["dataset_id"] 351 | batch_size = len(model_ids) 352 | 353 | with torch.no_grad(): 354 | rays = get_rays_from_angles( 355 | H=self.cfg.render.img_size, 356 | W=self.cfg.render.img_size, 357 | focal=float(self.cfg.render.focal_length), 358 | near_plane=self.cfg.render.near_plane, 359 | far_plane=self.cfg.render.far_plane, 360 | elev_angles=batch["elev_angle"], 361 | azim_angles=batch["azim_angle"], 362 | dists=batch["dist"], 363 | device=self.device, 364 | indices=batch["flat_indices"], 365 | ) # [(N*Num_rays), 8] #2d 366 | 367 | mask_ray_labels = [] 368 | rgb_ray_labels = [] 369 | 370 | rays = rays.reshape(batch_size, self.ray_num, rays.shape[-1]) 371 | 372 | for i, model_id in enumerate(model_ids): 373 | 374 | temp_model = self.distillation_ckpts[model_id.item()] 375 | temp_model = temp_model.to(self.device) 376 | temp_c_latent = temp_model.model.encoder(inp_imgs[i].unsqueeze(0)) 377 | temp_c_latent = temp_c_latent.repeat(self.ray_num, 1) 378 | 379 | ray_outs = render_rays( 380 | ray_batch=rays[i], 381 | c_latent=temp_c_latent, 382 | decoder=temp_model.model.decoder, 383 | N_samples=self.cfg.render.on_ray_num_samples, 384 | has_rgb=self.cfg.render.rgb, 385 | has_normal=self.cfg.render.normals, 386 | ) 387 | mask_ray_labels.append(ray_outs["acc_map"]) 388 | rgb_ray_labels.append(ray_outs["rgb_map"]) 389 | 390 | self.distillation_ckpts[model_id.item()] = temp_model.to("cpu") 391 | 392 | rays = rays.reshape(-1, rays.shape[-1]) 393 | mask_ray_labels = torch.cat(mask_ray_labels, dim=0) 394 | rgb_ray_labels = torch.cat(rgb_ray_labels, dim=0) 395 | return inp_imgs, mask_ray_labels, rgb_ray_labels, rays 396 | 397 | def compute_ray_loss(self, batch): 398 | 399 | # [N,3, img_size, img_size], 1-d [N*num_rays], [N*num_rays, 8] 400 | ( 401 | inp_imgs, 402 | mask_ray_labels, 403 | rgb_ray_labels, 404 | rays, 405 | ) = self._extract_ray_batch(batch) 406 | 407 | if self.use_encoder_transforms: 408 | inp_imgs = self._process_encoder_images(inp_imgs) 409 | 410 | self.model.to(self.device) 411 | c_latent = self.model.encoder(inp_imgs) # [N, c_dim] 412 | # For instance, C = [[1,2],[3,4]] and num_rays = 2 413 | # below lines return C = [[1,2],[1,2],[3,4],[3,4]] 414 | c_latent = ( 415 | c_latent.unsqueeze(1).repeat(1, self.ray_num, 1).view(-1, c_latent.shape[1]) 416 | ) # [N*N_num_rays, c_dim] 417 | 418 | ray_outs = render_rays( 419 | ray_batch=rays, # [N*num_rays, 8] 420 | c_latent=c_latent, # [N*N_num_rays, c_dim] 421 | decoder=self.model.decoder, # nn.Module 422 | N_samples=self.cfg.render.on_ray_num_samples, # int 423 | has_rgb=self.cfg.render.rgb, 424 | has_normal=self.cfg.render.normals, 425 | ) 426 | mask_ray_outs = ray_outs["acc_map"].to(self.device) # [N*num_rays] 1-d 427 | 428 | loss = [] 429 | loss_mask = torch.nn.functional.mse_loss( 430 | mask_ray_labels, 431 | mask_ray_outs, # reduction="none" 432 | ) # [N*num_rays] 1-d 433 | loss.append(loss_mask) 434 | 435 | if self.cfg.render.rgb: 436 | rgb_ray_outs = ray_outs["rgb_map"].to(self.device) 437 | loss_rgb = torch.nn.functional.mse_loss( 438 | rgb_ray_labels, 439 | rgb_ray_outs, 440 | ) 441 | loss.append(loss_rgb) 442 | 443 | # loss = torch.tensor(loss, requires_grad=True) 444 | if self.cfg.render.rgb: 445 | loss = (loss_mask + loss_rgb) / 2.0 446 | else: 447 | loss = loss_mask 448 | 449 | return loss 450 | 451 | def training_step(self, batch, batch_idx): 452 | """ 453 | This is the method that gets distributed 454 | """ 455 | if self.cfg.distillation.mode == "point": 456 | loss = self.compute_point_loss(batch) 457 | else: 458 | loss = self.compute_ray_loss(batch) 459 | self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True) 460 | return {"loss": loss, "progress_bar": {"global_step": self.global_step}} 461 | 462 | def forward(self, mode, inputs): 463 | pass 464 | 465 | 466 | def train_model(cfg): 467 | print(OmegaConf.to_yaml(cfg)) # TODO: Add this to tensorboard logging 468 | 469 | # configure data loader 470 | distillation_data_loader = DistillationDataModule( 471 | cfg=cfg, 472 | base_path=_base_path, 473 | ) 474 | 475 | model = VolumetricNetwork(cfg=cfg) 476 | log_dir = osp.join(_base_path, cfg.logging.log_dir, cfg.logging.name) 477 | os.makedirs(log_dir, exist_ok=True) 478 | OmegaConf.save(cfg, osp.join(log_dir, "config.txt")) 479 | 480 | logger = TensorBoardLogger( 481 | osp.join(_base_path, cfg.logging.log_dir), name=cfg.logging.name 482 | ) 483 | checkpoint_callback = ModelCheckpoint( 484 | save_top_k=-1, 485 | every_n_val_epochs=cfg.optim.save_freq, 486 | filename="checkpoint_{epoch}", 487 | ) 488 | 489 | if cfg.optim.use_pretrain: 490 | temp_model = VolumetricNetworkCkpt.load_from_checkpoint( 491 | cfg.optim.checkpoint_path 492 | ) 493 | model.model.load_state_dict(temp_model.model.state_dict()) 494 | else: 495 | checkpoint = None 496 | 497 | lr_monitor = LearningRateMonitor(logging_interval="step") 498 | 499 | # https://pytorch-lightning.readthedocs.io/en/latest/trainer.html 500 | trainer = Trainer( 501 | logger=logger, 502 | gpus=cfg.resources.gpus, 503 | num_nodes=cfg.resources.num_nodes, 504 | val_check_interval=cfg.optim.val_check_interval, 505 | limit_val_batches=cfg.optim.num_val_iter, 506 | callbacks=[checkpoint_callback, lr_monitor], 507 | # resume_from_checkpoint=checkpoint, 508 | resume_from_checkpoint=None, # Only loading weights 509 | max_epochs=cfg.optim.max_epochs, 510 | accelerator=cfg.resources.accelerator 511 | if cfg.resources.accelerator != "none" 512 | else None, 513 | deterministic=False, 514 | # profiler="simple", 515 | ) 516 | trainer.fit(model, distillation_data_loader) 517 | 518 | 519 | @hydra.main(config_name="config") 520 | def main(cfg: config.Config) -> None: 521 | # Set the everythin - randon, numpy, torch, torch manula, cuda!! 522 | seed_everything(12) 523 | 524 | # If not cluster launch job locally 525 | if not cfg.resources.use_cluster: 526 | train_model(cfg) 527 | else: 528 | print(OmegaConf.to_yaml(cfg)) # TODO: Add this to tensorboard logging 529 | 530 | # dummy to get veriosn only 531 | dummy_logger = TensorBoardLogger( 532 | osp.join(_base_path, cfg.logging.log_dir), name=cfg.logging.name 533 | ) 534 | 535 | submitit_dir = osp.join( 536 | _base_path, 537 | cfg.logging.log_dir, 538 | cfg.logging.name, 539 | "submitit_" + str(dummy_logger.version), 540 | ) 541 | executor = submitit.AutoExecutor(folder=submitit_dir) 542 | 543 | job_kwargs = { 544 | "mem_gb": 700, 545 | "timeout_min": cfg.resources.time, 546 | "name": cfg.logging.name, 547 | "slurm_partition": cfg.resources.partition, 548 | "gpus_per_node": cfg.resources.gpus, 549 | "tasks_per_node": cfg.resources.gpus, # one task per GPU 550 | "cpus_per_task": 5, 551 | "nodes": cfg.resources.num_nodes, 552 | } 553 | if cfg.resources.max_mem: 554 | job_kwargs["slurm_constraint"] = "volta32gb" 555 | if cfg.resources.partition == "priority": 556 | job_kwargs["slurm_comment"] = cfg.resources.comment 557 | 558 | executor.update_parameters(**job_kwargs) 559 | job = executor.submit(train_model, cfg) 560 | print("Submitit Job ID:", job.job_id) # ID of your job 561 | 562 | 563 | if __name__ == "__main__": 564 | main() 565 | -------------------------------------------------------------------------------- /volumetric_render.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Some functions in this code are borrowed from https://github.com/yenchenlin/nerf-pytorc 8 | # Copyright (c) 2020 bmild, MIT License 9 | 10 | import json 11 | import math 12 | import os 13 | import random 14 | import sys 15 | import time 16 | 17 | import imageio 18 | import numpy as np 19 | import pytorch3d 20 | import pytorch3d.renderer 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | from tqdm import tqdm, trange 25 | 26 | 27 | def network_query_fn_train(inputs, c, decoder, use_abs=False): 28 | # inputs shape : [N_rays, N_samples, 3] 29 | # c shpae: either [N_rays,c_size] 30 | if use_abs: 31 | return torch.abs(decoder(inputs, c=c)) 32 | 33 | return decoder(inputs, c=c) 34 | 35 | 36 | def network_query_fn_validation(inputs, c, decoder, use_relu=False): 37 | # inputs shape : [ray_chunk, N_samples, 3] 38 | # c shpae: either [1,c_size] 39 | assert c.shape[0] == 1, "C should be of shape 1*c_dim for val" 40 | 41 | c_inp = torch.cat(inputs.shape[0] * [c]) 42 | out = decoder(inputs, c=c_inp.type_as(inputs)) 43 | 44 | flags_x = (inputs[..., 0] > 0.5) * (inputs[..., 0] < -0.5) 45 | flags_y = (inputs[..., 1] > 0.5) * (inputs[..., 1] < -0.5) 46 | flags_z = (inputs[..., 2] > 0.5) * (inputs[..., 2] < -0.5) 47 | 48 | flags = flags_x * flags_y * flags_z 49 | flags = torch.stack([flags, flags, flags, flags], dim=-1) 50 | 51 | outs = decoder(inputs, c=c) 52 | outs[flags] = 0.0 53 | return outs 54 | 55 | 56 | def get_transformation(dist, elev, azim): 57 | # returns camera to world 58 | R, T = pytorch3d.renderer.look_at_view_transform(dist, elev, azim) 59 | T_rays = pytorch3d.renderer.camera_position_from_spherical_angles(dist, elev, azim) 60 | c2w = torch.cat((R, T_rays.reshape(1, 3, 1)), -1)[0] 61 | return torch.cat([c2w, torch.tensor([[0.0, 0.0, 0.0, 1.0]])], dim=0) 62 | 63 | 64 | def get_rays_from_angles( 65 | *, 66 | H, 67 | W, 68 | focal, 69 | near_plane, 70 | far_plane, 71 | elev_angles, 72 | azim_angles, 73 | dists, 74 | device, 75 | indices, 76 | transformation_rel=None, # 4x4 homogeneous matrix encoder to world 77 | ): 78 | rays = [] 79 | for i in range(elev_angles.shape[0]): 80 | if transformation_rel is not None: 81 | transformation_view_to_world = get_transformation( 82 | dists[i], elev_angles[i], azim_angles[i] 83 | ) 84 | # view to encoder 85 | c2w = torch.matmul( 86 | transformation_rel.inverse(), transformation_view_to_world 87 | ) 88 | c2w = c2w[:3, :] 89 | else: 90 | R, T = pytorch3d.renderer.look_at_view_transform( 91 | dists[i], elev_angles[i], azim_angles[i] 92 | ) 93 | T_rays = pytorch3d.renderer.camera_position_from_spherical_angles( 94 | dists[i], elev_angles[i], azim_angles[i] 95 | ) 96 | c2w = torch.cat((R, T_rays.reshape(1, 3, 1)), -1)[0] 97 | rays_o, rays_d = get_rays( 98 | H, W, focal, c2w 99 | ) # (rays_o = [H, W,3], rays_d = [H, W,3]) 100 | rays_o = rays_o.reshape(-1, 3) # [H*W, 3] 101 | rays_d = rays_d.reshape(-1, 3) # [H*W, 3] 102 | rays_o = rays_o[indices[i]].to(device) # [num_rays, 3] 103 | rays_d = rays_d[indices[i]].to(device) # [num_rays, 3] 104 | 105 | near, far = ( 106 | # near_plane * torch.ones_like(rays_d[..., :1]), 107 | # far_plane * torch.ones_like(rays_d[..., :1]), 108 | max(dists[i].to(device) - 0.90, 0.0) * torch.ones_like(rays_d[..., :1]), 109 | (dists[i].to(device) + 0.90) * torch.ones_like(rays_d[..., :1]), 110 | ) # [num_rays, 1], # [num_rays, 1] 111 | rays.append( 112 | torch.cat([rays_o, rays_d, near, far], -1).to(device) 113 | ) # List([num_rays, 8]) 114 | 115 | return torch.cat(rays, dim=0) # [N*num_rays, 8] (N=Batch size) 116 | 117 | 118 | # Ray helpers 119 | def get_rays(H, W, focal, c2w): 120 | i, j = torch.meshgrid( 121 | torch.linspace(0, W - 1, W), torch.linspace(0, H - 1, H) 122 | ) # pytorch's meshgrid has indexing='ij' 123 | i = i.t() 124 | j = j.t() 125 | dirs = torch.stack( 126 | [-(i - W * 0.5) / focal, -(j - H * 0.5) / focal, torch.ones_like(i)], -1 127 | ) # https://pytorch3d.org/docs/renderer_getting_started 128 | # Rotate ray directions from camera frame to the world frame 129 | rays_d = torch.sum( 130 | dirs[..., np.newaxis, :] * c2w[:3, :3], -1 131 | ) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 132 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 133 | rays_o = c2w[:3, -1].expand(rays_d.shape) 134 | return rays_o, rays_d # (rays_o = [H, W,3], rays_d = [H, W,3]) 135 | 136 | 137 | def raw2outputs( 138 | raw, 139 | z_vals, 140 | rays_d, 141 | raw_noise_std=0, 142 | white_bkgd=False, 143 | has_rgb=True, 144 | render_depth=False, 145 | raw_normals=None, 146 | ): 147 | if white_bkgd: 148 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1.0 - torch.exp( 149 | -act_fn(5 * raw) * dists 150 | ) 151 | else: 152 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1.0 - torch.exp( 153 | -act_fn(raw) * dists 154 | ) 155 | 156 | dists = z_vals[..., 1:] - z_vals[..., :-1] 157 | dists = torch.cat( 158 | [dists, torch.Tensor([1e10]).expand(dists[..., :1].shape).type_as(z_vals)], -1 159 | ) # [N_rays, N_samples] 160 | dists = dists * torch.norm(rays_d[..., None, :], dim=-1) 161 | 162 | if has_rgb: 163 | rgb = torch.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3] 164 | 165 | noise = 0.0 166 | if raw_noise_std > 0.0: 167 | noise = torch.randn(raw[..., 3].shape) * raw_noise_std 168 | 169 | if has_rgb: 170 | alpha = raw2alpha(raw[..., 3] + noise, dists) # [N_rays, N_samples] # Ex: Nert 171 | else: 172 | alpha = raw2alpha(raw + noise, dists) # [N_rays, N_samples] # Ex: Onet 173 | 174 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) 175 | weights = ( 176 | alpha # n(sample) W_n = alpha_n * (prod_{j