├── LICENSE ├── README.md ├── configs ├── config_dtu.txt ├── config_general.txt ├── config_llff.txt ├── config_nerf.txt └── lists │ ├── dtu_pairs.txt │ ├── dtu_pairs_ft.txt │ ├── dtu_pairs_val.txt │ ├── dtu_train_all.txt │ └── dtu_val_all.txt ├── data ├── __init__.py ├── dtu.py ├── get_datasets.py ├── llff.py └── nerf.py ├── model ├── __init__.py ├── geo_reasoner.py └── self_attn_renderer.py ├── pretrained_weights └── .gitignore ├── requirements.txt ├── run_geo_nerf.py └── utils ├── __init__.py ├── options.py ├── rendering.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | > # [CVPR 2022] GeoNeRF: Generalizing NeRF with Geometry Priors
2 | > Mohammad Mahdi Johari, Yann Lepoittevin, François Fleuret
3 | > [Project Page](https://www.idiap.ch/paper/geonerf/) | [Paper](https://arxiv.org/abs/2111.13539) 4 | 5 | This repository contains a PyTorch Lightning implementation of our paper, GeoNeRF: Generalizing NeRF with Geometry Priors. 6 | 7 | ## Installation 8 | 9 | #### Tested on NVIDIA Tesla V100 and GeForce RTX 3090 GPUs with PyTorch 1.9 and PyTorch Lightning 1.3.7 10 | 11 | To install the dependencies, in addition to PyTorch, run: 12 | 13 | ``` 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ## Evaluation and Training 18 | To reproduce our results, download pretrained weights from [here](https://drive.google.com/drive/folders/1ZtAc7VYvltcdodT_BrUrQ_4IAhz_L-Rf?usp=sharing) and put them in [pretrained_weights](./pretrained_weights) folder. Then, follow the instructions for each of the [LLFF (Real Forward-Facing)](#llff-real-forward-facing-dataset), [NeRF (Realistic Synthetic)](#nerf-realistic-synthetic-dataset), and [DTU](#dtu-dataset) datasets. 19 | 20 | ## LLFF (Real Forward-Facing) Dataset 21 | Download `nerf_llff_data.zip` from [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) and set its path as `llff_path` in the [config_llff.txt](./configs/config_llff.txt) file. 22 | 23 | For evaluating our generalizable model (`pretrained.ckpt` model in the [pretrained_weights](./pretrained_weights) folder), set the `scene` properly (e.g. fern) and set the number of source views to 9 (nb_views = 9) in the [config_llff.txt](./configs/config_llff.txt) file and run the following command: 24 | 25 | ``` 26 | python run_geo_nerf.py --config configs/config_llff.txt --eval 27 | ``` 28 | 29 | For fine-tuning on a specific scene, set nb_views = 7 and run the following command: 30 | 31 | ``` 32 | python run_geo_nerf.py --config configs/config_llff.txt 33 | ``` 34 | 35 | Once fine-tuning is finished, run the evaluation command with nb_views = 9 to get the final rendered results. 36 | 37 | ## NeRF (Realistic Synthetic) Dataset 38 | Download `nerf_synthetic.zip` from [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) and set its path as `nerf_path` in the [config_nerf.txt](configs/config_nerf.txt) file. 39 | 40 | For evaluating our generalizable model (`pretrained.ckpt` model in the [pretrained_weights](./pretrained_weights) folder), set the `scene` properly (e.g. lego) and set the number of source views to 9 (nb_views = 9) in the [config_nerf.txt](configs/config_nerf.txt) file and run the following command: 41 | 42 | ``` 43 | python run_geo_nerf.py --config configs/config_nerf.txt --eval 44 | ``` 45 | 46 | For fine-tuning on a specific scene, set nb_views = 7 and run the following command: 47 | 48 | ``` 49 | python run_geo_nerf.py --config configs/config_nerf.txt 50 | ``` 51 | 52 | Once fine-tuning is finished, run the evaluation command with nb_views = 9 to get the final rendered results. 53 | 54 | ## DTU Dataset 55 | Download the preprocessed [DTU training data](https://drive.google.com/file/d/1eDjh-_bxKKnEuz5h-HXS7EDJn59clx6V/view) 56 | and replace its `Depths` directory with [Depth_raw](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/cascade-stereo/CasMVSNet/dtu_data/dtu_train_hr/Depths_raw.zip) from original [MVSNet repository](https://github.com/YoYo000/MVSNet), and set `dtu_pre_path` referring to this dataset in the [config_dtu.txt](configs/config_dtu.txt) file. 57 | 58 | Then, download the original `Rectified` images from [DTU Website](https://roboimagedata.compute.dtu.dk/?page_id=36), and set `dtu_path` in the [config_dtu.txt](configs/config_dtu.txt) file accordingly. 59 | 60 | For evaluating our generalizable model (`pretrained.ckpt` model in the [pretrained_weights](./pretrained_weights) folder), set the `scene` properly (e.g. scan21) and set the number of source views to 9 (nb_views = 9) in the [config_dtu.txt](./configs/config_dtu.txt) file and run the following command: 61 | 62 | ``` 63 | python run_geo_nerf.py --config configs/config_dtu.txt --eval 64 | ``` 65 | 66 | For fine-tuning on a specific scene, use the same nb_views = 9 and run the following command: 67 | 68 | ``` 69 | python run_geo_nerf.py --config configs/config_dtu.txt 70 | ``` 71 | 72 | Once fine-tuning is finished, run the evaluation command with nb_views = 9 to get the final rendered results. 73 | 74 | ### RGBD Compatible model 75 | By adding `--use_depth` argument to the aforementioned commands, you can use our RGB compatible model on the DTU dataset and exploit the ground truth, low-resolution depths to help the rendering process. The pretrained weights for this model is `pretrained_w_depth.ckpt`. 76 | 77 | ## Training From Scratch 78 | For training our model from scratch, first, prepare the following datasets: 79 | 80 | * The original `Rectified` images from [DTU](https://roboimagedata.compute.dtu.dk/?page_id=36). Set the corresponding path as `dtu_path` in the [config_general.txt](configs/config_general.txt) file. 81 | 82 | * The preprocessed [DTU training data](https://drive.google.com/file/d/1eDjh-_bxKKnEuz5h-HXS7EDJn59clx6V/view) 83 | with the replacement of its `Depths` directory with [Depth_raw](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/cascade-stereo/CasMVSNet/dtu_data/dtu_train_hr/Depths_raw.zip). Set the corresponding path as `dtu_pre_path` in the [config_general.txt](configs/config_general.txt) file. 84 | 85 | * LLFF released scenes. Download [real_iconic_noface.zip](https://drive.google.com/drive/folders/1M-_Fdn4ajDa0CS8-iqejv0fQQeuonpKF) and remove the test scenes with the following command: 86 | ``` 87 | unzip real_iconic_noface.zip 88 | cd real_iconic_noface/ 89 | rm -rf data2_fernvlsb data2_hugetrike data2_trexsanta data3_orchid data5_leafscene data5_lotr data5_redflower 90 | ``` 91 | Then, set the corresponding path as `llff_path` in the [config_general.txt](configs/config_general.txt) file. 92 | 93 | * Collected scenes from [IBRNet](https://github.com/googleinterns/IBRNet) ([Subset1](https://drive.google.com/file/d/1rkzl3ecL3H0Xxf5WTyc2Swv30RIyr1R_/view?usp=sharing) and [Subset2](https://drive.google.com/file/d/1Uxw0neyiIn3Ve8mpRsO6A06KfbqNrWuq/view?usp=sharing)). Set the corresponding paths as `ibrnet1_path` and `ibrnet2_path` in the [config_general.txt](configs/config_general.txt) file. 94 | 95 | Also, download `nerf_llff_data.zip` and `nerf_synthetic.zip` from [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) for validation and testing and set their corresponding paths as `llff_test_path` and `nerf_path` in the [config_general.txt](configs/config_general.txt) file. 96 | 97 | Once all the datasets are available, train the network from scratch with the following command: 98 | ``` 99 | python run_geo_nerf.py --config configs/config_general.txt 100 | ``` 101 | ### Contact 102 | You can contact the author through email: mohammad.johari At idiap.ch. 103 | 104 | ## Citing 105 | If you find our work useful, please consider citing: 106 | ```BibTeX 107 | @inproceedings{johari-et-al-2022, 108 | author = {Johari, M. and Lepoittevin, Y. and Fleuret, F.}, 109 | title = {GeoNeRF: Generalizing NeRF with Geometry Priors}, 110 | booktitle = {Proceedings of the IEEE International Conference on Computer Vision and Pattern Recognition (CVPR)}, 111 | year = {2022} 112 | } 113 | ``` 114 | 115 | ### Acknowledgement 116 | This work was supported by ams OSRAM. -------------------------------------------------------------------------------- /configs/config_dtu.txt: -------------------------------------------------------------------------------- 1 | ### INPUT 2 | expname = scan21_test 3 | logdir = ./logs 4 | nb_views = 9 #### use 9 for both evaluation and fine-tuning 5 | 6 | ## dataset 7 | dataset_name = dtu 8 | dtu_path = Path to DTU MVS 9 | dtu_pre_path = Path to preprocessed DTU MVS 10 | scene = scan21 11 | 12 | ### TESTING 13 | chunk = 4096 ### Reduce it to save memory 14 | 15 | ### TRAINING 16 | num_steps = 10000 17 | lrate = 0.0002 -------------------------------------------------------------------------------- /configs/config_general.txt: -------------------------------------------------------------------------------- 1 | ### INPUT 2 | expname = Generalizable 3 | logdir = ./logs 4 | nb_views = 6 5 | 6 | ## dataset 7 | dataset_name = llff 8 | dtu_path = Path to DTU MVS 9 | dtu_pre_path = Path to preprocessed DTU MVS 10 | llff_path = Path to LLFF training scenes (real_iconic_noface) 11 | ibrnet1_path = Path to IBRNet dataset 1 (ibrnet_collected_1) 12 | ibrnet2_path = Path to IBRNet dataset 1 (ibrnet_collected_2) 13 | nerf_path = Path to NeRF dataset (nerf_synthetic) 14 | llff_test_path = Path to LLFF test scenes (nerf_llff_data) 15 | scene = None 16 | 17 | ### TESTING 18 | chunk = 4096 ### Reduce it to save memory 19 | 20 | ### TRAINING 21 | num_steps = 250000 22 | lrate = 0.0005 -------------------------------------------------------------------------------- /configs/config_llff.txt: -------------------------------------------------------------------------------- 1 | ### INPUT 2 | expname = fern_test 3 | logdir = ./logs 4 | nb_views = 9 #### Set to 7 for fine-tuning 5 | 6 | ## dataset 7 | dataset_name = llff 8 | llff_path = Path to LLFF test scenes (nerf_llff_data) 9 | scene = fern 10 | 11 | ### TESTING 12 | chunk = 4096 ### Reduce it to save memory 13 | 14 | ### TRAINING 15 | num_steps = 10000 16 | lrate = 0.0002 -------------------------------------------------------------------------------- /configs/config_nerf.txt: -------------------------------------------------------------------------------- 1 | ### INPUT 2 | expname = lego_test 3 | logdir = ./logs 4 | nb_views = 9 #### Set to 7 for fine-tuning 5 | 6 | ## dataset 7 | dataset_name = nerf 8 | nerf_path = Path to NeRF dataset (nerf_synthetic) 9 | scene = lego 10 | 11 | ### TESTING 12 | chunk = 4096 ### Reduce it to save memory 13 | 14 | ### TRAINING 15 | num_steps = 10000 16 | lrate = 0.0002 -------------------------------------------------------------------------------- /configs/lists/dtu_pairs.txt: -------------------------------------------------------------------------------- 1 | 33 2 | 1 3 | 10 9 2850.87 10 2583.94 2 2105.59 0 2052.84 8 1868.24 13 1184.23 14 1017.51 12 961.966 7 670.208 15 657.218 4 | 2 5 | 10 8 2501.24 1 2106.88 7 1856.5 9 1782.34 3 1141.77 15 1061.76 14 815.457 16 762.153 6 709.789 10 699.921 6 | 8 7 | 10 15 3124.01 9 3099.92 14 2756.29 2 2501.22 7 2449.32 1 1875.94 16 1726.04 13 1325.76 23 1177.09 24 1108.82 8 | 9 9 | 10 13 3355.62 14 3226.07 8 3098.8 10 3097.07 1 2861.42 12 1873.63 2 1785.98 15 1753.32 25 1365.45 0 1261.59 10 | 10 11 | 10 12 3750.7 9 3085.87 13 3028.39 1 2590.55 0 2369.79 11 2266.67 14 1524.16 26 1448.15 27 1293.6 8 1041.84 12 | 11 13 | 10 12 3543.76 27 3056.05 10 2248.07 26 1524.28 28 1273.33 13 1265.9 29 1129.55 0 998.164 9 591.176 30 572.919 14 | 12 15 | 10 27 3889.87 10 3754.54 13 3745.21 11 3584.26 26 3574.56 25 1877.11 9 1866.34 29 1482.72 30 1418.51 14 1341.86 16 | 13 17 | 10 12 3773.14 26 3699.28 25 3657.17 14 3652.04 9 3356.29 10 3049.27 24 2098.91 27 1900.96 31 1460.96 30 1349.62 18 | 14 19 | 10 13 3663.52 24 3610.69 9 3232.55 25 3216.4 15 3128.84 8 2758.04 23 2219.91 26 1567.45 10 1536.6 32 1419.33 20 | 15 21 | 10 23 3194.92 14 3126 8 3120.43 16 2897.02 24 2562.49 7 2084.05 22 2041.63 9 1752.08 33 1232.29 13 1137.55 22 | 22 23 | 10 23 3232.68 34 3175.15 35 2831.09 16 2712.51 21 2632.19 15 2033.39 33 1712.67 17 1393.86 36 1290.96 24 1195.33 24 | 23 25 | 10 24 3710.9 33 3603.07 22 3244.2 15 3190.62 34 3086.49 14 2220.11 32 2100 16 1917.1 35 1359.79 25 1356.71 26 | 24 27 | 10 25 3844.6 32 3750.75 23 3710.6 14 3609.09 33 3091.04 15 2559.24 31 2423.71 13 2109.36 26 1440.58 34 1410.03 28 | 25 29 | 10 26 3951.74 31 3888.57 24 3833.07 13 3667.35 14 3208.21 32 2993.46 30 2681.52 12 1900.23 45 1484.03 27 1462.88 30 | 26 31 | 10 30 4033.35 27 3970.47 25 3925.25 13 3686.34 12 3595.59 29 2943.87 31 2917 14 1556.34 11 1554.75 46 1503.84 32 | 27 33 | 10 29 4027.84 26 3929.94 12 3875.58 11 3085.03 28 2908.6 30 2792.67 13 1878.42 25 1438.55 47 1425.2 10 1290.25 34 | 28 35 | 10 29 3687.02 48 3209.13 27 2872.86 47 2014.53 30 1361.95 11 1273.6 26 1062.85 12 840.841 46 672.985 31 271.952 36 | 29 37 | 10 27 4029.43 30 3909.55 28 3739.93 47 3695.23 48 3135.87 26 2910.97 46 2229.55 12 1479.16 31 1430.26 11 1144.56 38 | 30 39 | 10 26 4029.86 29 3953.72 31 3811.12 46 3630.46 47 3105.96 27 2824.43 25 2657.89 45 2347.75 32 1459.11 12 1429.62 40 | 31 41 | 10 25 3882.21 30 3841.88 32 3808.5 45 3649.82 46 3000.67 26 2939.94 24 2409.93 44 2381.3 13 1467.59 29 1459.56 42 | 32 43 | 10 31 3826.5 24 3744.14 33 3613.24 44 3552.04 25 3004.6 45 2884.59 43 2393.34 23 2095.27 30 1478.6 14 1420.78 44 | 33 45 | 10 32 3618.11 23 3598.1 34 3530.53 43 3462.37 24 3091.53 44 2608.08 42 2426 22 1717.94 31 1407.65 25 1324.78 46 | 34 47 | 10 33 3523.37 42 3356.55 35 3210.34 22 3178.85 23 3079.03 43 2396.45 41 2386.86 24 1408.02 32 1301.34 21 1256.45 48 | 35 49 | 10 34 3187.88 41 3106.44 36 2866.04 22 2817.74 21 2654.87 40 2416.98 42 2137.81 23 1346.86 33 1150.33 16 1044.66 50 | 40 51 | 10 36 2918.14 41 2852.62 39 2782.6 35 2392.96 37 1641.45 21 1124.3 42 1056.48 34 877.946 38 853.944 20 788.701 52 | 41 53 | 10 35 3111.05 42 3049.71 40 2885.36 34 2371.02 36 1813.69 43 1164.71 22 1126.9 39 1011.26 21 906.536 33 903.238 54 | 42 55 | 10 34 3356.98 43 3183 41 3070.54 33 2421.77 35 2155.08 44 1278.41 23 1183.52 22 1147.07 40 1077.08 32 899.646 56 | 43 57 | 10 33 3461.24 44 3380.74 42 3188.7 34 2400.6 32 2399.09 45 1359.37 23 1314.08 41 1176.12 24 1159.62 31 901.556 58 | 44 59 | 10 32 3550.81 45 3510.16 43 3373.11 33 2602.33 31 2395.93 24 1410.43 46 1386.31 42 1279 25 1095.24 34 968.44 60 | 45 61 | 10 31 3650.09 46 3555.09 44 3491.15 32 2868.39 30 2373.59 25 1485.37 47 1405.28 43 1349.54 33 1104.77 26 1046.81 62 | 46 63 | 10 30 3635.64 47 3562.17 45 3524.17 31 2976.82 29 2264.04 26 1508.87 44 1367.41 48 1352.1 32 1211.24 25 1102.17 64 | 47 65 | 10 29 3705.31 46 3519.76 48 3450.48 30 3074.77 28 2054.63 27 1434.57 45 1377.34 31 1268.23 26 1223.83 25 471.111 66 | 48 67 | 10 47 3401.95 28 3224.84 29 3101.16 46 1317.1 30 1306.7 27 1235.07 26 537.731 31 291.919 45 276.869 11 258.856 -------------------------------------------------------------------------------- /configs/lists/dtu_pairs_ft.txt: -------------------------------------------------------------------------------- 1 | 29 2 | 1 3 | 10 9 2850.87 10 2583.94 2 2105.59 0 2052.84 8 1868.24 13 1184.23 14 1017.51 12 961.966 7 670.208 15 657.218 4 | 2 5 | 10 8 2501.24 1 2106.88 7 1856.5 9 1782.34 3 1141.77 15 1061.76 14 815.457 16 762.153 6 709.789 10 699.921 6 | 8 7 | 10 15 3124.01 9 3099.92 14 2756.29 2 2501.22 7 2449.32 1 1875.94 16 1726.04 13 1325.76 23 1177.09 24 1108.82 8 | 9 9 | 10 13 3355.62 14 3226.07 8 3098.8 10 3097.07 1 2861.42 12 1873.63 2 1785.98 15 1753.32 25 1365.45 0 1261.59 10 | 10 11 | 10 12 3750.7 9 3085.87 13 3028.39 1 2590.55 0 2369.79 11 2266.67 14 1524.16 26 1448.15 27 1293.6 8 1041.84 12 | 11 13 | 10 12 3543.76 27 3056.05 10 2248.07 26 1524.28 28 1273.33 13 1265.9 29 1129.55 0 998.164 9 591.176 30 572.919 14 | 12 15 | 10 27 3889.87 10 3754.54 13 3745.21 11 3584.26 26 3574.56 25 1877.11 9 1866.34 29 1482.72 30 1418.51 14 1341.86 16 | 13 17 | 10 12 3773.14 26 3699.28 25 3657.17 14 3652.04 9 3356.29 10 3049.27 24 2098.91 27 1900.96 31 1460.96 30 1349.62 18 | 14 19 | 10 13 3663.52 24 3610.69 9 3232.55 25 3216.4 15 3128.84 8 2758.04 23 2219.91 26 1567.45 10 1536.6 32 1419.33 20 | 15 21 | 10 23 3194.92 14 3126 8 3120.43 16 2897.02 24 2562.49 7 2084.05 22 2041.63 9 1752.08 33 1232.29 13 1137.55 22 | 22 23 | 10 23 3232.68 34 3175.15 35 2831.09 16 2712.51 21 2632.19 15 2033.39 33 1712.67 17 1393.86 36 1290.96 24 1195.33 24 | 26 25 | 10 30 4033.35 27 3970.47 25 3925.25 13 3686.34 12 3595.59 29 2943.87 31 2917 14 1556.34 11 1554.75 46 1503.84 26 | 27 27 | 10 29 4027.84 26 3929.94 12 3875.58 11 3085.03 28 2908.6 30 2792.67 13 1878.42 25 1438.55 47 1425.2 10 1290.25 28 | 28 29 | 10 29 3687.02 48 3209.13 27 2872.86 47 2014.53 30 1361.95 11 1273.6 26 1062.85 12 840.841 46 672.985 31 271.952 30 | 29 31 | 10 27 4029.43 30 3909.55 28 3739.93 47 3695.23 48 3135.87 26 2910.97 46 2229.55 12 1479.16 31 1430.26 11 1144.56 32 | 30 33 | 10 26 4029.86 29 3953.72 31 3811.12 46 3630.46 47 3105.96 27 2824.43 25 2657.89 45 2347.75 32 1459.11 12 1429.62 34 | 31 35 | 10 25 3882.21 30 3841.88 32 3808.5 45 3649.82 46 3000.67 26 2939.94 24 2409.93 44 2381.3 13 1467.59 29 1459.56 36 | 33 37 | 10 32 3618.11 23 3598.1 34 3530.53 43 3462.37 24 3091.53 44 2608.08 42 2426 22 1717.94 31 1407.65 25 1324.78 38 | 34 39 | 10 33 3523.37 42 3356.55 35 3210.34 22 3178.85 23 3079.03 43 2396.45 41 2386.86 24 1408.02 32 1301.34 21 1256.45 40 | 35 41 | 10 34 3187.88 41 3106.44 36 2866.04 22 2817.74 21 2654.87 40 2416.98 42 2137.81 23 1346.86 33 1150.33 16 1044.66 42 | 40 43 | 10 36 2918.14 41 2852.62 39 2782.6 35 2392.96 37 1641.45 21 1124.3 42 1056.48 34 877.946 38 853.944 20 788.701 44 | 41 45 | 10 35 3111.05 42 3049.71 40 2885.36 34 2371.02 36 1813.69 43 1164.71 22 1126.9 39 1011.26 21 906.536 33 903.238 46 | 42 47 | 10 34 3356.98 43 3183 41 3070.54 33 2421.77 35 2155.08 44 1278.41 23 1183.52 22 1147.07 40 1077.08 32 899.646 48 | 43 49 | 10 33 3461.24 44 3380.74 42 3188.7 34 2400.6 32 2399.09 45 1359.37 23 1314.08 41 1176.12 24 1159.62 31 901.556 50 | 44 51 | 10 32 3550.81 45 3510.16 43 3373.11 33 2602.33 31 2395.93 24 1410.43 46 1386.31 42 1279 25 1095.24 34 968.44 52 | 45 53 | 10 31 3650.09 46 3555.09 44 3491.15 32 2868.39 30 2373.59 25 1485.37 47 1405.28 43 1349.54 33 1104.77 26 1046.81 54 | 46 55 | 10 30 3635.64 47 3562.17 45 3524.17 31 2976.82 29 2264.04 26 1508.87 44 1367.41 48 1352.1 32 1211.24 25 1102.17 56 | 47 57 | 10 29 3705.31 46 3519.76 48 3450.48 30 3074.77 28 2054.63 27 1434.57 45 1377.34 31 1268.23 26 1223.83 25 471.111 58 | 48 59 | 10 47 3401.95 28 3224.84 29 3101.16 46 1317.1 30 1306.7 27 1235.07 26 537.731 31 291.919 45 276.869 11 258.856 -------------------------------------------------------------------------------- /configs/lists/dtu_pairs_val.txt: -------------------------------------------------------------------------------- 1 | 4 2 | 23 3 | 10 24 3710.9 33 3603.07 22 3244.2 15 3190.62 34 3086.49 14 2220.11 32 2100 16 1917.1 35 1359.79 25 1356.71 4 | 24 5 | 10 25 3844.6 32 3750.75 23 3710.6 14 3609.09 33 3091.04 15 2559.24 31 2423.71 13 2109.36 26 1440.58 34 1410.03 6 | 25 7 | 10 26 3951.74 31 3888.57 24 3833.07 13 3667.35 14 3208.21 32 2993.46 30 2681.52 12 1900.23 45 1484.03 27 1462.88 8 | 32 9 | 10 31 3826.5 24 3744.14 33 3613.24 44 3552.04 25 3004.6 45 2884.59 43 2393.34 23 2095.27 30 1478.6 14 1420.78 -------------------------------------------------------------------------------- /configs/lists/dtu_train_all.txt: -------------------------------------------------------------------------------- 1 | scan3 2 | scan4 3 | scan5 4 | scan6 5 | scan9 6 | scan10 7 | scan11 8 | scan12 9 | scan13 10 | scan14 11 | scan15 12 | scan16 13 | scan17 14 | scan18 15 | scan19 16 | scan20 17 | scan22 18 | scan23 19 | scan24 20 | scan28 21 | scan32 22 | scan33 23 | scan35 24 | scan36 25 | scan37 26 | scan42 27 | scan43 28 | scan44 29 | scan46 30 | scan47 31 | scan48 32 | scan49 33 | scan50 34 | scan52 35 | scan53 36 | scan59 37 | scan60 38 | scan61 39 | scan62 40 | scan64 41 | scan65 42 | scan66 43 | scan67 44 | scan68 45 | scan69 46 | scan70 47 | scan71 48 | scan72 49 | scan74 50 | scan75 51 | scan76 52 | scan77 53 | scan84 54 | scan85 55 | scan86 56 | scan87 57 | scan88 58 | scan89 59 | scan90 60 | scan91 61 | scan92 62 | scan93 63 | scan94 64 | scan95 65 | scan96 66 | scan97 67 | scan98 68 | scan99 69 | scan100 70 | scan101 71 | scan102 72 | scan104 73 | scan105 74 | scan106 75 | scan107 76 | scan108 77 | scan109 78 | scan118 79 | scan119 80 | scan120 81 | scan121 82 | scan122 83 | scan123 84 | scan124 85 | scan125 86 | scan126 87 | scan127 88 | scan128 -------------------------------------------------------------------------------- /configs/lists/dtu_val_all.txt: -------------------------------------------------------------------------------- 1 | scan8 2 | scan21 3 | scan30 4 | scan31 5 | scan34 6 | scan38 7 | scan40 8 | scan41 9 | scan45 10 | scan55 11 | scan63 12 | scan82 13 | scan103 14 | scan110 15 | scan114 -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/GeoNeRF/e6249fdae5672853c6bbbd4ba380c4c166d02c95/data/__init__.py -------------------------------------------------------------------------------- /data/dtu.py: -------------------------------------------------------------------------------- 1 | # GeoNeRF is a generalizable NeRF model that renders novel views 2 | # without requiring per-scene optimization. This software is the 3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with 4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin, 5 | # and Francois Fleuret. 6 | 7 | # Copyright (c) 2022 ams International AG 8 | 9 | # This file is part of GeoNeRF. 10 | # GeoNeRF is free software: you can redistribute it and/or modify 11 | # it under the terms of the GNU General Public License version 3 as 12 | # published by the Free Software Foundation. 13 | 14 | # GeoNeRF is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | 19 | # You should have received a copy of the GNU General Public License 20 | # along with GeoNeRF. If not, see . 21 | 22 | # This file incorporates work covered by the following copyright and 23 | # permission notice: 24 | 25 | # MIT License 26 | 27 | # Copyright (c) 2021 apchenstu 28 | 29 | # Permission is hereby granted, free of charge, to any person obtaining a copy 30 | # of this software and associated documentation files (the "Software"), to deal 31 | # in the Software without restriction, including without limitation the rights 32 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 33 | # copies of the Software, and to permit persons to whom the Software is 34 | # furnished to do so, subject to the following conditions: 35 | 36 | # The above copyright notice and this permission notice shall be included in all 37 | # copies or substantial portions of the Software. 38 | 39 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 40 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 41 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 42 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 43 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 44 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 45 | # SOFTWARE. 46 | 47 | from torch.utils.data import Dataset 48 | from torchvision import transforms as T 49 | 50 | import os 51 | import cv2 52 | import numpy as np 53 | from PIL import Image 54 | 55 | from utils.utils import read_pfm, get_nearest_pose_ids 56 | 57 | class DTU_Dataset(Dataset): 58 | def __init__( 59 | self, 60 | original_root_dir, 61 | preprocessed_root_dir, 62 | split, 63 | nb_views, 64 | downSample=1.0, 65 | max_len=-1, 66 | scene="None", 67 | ): 68 | self.original_root_dir = original_root_dir 69 | self.preprocessed_root_dir = preprocessed_root_dir 70 | self.split = split 71 | self.scene = scene 72 | 73 | self.downSample = downSample 74 | self.scale_factor = 1.0 / 200 75 | self.interval_scale = 1.06 76 | self.max_len = max_len 77 | self.nb_views = nb_views 78 | 79 | self.build_metas() 80 | self.build_proj_mats() 81 | self.define_transforms() 82 | 83 | def define_transforms(self): 84 | self.transform = T.Compose( 85 | [ 86 | T.ToTensor(), 87 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 88 | ] 89 | ) 90 | 91 | def build_metas(self): 92 | self.metas = [] 93 | with open(f"configs/lists/dtu_{self.split}_all.txt") as f: 94 | self.scans = [line.rstrip() for line in f.readlines()] 95 | if self.scene != "None": 96 | self.scans = [self.scene] 97 | 98 | # light conditions 2-5 for training 99 | # light condition 3 for testing (the brightest?) 100 | light_idxs = ( 101 | [3] if "train" != self.split or self.scene != "None" else range(2, 5) 102 | ) 103 | 104 | self.id_list = [] 105 | 106 | if self.split == "train": 107 | if self.scene == "None": 108 | pair_file = f"configs/lists/dtu_pairs.txt" 109 | else: 110 | pair_file = f"configs/lists/dtu_pairs_ft.txt" 111 | else: 112 | pair_file = f"configs/lists/dtu_pairs_val.txt" 113 | 114 | for scan in self.scans: 115 | with open(pair_file) as f: 116 | num_viewpoint = int(f.readline()) 117 | for _ in range(num_viewpoint): 118 | ref_view = int(f.readline().rstrip()) 119 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 120 | for light_idx in light_idxs: 121 | self.metas += [(scan, light_idx, ref_view, src_views)] 122 | self.id_list.append([ref_view] + src_views) 123 | 124 | self.id_list = np.unique(self.id_list) 125 | self.build_remap() 126 | 127 | def build_proj_mats(self): 128 | near_fars, intrinsics, world2cams, cam2worlds = [], [], [], [] 129 | for vid in self.id_list: 130 | proj_mat_filename = os.path.join( 131 | self.preprocessed_root_dir, f"Cameras/train/{vid:08d}_cam.txt" 132 | ) 133 | intrinsic, extrinsic, near_far = self.read_cam_file(proj_mat_filename) 134 | intrinsic[:2] *= 4 135 | extrinsic[:3, 3] *= self.scale_factor 136 | 137 | intrinsic[:2] = intrinsic[:2] * self.downSample 138 | intrinsics += [intrinsic.copy()] 139 | 140 | near_fars += [near_far] 141 | world2cams += [extrinsic] 142 | cam2worlds += [np.linalg.inv(extrinsic)] 143 | 144 | self.near_fars, self.intrinsics = np.stack(near_fars), np.stack(intrinsics) 145 | self.world2cams, self.cam2worlds = np.stack(world2cams), np.stack(cam2worlds) 146 | 147 | def read_cam_file(self, filename): 148 | with open(filename) as f: 149 | lines = [line.rstrip() for line in f.readlines()] 150 | # extrinsics: line [1,5), 4x4 matrix 151 | extrinsics = np.fromstring(" ".join(lines[1:5]), dtype=np.float32, sep=" ") 152 | extrinsics = extrinsics.reshape((4, 4)) 153 | # intrinsics: line [7-10), 3x3 matrix 154 | intrinsics = np.fromstring(" ".join(lines[7:10]), dtype=np.float32, sep=" ") 155 | intrinsics = intrinsics.reshape((3, 3)) 156 | # depth_min & depth_interval: line 11 157 | depth_min, depth_interval = lines[11].split() 158 | depth_min = float(depth_min) * self.scale_factor 159 | depth_max = depth_min + float(depth_interval) * 192 * self.interval_scale * self.scale_factor 160 | 161 | intrinsics[0, 2] = intrinsics[0, 2] + 80.0 / 4.0 162 | intrinsics[1, 2] = intrinsics[1, 2] + 44.0 / 4.0 163 | intrinsics[:2] = intrinsics[:2] 164 | 165 | return intrinsics, extrinsics, [depth_min, depth_max] 166 | 167 | def read_depth(self, filename, far_bound, noisy_factor=1.0): 168 | depth_h = self.scale_factor * np.array( 169 | read_pfm(filename)[0], dtype=np.float32 170 | ) 171 | depth_h = cv2.resize( 172 | depth_h, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_NEAREST 173 | ) 174 | 175 | depth_h = cv2.resize( 176 | depth_h, 177 | None, 178 | fx=self.downSample * noisy_factor, 179 | fy=self.downSample * noisy_factor, 180 | interpolation=cv2.INTER_NEAREST, 181 | ) 182 | 183 | ## Exclude points beyond the bounds 184 | depth_h[depth_h > far_bound * 0.95] = 0.0 185 | 186 | depth = {} 187 | for l in range(3): 188 | depth[f"level_{l}"] = cv2.resize( 189 | depth_h, 190 | None, 191 | fx=1.0 / (2**l), 192 | fy=1.0 / (2**l), 193 | interpolation=cv2.INTER_NEAREST, 194 | ) 195 | 196 | if self.split == "train": 197 | cutout = np.ones_like(depth[f"level_2"]) 198 | h0 = int(np.random.randint(0, high=cutout.shape[0] // 5, size=1)) 199 | h1 = int( 200 | np.random.randint( 201 | 4 * cutout.shape[0] // 5, high=cutout.shape[0], size=1 202 | ) 203 | ) 204 | w0 = int(np.random.randint(0, high=cutout.shape[1] // 5, size=1)) 205 | w1 = int( 206 | np.random.randint( 207 | 4 * cutout.shape[1] // 5, high=cutout.shape[1], size=1 208 | ) 209 | ) 210 | cutout[h0:h1, w0:w1] = 0 211 | depth_aug = depth[f"level_2"] * cutout 212 | else: 213 | depth_aug = depth[f"level_2"].copy() 214 | 215 | return depth, depth_h, depth_aug 216 | 217 | def build_remap(self): 218 | self.remap = np.zeros(np.max(self.id_list) + 1).astype("int") 219 | for i, item in enumerate(self.id_list): 220 | self.remap[item] = i 221 | 222 | def __len__(self): 223 | return len(self.metas) if self.max_len <= 0 else self.max_len 224 | 225 | def __getitem__(self, idx): 226 | if self.split == "train" and self.scene == "None": 227 | noisy_factor = float(np.random.choice([1.0, 0.5], 1)) 228 | close_views = int(np.random.choice([3, 4, 5], 1)) 229 | else: 230 | noisy_factor = 1.0 231 | close_views = 5 232 | 233 | scan, light_idx, target_view, src_views = self.metas[idx] 234 | view_ids = src_views[:self.nb_views] + [target_view] 235 | 236 | affine_mats, affine_mats_inv = [], [] 237 | imgs, depths_h, depths_aug = [], [], [] 238 | depths = {"level_0": [], "level_1": [], "level_2": []} 239 | intrinsics, w2cs, c2ws, near_fars = [], [], [], [] 240 | 241 | for vid in view_ids: 242 | # Note that the id in image file names is from 1 to 49 (not 0~48) 243 | img_filename = os.path.join( 244 | self.original_root_dir, 245 | f"Rectified/{scan}/rect_{vid + 1:03d}_{light_idx}_r5000.png", 246 | ) 247 | depth_filename = os.path.join( 248 | self.preprocessed_root_dir, f"Depths/{scan}/depth_map_{vid:04d}.pfm" 249 | ) 250 | img = Image.open(img_filename) 251 | img_wh = np.round( 252 | np.array(img.size) / 2.0 * self.downSample * noisy_factor 253 | ).astype("int") 254 | img = img.resize(img_wh, Image.BICUBIC) 255 | img = self.transform(img) 256 | imgs += [img] 257 | 258 | index_mat = self.remap[vid] 259 | 260 | intrinsic = self.intrinsics[index_mat].copy() 261 | intrinsic[:2] = intrinsic[:2] * noisy_factor 262 | intrinsics.append(intrinsic) 263 | 264 | w2c = self.world2cams[index_mat] 265 | w2cs.append(w2c) 266 | c2ws.append(self.cam2worlds[index_mat]) 267 | 268 | aff = [] 269 | aff_inv = [] 270 | for l in range(3): 271 | proj_mat_l = np.eye(4) 272 | intrinsic_temp = intrinsic.copy() 273 | intrinsic_temp[:2] = intrinsic_temp[:2] / (2**l) 274 | proj_mat_l[:3, :4] = intrinsic_temp @ w2c[:3, :4] 275 | aff.append(proj_mat_l.copy()) 276 | aff_inv.append(np.linalg.inv(proj_mat_l)) 277 | aff = np.stack(aff, axis=-1) 278 | aff_inv = np.stack(aff_inv, axis=-1) 279 | 280 | affine_mats.append(aff) 281 | affine_mats_inv.append(aff_inv) 282 | 283 | near_far = self.near_fars[index_mat] 284 | 285 | depth, depth_h, depth_aug = self.read_depth( 286 | depth_filename, near_far[1], noisy_factor 287 | ) 288 | 289 | depths["level_0"].append(depth["level_0"]) 290 | depths["level_1"].append(depth["level_1"]) 291 | depths["level_2"].append(depth["level_2"]) 292 | depths_h.append(depth_h) 293 | depths_aug.append(depth_aug) 294 | 295 | near_fars.append(near_far) 296 | 297 | imgs = np.stack(imgs) 298 | depths_h, depths_aug = np.stack(depths_h), np.stack(depths_aug) 299 | depths["level_0"] = np.stack(depths["level_0"]) 300 | depths["level_1"] = np.stack(depths["level_1"]) 301 | depths["level_2"] = np.stack(depths["level_2"]) 302 | affine_mats, affine_mats_inv = np.stack(affine_mats), np.stack(affine_mats_inv) 303 | intrinsics = np.stack(intrinsics) 304 | w2cs = np.stack(w2cs) 305 | c2ws = np.stack(c2ws) 306 | near_fars = np.stack(near_fars) 307 | 308 | closest_idxs = [] 309 | for pose in c2ws[:-1]: 310 | closest_idxs.append( 311 | get_nearest_pose_ids( 312 | pose, 313 | ref_poses=c2ws[:-1], 314 | num_select=close_views, 315 | angular_dist_method="dist", 316 | ) 317 | ) 318 | closest_idxs = np.stack(closest_idxs, axis=0) 319 | 320 | sample = {} 321 | sample["images"] = imgs 322 | sample["depths"] = depths 323 | sample["depths_h"] = depths_h 324 | sample["depths_aug"] = depths_aug 325 | sample["w2cs"] = w2cs 326 | sample["c2ws"] = c2ws 327 | sample["near_fars"] = near_fars 328 | sample["intrinsics"] = intrinsics 329 | sample["affine_mats"] = affine_mats 330 | sample["affine_mats_inv"] = affine_mats_inv 331 | sample["closest_idxs"] = closest_idxs 332 | 333 | return sample 334 | -------------------------------------------------------------------------------- /data/get_datasets.py: -------------------------------------------------------------------------------- 1 | # GeoNeRF is a generalizable NeRF model that renders novel views 2 | # without requiring per-scene optimization. This software is the 3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with 4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin, 5 | # and Francois Fleuret. 6 | 7 | # Copyright (c) 2022 ams International AG 8 | 9 | # This file is part of GeoNeRF. 10 | # GeoNeRF is free software: you can redistribute it and/or modify 11 | # it under the terms of the GNU General Public License version 3 as 12 | # published by the Free Software Foundation. 13 | 14 | # GeoNeRF is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | 19 | # You should have received a copy of the GNU General Public License 20 | # along with GeoNeRF. If not, see . 21 | 22 | import torch 23 | from torch.utils.data import ConcatDataset, WeightedRandomSampler 24 | import numpy as np 25 | 26 | from data.llff import LLFF_Dataset 27 | from data.dtu import DTU_Dataset 28 | from data.nerf import NeRF_Dataset 29 | 30 | def get_training_dataset(args, downsample=1.0): 31 | train_datasets = [ 32 | DTU_Dataset( 33 | original_root_dir=args.dtu_path, 34 | preprocessed_root_dir=args.dtu_pre_path, 35 | split="train", 36 | max_len=-1, 37 | downSample=downsample, 38 | nb_views=args.nb_views, 39 | ), 40 | LLFF_Dataset( 41 | root_dir=args.ibrnet1_path, 42 | split="train", 43 | max_len=-1, 44 | downSample=downsample, 45 | nb_views=args.nb_views, 46 | imgs_folder_name="images", 47 | ), 48 | LLFF_Dataset( 49 | root_dir=args.ibrnet2_path, 50 | split="train", 51 | max_len=-1, 52 | downSample=downsample, 53 | nb_views=args.nb_views, 54 | imgs_folder_name="images", 55 | ), 56 | LLFF_Dataset( 57 | root_dir=args.llff_path, 58 | split="train", 59 | max_len=-1, 60 | downSample=downsample, 61 | nb_views=args.nb_views, 62 | imgs_folder_name="images_4", 63 | ), 64 | ] 65 | weights = [0.5, 0.22, 0.12, 0.16] 66 | 67 | train_weights_samples = [] 68 | for dataset, weight in zip(train_datasets, weights): 69 | num_samples = len(dataset) 70 | weight_each_sample = weight / num_samples 71 | train_weights_samples.extend([weight_each_sample] * num_samples) 72 | 73 | train_dataset = ConcatDataset(train_datasets) 74 | train_weights = torch.from_numpy(np.array(train_weights_samples)) 75 | train_sampler = WeightedRandomSampler(train_weights, len(train_weights)) 76 | 77 | return train_dataset, train_sampler 78 | 79 | 80 | def get_finetuning_dataset(args, downsample=1.0): 81 | if args.dataset_name == "dtu": 82 | train_dataset = DTU_Dataset( 83 | original_root_dir=args.dtu_path, 84 | preprocessed_root_dir=args.dtu_pre_path, 85 | split="train", 86 | max_len=-1, 87 | downSample=downsample, 88 | nb_views=args.nb_views, 89 | scene=args.scene, 90 | ) 91 | elif args.dataset_name == "llff": 92 | train_dataset = LLFF_Dataset( 93 | root_dir=args.llff_path, 94 | split="train", 95 | max_len=-1, 96 | downSample=downsample, 97 | nb_views=args.nb_views, 98 | scene=args.scene, 99 | imgs_folder_name="images_4", 100 | ) 101 | elif args.dataset_name == "nerf": 102 | train_dataset = NeRF_Dataset( 103 | root_dir=args.nerf_path, 104 | split="train", 105 | max_len=-1, 106 | downSample=downsample, 107 | nb_views=args.nb_views, 108 | scene=args.scene, 109 | ) 110 | 111 | train_sampler = None 112 | 113 | return train_dataset, train_sampler 114 | 115 | 116 | def get_validation_dataset(args, downsample=1.0): 117 | if args.scene == "None": 118 | max_len = 2 119 | else: 120 | max_len = -1 121 | 122 | if args.dataset_name == "dtu": 123 | val_dataset = DTU_Dataset( 124 | original_root_dir=args.dtu_path, 125 | preprocessed_root_dir=args.dtu_pre_path, 126 | split="val", 127 | max_len=max_len, 128 | downSample=downsample, 129 | nb_views=args.nb_views, 130 | scene=args.scene, 131 | ) 132 | elif args.dataset_name == "llff": 133 | val_dataset = LLFF_Dataset( 134 | root_dir=args.llff_test_path if not args.llff_test_path is None else args.llff_path, 135 | split="val", 136 | max_len=max_len, 137 | downSample=downsample, 138 | nb_views=args.nb_views, 139 | scene=args.scene, 140 | imgs_folder_name="images_4", 141 | ) 142 | elif args.dataset_name == "nerf": 143 | val_dataset = NeRF_Dataset( 144 | root_dir=args.nerf_path, 145 | split="val", 146 | max_len=max_len, 147 | downSample=downsample, 148 | nb_views=args.nb_views, 149 | scene=args.scene, 150 | ) 151 | 152 | return val_dataset 153 | -------------------------------------------------------------------------------- /data/llff.py: -------------------------------------------------------------------------------- 1 | # GeoNeRF is a generalizable NeRF model that renders novel views 2 | # without requiring per-scene optimization. This software is the 3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with 4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin, 5 | # and Francois Fleuret. 6 | 7 | # Copyright (c) 2022 ams International AG 8 | 9 | # This file is part of GeoNeRF. 10 | # GeoNeRF is free software: you can redistribute it and/or modify 11 | # it under the terms of the GNU General Public License version 3 as 12 | # published by the Free Software Foundation. 13 | 14 | # GeoNeRF is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | 19 | # You should have received a copy of the GNU General Public License 20 | # along with GeoNeRF. If not, see . 21 | 22 | # This file incorporates work covered by the following copyright and 23 | # permission notice: 24 | 25 | # MIT License 26 | 27 | # Copyright (c) 2021 apchenstu 28 | 29 | # Permission is hereby granted, free of charge, to any person obtaining a copy 30 | # of this software and associated documentation files (the "Software"), to deal 31 | # in the Software without restriction, including without limitation the rights 32 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 33 | # copies of the Software, and to permit persons to whom the Software is 34 | # furnished to do so, subject to the following conditions: 35 | 36 | # The above copyright notice and this permission notice shall be included in all 37 | # copies or substantial portions of the Software. 38 | 39 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 40 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 41 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 42 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 43 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 44 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 45 | # SOFTWARE. 46 | 47 | from torch.utils.data import Dataset 48 | from torchvision import transforms as T 49 | 50 | import os 51 | import glob 52 | import numpy as np 53 | from PIL import Image 54 | 55 | from utils.utils import get_nearest_pose_ids 56 | 57 | def normalize(v): 58 | return v / np.linalg.norm(v) 59 | 60 | 61 | def average_poses(poses): 62 | # 1. Compute the center 63 | center = poses[..., 3].mean(0) # (3) 64 | 65 | # 2. Compute the z axis 66 | z = normalize(poses[..., 2].mean(0)) # (3) 67 | 68 | # 3. Compute axis y' (no need to normalize as it's not the final output) 69 | y_ = poses[..., 1].mean(0) # (3) 70 | 71 | # 4. Compute the x axis 72 | x = normalize(np.cross(y_, z)) # (3) 73 | 74 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) 75 | y = np.cross(z, x) # (3) 76 | 77 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4) 78 | 79 | return pose_avg 80 | 81 | 82 | def center_poses(poses, blender2opencv): 83 | pose_avg = average_poses(poses) # (3, 4) 84 | pose_avg_homo = np.eye(4) 85 | 86 | # convert to homogeneous coordinate for faster computation 87 | # by simply adding 0, 0, 0, 1 as the last row 88 | pose_avg_homo[:3] = pose_avg 89 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4) 90 | 91 | # (N_images, 4, 4) homogeneous coordinate 92 | poses_homo = np.concatenate([poses, last_row], 1) 93 | 94 | poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4) 95 | poses_centered = poses_centered @ blender2opencv 96 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4) 97 | 98 | return poses_centered, np.linalg.inv(pose_avg_homo) @ blender2opencv 99 | 100 | 101 | class LLFF_Dataset(Dataset): 102 | def __init__( 103 | self, 104 | root_dir, 105 | split, 106 | nb_views, 107 | downSample=1.0, 108 | max_len=-1, 109 | scene="None", 110 | imgs_folder_name="images", 111 | ): 112 | self.root_dir = root_dir 113 | self.split = split 114 | self.nb_views = nb_views 115 | self.scene = scene 116 | self.imgs_folder_name = imgs_folder_name 117 | 118 | self.downsample = downSample 119 | self.max_len = max_len 120 | self.img_wh = (int(960 * self.downsample), int(720 * self.downsample)) 121 | 122 | self.define_transforms() 123 | self.blender2opencv = np.array( 124 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] 125 | ) 126 | 127 | self.build_metas() 128 | 129 | def define_transforms(self): 130 | self.transform = T.Compose( 131 | [ 132 | T.ToTensor(), 133 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 134 | ] 135 | ) 136 | 137 | def build_metas(self): 138 | if self.scene != "None": 139 | self.scans = [ 140 | os.path.basename(scan_dir) 141 | for scan_dir in sorted( 142 | glob.glob(os.path.join(self.root_dir, self.scene)) 143 | ) 144 | ] 145 | else: 146 | self.scans = [ 147 | os.path.basename(scan_dir) 148 | for scan_dir in sorted(glob.glob(os.path.join(self.root_dir, "*"))) 149 | ] 150 | 151 | self.meta = [] 152 | self.image_paths = {} 153 | self.near_far = {} 154 | self.id_list = {} 155 | self.closest_idxs = {} 156 | self.c2ws = {} 157 | self.w2cs = {} 158 | self.intrinsics = {} 159 | self.affine_mats = {} 160 | self.affine_mats_inv = {} 161 | for scan in self.scans: 162 | self.image_paths[scan] = sorted( 163 | glob.glob(os.path.join(self.root_dir, scan, self.imgs_folder_name, "*")) 164 | ) 165 | poses_bounds = np.load( 166 | os.path.join(self.root_dir, scan, "poses_bounds.npy") 167 | ) # (N_images, 17) 168 | poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5) 169 | bounds = poses_bounds[:, -2:] # (N_images, 2) 170 | 171 | # Step 1: rescale focal length according to training resolution 172 | H, W, focal = poses[0, :, -1] # original intrinsics, same for all images 173 | 174 | focal = [focal * self.img_wh[0] / W, focal * self.img_wh[1] / H] 175 | 176 | # Step 2: correct poses 177 | poses = np.concatenate( 178 | [poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1 179 | ) 180 | poses, _ = center_poses(poses, self.blender2opencv) 181 | # poses = poses @ self.blender2opencv 182 | 183 | # Step 3: correct scale so that the nearest depth is at a little more than 1.0 184 | near_original = bounds.min() 185 | scale_factor = near_original * 0.75 # 0.75 is the default parameter 186 | bounds /= scale_factor 187 | poses[..., 3] /= scale_factor 188 | 189 | self.near_far[scan] = bounds.astype('float32') 190 | 191 | num_viewpoint = len(self.image_paths[scan]) 192 | val_ids = [idx for idx in range(0, num_viewpoint, 8)] 193 | w, h = self.img_wh 194 | 195 | self.id_list[scan] = [] 196 | self.closest_idxs[scan] = [] 197 | self.c2ws[scan] = [] 198 | self.w2cs[scan] = [] 199 | self.intrinsics[scan] = [] 200 | self.affine_mats[scan] = [] 201 | self.affine_mats_inv[scan] = [] 202 | for idx in range(num_viewpoint): 203 | if ( 204 | (self.split == "val" and idx in val_ids) 205 | or ( 206 | self.split == "train" 207 | and self.scene != "None" 208 | and idx not in val_ids 209 | ) 210 | or (self.split == "train" and self.scene == "None") 211 | ): 212 | self.meta.append({"scan": scan, "target_idx": idx}) 213 | 214 | view_ids = get_nearest_pose_ids( 215 | poses[idx, :, :], 216 | ref_poses=poses[..., :], 217 | num_select=self.nb_views + 1, 218 | angular_dist_method="dist", 219 | ) 220 | 221 | self.id_list[scan].append(view_ids) 222 | 223 | closest_idxs = [] 224 | source_views = view_ids[1:] 225 | for vid in source_views: 226 | closest_idxs.append( 227 | get_nearest_pose_ids( 228 | poses[vid, :, :], 229 | ref_poses=poses[source_views], 230 | num_select=5, 231 | angular_dist_method="dist", 232 | ) 233 | ) 234 | self.closest_idxs[scan].append(np.stack(closest_idxs, axis=0)) 235 | 236 | c2w = np.eye(4).astype('float32') 237 | c2w[:3] = poses[idx] 238 | w2c = np.linalg.inv(c2w) 239 | self.c2ws[scan].append(c2w) 240 | self.w2cs[scan].append(w2c) 241 | 242 | intrinsic = np.array([[focal[0], 0, w / 2], [0, focal[1], h / 2], [0, 0, 1]]).astype('float32') 243 | self.intrinsics[scan].append(intrinsic) 244 | 245 | def __len__(self): 246 | return len(self.meta) if self.max_len <= 0 else self.max_len 247 | 248 | def __getitem__(self, idx): 249 | if self.split == "train" and self.scene == "None": 250 | noisy_factor = float(np.random.choice([1.0, 0.75, 0.5], 1)) 251 | close_views = int(np.random.choice([3, 4, 5], 1)) 252 | else: 253 | noisy_factor = 1.0 254 | close_views = 5 255 | 256 | scan = self.meta[idx]["scan"] 257 | target_idx = self.meta[idx]["target_idx"] 258 | 259 | view_ids = self.id_list[scan][target_idx] 260 | target_view = view_ids[0] 261 | src_views = view_ids[1:] 262 | view_ids = [vid for vid in src_views] + [target_view] 263 | 264 | closest_idxs = self.closest_idxs[scan][target_idx][:, :close_views] 265 | 266 | imgs, depths, depths_h, depths_aug = [], [], [], [] 267 | intrinsics, w2cs, c2ws, near_fars = [], [], [], [] 268 | affine_mats, affine_mats_inv = [], [] 269 | 270 | w, h = self.img_wh 271 | w, h = int(w * noisy_factor), int(h * noisy_factor) 272 | 273 | for vid in view_ids: 274 | img_filename = self.image_paths[scan][vid] 275 | img = Image.open(img_filename).convert("RGB") 276 | if img.size != (w, h): 277 | img = img.resize((w, h), Image.BICUBIC) 278 | img = self.transform(img) 279 | imgs.append(img) 280 | 281 | intrinsic = self.intrinsics[scan][vid].copy() 282 | intrinsic[:2] = intrinsic[:2] * noisy_factor 283 | intrinsics.append(intrinsic) 284 | 285 | w2c = self.w2cs[scan][vid] 286 | w2cs.append(w2c) 287 | c2ws.append(self.c2ws[scan][vid]) 288 | 289 | aff = [] 290 | aff_inv = [] 291 | for l in range(3): 292 | proj_mat_l = np.eye(4) 293 | intrinsic_temp = intrinsic.copy() 294 | intrinsic_temp[:2] = intrinsic_temp[:2] / (2**l) 295 | proj_mat_l[:3, :4] = intrinsic_temp @ w2c[:3, :4] 296 | aff.append(proj_mat_l.copy()) 297 | aff_inv.append(np.linalg.inv(proj_mat_l)) 298 | aff = np.stack(aff, axis=-1) 299 | aff_inv = np.stack(aff_inv, axis=-1) 300 | 301 | affine_mats.append(aff) 302 | affine_mats_inv.append(aff_inv) 303 | 304 | near_fars.append(self.near_far[scan][vid]) 305 | 306 | depths_h.append(np.zeros([h, w])) 307 | depths.append(np.zeros([h // 4, w // 4])) 308 | depths_aug.append(np.zeros([h // 4, w // 4])) 309 | 310 | imgs = np.stack(imgs) 311 | depths = np.stack(depths) 312 | depths_h = np.stack(depths_h) 313 | depths_aug = np.stack(depths_aug) 314 | affine_mats = np.stack(affine_mats) 315 | affine_mats_inv = np.stack(affine_mats_inv) 316 | intrinsics = np.stack(intrinsics) 317 | w2cs = np.stack(w2cs) 318 | c2ws = np.stack(c2ws) 319 | near_fars = np.stack(near_fars) 320 | 321 | sample = {} 322 | sample["images"] = imgs 323 | sample["depths"] = depths 324 | sample["depths_h"] = depths_h 325 | sample["depths_aug"] = depths_aug 326 | sample["w2cs"] = w2cs 327 | sample["c2ws"] = c2ws 328 | sample["near_fars"] = near_fars 329 | sample["affine_mats"] = affine_mats 330 | sample["affine_mats_inv"] = affine_mats_inv 331 | sample["intrinsics"] = intrinsics 332 | sample["closest_idxs"] = closest_idxs 333 | 334 | return sample 335 | -------------------------------------------------------------------------------- /data/nerf.py: -------------------------------------------------------------------------------- 1 | # GeoNeRF is a generalizable NeRF model that renders novel views 2 | # without requiring per-scene optimization. This software is the 3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with 4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin, 5 | # and Francois Fleuret. 6 | 7 | # Copyright (c) 2022 ams International AG 8 | 9 | # This file is part of GeoNeRF. 10 | # GeoNeRF is free software: you can redistribute it and/or modify 11 | # it under the terms of the GNU General Public License version 3 as 12 | # published by the Free Software Foundation. 13 | 14 | # GeoNeRF is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | 19 | # You should have received a copy of the GNU General Public License 20 | # along with GeoNeRF. If not, see . 21 | 22 | # This file incorporates work covered by the following copyright and 23 | # permission notice: 24 | 25 | # MIT License 26 | 27 | # Copyright (c) 2021 apchenstu 28 | 29 | # Permission is hereby granted, free of charge, to any person obtaining a copy 30 | # of this software and associated documentation files (the "Software"), to deal 31 | # in the Software without restriction, including without limitation the rights 32 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 33 | # copies of the Software, and to permit persons to whom the Software is 34 | # furnished to do so, subject to the following conditions: 35 | 36 | # The above copyright notice and this permission notice shall be included in all 37 | # copies or substantial portions of the Software. 38 | 39 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 40 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 41 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 42 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 43 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 44 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 45 | # SOFTWARE. 46 | 47 | from torch.utils.data import Dataset 48 | from torchvision import transforms as T 49 | 50 | import os 51 | import json 52 | import numpy as np 53 | from PIL import Image 54 | 55 | from utils.utils import get_nearest_pose_ids 56 | 57 | class NeRF_Dataset(Dataset): 58 | def __init__( 59 | self, 60 | root_dir, 61 | split, 62 | nb_views, 63 | downSample=1.0, 64 | max_len=-1, 65 | scene="None", 66 | ): 67 | self.root_dir = root_dir 68 | self.split = split 69 | self.nb_views = nb_views 70 | self.scene = scene 71 | 72 | self.downsample = downSample 73 | self.max_len = max_len 74 | 75 | self.img_wh = (int(800 * self.downsample), int(800 * self.downsample)) 76 | 77 | self.define_transforms() 78 | self.blender2opencv = np.array( 79 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] 80 | ) 81 | 82 | self.build_metas() 83 | 84 | def define_transforms(self): 85 | self.transform = T.ToTensor() 86 | 87 | self.src_transform = T.Compose( 88 | [ 89 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 90 | ] 91 | ) 92 | 93 | def build_metas(self): 94 | self.meta = {} 95 | with open( 96 | os.path.join(self.root_dir, self.scene, "transforms_train.json"), "r" 97 | ) as f: 98 | self.meta["train"] = json.load(f) 99 | 100 | with open( 101 | os.path.join(self.root_dir, self.scene, "transforms_test.json"), "r" 102 | ) as f: 103 | self.meta["val"] = json.load(f) 104 | 105 | w, h = self.img_wh 106 | 107 | # original focal length 108 | focal = 0.5 * 800 / np.tan(0.5 * self.meta["train"]["camera_angle_x"]) 109 | 110 | # modify focal length to match size self.img_wh 111 | focal *= self.img_wh[0] / 800 112 | 113 | self.near_far = np.array([2.0, 6.0]) 114 | 115 | self.image_paths = {"train": [], "val": []} 116 | self.c2ws = {"train": [], "val": []} 117 | self.w2cs = {"train": [], "val": []} 118 | self.intrinsics = {"train": [], "val": []} 119 | 120 | for frame in self.meta["train"]["frames"]: 121 | self.image_paths["train"].append( 122 | os.path.join(self.root_dir, self.scene, f"{frame['file_path']}.png") 123 | ) 124 | 125 | c2w = np.array(frame["transform_matrix"]) @ self.blender2opencv 126 | w2c = np.linalg.inv(c2w) 127 | self.c2ws["train"].append(c2w) 128 | self.w2cs["train"].append(w2c) 129 | 130 | intrinsic = np.array([[focal, 0, w / 2], [0, focal, h / 2], [0, 0, 1]]) 131 | self.intrinsics["train"].append(intrinsic.copy()) 132 | 133 | self.c2ws["train"] = np.stack(self.c2ws["train"], axis=0) 134 | self.w2cs["train"] = np.stack(self.w2cs["train"], axis=0) 135 | self.intrinsics["train"] = np.stack(self.intrinsics["train"], axis=0) 136 | 137 | for frame in self.meta["val"]["frames"]: 138 | self.image_paths["val"].append( 139 | os.path.join(self.root_dir, self.scene, f"{frame['file_path']}.png") 140 | ) 141 | 142 | c2w = np.array(frame["transform_matrix"]) @ self.blender2opencv 143 | w2c = np.linalg.inv(c2w) 144 | self.c2ws["val"].append(c2w) 145 | self.w2cs["val"].append(w2c) 146 | 147 | intrinsic = np.array([[focal, 0, w / 2], [0, focal, h / 2], [0, 0, 1]]) 148 | self.intrinsics["val"].append(intrinsic.copy()) 149 | 150 | self.c2ws["val"] = np.stack(self.c2ws["val"], axis=0) 151 | self.w2cs["val"] = np.stack(self.w2cs["val"], axis=0) 152 | self.intrinsics["val"] = np.stack(self.intrinsics["val"], axis=0) 153 | 154 | def __len__(self): 155 | return len(self.image_paths[self.split]) if self.max_len <= 0 else self.max_len 156 | 157 | def __getitem__(self, idx): 158 | target_frame = self.meta[self.split]["frames"][idx] 159 | c2w = np.array(target_frame["transform_matrix"]) @ self.blender2opencv 160 | w2c = np.linalg.inv(c2w) 161 | 162 | if self.split == "train": 163 | src_views = get_nearest_pose_ids( 164 | c2w, 165 | ref_poses=self.c2ws["train"], 166 | num_select=self.nb_views + 1, 167 | angular_dist_method="dist", 168 | )[1:] 169 | else: 170 | src_views = get_nearest_pose_ids( 171 | c2w, 172 | ref_poses=self.c2ws["train"], 173 | num_select=self.nb_views, 174 | angular_dist_method="dist", 175 | ) 176 | 177 | imgs, depths, depths_h, depths_aug = [], [], [], [] 178 | intrinsics, w2cs, c2ws, near_fars = [], [], [], [] 179 | affine_mats, affine_mats_inv = [], [] 180 | 181 | w, h = self.img_wh 182 | 183 | for vid in src_views: 184 | img_filename = self.image_paths["train"][vid] 185 | img = Image.open(img_filename) 186 | if img.size != (w, h): 187 | img = img.resize((w, h), Image.BICUBIC) 188 | 189 | img = self.transform(img) 190 | img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB 191 | imgs.append(self.src_transform(img)) 192 | 193 | intrinsic = self.intrinsics["train"][vid] 194 | intrinsics.append(intrinsic) 195 | 196 | w2c = self.w2cs["train"][vid] 197 | w2cs.append(w2c) 198 | c2ws.append(self.c2ws["train"][vid]) 199 | 200 | aff = [] 201 | aff_inv = [] 202 | for l in range(3): 203 | proj_mat_l = np.eye(4) 204 | intrinsic_temp = intrinsic.copy() 205 | intrinsic_temp[:2] = intrinsic_temp[:2] / (2**l) 206 | proj_mat_l[:3, :4] = intrinsic_temp @ w2c[:3, :4] 207 | aff.append(proj_mat_l.copy()) 208 | aff_inv.append(np.linalg.inv(proj_mat_l)) 209 | aff = np.stack(aff, axis=-1) 210 | aff_inv = np.stack(aff_inv, axis=-1) 211 | 212 | affine_mats.append(aff) 213 | affine_mats_inv.append(aff_inv) 214 | 215 | near_fars.append(self.near_far) 216 | 217 | depths_h.append(np.zeros([h, w])) 218 | depths.append(np.zeros([h // 4, w // 4])) 219 | depths_aug.append(np.zeros([h // 4, w // 4])) 220 | 221 | ## Adding target data 222 | img_filename = self.image_paths[self.split][idx] 223 | img = Image.open(img_filename) 224 | if img.size != (w, h): 225 | img = img.resize((w, h), Image.BICUBIC) 226 | 227 | img = self.transform(img) # (4, h, w) 228 | img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB 229 | imgs.append(self.src_transform(img)) 230 | 231 | intrinsic = self.intrinsics[self.split][idx] 232 | intrinsics.append(intrinsic) 233 | 234 | w2c = self.w2cs[self.split][idx] 235 | w2cs.append(w2c) 236 | c2ws.append(self.c2ws[self.split][idx]) 237 | 238 | near_fars.append(self.near_far) 239 | 240 | depths_h.append(np.zeros([h, w])) 241 | depths.append(np.zeros([h // 4, w // 4])) 242 | depths_aug.append(np.zeros([h // 4, w // 4])) 243 | 244 | ## Stacking 245 | imgs = np.stack(imgs) 246 | depths = np.stack(depths) 247 | depths_h = np.stack(depths_h) 248 | depths_aug = np.stack(depths_aug) 249 | affine_mats = np.stack(affine_mats) 250 | affine_mats_inv = np.stack(affine_mats_inv) 251 | intrinsics = np.stack(intrinsics) 252 | w2cs = np.stack(w2cs) 253 | c2ws = np.stack(c2ws) 254 | near_fars = np.stack(near_fars) 255 | 256 | closest_idxs = [] 257 | for pose in c2ws[:-1]: 258 | closest_idxs.append( 259 | get_nearest_pose_ids( 260 | pose, ref_poses=c2ws[:-1], num_select=5, angular_dist_method="dist" 261 | ) 262 | ) 263 | closest_idxs = np.stack(closest_idxs, axis=0) 264 | 265 | sample = {} 266 | sample["images"] = imgs 267 | sample["depths"] = depths 268 | sample["depths_h"] = depths_h 269 | sample["depths_aug"] = depths_aug 270 | sample["w2cs"] = w2cs.astype("float32") 271 | sample["c2ws"] = c2ws.astype("float32") 272 | sample["near_fars"] = near_fars 273 | sample["affine_mats"] = affine_mats 274 | sample["affine_mats_inv"] = affine_mats_inv 275 | sample["intrinsics"] = intrinsics.astype("float32") 276 | sample["closest_idxs"] = closest_idxs 277 | 278 | return sample 279 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/GeoNeRF/e6249fdae5672853c6bbbd4ba380c4c166d02c95/model/__init__.py -------------------------------------------------------------------------------- /model/geo_reasoner.py: -------------------------------------------------------------------------------- 1 | # GeoNeRF is a generalizable NeRF model that renders novel views 2 | # without requiring per-scene optimization. This software is the 3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with 4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin, 5 | # and Francois Fleuret. 6 | 7 | # Copyright (c) 2022 ams International AG 8 | 9 | # This file is part of GeoNeRF. 10 | # GeoNeRF is free software: you can redistribute it and/or modify 11 | # it under the terms of the GNU General Public License version 3 as 12 | # published by the Free Software Foundation. 13 | 14 | # GeoNeRF is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | 19 | # You should have received a copy of the GNU General Public License 20 | # along with GeoNeRF. If not, see . 21 | 22 | # This file incorporates work covered by the following copyright and 23 | # permission notice: 24 | 25 | # Copyright (c) 2020 AI葵 26 | 27 | # This file is part of CasMVSNet_pl. 28 | # CasMVSNet_pl is free software: you can redistribute it and/or modify 29 | # it under the terms of the GNU General Public License version 3 as 30 | # published by the Free Software Foundation. 31 | 32 | # CasMVSNet_pl is distributed in the hope that it will be useful, 33 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | # GNU General Public License for more details. 36 | 37 | # You should have received a copy of the GNU General Public License 38 | # along with CasMVSNet_pl. If not, see . 39 | 40 | import torch 41 | import torch.nn as nn 42 | import torch.nn.functional as F 43 | from torch.utils.checkpoint import checkpoint 44 | 45 | from utils.utils import homo_warp 46 | from inplace_abn import InPlaceABN 47 | 48 | 49 | def get_depth_values(current_depth, n_depths, depth_interval): 50 | depth_min = torch.clamp_min(current_depth - n_depths / 2 * depth_interval, 1e-7) 51 | depth_values = ( 52 | depth_min 53 | + depth_interval 54 | * torch.arange( 55 | 0, n_depths, device=current_depth.device, dtype=current_depth.dtype 56 | )[None, :, None, None] 57 | ) 58 | return depth_values 59 | 60 | 61 | class ConvBnReLU(nn.Module): 62 | def __init__( 63 | self, 64 | in_channels, 65 | out_channels, 66 | kernel_size=3, 67 | stride=1, 68 | pad=1, 69 | norm_act=InPlaceABN, 70 | ): 71 | super(ConvBnReLU, self).__init__() 72 | self.conv = nn.Conv2d( 73 | in_channels, 74 | out_channels, 75 | kernel_size, 76 | stride=stride, 77 | padding=pad, 78 | bias=False, 79 | ) 80 | self.bn = norm_act(out_channels) 81 | 82 | def forward(self, x): 83 | return self.bn(self.conv(x)) 84 | 85 | 86 | class ConvBnReLU3D(nn.Module): 87 | def __init__( 88 | self, 89 | in_channels, 90 | out_channels, 91 | kernel_size=3, 92 | stride=1, 93 | pad=1, 94 | norm_act=InPlaceABN, 95 | ): 96 | super(ConvBnReLU3D, self).__init__() 97 | self.conv = nn.Conv3d( 98 | in_channels, 99 | out_channels, 100 | kernel_size, 101 | stride=stride, 102 | padding=pad, 103 | bias=False, 104 | ) 105 | self.bn = norm_act(out_channels) 106 | 107 | def forward(self, x): 108 | return self.bn(self.conv(x)) 109 | 110 | 111 | class FeatureNet(nn.Module): 112 | def __init__(self, norm_act=InPlaceABN): 113 | super(FeatureNet, self).__init__() 114 | 115 | self.conv0 = nn.Sequential( 116 | ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act), 117 | ConvBnReLU(8, 8, 3, 1, 1, norm_act=norm_act), 118 | ) 119 | 120 | self.conv1 = nn.Sequential( 121 | ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act), 122 | ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act), 123 | ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act), 124 | ) 125 | 126 | self.conv2 = nn.Sequential( 127 | ConvBnReLU(16, 32, 5, 2, 2, norm_act=norm_act), 128 | ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act), 129 | ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act), 130 | ) 131 | 132 | self.toplayer = nn.Conv2d(32, 32, 1) 133 | self.lat1 = nn.Conv2d(16, 32, 1) 134 | self.lat0 = nn.Conv2d(8, 32, 1) 135 | 136 | # to reduce channel size of the outputs from FPN 137 | self.smooth1 = nn.Conv2d(32, 16, 3, padding=1) 138 | self.smooth0 = nn.Conv2d(32, 8, 3, padding=1) 139 | 140 | def _upsample_add(self, x, y): 141 | return F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + y 142 | 143 | def forward(self, x, dummy=None): 144 | # x: (B, 3, H, W) 145 | conv0 = self.conv0(x) # (B, 8, H, W) 146 | conv1 = self.conv1(conv0) # (B, 16, H//2, W//2) 147 | conv2 = self.conv2(conv1) # (B, 32, H//4, W//4) 148 | feat2 = self.toplayer(conv2) # (B, 32, H//4, W//4) 149 | feat1 = self._upsample_add(feat2, self.lat1(conv1)) # (B, 32, H//2, W//2) 150 | feat0 = self._upsample_add(feat1, self.lat0(conv0)) # (B, 32, H, W) 151 | 152 | # reduce output channels 153 | feat1 = self.smooth1(feat1) # (B, 16, H//2, W//2) 154 | feat0 = self.smooth0(feat0) # (B, 8, H, W) 155 | 156 | feats = {"level_0": feat0, "level_1": feat1, "level_2": feat2} 157 | 158 | return feats 159 | 160 | 161 | class CostRegNet(nn.Module): 162 | def __init__(self, in_channels, norm_act=InPlaceABN): 163 | super(CostRegNet, self).__init__() 164 | self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act) 165 | 166 | self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act) 167 | self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act) 168 | 169 | self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act) 170 | self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act) 171 | 172 | self.conv5 = ConvBnReLU3D(32, 64, stride=2, norm_act=norm_act) 173 | self.conv6 = ConvBnReLU3D(64, 64, norm_act=norm_act) 174 | 175 | self.conv7 = nn.Sequential( 176 | nn.ConvTranspose3d( 177 | 64, 32, 3, padding=1, output_padding=1, stride=2, bias=False 178 | ), 179 | norm_act(32), 180 | ) 181 | 182 | self.conv9 = nn.Sequential( 183 | nn.ConvTranspose3d( 184 | 32, 16, 3, padding=1, output_padding=1, stride=2, bias=False 185 | ), 186 | norm_act(16), 187 | ) 188 | 189 | self.conv11 = nn.Sequential( 190 | nn.ConvTranspose3d( 191 | 16, 8, 3, padding=1, output_padding=1, stride=2, bias=False 192 | ), 193 | norm_act(8), 194 | ) 195 | 196 | self.br1 = ConvBnReLU3D(8, 8, norm_act=norm_act) 197 | self.br2 = ConvBnReLU3D(8, 8, norm_act=norm_act) 198 | 199 | self.prob = nn.Conv3d(8, 1, 3, stride=1, padding=1) 200 | 201 | def forward(self, x): 202 | if x.shape[-2] % 8 != 0 or x.shape[-1] % 8 != 0: 203 | pad_h = 8 * (x.shape[-2] // 8 + 1) - x.shape[-2] 204 | pad_w = 8 * (x.shape[-1] // 8 + 1) - x.shape[-1] 205 | x = F.pad(x, (0, pad_w, 0, pad_h), mode="constant", value=0) 206 | else: 207 | pad_h = 0 208 | pad_w = 0 209 | 210 | conv0 = self.conv0(x) 211 | conv2 = self.conv2(self.conv1(conv0)) 212 | conv4 = self.conv4(self.conv3(conv2)) 213 | 214 | x = self.conv6(self.conv5(conv4)) 215 | x = conv4 + self.conv7(x) 216 | del conv4 217 | x = conv2 + self.conv9(x) 218 | del conv2 219 | x = conv0 + self.conv11(x) 220 | del conv0 221 | #################### 222 | x1 = self.br1(x) 223 | with torch.enable_grad(): 224 | x2 = self.br2(x) 225 | #################### 226 | p = self.prob(x1) 227 | 228 | if pad_h > 0 or pad_w > 0: 229 | x2 = x2[..., :-pad_h, :-pad_w] 230 | p = p[..., :-pad_h, :-pad_w] 231 | 232 | return x2, p 233 | 234 | 235 | class CasMVSNet(nn.Module): 236 | def __init__(self, num_groups=8, norm_act=InPlaceABN, levels=3, use_depth=False): 237 | super(CasMVSNet, self).__init__() 238 | self.levels = levels # 3 depth levels 239 | self.n_depths = [8, 32, 48] 240 | self.interval_ratios = [1, 2, 4] 241 | self.use_depth = use_depth 242 | 243 | self.G = num_groups # number of groups in groupwise correlation 244 | self.feature = FeatureNet() 245 | 246 | for l in range(self.levels): 247 | if l == self.levels - 1 and self.use_depth: 248 | cost_reg_l = CostRegNet(self.G + 1, norm_act) 249 | else: 250 | cost_reg_l = CostRegNet(self.G, norm_act) 251 | 252 | setattr(self, f"cost_reg_{l}", cost_reg_l) 253 | 254 | def build_cost_volumes(self, feats, affine_mats, affine_mats_inv, depth_values, idx, spikes): 255 | B, V, C, H, W = feats.shape 256 | D = depth_values.shape[1] 257 | 258 | ref_feats, src_feats = feats[:, idx[0]], feats[:, idx[1:]] 259 | src_feats = src_feats.permute(1, 0, 2, 3, 4) # (V-1, B, C, h, w) 260 | 261 | affine_mats_inv = affine_mats_inv[:, idx[0]] 262 | affine_mats = affine_mats[:, idx[1:]] 263 | affine_mats = affine_mats.permute(1, 0, 2, 3) # (V-1, B, 3, 4) 264 | 265 | ref_volume = ref_feats.unsqueeze(2).repeat(1, 1, D, 1, 1) # (B, C, D, h, w) 266 | 267 | ref_volume = ref_volume.view(B, self.G, C // self.G, *ref_volume.shape[-3:]) 268 | volume_sum = 0 269 | 270 | for i in range(len(idx) - 1): 271 | proj_mat = (affine_mats[i].double() @ affine_mats_inv.double()).float()[ 272 | :, :3 273 | ] 274 | warped_volume, grid = homo_warp(src_feats[i], proj_mat, depth_values) 275 | 276 | warped_volume = warped_volume.view_as(ref_volume) 277 | volume_sum = volume_sum + warped_volume # (B, G, C//G, D, h, w) 278 | 279 | volume = (volume_sum * ref_volume).mean(dim=2) / (V - 1) 280 | 281 | if spikes is None: 282 | output = volume 283 | else: 284 | output = torch.cat([volume, spikes], dim=1) 285 | 286 | return output 287 | 288 | def create_neural_volume( 289 | self, 290 | feats, 291 | affine_mats, 292 | affine_mats_inv, 293 | idx, 294 | init_depth_min, 295 | depth_interval, 296 | gt_depths, 297 | ): 298 | if feats["level_0"].shape[-1] >= 800: 299 | hres_input = True 300 | else: 301 | hres_input = False 302 | 303 | B, V = affine_mats.shape[:2] 304 | 305 | v_feat = {} 306 | depth_maps = {} 307 | depth_values = {} 308 | for l in reversed(range(self.levels)): # (2, 1, 0) 309 | feats_l = feats[f"level_{l}"] # (B*V, C, h, w) 310 | feats_l = feats_l.view(B, V, *feats_l.shape[1:]) # (B, V, C, h, w) 311 | h, w = feats_l.shape[-2:] 312 | depth_interval_l = depth_interval * self.interval_ratios[l] 313 | D = self.n_depths[l] 314 | if l == self.levels - 1: # coarsest level 315 | depth_values_l = init_depth_min + depth_interval_l * torch.arange( 316 | 0, D, device=feats_l.device, dtype=feats_l.dtype 317 | ) # (D) 318 | depth_values_l = depth_values_l[None, :, None, None].expand( 319 | -1, -1, h, w 320 | ) 321 | 322 | if self.use_depth: 323 | gt_mask = gt_depths > 0 324 | sp_idx_float = ( 325 | gt_mask * (gt_depths - init_depth_min) / (depth_interval_l) 326 | )[:, :, None] 327 | spikes = ( 328 | torch.arange(D).view(1, 1, -1, 1, 1).cuda() 329 | == sp_idx_float.floor().long() 330 | ) * (1 - sp_idx_float.frac()) 331 | spikes = spikes + ( 332 | torch.arange(D).view(1, 1, -1, 1, 1).cuda() 333 | == sp_idx_float.ceil().long() 334 | ) * (sp_idx_float.frac()) 335 | spikes = (spikes * gt_mask[:, :, None]).float() 336 | else: 337 | depth_lm1 = depth_l.detach() # the depth of previous level 338 | depth_lm1 = F.interpolate( 339 | depth_lm1, scale_factor=2, mode="bilinear", align_corners=True 340 | ) # (B, 1, h, w) 341 | depth_values_l = get_depth_values(depth_lm1, D, depth_interval_l) 342 | 343 | affine_mats_l = affine_mats[..., l] 344 | affine_mats_inv_l = affine_mats_inv[..., l] 345 | 346 | if l == self.levels - 1 and self.use_depth: 347 | spikes_ = spikes 348 | else: 349 | spikes_ = None 350 | 351 | if hres_input: 352 | v_feat_l = checkpoint( 353 | self.build_cost_volumes, 354 | feats_l, 355 | affine_mats_l, 356 | affine_mats_inv_l, 357 | depth_values_l, 358 | idx, 359 | spikes_, 360 | preserve_rng_state=False, 361 | ) 362 | else: 363 | v_feat_l = self.build_cost_volumes( 364 | feats_l, 365 | affine_mats_l, 366 | affine_mats_inv_l, 367 | depth_values_l, 368 | idx, 369 | spikes_, 370 | ) 371 | 372 | cost_reg_l = getattr(self, f"cost_reg_{l}") 373 | v_feat_l, depth_prob = cost_reg_l(v_feat_l) # (B, 1, D, h, w) 374 | 375 | depth_l = (F.softmax(depth_prob, dim=2) * depth_values_l[:, None]).sum( 376 | dim=2 377 | ) 378 | 379 | v_feat[f"level_{l}"] = v_feat_l 380 | depth_maps[f"level_{l}"] = depth_l 381 | depth_values[f"level_{l}"] = depth_values_l 382 | 383 | return v_feat, depth_maps, depth_values 384 | 385 | def forward( 386 | self, imgs, affine_mats, affine_mats_inv, near_far, closest_idxs, gt_depths=None 387 | ): 388 | B, V, _, H, W = imgs.shape 389 | 390 | ## Feature Pyramid 391 | feats = self.feature( 392 | imgs.reshape(B * V, 3, H, W) 393 | ) # (B*V, 8, H, W), (B*V, 16, H//2, W//2), (B*V, 32, H//4, W//4) 394 | feats_fpn = feats[f"level_0"].reshape(B, V, *feats[f"level_0"].shape[1:]) 395 | 396 | feats_vol = {"level_0": [], "level_1": [], "level_2": []} 397 | depth_map = {"level_0": [], "level_1": [], "level_2": []} 398 | depth_values = {"level_0": [], "level_1": [], "level_2": []} 399 | ## Create cost volumes for each view 400 | for i in range(0, V): 401 | permuted_idx = torch.tensor(closest_idxs[0, i]).cuda() 402 | 403 | init_depth_min = near_far[0, i, 0] 404 | depth_interval = ( 405 | (near_far[0, i, 1] - near_far[0, i, 0]) 406 | / self.n_depths[-1] 407 | / self.interval_ratios[-1] 408 | ) 409 | 410 | v_feat, d_map, d_values = self.create_neural_volume( 411 | feats, 412 | affine_mats, 413 | affine_mats_inv, 414 | idx=permuted_idx, 415 | init_depth_min=init_depth_min, 416 | depth_interval=depth_interval, 417 | gt_depths=gt_depths[:, i : i + 1], 418 | ) 419 | 420 | for l in range(3): 421 | feats_vol[f"level_{l}"].append(v_feat[f"level_{l}"]) 422 | depth_map[f"level_{l}"].append(d_map[f"level_{l}"]) 423 | depth_values[f"level_{l}"].append(d_values[f"level_{l}"]) 424 | 425 | for l in range(3): 426 | feats_vol[f"level_{l}"] = torch.stack(feats_vol[f"level_{l}"], dim=1) 427 | depth_map[f"level_{l}"] = torch.cat(depth_map[f"level_{l}"], dim=1) 428 | depth_values[f"level_{l}"] = torch.stack(depth_values[f"level_{l}"], dim=1) 429 | 430 | return feats_vol, feats_fpn, depth_map, depth_values 431 | -------------------------------------------------------------------------------- /model/self_attn_renderer.py: -------------------------------------------------------------------------------- 1 | # GeoNeRF is a generalizable NeRF model that renders novel views 2 | # without requiring per-scene optimization. This software is the 3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with 4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin, 5 | # and Francois Fleuret. 6 | 7 | # Copyright (c) 2022 ams International AG 8 | 9 | # This file is part of GeoNeRF. 10 | # GeoNeRF is free software: you can redistribute it and/or modify 11 | # it under the terms of the GNU General Public License version 3 as 12 | # published by the Free Software Foundation. 13 | 14 | # GeoNeRF is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | 19 | # You should have received a copy of the GNU General Public License 20 | # along with GeoNeRF. If not, see . 21 | 22 | # This file incorporates work covered by the following copyright and 23 | # permission notice: 24 | 25 | # Copyright 2020 Google LLC 26 | # 27 | # Licensed under the Apache License, Version 2.0 (the "License"); 28 | # you may not use this file except in compliance with the License. 29 | # You may obtain a copy of the License at 30 | # 31 | # https://www.apache.org/licenses/LICENSE-2.0 32 | # 33 | # Unless required by applicable law or agreed to in writing, software 34 | # distributed under the License is distributed on an "AS IS" BASIS, 35 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 36 | # See the License for the specific language governing permissions and 37 | # limitations under the License. 38 | 39 | import torch 40 | import torch.nn as nn 41 | import torch.nn.functional as F 42 | 43 | import math 44 | 45 | def weights_init(m): 46 | if isinstance(m, nn.Linear): 47 | stdv = 1.0 / math.sqrt(m.weight.size(1)) 48 | m.weight.data.uniform_(-stdv, stdv) 49 | if m.bias is not None: 50 | m.bias.data.uniform_(stdv, stdv) 51 | 52 | 53 | def masked_softmax(x, mask, **kwargs): 54 | x_masked = x.masked_fill(mask == 0, -float("inf")) 55 | 56 | return torch.softmax(x_masked, **kwargs) 57 | 58 | 59 | ## Auto-encoder network 60 | class ConvAutoEncoder(nn.Module): 61 | def __init__(self, num_ch, S): 62 | super(ConvAutoEncoder, self).__init__() 63 | 64 | # Encoder 65 | self.conv1 = nn.Sequential( 66 | nn.Conv1d(num_ch, num_ch * 2, 3, stride=1, padding=1), 67 | nn.LayerNorm(S, elementwise_affine=False), 68 | nn.ELU(alpha=1.0, inplace=True), 69 | nn.MaxPool1d(2), 70 | ) 71 | self.conv2 = nn.Sequential( 72 | nn.Conv1d(num_ch * 2, num_ch * 4, 3, stride=1, padding=1), 73 | nn.LayerNorm(S // 2, elementwise_affine=False), 74 | nn.ELU(alpha=1.0, inplace=True), 75 | nn.MaxPool1d(2), 76 | ) 77 | self.conv3 = nn.Sequential( 78 | nn.Conv1d(num_ch * 4, num_ch * 4, 3, stride=1, padding=1), 79 | nn.LayerNorm(S // 4, elementwise_affine=False), 80 | nn.ELU(alpha=1.0, inplace=True), 81 | nn.MaxPool1d(2), 82 | ) 83 | 84 | # Decoder 85 | self.t_conv1 = nn.Sequential( 86 | nn.ConvTranspose1d(num_ch * 4, num_ch * 4, 4, stride=2, padding=1), 87 | nn.LayerNorm(S // 4, elementwise_affine=False), 88 | nn.ELU(alpha=1.0, inplace=True), 89 | ) 90 | self.t_conv2 = nn.Sequential( 91 | nn.ConvTranspose1d(num_ch * 8, num_ch * 2, 4, stride=2, padding=1), 92 | nn.LayerNorm(S // 2, elementwise_affine=False), 93 | nn.ELU(alpha=1.0, inplace=True), 94 | ) 95 | self.t_conv3 = nn.Sequential( 96 | nn.ConvTranspose1d(num_ch * 4, num_ch, 4, stride=2, padding=1), 97 | nn.LayerNorm(S, elementwise_affine=False), 98 | nn.ELU(alpha=1.0, inplace=True), 99 | ) 100 | # Output 101 | self.conv_out = nn.Sequential( 102 | nn.Conv1d(num_ch * 2, num_ch, 3, stride=1, padding=1), 103 | nn.LayerNorm(S, elementwise_affine=False), 104 | nn.ELU(alpha=1.0, inplace=True), 105 | ) 106 | 107 | def forward(self, x): 108 | input = x 109 | x = self.conv1(x) 110 | conv1_out = x 111 | x = self.conv2(x) 112 | conv2_out = x 113 | x = self.conv3(x) 114 | 115 | x = self.t_conv1(x) 116 | x = self.t_conv2(torch.cat([x, conv2_out], dim=1)) 117 | x = self.t_conv3(torch.cat([x, conv1_out], dim=1)) 118 | 119 | x = self.conv_out(torch.cat([x, input], dim=1)) 120 | 121 | return x 122 | 123 | 124 | class ScaledDotProductAttention(nn.Module): 125 | def __init__(self, temperature, attn_dropout=0.1): 126 | super().__init__() 127 | self.temperature = temperature 128 | 129 | def forward(self, q, k, v, mask=None): 130 | 131 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 132 | 133 | if mask is not None: 134 | attn = masked_softmax(attn, mask, dim=-1) 135 | else: 136 | attn = F.softmax(attn, dim=-1) 137 | 138 | output = torch.matmul(attn, v) 139 | 140 | return output, attn 141 | 142 | 143 | class PositionwiseFeedForward(nn.Module): 144 | def __init__(self, d_in, d_hid, dropout=0.1): 145 | super().__init__() 146 | self.w_1 = nn.Linear(d_in, d_hid) # position-wise 147 | self.w_2 = nn.Linear(d_hid, d_in) # position-wise 148 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) 149 | 150 | def forward(self, x): 151 | 152 | residual = x 153 | 154 | x = self.w_2(F.relu(self.w_1(x))) 155 | x += residual 156 | 157 | x = self.layer_norm(x) 158 | 159 | return x 160 | 161 | 162 | class MultiHeadAttention(nn.Module): 163 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 164 | super().__init__() 165 | 166 | self.n_head = n_head 167 | self.d_k = d_k 168 | self.d_v = d_v 169 | 170 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) 171 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) 172 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) 173 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False) 174 | 175 | self.attention = ScaledDotProductAttention(temperature=d_k**0.5) 176 | 177 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 178 | 179 | def forward(self, q, k, v, mask=None): 180 | 181 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 182 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) 183 | 184 | residual = q 185 | 186 | # Pass through the pre-attention projection: b x lq x (n*dv) 187 | # Separate different heads: b x lq x n x dv 188 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 189 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 190 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 191 | 192 | # Transpose for attention dot product: b x n x lq x dv 193 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 194 | 195 | if mask is not None: 196 | mask = mask.transpose(1, 2).unsqueeze(1) # For head axis broadcasting. 197 | 198 | q, attn = self.attention(q, k, v, mask=mask) 199 | 200 | # Transpose to move the head dimension back: b x lq x n x dv 201 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) 202 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 203 | q = self.fc(q) 204 | q += residual 205 | 206 | q = self.layer_norm(q) 207 | 208 | return q, attn 209 | 210 | 211 | class EncoderLayer(nn.Module): 212 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0): 213 | super(EncoderLayer, self).__init__() 214 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 215 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 216 | 217 | def forward(self, enc_input, slf_attn_mask=None): 218 | enc_output, enc_slf_attn = self.slf_attn( 219 | enc_input, enc_input, enc_input, mask=slf_attn_mask 220 | ) 221 | enc_output = self.pos_ffn(enc_output) 222 | return enc_output, enc_slf_attn 223 | 224 | 225 | class Renderer(nn.Module): 226 | def __init__(self, nb_samples_per_ray): 227 | super(Renderer, self).__init__() 228 | 229 | self.dim = 32 230 | self.attn_token_gen = nn.Linear(24 + 1 + 8, self.dim) 231 | 232 | ## Self-Attention Settings 233 | d_inner = self.dim 234 | n_head = 4 235 | d_k = self.dim // n_head 236 | d_v = self.dim // n_head 237 | num_layers = 4 238 | self.attn_layers = nn.ModuleList( 239 | [ 240 | EncoderLayer(self.dim, d_inner, n_head, d_k, d_v) 241 | for i in range(num_layers) 242 | ] 243 | ) 244 | 245 | ## Processing the mean and variance of input features 246 | self.var_mean_fc1 = nn.Linear(16, self.dim) 247 | self.var_mean_fc2 = nn.Linear(self.dim, self.dim) 248 | 249 | ## Setting mask of var_mean always enabled 250 | self.var_mean_mask = torch.tensor([1]).cuda() 251 | self.var_mean_mask.requires_grad = False 252 | 253 | ## For aggregating data along ray samples 254 | self.auto_enc = ConvAutoEncoder(self.dim, nb_samples_per_ray) 255 | 256 | self.sigma_fc1 = nn.Linear(self.dim, self.dim) 257 | self.sigma_fc2 = nn.Linear(self.dim, self.dim // 2) 258 | self.sigma_fc3 = nn.Linear(self.dim // 2, 1) 259 | 260 | self.rgb_fc1 = nn.Linear(self.dim + 9, self.dim) 261 | self.rgb_fc2 = nn.Linear(self.dim, self.dim // 2) 262 | self.rgb_fc3 = nn.Linear(self.dim // 2, 1) 263 | 264 | ## Initialization 265 | self.sigma_fc3.apply(weights_init) 266 | 267 | def forward(self, viewdirs, feat, occ_masks): 268 | ## Viewing samples regardless of batch or ray 269 | N, S, V = feat.shape[:3] 270 | feat = feat.view(-1, *feat.shape[2:]) 271 | v_feat = feat[..., :24] 272 | s_feat = feat[..., 24 : 24 + 8] 273 | colors = feat[..., 24 + 8 : -1] 274 | vis_mask = feat[..., -1:].detach() 275 | 276 | occ_masks = occ_masks.view(-1, *occ_masks.shape[2:]) 277 | viewdirs = viewdirs.view(-1, *viewdirs.shape[2:]) 278 | 279 | ## Mean and variance of 2D features provide view-independent tokens 280 | var_mean = torch.var_mean(s_feat, dim=1, unbiased=False, keepdim=True) 281 | var_mean = torch.cat(var_mean, dim=-1) 282 | var_mean = F.elu(self.var_mean_fc1(var_mean)) 283 | var_mean = F.elu(self.var_mean_fc2(var_mean)) 284 | 285 | ## Converting the input features to tokens (view-dependent) before self-attention 286 | tokens = F.elu( 287 | self.attn_token_gen(torch.cat([v_feat, vis_mask, s_feat], dim=-1)) 288 | ) 289 | tokens = torch.cat([tokens, var_mean], dim=1) 290 | 291 | ## Adding a new channel to mask for var_mean 292 | vis_mask = torch.cat( 293 | [vis_mask, self.var_mean_mask.view(1, 1, 1).expand(N * S, -1, -1)], dim=1 294 | ) 295 | ## If a point is not visible by any source view, force its masks to enabled 296 | vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1) 297 | 298 | ## Taking occ_masks into account, but remembering if there were any visibility before that 299 | mask_cloned = vis_mask.clone() 300 | vis_mask[:, :-1] *= occ_masks 301 | vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1) 302 | masks = vis_mask * mask_cloned 303 | 304 | ## Performing self-attention 305 | for layer in self.attn_layers: 306 | tokens, _ = layer(tokens, masks) 307 | 308 | ## Predicting sigma with an Auto-Encoder and MLP 309 | sigma_tokens = tokens[:, -1:] 310 | sigma_tokens = sigma_tokens.view(N, S, self.dim).transpose(1, 2) 311 | sigma_tokens = self.auto_enc(sigma_tokens) 312 | sigma_tokens = sigma_tokens.transpose(1, 2).reshape(N * S, 1, self.dim) 313 | 314 | sigma_tokens = F.elu(self.sigma_fc1(sigma_tokens)) 315 | sigma_tokens = F.elu(self.sigma_fc2(sigma_tokens)) 316 | sigma = torch.relu(self.sigma_fc3(sigma_tokens[:, 0])) 317 | 318 | ## Concatenating positional encodings and predicting RGB weights 319 | rgb_tokens = torch.cat([tokens[:, :-1], viewdirs], dim=-1) 320 | rgb_tokens = F.elu(self.rgb_fc1(rgb_tokens)) 321 | rgb_tokens = F.elu(self.rgb_fc2(rgb_tokens)) 322 | rgb_w = self.rgb_fc3(rgb_tokens) 323 | rgb_w = masked_softmax(rgb_w, masks[:, :-1], dim=1) 324 | 325 | rgb = (colors * rgb_w).sum(1) 326 | 327 | outputs = torch.cat([rgb, sigma], -1) 328 | outputs = outputs.reshape(N, S, -1) 329 | 330 | return outputs 331 | -------------------------------------------------------------------------------- /pretrained_weights/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning==1.3.7 2 | inplace_abn 3 | imageio 4 | pillow 5 | scikit-image 6 | opencv-python 7 | ConfigArgParse 8 | lpips 9 | kornia 10 | ipdb -------------------------------------------------------------------------------- /run_geo_nerf.py: -------------------------------------------------------------------------------- 1 | # GeoNeRF is a generalizable NeRF model that renders novel views 2 | # without requiring per-scene optimization. This software is the 3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with 4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin, 5 | # and Francois Fleuret. 6 | 7 | # Copyright (c) 2022 ams International AG 8 | 9 | # This file is part of GeoNeRF. 10 | # GeoNeRF is free software: you can redistribute it and/or modify 11 | # it under the terms of the GNU General Public License version 3 as 12 | # published by the Free Software Foundation. 13 | 14 | # GeoNeRF is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | 19 | # You should have received a copy of the GNU General Public License 20 | # along with GeoNeRF. If not, see . 21 | 22 | # This file incorporates work covered by the following copyright and 23 | # permission notice: 24 | 25 | # MIT License 26 | 27 | # Copyright (c) 2021 apchenstu 28 | 29 | # Permission is hereby granted, free of charge, to any person obtaining a copy 30 | # of this software and associated documentation files (the "Software"), to deal 31 | # in the Software without restriction, including without limitation the rights 32 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 33 | # copies of the Software, and to permit persons to whom the Software is 34 | # furnished to do so, subject to the following conditions: 35 | 36 | # The above copyright notice and this permission notice shall be included in all 37 | # copies or substantial portions of the Software. 38 | 39 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 40 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 41 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 42 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 43 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 44 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 45 | # SOFTWARE. 46 | 47 | import torch 48 | from torch.utils.data import DataLoader 49 | from torch.optim.lr_scheduler import CosineAnnealingLR 50 | 51 | from pytorch_lightning.callbacks import ModelCheckpoint 52 | from pytorch_lightning import LightningModule, Trainer, loggers 53 | from pytorch_lightning.loggers import WandbLogger 54 | 55 | import os 56 | import time 57 | import numpy as np 58 | import imageio 59 | import lpips 60 | from skimage.metrics import structural_similarity as ssim 61 | 62 | from model.geo_reasoner import CasMVSNet 63 | from model.self_attn_renderer import Renderer 64 | from utils.rendering import render_rays 65 | from utils.utils import ( 66 | load_ckpt, 67 | init_log, 68 | get_rays_pts, 69 | SL1Loss, 70 | self_supervision_loss, 71 | img2mse, 72 | mse2psnr, 73 | acc_threshold, 74 | abs_error, 75 | visualize_depth, 76 | ) 77 | from utils.options import config_parser 78 | from data.get_datasets import ( 79 | get_training_dataset, 80 | get_finetuning_dataset, 81 | get_validation_dataset, 82 | ) 83 | 84 | lpips_fn = lpips.LPIPS(net="vgg") 85 | 86 | class GeoNeRF(LightningModule): 87 | def __init__(self, hparams): 88 | super(GeoNeRF, self).__init__() 89 | self.hparams.update(vars(hparams)) 90 | self.wr_cntr = 0 91 | 92 | self.depth_loss = SL1Loss() 93 | self.learning_rate = hparams.lrate 94 | 95 | # Create geometry_reasoner and renderer models 96 | self.geo_reasoner = CasMVSNet(use_depth=hparams.use_depth).cuda() 97 | self.renderer = Renderer( 98 | nb_samples_per_ray=hparams.nb_coarse + hparams.nb_fine 99 | ).cuda() 100 | 101 | self.eval_metric = [0.01, 0.05, 0.1] 102 | 103 | self.automatic_optimization = False 104 | self.save_hyperparameters() 105 | 106 | def unpreprocess(self, data, shape=(1, 1, 3, 1, 1)): 107 | # to unnormalize image for visualization 108 | device = data.device 109 | mean = ( 110 | torch.tensor([-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225]) 111 | .view(*shape) 112 | .to(device) 113 | ) 114 | std = torch.tensor([1 / 0.229, 1 / 0.224, 1 / 0.225]).view(*shape).to(device) 115 | 116 | return (data - mean) / std 117 | 118 | def prepare_data(self): 119 | if self.hparams.scene == "None": ## Generalizable 120 | self.train_dataset, self.train_sampler = get_training_dataset(self.hparams) 121 | self.val_dataset = get_validation_dataset(self.hparams) 122 | else: ## Fine-tune 123 | self.train_dataset, self.train_sampler = get_finetuning_dataset( 124 | self.hparams 125 | ) 126 | self.val_dataset = get_validation_dataset(self.hparams) 127 | 128 | def configure_optimizers(self): 129 | eps = 1e-5 130 | 131 | opt = torch.optim.Adam( 132 | list(self.geo_reasoner.parameters()) + list(self.renderer.parameters()), 133 | lr=self.learning_rate, 134 | betas=(0.9, 0.999), 135 | ) 136 | sch = CosineAnnealingLR(opt, T_max=self.hparams.num_steps, eta_min=eps) 137 | 138 | return [opt], [sch] 139 | 140 | def train_dataloader(self): 141 | return DataLoader( 142 | self.train_dataset, 143 | sampler=self.train_sampler, 144 | shuffle=True if self.train_sampler is None else False, 145 | num_workers=8, 146 | batch_size=1, 147 | pin_memory=True, 148 | ) 149 | 150 | def val_dataloader(self): 151 | return DataLoader( 152 | self.val_dataset, 153 | shuffle=False, 154 | num_workers=1, 155 | batch_size=1, 156 | pin_memory=True, 157 | ) 158 | 159 | def training_step(self, batch, batch_nb): 160 | loss = 0 161 | nb_views = self.hparams.nb_views 162 | H, W = batch["images"].shape[-2:] 163 | H, W = int(H), int(W) 164 | 165 | ## Inferring Geometry Reasoner 166 | feats_vol, feats_fpn, depth_map, depth_values = self.geo_reasoner( 167 | imgs=batch["images"][:, :nb_views], 168 | affine_mats=batch["affine_mats"][:, :nb_views], 169 | affine_mats_inv=batch["affine_mats_inv"][:, :nb_views], 170 | near_far=batch["near_fars"][:, :nb_views], 171 | closest_idxs=batch["closest_idxs"][:, :nb_views], 172 | gt_depths=batch["depths_aug"][:, :nb_views], 173 | ) 174 | 175 | ## Normalizing depth maps in NDC coordinate 176 | depth_map_norm = {} 177 | for l in range(3): 178 | depth_map_norm[f"level_{l}"] = ( 179 | depth_map[f"level_{l}"].detach() - depth_values[f"level_{l}"][:, :, 0] 180 | ) / ( 181 | depth_values[f"level_{l}"][:, :, -1] 182 | - depth_values[f"level_{l}"][:, :, 0] 183 | ) 184 | 185 | unpre_imgs = self.unpreprocess(batch["images"]) 186 | 187 | ( 188 | pts_depth, 189 | rays_pts, 190 | rays_pts_ndc, 191 | rays_dir, 192 | rays_gt_rgb, 193 | rays_gt_depth, 194 | rays_pixs, 195 | ) = get_rays_pts( 196 | H, 197 | W, 198 | batch["c2ws"], 199 | batch["w2cs"], 200 | batch["intrinsics"], 201 | batch["near_fars"], 202 | depth_values, 203 | self.hparams.nb_coarse, 204 | self.hparams.nb_fine, 205 | nb_views=nb_views, 206 | train=True, 207 | train_batch_size=self.hparams.batch_size, 208 | target_img=unpre_imgs[0, -1], 209 | target_depth=batch["depths_h"][0, -1], 210 | ) 211 | 212 | ## Rendering 213 | rendered_rgb, rendered_depth = render_rays( 214 | c2ws=batch["c2ws"][0, :nb_views], 215 | rays_pts=rays_pts, 216 | rays_pts_ndc=rays_pts_ndc, 217 | pts_depth=pts_depth, 218 | rays_dir=rays_dir, 219 | feats_vol=feats_vol, 220 | feats_fpn=feats_fpn[:, :nb_views], 221 | imgs=unpre_imgs[:, :nb_views], 222 | depth_map_norm=depth_map_norm, 223 | renderer_net=self.renderer, 224 | ) 225 | 226 | # Supervising depth maps with either ground truth depth or self-supervision loss 227 | ## This loss is only used in the generalizable model 228 | if self.hparams.scene == "None": 229 | ## if ground truth is available 230 | if isinstance(batch["depths"], dict): 231 | loss = loss + 1 * self.depth_loss(depth_map, batch["depths"]) 232 | if loss != 0: 233 | self.log("train/dlossgt", loss.item(), prog_bar=False) 234 | else: 235 | loss = loss + 0.1 * self_supervision_loss( 236 | self.depth_loss, 237 | rays_pixs, 238 | rendered_depth.detach(), 239 | depth_map, 240 | rays_gt_rgb, 241 | unpre_imgs, 242 | rendered_rgb.detach(), 243 | batch["intrinsics"], 244 | batch["c2ws"], 245 | batch["w2cs"], 246 | ) 247 | if loss != 0: 248 | self.log("train/dlosspgt", loss.item(), prog_bar=False) 249 | 250 | mask = rays_gt_depth > 0 251 | depth_available = mask.sum() > 0 252 | 253 | ## Supervising ray depths 254 | if depth_available: 255 | ## This loss is only used in the generalizable model 256 | if self.hparams.scene == "None": 257 | loss = loss + 0.1 * self.depth_loss(rendered_depth, rays_gt_depth) 258 | 259 | self.log( 260 | f"train/acc_l_{self.eval_metric[0]}mm", 261 | acc_threshold( 262 | rendered_depth, rays_gt_depth, mask, self.eval_metric[0] 263 | ).mean(), 264 | prog_bar=False, 265 | ) 266 | self.log( 267 | f"train/acc_l_{self.eval_metric[1]}mm", 268 | acc_threshold( 269 | rendered_depth, rays_gt_depth, mask, self.eval_metric[1] 270 | ).mean(), 271 | prog_bar=False, 272 | ) 273 | self.log( 274 | f"train/acc_l_{self.eval_metric[2]}mm", 275 | acc_threshold( 276 | rendered_depth, rays_gt_depth, mask, self.eval_metric[2] 277 | ).mean(), 278 | prog_bar=False, 279 | ) 280 | 281 | abs_err = abs_error(rendered_depth, rays_gt_depth, mask).mean() 282 | self.log("train/abs_err", abs_err, prog_bar=False) 283 | 284 | ## Reconstruction loss 285 | mse_loss = img2mse(rendered_rgb, rays_gt_rgb) 286 | loss = loss + mse_loss 287 | 288 | with torch.no_grad(): 289 | self.log("train/loss", loss.item(), prog_bar=True) 290 | psnr = mse2psnr(mse_loss.detach()) 291 | self.log("train/PSNR", psnr.item(), prog_bar=False) 292 | self.log("train/img_mse_loss", mse_loss.item(), prog_bar=False) 293 | 294 | # Manual Optimization 295 | self.manual_backward(loss) 296 | 297 | opt = self.optimizers() 298 | sch = self.lr_schedulers() 299 | 300 | # Warming up the learning rate 301 | if self.trainer.global_step < self.hparams.warmup_steps: 302 | lr_scale = min( 303 | 1.0, float(self.trainer.global_step + 1) / self.hparams.warmup_steps 304 | ) 305 | for pg in opt.param_groups: 306 | pg["lr"] = lr_scale * self.learning_rate 307 | 308 | self.log("train/lr", opt.param_groups[0]["lr"], prog_bar=False) 309 | 310 | opt.step() 311 | opt.zero_grad() 312 | sch.step() 313 | 314 | return {"loss": loss} 315 | 316 | def validation_step(self, batch, batch_nb): 317 | ## This makes Batchnorm to behave like InstanceNorm 318 | self.geo_reasoner.train() 319 | 320 | log_keys = [ 321 | "val_psnr", 322 | "val_ssim", 323 | "val_lpips", 324 | "val_depth_loss_r", 325 | "val_abs_err", 326 | "mask_sum", 327 | ] + [f"val_acc_{i}mm" for i in self.eval_metric] 328 | log = {} 329 | log = init_log(log, log_keys) 330 | 331 | H, W = batch["images"].shape[-2:] 332 | H, W = int(H), int(W) 333 | 334 | nb_views = self.hparams.nb_views 335 | 336 | with torch.no_grad(): 337 | ## Inferring Geometry Reasoner 338 | feats_vol, feats_fpn, depth_map, depth_values = self.geo_reasoner( 339 | imgs=batch["images"][:, :nb_views], 340 | affine_mats=batch["affine_mats"][:, :nb_views], 341 | affine_mats_inv=batch["affine_mats_inv"][:, :nb_views], 342 | near_far=batch["near_fars"][:, :nb_views], 343 | closest_idxs=batch["closest_idxs"][:, :nb_views], 344 | gt_depths=batch["depths_aug"][:, :nb_views], 345 | ) 346 | 347 | ## Normalizing depth maps in NDC coordinate 348 | depth_map_norm = {} 349 | for l in range(3): 350 | depth_map_norm[f"level_{l}"] = ( 351 | depth_map[f"level_{l}"] - depth_values[f"level_{l}"][:, :, 0] 352 | ) / ( 353 | depth_values[f"level_{l}"][:, :, -1] 354 | - depth_values[f"level_{l}"][:, :, 0] 355 | ) 356 | 357 | unpre_imgs = self.unpreprocess(batch["images"]) 358 | 359 | rendered_rgb, rendered_depth = [], [] 360 | for chunk_idx in range( 361 | H * W // self.hparams.chunk + int(H * W % self.hparams.chunk > 0) 362 | ): 363 | pts_depth, rays_pts, rays_pts_ndc, rays_dir, _, _, _ = get_rays_pts( 364 | H, 365 | W, 366 | batch["c2ws"], 367 | batch["w2cs"], 368 | batch["intrinsics"], 369 | batch["near_fars"], 370 | depth_values, 371 | self.hparams.nb_coarse, 372 | self.hparams.nb_fine, 373 | nb_views=nb_views, 374 | chunk=self.hparams.chunk, 375 | chunk_idx=chunk_idx, 376 | ) 377 | 378 | ## Rendering 379 | rend_rgb, rend_depth = render_rays( 380 | c2ws=batch["c2ws"][0, :nb_views], 381 | rays_pts=rays_pts, 382 | rays_pts_ndc=rays_pts_ndc, 383 | pts_depth=pts_depth, 384 | rays_dir=rays_dir, 385 | feats_vol=feats_vol, 386 | feats_fpn=feats_fpn[:, :nb_views], 387 | imgs=unpre_imgs[:, :nb_views], 388 | depth_map_norm=depth_map_norm, 389 | renderer_net=self.renderer, 390 | ) 391 | rendered_rgb.append(rend_rgb) 392 | rendered_depth.append(rend_depth) 393 | rendered_rgb = torch.clamp( 394 | torch.cat(rendered_rgb).reshape(H, W, 3).permute(2, 0, 1), 0, 1 395 | ) 396 | rendered_depth = torch.cat(rendered_depth).reshape(H, W) 397 | 398 | ## Check if there is any ground truth depth information for the dataset 399 | depth_available = batch["depths_h"].sum() > 0 400 | 401 | ## Evaluate only on pixels with meaningful ground truth depths 402 | if depth_available: 403 | mask = batch["depths_h"] > 0 404 | img_gt_masked = (unpre_imgs[0, -1] * mask[0, -1][None]).cpu() 405 | rendered_rgb_masked = (rendered_rgb * mask[0, -1][None]).cpu() 406 | else: 407 | img_gt_masked = unpre_imgs[0, -1].cpu() 408 | rendered_rgb_masked = rendered_rgb.cpu() 409 | 410 | unpre_imgs = unpre_imgs.cpu() 411 | rendered_rgb, rendered_depth = rendered_rgb.cpu(), rendered_depth.cpu() 412 | img_err_abs = (rendered_rgb_masked - img_gt_masked).abs() 413 | 414 | depth_target = batch["depths_h"][0, -1].cpu() 415 | mask_target = depth_target > 0 416 | 417 | if depth_available: 418 | log["val_psnr"] = mse2psnr(torch.mean(img_err_abs[:, mask_target] ** 2)) 419 | else: 420 | log["val_psnr"] = mse2psnr(torch.mean(img_err_abs**2)) 421 | log["val_ssim"] = ssim( 422 | rendered_rgb_masked.permute(1, 2, 0).numpy(), 423 | img_gt_masked.permute(1, 2, 0).numpy(), 424 | data_range=1, 425 | multichannel=True, 426 | ) 427 | log["val_lpips"] = lpips_fn( 428 | rendered_rgb_masked[None] * 2 - 1, img_gt_masked[None] * 2 - 1 429 | ).item() # Normalize to [-1,1] 430 | 431 | depth_minmax = [ 432 | 0.9 * batch["near_fars"].min().detach().cpu().numpy(), 433 | 1.1 * batch["near_fars"].max().detach().cpu().numpy(), 434 | ] 435 | rendered_depth_vis, _ = visualize_depth(rendered_depth, depth_minmax) 436 | 437 | if depth_available: 438 | log["val_abs_err"] = abs_error( 439 | rendered_depth, depth_target, mask_target 440 | ).sum() 441 | log[f"val_acc_{self.eval_metric[0]}mm"] = acc_threshold( 442 | rendered_depth, depth_target, mask_target, self.eval_metric[0] 443 | ).sum() 444 | log[f"val_acc_{self.eval_metric[1]}mm"] = acc_threshold( 445 | rendered_depth, depth_target, mask_target, self.eval_metric[1] 446 | ).sum() 447 | log[f"val_acc_{self.eval_metric[2]}mm"] = acc_threshold( 448 | rendered_depth, depth_target, mask_target, self.eval_metric[2] 449 | ).sum() 450 | log["mask_sum"] = mask_target.float().sum() 451 | 452 | img_vis = ( 453 | torch.cat( 454 | ( 455 | unpre_imgs[:, -1], 456 | torch.stack([rendered_rgb, img_err_abs * 5]), 457 | rendered_depth_vis[None], 458 | ), 459 | dim=0, 460 | ) 461 | .clip(0, 1) 462 | .permute(2, 0, 3, 1) 463 | .reshape(H, -1, 3) 464 | .numpy() 465 | ) 466 | 467 | os.makedirs( 468 | f"{self.hparams.logdir}/{self.hparams.dataset_name}/{self.hparams.expname}/rendered_results/", 469 | exist_ok=True, 470 | ) 471 | imageio.imwrite( 472 | f"{self.hparams.logdir}/{self.hparams.dataset_name}/{self.hparams.expname}/rendered_results/{self.wr_cntr:03d}.png", 473 | ( 474 | rendered_rgb.detach().permute(1, 2, 0).clip(0.0, 1.0).cpu().numpy() 475 | * 255 476 | ).astype("uint8"), 477 | ) 478 | 479 | os.makedirs( 480 | f"{self.hparams.logdir}/{self.hparams.dataset_name}/{self.hparams.expname}/evaluation/", 481 | exist_ok=True, 482 | ) 483 | imageio.imwrite( 484 | f"{self.hparams.logdir}/{self.hparams.dataset_name}/{self.hparams.expname}/evaluation/{self.global_step:08d}_{self.wr_cntr:02d}.png", 485 | (img_vis * 255).astype("uint8"), 486 | ) 487 | 488 | print(f"Image {self.wr_cntr:02d} rendered.") 489 | self.wr_cntr += 1 490 | 491 | return log 492 | 493 | def validation_epoch_end(self, outputs): 494 | mean_psnr = torch.stack([x["val_psnr"] for x in outputs]).mean() 495 | mean_ssim = np.stack([x["val_ssim"] for x in outputs]).mean() 496 | mean_lpips = np.stack([x["val_lpips"] for x in outputs]).mean() 497 | mask_sum = torch.stack([x["mask_sum"] for x in outputs]).sum() 498 | mean_d_loss_r = torch.stack([x["val_depth_loss_r"] for x in outputs]).mean() 499 | mean_abs_err = torch.stack([x["val_abs_err"] for x in outputs]).sum() / mask_sum 500 | mean_acc_1mm = ( 501 | torch.stack([x[f"val_acc_{self.eval_metric[0]}mm"] for x in outputs]).sum() 502 | / mask_sum 503 | ) 504 | mean_acc_2mm = ( 505 | torch.stack([x[f"val_acc_{self.eval_metric[1]}mm"] for x in outputs]).sum() 506 | / mask_sum 507 | ) 508 | mean_acc_4mm = ( 509 | torch.stack([x[f"val_acc_{self.eval_metric[2]}mm"] for x in outputs]).sum() 510 | / mask_sum 511 | ) 512 | 513 | self.log("val/PSNR", mean_psnr, prog_bar=False) 514 | self.log("val/SSIM", mean_ssim, prog_bar=False) 515 | self.log("val/LPIPS", mean_lpips, prog_bar=False) 516 | if mask_sum > 0: 517 | self.log("val/d_loss_r", mean_d_loss_r, prog_bar=False) 518 | self.log("val/abs_err", mean_abs_err, prog_bar=False) 519 | self.log(f"val/acc_{self.eval_metric[0]}mm", mean_acc_1mm, prog_bar=False) 520 | self.log(f"val/acc_{self.eval_metric[1]}mm", mean_acc_2mm, prog_bar=False) 521 | self.log(f"val/acc_{self.eval_metric[2]}mm", mean_acc_4mm, prog_bar=False) 522 | 523 | with open( 524 | f"{self.hparams.logdir}/{self.hparams.dataset_name}/{self.hparams.expname}/{self.hparams.expname}_metrics.txt", 525 | "w", 526 | ) as metric_file: 527 | metric_file.write(f"PSNR: {mean_psnr}\n") 528 | metric_file.write(f"SSIM: {mean_ssim}\n") 529 | metric_file.write(f"LPIPS: {mean_lpips}") 530 | 531 | return 532 | 533 | 534 | if __name__ == "__main__": 535 | torch.set_default_dtype(torch.float32) 536 | args = config_parser() 537 | geonerf = GeoNeRF(args) 538 | 539 | ## Checking to logdir to see if there is any checkpoint file to continue with 540 | ckpt_path = f"{args.logdir}/{args.dataset_name}/{args.expname}/ckpts" 541 | if os.path.isdir(ckpt_path) and len(os.listdir(ckpt_path)) > 0: 542 | ckpt_file = os.path.join(ckpt_path, os.listdir(ckpt_path)[-1]) 543 | else: 544 | ckpt_file = None 545 | 546 | ## Setting a callback to automatically save checkpoints 547 | checkpoint_callback = ModelCheckpoint( 548 | f"{args.logdir}/{args.dataset_name}/{args.expname}/ckpts", 549 | filename="ckpt_step-{step:06d}", 550 | auto_insert_metric_name=False, 551 | save_top_k=-1, 552 | ) 553 | 554 | ## Setting up a logger 555 | if args.logger == "wandb": 556 | logger = WandbLogger( 557 | name=args.expname, 558 | project="GeoNeRF", 559 | save_dir=f"{args.logdir}", 560 | resume="allow", 561 | id=args.expname, 562 | ) 563 | elif args.logger == "tensorboard": 564 | logger = loggers.TestTubeLogger( 565 | save_dir=f"{args.logdir}/{args.dataset_name}/{args.expname}", 566 | name=args.expname + "_logs", 567 | debug=False, 568 | create_git_tag=False, 569 | ) 570 | else: 571 | logger = None 572 | 573 | args.use_amp = False if args.eval else True 574 | trainer = Trainer( 575 | max_steps=args.num_steps, 576 | callbacks=checkpoint_callback, 577 | checkpoint_callback=True, 578 | resume_from_checkpoint=ckpt_file, 579 | logger=logger, 580 | progress_bar_refresh_rate=1, 581 | gpus=1, 582 | num_sanity_val_steps=0, 583 | val_check_interval=2000 if args.scene == "None" else 1.0, 584 | check_val_every_n_epoch=1000 if args.scene != 'None' else 1, 585 | benchmark=True, 586 | precision=16 if args.use_amp else 32, 587 | amp_level="O1", 588 | ) 589 | 590 | if not args.eval: ## Train 591 | if args.scene != "None": ## Fine-tune 592 | if args.use_depth: 593 | ckpt_file = "pretrained_weights/pretrained_w_depth.ckpt" 594 | else: 595 | ckpt_file = "pretrained_weights/pretrained.ckpt" 596 | load_ckpt(geonerf.geo_reasoner, ckpt_file, "geo_reasoner") 597 | load_ckpt(geonerf.renderer, ckpt_file, "renderer") 598 | elif not args.use_depth: ## Generalizable 599 | ## Loading the pretrained weights from Cascade MVSNet 600 | torch.utils.model_zoo.load_url( 601 | "https://github.com/kwea123/CasMVSNet_pl/releases/download/1.5/epoch.15.ckpt", 602 | model_dir="pretrained_weights", 603 | ) 604 | ckpt_file = "pretrained_weights/epoch.15.ckpt" 605 | load_ckpt(geonerf.geo_reasoner, ckpt_file, "model", strict=False) 606 | 607 | trainer.fit(geonerf) 608 | else: ## Eval 609 | geonerf = GeoNeRF(args) 610 | 611 | if ckpt_file is None: 612 | if args.use_depth: 613 | ckpt_file = "pretrained_weights/pretrained_w_depth.ckpt" 614 | else: 615 | ckpt_file = "pretrained_weights/pretrained.ckpt" 616 | load_ckpt(geonerf.geo_reasoner, ckpt_file, "geo_reasoner") 617 | load_ckpt(geonerf.renderer, ckpt_file, "renderer") 618 | 619 | trainer.validate(geonerf) 620 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/GeoNeRF/e6249fdae5672853c6bbbd4ba380c4c166d02c95/utils/__init__.py -------------------------------------------------------------------------------- /utils/options.py: -------------------------------------------------------------------------------- 1 | # GeoNeRF is a generalizable NeRF model that renders novel views 2 | # without requiring per-scene optimization. This software is the 3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with 4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin, 5 | # and Francois Fleuret. 6 | 7 | # Copyright (c) 2022 ams International AG 8 | 9 | # This file is part of GeoNeRF. 10 | # GeoNeRF is free software: you can redistribute it and/or modify 11 | # it under the terms of the GNU General Public License version 3 as 12 | # published by the Free Software Foundation. 13 | 14 | # GeoNeRF is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | 19 | # You should have received a copy of the GNU General Public License 20 | # along with GeoNeRF. If not, see . 21 | 22 | import configargparse 23 | 24 | def config_parser(): 25 | parser = configargparse.ArgumentParser() 26 | parser.add_argument("--config", is_config_file=True, help="Config file path") 27 | 28 | # Datasets options 29 | parser.add_argument("--dataset_name", type=str, default="llff", choices=["llff", "nerf", "dtu"],) 30 | parser.add_argument("--llff_path", type=str, help="Path to llff dataset") 31 | parser.add_argument("--llff_test_path", type=str, help="Path to llff dataset") 32 | parser.add_argument("--dtu_path", type=str, help="Path to dtu dataset") 33 | parser.add_argument("--dtu_pre_path", type=str, help="Path to preprocessed dtu dataset") 34 | parser.add_argument("--nerf_path", type=str, help="Path to nerf dataset") 35 | parser.add_argument("--ams_path", type=str, help="Path to ams dataset") 36 | parser.add_argument("--ibrnet1_path", type=str, help="Path to ibrnet1 dataset") 37 | parser.add_argument("--ibrnet2_path", type=str, help="Path to ibrnet2 dataset") 38 | 39 | # Training options 40 | parser.add_argument("--batch_size", type=int, default=512) 41 | parser.add_argument("--num_steps", type=int, default=200000) 42 | parser.add_argument("--nb_views", type=int, default=3) 43 | parser.add_argument("--lrate", type=float, default=5e-4, help="Learning rate") 44 | parser.add_argument("--warmup_steps", type=int, default=500, help="Gradually warm-up learning rate in optimizer") 45 | parser.add_argument("--scene", type=str, default="None", help="Scene for fine-tuning") 46 | 47 | # Rendering options 48 | parser.add_argument("--chunk", type=int, default=4096, help="Number of rays rendered in parallel") 49 | parser.add_argument("--nb_coarse", type=int, default=96, help="Number of coarse samples per ray") 50 | parser.add_argument("--nb_fine", type=int, default=32, help="Number of additional fine samples per ray",) 51 | 52 | # Other options 53 | parser.add_argument("--expname", type=str, help="Experiment name") 54 | parser.add_argument("--logger", type=str, default="tensorboard", choices=["wandb", "tensorboard", "none"]) 55 | parser.add_argument("--logdir", type=str, default="./logs/", help="Where to store ckpts and logs") 56 | parser.add_argument("--eval", action="store_true", help="Render and evaluate the test set") 57 | parser.add_argument("--use_depth", action="store_true", help="Use ground truth low-res depth maps in rendering process") 58 | 59 | return parser.parse_args() 60 | -------------------------------------------------------------------------------- /utils/rendering.py: -------------------------------------------------------------------------------- 1 | # GeoNeRF is a generalizable NeRF model that renders novel views 2 | # without requiring per-scene optimization. This software is the 3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with 4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin, 5 | # and Francois Fleuret. 6 | 7 | # Copyright (c) 2022 ams International AG 8 | 9 | # This file is part of GeoNeRF. 10 | # GeoNeRF is free software: you can redistribute it and/or modify 11 | # it under the terms of the GNU General Public License version 3 as 12 | # published by the Free Software Foundation. 13 | 14 | # GeoNeRF is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | 19 | # You should have received a copy of the GNU General Public License 20 | # along with GeoNeRF. If not, see . 21 | 22 | # This file incorporates work covered by the following copyright and 23 | # permission notice: 24 | 25 | # MIT License 26 | 27 | # Copyright (c) 2021 apchenstu 28 | 29 | # Permission is hereby granted, free of charge, to any person obtaining a copy 30 | # of this software and associated documentation files (the "Software"), to deal 31 | # in the Software without restriction, including without limitation the rights 32 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 33 | # copies of the Software, and to permit persons to whom the Software is 34 | # furnished to do so, subject to the following conditions: 35 | 36 | # The above copyright notice and this permission notice shall be included in all 37 | # copies or substantial portions of the Software. 38 | 39 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 40 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 41 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 42 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 43 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 44 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 45 | # SOFTWARE. 46 | 47 | import torch 48 | import torch.nn.functional as F 49 | 50 | from utils.utils import normal_vect, interpolate_3D, interpolate_2D 51 | 52 | 53 | class Embedder: 54 | def __init__(self, **kwargs): 55 | self.kwargs = kwargs 56 | self.create_embedding_fn() 57 | 58 | def create_embedding_fn(self): 59 | embed_fns = [] 60 | 61 | if self.kwargs["include_input"]: 62 | embed_fns.append(lambda x: x) 63 | 64 | max_freq = self.kwargs["max_freq_log2"] 65 | N_freqs = self.kwargs["num_freqs"] 66 | 67 | if self.kwargs["log_sampling"]: 68 | freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs) 69 | else: 70 | freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs) 71 | self.freq_bands = freq_bands.reshape(1, -1, 1).cuda() 72 | 73 | for freq in freq_bands: 74 | for p_fn in self.kwargs["periodic_fns"]: 75 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 76 | 77 | self.embed_fns = embed_fns 78 | 79 | def embed(self, inputs): 80 | repeat = inputs.dim() - 1 81 | inputs_scaled = ( 82 | inputs.unsqueeze(-2) * self.freq_bands.view(*[1] * repeat, -1, 1) 83 | ).reshape(*inputs.shape[:-1], -1) 84 | inputs_scaled = torch.cat( 85 | (inputs, torch.sin(inputs_scaled), torch.cos(inputs_scaled)), dim=-1 86 | ) 87 | return inputs_scaled 88 | 89 | 90 | def get_embedder(multires=4): 91 | 92 | embed_kwargs = { 93 | "include_input": True, 94 | "max_freq_log2": multires - 1, 95 | "num_freqs": multires, 96 | "log_sampling": True, 97 | "periodic_fns": [torch.sin, torch.cos], 98 | } 99 | 100 | embedder_obj = Embedder(**embed_kwargs) 101 | embed = lambda x, eo=embedder_obj: eo.embed(x) 102 | return embed 103 | 104 | 105 | def sigma2weights(sigma): 106 | alpha = 1.0 - torch.exp(-sigma) 107 | T = torch.cumprod( 108 | torch.cat( 109 | [torch.ones(alpha.shape[0], 1).to(alpha.device), 1.0 - alpha + 1e-10], -1 110 | ), 111 | -1, 112 | )[:, :-1] 113 | weights = alpha * T 114 | 115 | return weights 116 | 117 | 118 | def volume_rendering(rgb_sigma, pts_depth): 119 | rgb = rgb_sigma[..., :3] 120 | weights = sigma2weights(rgb_sigma[..., 3]) 121 | 122 | rendered_rgb = torch.sum(weights[..., None] * rgb, -2) 123 | rendered_depth = torch.sum(weights * pts_depth, -1) 124 | 125 | return rendered_rgb, rendered_depth 126 | 127 | 128 | def get_angle_wrt_src_cams(c2ws, rays_pts, rays_dir_unit): 129 | nb_rays = rays_pts.shape[0] 130 | ## Unit vectors from source cameras to the points on the ray 131 | dirs = normal_vect(rays_pts.unsqueeze(2) - c2ws[:, :3, 3][None, None]) 132 | ## Cosine of the angle between two directions 133 | angle_cos = torch.sum( 134 | dirs * rays_dir_unit.reshape(nb_rays, 1, 1, 3), dim=-1, keepdim=True 135 | ) 136 | # Cosine to Sine and approximating it as the angle (angle << 1 => sin(angle) = angle) 137 | angle = (1 - (angle_cos**2)).abs().sqrt() 138 | 139 | return angle 140 | 141 | 142 | def interpolate_pts_feats(imgs, feats_fpn, feats_vol, rays_pts_ndc): 143 | nb_views = feats_fpn.shape[1] 144 | interpolated_feats = [] 145 | 146 | for i in range(nb_views): 147 | ray_feats_0 = interpolate_3D( 148 | feats_vol[f"level_0"][:, i], rays_pts_ndc[f"level_0"][:, :, i] 149 | ) 150 | ray_feats_1 = interpolate_3D( 151 | feats_vol[f"level_1"][:, i], rays_pts_ndc[f"level_1"][:, :, i] 152 | ) 153 | ray_feats_2 = interpolate_3D( 154 | feats_vol[f"level_2"][:, i], rays_pts_ndc[f"level_2"][:, :, i] 155 | ) 156 | 157 | ray_feats_fpn, ray_colors, ray_masks = interpolate_2D( 158 | feats_fpn[:, i], imgs[:, i], rays_pts_ndc[f"level_0"][:, :, i] 159 | ) 160 | 161 | interpolated_feats.append( 162 | torch.cat( 163 | [ 164 | ray_feats_0, 165 | ray_feats_1, 166 | ray_feats_2, 167 | ray_feats_fpn, 168 | ray_colors, 169 | ray_masks, 170 | ], 171 | dim=-1, 172 | ) 173 | ) 174 | interpolated_feats = torch.stack(interpolated_feats, dim=2) 175 | 176 | return interpolated_feats 177 | 178 | 179 | def get_occ_masks(depth_map_norm, rays_pts_ndc, visibility_thr=0.2): 180 | nb_views = depth_map_norm["level_0"].shape[1] 181 | z_diff = [] 182 | for i in range(nb_views): 183 | ## Interpolate depth maps corresponding to each sample point 184 | # [1 H W 3] (x,y,z) 185 | grid = rays_pts_ndc[f"level_0"][None, :, :, i, :2] * 2 - 1.0 186 | rays_depths = F.grid_sample( 187 | depth_map_norm["level_0"][:, i : i + 1], 188 | grid, 189 | align_corners=True, 190 | mode="bilinear", 191 | padding_mode="border", 192 | )[0, 0] 193 | z_diff.append(rays_pts_ndc["level_0"][:, :, i, 2] - rays_depths) 194 | z_diff = torch.stack(z_diff, dim=2) 195 | 196 | occ_masks = z_diff.unsqueeze(-1) < visibility_thr 197 | 198 | return occ_masks 199 | 200 | 201 | def render_rays( 202 | c2ws, 203 | rays_pts, 204 | rays_pts_ndc, 205 | pts_depth, 206 | rays_dir, 207 | feats_vol, 208 | feats_fpn, 209 | imgs, 210 | depth_map_norm, 211 | renderer_net, 212 | ): 213 | ## The angles between the ray and source camera vectors 214 | rays_dir_unit = rays_dir / torch.norm(rays_dir, dim=-1, keepdim=True) 215 | angles = get_angle_wrt_src_cams(c2ws, rays_pts, rays_dir_unit) 216 | 217 | ## Positional encoding 218 | embedded_angles = get_embedder()(angles) 219 | 220 | ## Interpolate all features for sample points 221 | pts_feat = interpolate_pts_feats(imgs, feats_fpn, feats_vol, rays_pts_ndc) 222 | 223 | ## Getting Occlusion Masks based on predicted depths 224 | occ_masks = get_occ_masks(depth_map_norm, rays_pts_ndc) 225 | 226 | ## rendering sigma and RGB values 227 | rgb_sigma = renderer_net(embedded_angles, pts_feat, occ_masks) 228 | 229 | rendered_rgb, rendered_depth = volume_rendering(rgb_sigma, pts_depth) 230 | 231 | return rendered_rgb, rendered_depth 232 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # GeoNeRF is a generalizable NeRF model that renders novel views 2 | # without requiring per-scene optimization. This software is the 3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with 4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin, 5 | # and Francois Fleuret. 6 | 7 | # Copyright (c) 2022 ams International AG 8 | 9 | # This file is part of GeoNeRF. 10 | # GeoNeRF is free software: you can redistribute it and/or modify 11 | # it under the terms of the GNU General Public License version 3 as 12 | # published by the Free Software Foundation. 13 | 14 | # GeoNeRF is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | 19 | # You should have received a copy of the GNU General Public License 20 | # along with GeoNeRF. If not, see . 21 | 22 | # This file incorporates work covered by the following copyright and 23 | # permission notice: 24 | 25 | # MIT License 26 | 27 | # Copyright (c) 2021 apchenstu 28 | 29 | # Permission is hereby granted, free of charge, to any person obtaining a copy 30 | # of this software and associated documentation files (the "Software"), to deal 31 | # in the Software without restriction, including without limitation the rights 32 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 33 | # copies of the Software, and to permit persons to whom the Software is 34 | # furnished to do so, subject to the following conditions: 35 | 36 | # The above copyright notice and this permission notice shall be included in all 37 | # copies or substantial portions of the Software. 38 | 39 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 40 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 41 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 42 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 43 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 44 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 45 | # SOFTWARE. 46 | 47 | import torch 48 | import torch.nn as nn 49 | import torch.nn.functional as F 50 | import torchvision.transforms as T 51 | 52 | import numpy as np 53 | import cv2 54 | import re 55 | 56 | from PIL import Image 57 | from kornia.utils import create_meshgrid 58 | 59 | img2mse = lambda x, y: torch.mean((x - y) ** 2) 60 | mse2psnr = lambda x: -10.0 * torch.log(x) / torch.log(torch.Tensor([10.0]).to(x.device)) 61 | 62 | 63 | def load_ckpt(network, ckpt_file, key_prefix, strict=True): 64 | ckpt_dict = torch.load(ckpt_file) 65 | 66 | if "state_dict" in ckpt_dict.keys(): 67 | ckpt_dict = ckpt_dict["state_dict"] 68 | 69 | state_dict = {} 70 | for key, val in ckpt_dict.items(): 71 | if key_prefix in key: 72 | state_dict[key[len(key_prefix) + 1 :]] = val 73 | network.load_state_dict(state_dict, strict) 74 | 75 | 76 | def init_log(log, keys): 77 | for key in keys: 78 | log[key] = torch.tensor([0.0], dtype=float) 79 | return log 80 | 81 | 82 | class SL1Loss(nn.Module): 83 | def __init__(self, levels=3): 84 | super(SL1Loss, self).__init__() 85 | self.levels = levels 86 | self.loss = nn.SmoothL1Loss(reduction="mean") 87 | self.loss_ray = nn.SmoothL1Loss(reduction="none") 88 | 89 | def forward(self, inputs, targets): 90 | loss = 0 91 | if isinstance(inputs, dict): 92 | for l in range(self.levels): 93 | depth_pred_l = inputs[f"level_{l}"] 94 | V = depth_pred_l.shape[1] 95 | 96 | depth_gt_l = targets[f"level_{l}"] 97 | depth_gt_l = depth_gt_l[:, :V] 98 | mask_l = depth_gt_l > 0 99 | 100 | loss = loss + self.loss( 101 | depth_pred_l[mask_l], depth_gt_l[mask_l] 102 | ) * 2 ** (1 - l) 103 | else: 104 | mask = targets > 0 105 | loss = loss + (self.loss_ray(inputs, targets) * mask).sum() / len(mask) 106 | 107 | return loss 108 | 109 | 110 | def self_supervision_loss( 111 | loss_fn, 112 | rays_pixs, 113 | rendered_depth, 114 | depth_map, 115 | rays_gt_rgb, 116 | unpre_imgs, 117 | rendered_rgb, 118 | intrinsics, 119 | c2ws, 120 | w2cs, 121 | ): 122 | loss = 0 123 | target_points = torch.stack( 124 | [rays_pixs[1], rays_pixs[0], torch.ones(rays_pixs[0].shape[0]).cuda()], dim=-1 125 | ) 126 | target_points = rendered_depth.view(-1, 1) * ( 127 | target_points @ torch.inverse(intrinsics[0, -1]).t() 128 | ) 129 | target_points = target_points @ c2ws[0, -1][:3, :3].t() + c2ws[0, -1][:3, 3] 130 | 131 | rgb_mask = (rendered_rgb - rays_gt_rgb).abs().mean(dim=-1) < 0.02 132 | 133 | for v in range(len(w2cs[0]) - 1): 134 | points_v = target_points @ w2cs[0, v][:3, :3].t() + w2cs[0, v][:3, 3] 135 | points_v = points_v @ intrinsics[0, v].t() 136 | z_pred = points_v[:, -1].clone() 137 | points_v = points_v[:, :2] / points_v[:, -1:] 138 | 139 | points_unit = points_v.clone() 140 | H, W = depth_map["level_0"].shape[-2:] 141 | points_unit[:, 0] = points_unit[:, 0] / W 142 | points_unit[:, 1] = points_unit[:, 1] / H 143 | grid = 2 * points_unit - 1 144 | 145 | warped_rgbs = F.grid_sample( 146 | unpre_imgs[:, v], 147 | grid.view(1, -1, 1, 2), 148 | align_corners=True, 149 | mode="bilinear", 150 | padding_mode="zeros", 151 | ).squeeze() 152 | photo_mask = (warped_rgbs.t() - rays_gt_rgb).abs().mean(dim=-1) < 0.02 153 | 154 | pixel_coor = points_v.round().long() 155 | k = 5 156 | pixel_coor[:, 0] = pixel_coor[:, 0].clip(k // 2, W - (k // 2) - 1) 157 | pixel_coor[:, 1] = pixel_coor[:, 1].clip(2, H - (k // 2) - 1) 158 | lower_b = pixel_coor - (k // 2) 159 | higher_b = pixel_coor + (k // 2) 160 | 161 | ind_h = ( 162 | lower_b[:, 1:] * torch.arange(k - 1, -1, -1).view(1, -1).cuda() 163 | + higher_b[:, 1:] * torch.arange(0, k).view(1, -1).cuda() 164 | ) // (k - 1) 165 | ind_w = ( 166 | lower_b[:, 0:1] * torch.arange(k - 1, -1, -1).view(1, -1).cuda() 167 | + higher_b[:, 0:1] * torch.arange(0, k).view(1, -1).cuda() 168 | ) // (k - 1) 169 | 170 | patches_h = torch.gather( 171 | unpre_imgs[:, v].mean(dim=1).expand(ind_h.shape[0], -1, -1), 172 | 1, 173 | ind_h.unsqueeze(-1).expand(-1, -1, W), 174 | ) 175 | patches = torch.gather(patches_h, 2, ind_w.unsqueeze(1).expand(-1, k, -1)) 176 | ent_mask = patches.view(-1, k * k).std(dim=-1) > 0.05 177 | 178 | for l in range(3): 179 | depth = F.grid_sample( 180 | depth_map[f"level_{l}"][:, v : v + 1], 181 | grid.view(1, -1, 1, 2), 182 | align_corners=True, 183 | mode="bilinear", 184 | padding_mode="zeros", 185 | ).squeeze() 186 | in_mask = (grid > -1.0) * (grid < 1.0) 187 | in_mask = (in_mask[..., 0] * in_mask[..., 1]).float() 188 | loss = loss + loss_fn( 189 | depth, z_pred * in_mask * photo_mask * ent_mask * rgb_mask 190 | ) * 2 ** (1 - l) 191 | loss = loss / (len(w2cs[0]) - 1) 192 | 193 | return loss 194 | 195 | 196 | def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET): 197 | if type(depth) is not np.ndarray: 198 | depth = depth.cpu().numpy() 199 | 200 | x = np.nan_to_num(depth) # change nan to 0 201 | if minmax is None: 202 | mi = np.min(x[x > 0]) # get minimum positive depth (ignore background) 203 | ma = np.max(x) 204 | else: 205 | mi, ma = minmax 206 | 207 | x = (x - mi) / (ma - mi + 1e-8) # normalize to 0~1 208 | x = (255 * x).astype(np.uint8) 209 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) 210 | x_ = T.ToTensor()(x_) # (3, H, W) 211 | return x_, [mi, ma] 212 | 213 | 214 | def abs_error(depth_pred, depth_gt, mask): 215 | depth_pred, depth_gt = depth_pred[mask], depth_gt[mask] 216 | err = depth_pred - depth_gt 217 | return np.abs(err) if type(depth_pred) is np.ndarray else err.abs() 218 | 219 | 220 | def acc_threshold(depth_pred, depth_gt, mask, threshold): 221 | errors = abs_error(depth_pred, depth_gt, mask) 222 | acc_mask = errors < threshold 223 | return ( 224 | acc_mask.astype("float") if type(depth_pred) is np.ndarray else acc_mask.float() 225 | ) 226 | 227 | 228 | # Ray helpers 229 | def get_rays( 230 | H, 231 | W, 232 | intrinsics_target, 233 | c2w_target, 234 | chunk=-1, 235 | chunk_id=-1, 236 | train=True, 237 | train_batch_size=-1, 238 | mask=None, 239 | ): 240 | if train: 241 | if mask is None: 242 | xs, ys = ( 243 | torch.randint(0, W, (train_batch_size,)).float().cuda(), 244 | torch.randint(0, H, (train_batch_size,)).float().cuda(), 245 | ) 246 | else: # Sample 8 times more points to get mask points as much as possible 247 | xs, ys = ( 248 | torch.randint(0, W, (8 * train_batch_size,)).float().cuda(), 249 | torch.randint(0, H, (8 * train_batch_size,)).float().cuda(), 250 | ) 251 | masked_points = mask[ys.long(), xs.long()] 252 | xs_, ys_ = xs[~masked_points], ys[~masked_points] 253 | xs, ys = xs[masked_points], ys[masked_points] 254 | xs, ys = torch.cat([xs, xs_]), torch.cat([ys, ys_]) 255 | xs, ys = xs[:train_batch_size], ys[:train_batch_size] 256 | else: 257 | ys, xs = torch.meshgrid( 258 | torch.linspace(0, H - 1, H), torch.linspace(0, W - 1, W) 259 | ) # pytorch's meshgrid has indexing='ij' 260 | ys, xs = ys.cuda().reshape(-1), xs.cuda().reshape(-1) 261 | if chunk > 0: 262 | ys, xs = ( 263 | ys[chunk_id * chunk : (chunk_id + 1) * chunk], 264 | xs[chunk_id * chunk : (chunk_id + 1) * chunk], 265 | ) 266 | 267 | dirs = torch.stack( 268 | [ 269 | (xs - intrinsics_target[0, 2]) / intrinsics_target[0, 0], 270 | (ys - intrinsics_target[1, 2]) / intrinsics_target[1, 1], 271 | torch.ones_like(xs), 272 | ], 273 | -1, 274 | ) # use 1 instead of -1 275 | 276 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 277 | rays_dir = ( 278 | dirs @ c2w_target[:3, :3].t() 279 | ) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 280 | rays_orig = c2w_target[:3, -1].clone().reshape(1, 3).expand(rays_dir.shape[0], -1) 281 | 282 | rays_pixs = torch.stack((ys, xs)) # row col 283 | 284 | return rays_orig, rays_dir, rays_pixs 285 | 286 | 287 | def conver_to_ndc(ray_pts, w2c_ref, intrinsics_ref, W_H, depth_values): 288 | nb_rays, nb_samples = ray_pts.shape[:2] 289 | ray_pts = ray_pts.reshape(-1, 3) 290 | 291 | R = w2c_ref[:3, :3] # (3, 3) 292 | T = w2c_ref[:3, 3:] # (3, 1) 293 | ray_pts = torch.matmul(ray_pts, R.t()) + T.reshape(1, 3) 294 | 295 | ray_pts_ndc = ray_pts @ intrinsics_ref.t() 296 | ray_pts_ndc[:, :2] = ray_pts_ndc[:, :2] / ( 297 | ray_pts_ndc[:, -1:] * W_H.reshape(1, 2) 298 | ) # normalize x,y to 0~1 299 | 300 | grid = ray_pts_ndc[None, None, :, :2] * 2 - 1 301 | near = F.grid_sample( 302 | depth_values[:, :1], 303 | grid, 304 | align_corners=True, 305 | mode="bilinear", 306 | padding_mode="border", 307 | ).squeeze() 308 | far = F.grid_sample( 309 | depth_values[:, -1:], 310 | grid, 311 | align_corners=True, 312 | mode="bilinear", 313 | padding_mode="border", 314 | ).squeeze() 315 | ray_pts_ndc[:, 2] = (ray_pts_ndc[:, 2] - near) / (far - near) # normalize z to 0~1 316 | 317 | ray_pts_ndc = ray_pts_ndc.view(nb_rays, nb_samples, 3) 318 | 319 | return ray_pts_ndc 320 | 321 | 322 | def get_sample_points( 323 | nb_coarse, 324 | nb_fine, 325 | near, 326 | far, 327 | rays_o, 328 | rays_d, 329 | nb_views, 330 | w2cs, 331 | intrinsics, 332 | depth_values, 333 | W_H, 334 | with_noise=False, 335 | ): 336 | device = rays_o.device 337 | nb_rays = rays_o.shape[0] 338 | 339 | with torch.no_grad(): 340 | t_vals = torch.linspace(0.0, 1.0, steps=nb_coarse).view(1, nb_coarse).to(device) 341 | pts_depth = near * (1.0 - t_vals) + far * (t_vals) 342 | pts_depth = pts_depth.expand([nb_rays, nb_coarse]) 343 | ray_pts = rays_o.unsqueeze(1) + pts_depth.unsqueeze(-1) * rays_d.unsqueeze(1) 344 | 345 | ## Counting the number of source views for which the points are valid 346 | valid_points = torch.zeros([nb_rays, nb_coarse]).to(device) 347 | for idx in range(nb_views): 348 | w2c_ref, intrinsic_ref = w2cs[0, idx], intrinsics[0, idx] 349 | ray_pts_ndc = conver_to_ndc( 350 | ray_pts, 351 | w2c_ref, 352 | intrinsic_ref, 353 | W_H, 354 | depth_values=depth_values[f"level_0"][:, idx], 355 | ) 356 | valid_points += ( 357 | ((ray_pts_ndc >= 0) & (ray_pts_ndc <= 1)).sum(dim=-1) == 3 358 | ).float() 359 | 360 | ## Creating a distribution based on the counted values and sample more points 361 | if nb_fine > 0: 362 | point_distr = torch.distributions.categorical.Categorical( 363 | logits=valid_points 364 | ) 365 | t_vals = ( 366 | point_distr.sample([nb_fine]).t() 367 | - torch.rand([nb_rays, nb_fine]).cuda() 368 | ) / (nb_coarse - 1) 369 | pts_depth_fine = near * (1.0 - t_vals) + far * (t_vals) 370 | 371 | pts_depth = torch.cat([pts_depth, pts_depth_fine], dim=-1) 372 | pts_depth, _ = torch.sort(pts_depth) 373 | 374 | if with_noise: ## Add noise to sample points during training 375 | # get intervals between samples 376 | mids = 0.5 * (pts_depth[..., 1:] + pts_depth[..., :-1]) 377 | upper = torch.cat([mids, pts_depth[..., -1:]], -1) 378 | lower = torch.cat([pts_depth[..., :1], mids], -1) 379 | # stratified samples in those intervals 380 | t_rand = torch.rand(pts_depth.shape, device=device) 381 | pts_depth = lower + (upper - lower) * t_rand 382 | 383 | ray_pts = rays_o.unsqueeze(1) + pts_depth.unsqueeze(-1) * rays_d.unsqueeze(1) 384 | 385 | ray_pts_ndc = {"level_0": [], "level_1": [], "level_2": []} 386 | for idx in range(nb_views): 387 | w2c_ref, intrinsic_ref = w2cs[0, idx], intrinsics[0, idx] 388 | for l in range(3): 389 | ray_pts_ndc[f"level_{l}"].append( 390 | conver_to_ndc( 391 | ray_pts, 392 | w2c_ref, 393 | intrinsic_ref, 394 | W_H, 395 | depth_values=depth_values[f"level_{l}"][:, idx], 396 | ) 397 | ) 398 | for l in range(3): 399 | ray_pts_ndc[f"level_{l}"] = torch.stack(ray_pts_ndc[f"level_{l}"], dim=2) 400 | 401 | return pts_depth, ray_pts, ray_pts_ndc 402 | 403 | 404 | def get_rays_pts( 405 | H, 406 | W, 407 | c2ws, 408 | w2cs, 409 | intrinsics, 410 | near_fars, 411 | depth_values, 412 | nb_coarse, 413 | nb_fine, 414 | nb_views, 415 | chunk=-1, 416 | chunk_idx=-1, 417 | train=False, 418 | train_batch_size=-1, 419 | target_img=None, 420 | target_depth=None, 421 | ): 422 | if train: 423 | if target_depth.sum() > 0: 424 | depth_mask = target_depth > 0 425 | else: 426 | depth_mask = None 427 | else: 428 | depth_mask = None 429 | 430 | rays_orig, rays_dir, rays_pixs = get_rays( 431 | H, 432 | W, 433 | intrinsics[0, -1], 434 | c2ws[0, -1], 435 | chunk=chunk, 436 | chunk_id=chunk_idx, 437 | train=train, 438 | train_batch_size=train_batch_size, 439 | mask=depth_mask, 440 | ) 441 | 442 | ## Extracting ground truth color and depth of target view 443 | if train: 444 | rays_pixs_int = rays_pixs.long() 445 | rays_gt_rgb = target_img[:, rays_pixs_int[0], rays_pixs_int[1]].permute(1, 0) 446 | rays_gt_depth = target_depth[rays_pixs_int[0], rays_pixs_int[1]] 447 | else: 448 | rays_gt_rgb = None 449 | rays_gt_depth = None 450 | 451 | # travel along the rays 452 | near, far = near_fars[0, -1, 0], near_fars[0, -1, 1] ## near/far of the target view 453 | W_H = torch.tensor([W - 1, H - 1]).cuda() 454 | pts_depth, ray_pts, ray_pts_ndc = get_sample_points( 455 | nb_coarse, 456 | nb_fine, 457 | near, 458 | far, 459 | rays_orig, 460 | rays_dir, 461 | nb_views, 462 | w2cs, 463 | intrinsics, 464 | depth_values, 465 | W_H, 466 | with_noise=train, 467 | ) 468 | 469 | return ( 470 | pts_depth, 471 | ray_pts, 472 | ray_pts_ndc, 473 | rays_dir, 474 | rays_gt_rgb, 475 | rays_gt_depth, 476 | rays_pixs, 477 | ) 478 | 479 | 480 | def normal_vect(vect, dim=-1): 481 | return vect / (torch.sqrt(torch.sum(vect**2, dim=dim, keepdim=True)) + 1e-7) 482 | 483 | 484 | def interpolate_3D(feats, pts_ndc): 485 | H, W = pts_ndc.shape[-3:-1] 486 | grid = pts_ndc.view(-1, 1, H, W, 3) * 2 - 1.0 # [1 1 H W 3] (x,y,z) 487 | features = ( 488 | F.grid_sample( 489 | feats, grid, align_corners=True, mode="bilinear", padding_mode="border" 490 | )[:, :, 0] 491 | .permute(2, 3, 0, 1) 492 | .squeeze() 493 | ) 494 | 495 | return features 496 | 497 | 498 | def interpolate_2D(feats, imgs, pts_ndc): 499 | H, W = pts_ndc.shape[-3:-1] 500 | grid = pts_ndc[..., :2].view(-1, H, W, 2) * 2 - 1.0 # [1 H W 2] (x,y) 501 | features = ( 502 | F.grid_sample( 503 | feats, grid, align_corners=True, mode="bilinear", padding_mode="border" 504 | ) 505 | .permute(2, 3, 1, 0) 506 | .squeeze() 507 | ) 508 | images = ( 509 | F.grid_sample( 510 | imgs, grid, align_corners=True, mode="bilinear", padding_mode="border" 511 | ) 512 | .permute(2, 3, 1, 0) 513 | .squeeze() 514 | ) 515 | with torch.no_grad(): 516 | in_mask = (grid > -1.0) * (grid < 1.0) 517 | in_mask = (in_mask[..., 0] * in_mask[..., 1]).float().permute(1, 2, 0) 518 | 519 | return features, images, in_mask 520 | 521 | 522 | def read_pfm(filename): 523 | file = open(filename, "rb") 524 | color = None 525 | width = None 526 | height = None 527 | scale = None 528 | endian = None 529 | 530 | header = file.readline().decode("utf-8").rstrip() 531 | if header == "PF": 532 | color = True 533 | elif header == "Pf": 534 | color = False 535 | else: 536 | raise Exception("Not a PFM file.") 537 | 538 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("utf-8")) 539 | if dim_match: 540 | width, height = map(int, dim_match.groups()) 541 | else: 542 | raise Exception("Malformed PFM header.") 543 | 544 | scale = float(file.readline().rstrip()) 545 | if scale < 0: # little-endian 546 | endian = "<" 547 | scale = -scale 548 | else: 549 | endian = ">" # big-endian 550 | 551 | data = np.fromfile(file, endian + "f") 552 | shape = (height, width, 3) if color else (height, width) 553 | 554 | data = np.reshape(data, shape) 555 | data = np.flipud(data) 556 | file.close() 557 | return data, scale 558 | 559 | 560 | def homo_warp(src_feat, proj_mat, depth_values, src_grid=None, pad=0): 561 | if src_grid == None: 562 | B, C, H, W = src_feat.shape 563 | device = src_feat.device 564 | 565 | if pad > 0: 566 | H_pad, W_pad = H + pad * 2, W + pad * 2 567 | else: 568 | H_pad, W_pad = H, W 569 | 570 | if depth_values.dim() != 4: 571 | depth_values = depth_values[..., None, None].repeat(1, 1, H_pad, W_pad) 572 | D = depth_values.shape[1] 573 | 574 | R = proj_mat[:, :, :3] # (B, 3, 3) 575 | T = proj_mat[:, :, 3:] # (B, 3, 1) 576 | # create grid from the ref frame 577 | ref_grid = create_meshgrid( 578 | H_pad, W_pad, normalized_coordinates=False, device=device 579 | ) # (1, H, W, 2) 580 | if pad > 0: 581 | ref_grid -= pad 582 | 583 | ref_grid = ref_grid.permute(0, 3, 1, 2) # (1, 2, H, W) 584 | ref_grid = ref_grid.reshape(1, 2, W_pad * H_pad) # (1, 2, H*W) 585 | ref_grid = ref_grid.expand(B, -1, -1) # (B, 2, H*W) 586 | ref_grid = torch.cat( 587 | (ref_grid, torch.ones_like(ref_grid[:, :1])), 1 588 | ) # (B, 3, H*W) 589 | ref_grid_d = ref_grid.repeat(1, 1, D) # (B, 3, D*H*W) 590 | src_grid_d = R @ ref_grid_d + T / depth_values.reshape(B, 1, D * W_pad * H_pad) 591 | del ref_grid_d, ref_grid, proj_mat, R, T, depth_values # release (GPU) memory 592 | 593 | src_grid = ( 594 | src_grid_d[:, :2] / src_grid_d[:, 2:] 595 | ) # divide by depth (B, 2, D*H*W) 596 | del src_grid_d 597 | src_grid[:, 0] = src_grid[:, 0] / ((W - 1) / 2) - 1 # scale to -1~1 598 | src_grid[:, 1] = src_grid[:, 1] / ((H - 1) / 2) - 1 # scale to -1~1 599 | src_grid = src_grid.permute(0, 2, 1) # (B, D*H*W, 2) 600 | src_grid = src_grid.view(B, D, W_pad, H_pad, 2) 601 | 602 | B, D, W_pad, H_pad = src_grid.shape[:4] 603 | warped_src_feat = F.grid_sample( 604 | src_feat, 605 | src_grid.view(B, D, W_pad * H_pad, 2), 606 | mode="bilinear", 607 | padding_mode="zeros", 608 | align_corners=True, 609 | ) # (B, C, D, H*W) 610 | warped_src_feat = warped_src_feat.view(B, -1, D, H_pad, W_pad) 611 | # src_grid = src_grid.view(B, 1, D, H_pad, W_pad, 2) 612 | return warped_src_feat, src_grid 613 | 614 | ##### Functions for view selection 615 | TINY_NUMBER = 1e-5 # float32 only has 7 decimal digits precision 616 | 617 | def angular_dist_between_2_vectors(vec1, vec2): 618 | vec1_unit = vec1 / (np.linalg.norm(vec1, axis=1, keepdims=True) + TINY_NUMBER) 619 | vec2_unit = vec2 / (np.linalg.norm(vec2, axis=1, keepdims=True) + TINY_NUMBER) 620 | angular_dists = np.arccos( 621 | np.clip(np.sum(vec1_unit * vec2_unit, axis=-1), -1.0, 1.0) 622 | ) 623 | return angular_dists 624 | 625 | 626 | def batched_angular_dist_rot_matrix(R1, R2): 627 | assert ( 628 | R1.shape[-1] == 3 629 | and R2.shape[-1] == 3 630 | and R1.shape[-2] == 3 631 | and R2.shape[-2] == 3 632 | ) 633 | return np.arccos( 634 | np.clip( 635 | (np.trace(np.matmul(R2.transpose(0, 2, 1), R1), axis1=1, axis2=2) - 1) 636 | / 2.0, 637 | a_min=-1 + TINY_NUMBER, 638 | a_max=1 - TINY_NUMBER, 639 | ) 640 | ) 641 | 642 | 643 | def get_nearest_pose_ids( 644 | tar_pose, 645 | ref_poses, 646 | num_select, 647 | tar_id=-1, 648 | angular_dist_method="dist", 649 | scene_center=(0, 0, 0), 650 | ): 651 | num_cams = len(ref_poses) 652 | num_select = min(num_select, num_cams - 1) 653 | batched_tar_pose = tar_pose[None, ...].repeat(num_cams, 0) 654 | 655 | if angular_dist_method == "matrix": 656 | dists = batched_angular_dist_rot_matrix( 657 | batched_tar_pose[:, :3, :3], ref_poses[:, :3, :3] 658 | ) 659 | elif angular_dist_method == "vector": 660 | tar_cam_locs = batched_tar_pose[:, :3, 3] 661 | ref_cam_locs = ref_poses[:, :3, 3] 662 | scene_center = np.array(scene_center)[None, ...] 663 | tar_vectors = tar_cam_locs - scene_center 664 | ref_vectors = ref_cam_locs - scene_center 665 | dists = angular_dist_between_2_vectors(tar_vectors, ref_vectors) 666 | elif angular_dist_method == "dist": 667 | tar_cam_locs = batched_tar_pose[:, :3, 3] 668 | ref_cam_locs = ref_poses[:, :3, 3] 669 | dists = np.linalg.norm(tar_cam_locs - ref_cam_locs, axis=1) 670 | else: 671 | raise Exception("unknown angular distance calculation method!") 672 | 673 | if tar_id >= 0: 674 | assert tar_id < num_cams 675 | dists[tar_id] = 1e3 # make sure not to select the target id itself 676 | 677 | sorted_ids = np.argsort(dists) 678 | selected_ids = sorted_ids[:num_select] 679 | 680 | return selected_ids --------------------------------------------------------------------------------