├── LICENSE
├── README.md
├── configs
├── config_dtu.txt
├── config_general.txt
├── config_llff.txt
├── config_nerf.txt
└── lists
│ ├── dtu_pairs.txt
│ ├── dtu_pairs_ft.txt
│ ├── dtu_pairs_val.txt
│ ├── dtu_train_all.txt
│ └── dtu_val_all.txt
├── data
├── __init__.py
├── dtu.py
├── get_datasets.py
├── llff.py
└── nerf.py
├── model
├── __init__.py
├── geo_reasoner.py
└── self_attn_renderer.py
├── pretrained_weights
└── .gitignore
├── requirements.txt
├── run_geo_nerf.py
└── utils
├── __init__.py
├── options.py
├── rendering.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | > # [CVPR 2022] GeoNeRF: Generalizing NeRF with Geometry Priors
2 | > Mohammad Mahdi Johari, Yann Lepoittevin, François Fleuret
3 | > [Project Page](https://www.idiap.ch/paper/geonerf/) | [Paper](https://arxiv.org/abs/2111.13539)
4 |
5 | This repository contains a PyTorch Lightning implementation of our paper, GeoNeRF: Generalizing NeRF with Geometry Priors.
6 |
7 | ## Installation
8 |
9 | #### Tested on NVIDIA Tesla V100 and GeForce RTX 3090 GPUs with PyTorch 1.9 and PyTorch Lightning 1.3.7
10 |
11 | To install the dependencies, in addition to PyTorch, run:
12 |
13 | ```
14 | pip install -r requirements.txt
15 | ```
16 |
17 | ## Evaluation and Training
18 | To reproduce our results, download pretrained weights from [here](https://drive.google.com/drive/folders/1ZtAc7VYvltcdodT_BrUrQ_4IAhz_L-Rf?usp=sharing) and put them in [pretrained_weights](./pretrained_weights) folder. Then, follow the instructions for each of the [LLFF (Real Forward-Facing)](#llff-real-forward-facing-dataset), [NeRF (Realistic Synthetic)](#nerf-realistic-synthetic-dataset), and [DTU](#dtu-dataset) datasets.
19 |
20 | ## LLFF (Real Forward-Facing) Dataset
21 | Download `nerf_llff_data.zip` from [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) and set its path as `llff_path` in the [config_llff.txt](./configs/config_llff.txt) file.
22 |
23 | For evaluating our generalizable model (`pretrained.ckpt` model in the [pretrained_weights](./pretrained_weights) folder), set the `scene` properly (e.g. fern) and set the number of source views to 9 (nb_views = 9) in the [config_llff.txt](./configs/config_llff.txt) file and run the following command:
24 |
25 | ```
26 | python run_geo_nerf.py --config configs/config_llff.txt --eval
27 | ```
28 |
29 | For fine-tuning on a specific scene, set nb_views = 7 and run the following command:
30 |
31 | ```
32 | python run_geo_nerf.py --config configs/config_llff.txt
33 | ```
34 |
35 | Once fine-tuning is finished, run the evaluation command with nb_views = 9 to get the final rendered results.
36 |
37 | ## NeRF (Realistic Synthetic) Dataset
38 | Download `nerf_synthetic.zip` from [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) and set its path as `nerf_path` in the [config_nerf.txt](configs/config_nerf.txt) file.
39 |
40 | For evaluating our generalizable model (`pretrained.ckpt` model in the [pretrained_weights](./pretrained_weights) folder), set the `scene` properly (e.g. lego) and set the number of source views to 9 (nb_views = 9) in the [config_nerf.txt](configs/config_nerf.txt) file and run the following command:
41 |
42 | ```
43 | python run_geo_nerf.py --config configs/config_nerf.txt --eval
44 | ```
45 |
46 | For fine-tuning on a specific scene, set nb_views = 7 and run the following command:
47 |
48 | ```
49 | python run_geo_nerf.py --config configs/config_nerf.txt
50 | ```
51 |
52 | Once fine-tuning is finished, run the evaluation command with nb_views = 9 to get the final rendered results.
53 |
54 | ## DTU Dataset
55 | Download the preprocessed [DTU training data](https://drive.google.com/file/d/1eDjh-_bxKKnEuz5h-HXS7EDJn59clx6V/view)
56 | and replace its `Depths` directory with [Depth_raw](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/cascade-stereo/CasMVSNet/dtu_data/dtu_train_hr/Depths_raw.zip) from original [MVSNet repository](https://github.com/YoYo000/MVSNet), and set `dtu_pre_path` referring to this dataset in the [config_dtu.txt](configs/config_dtu.txt) file.
57 |
58 | Then, download the original `Rectified` images from [DTU Website](https://roboimagedata.compute.dtu.dk/?page_id=36), and set `dtu_path` in the [config_dtu.txt](configs/config_dtu.txt) file accordingly.
59 |
60 | For evaluating our generalizable model (`pretrained.ckpt` model in the [pretrained_weights](./pretrained_weights) folder), set the `scene` properly (e.g. scan21) and set the number of source views to 9 (nb_views = 9) in the [config_dtu.txt](./configs/config_dtu.txt) file and run the following command:
61 |
62 | ```
63 | python run_geo_nerf.py --config configs/config_dtu.txt --eval
64 | ```
65 |
66 | For fine-tuning on a specific scene, use the same nb_views = 9 and run the following command:
67 |
68 | ```
69 | python run_geo_nerf.py --config configs/config_dtu.txt
70 | ```
71 |
72 | Once fine-tuning is finished, run the evaluation command with nb_views = 9 to get the final rendered results.
73 |
74 | ### RGBD Compatible model
75 | By adding `--use_depth` argument to the aforementioned commands, you can use our RGB compatible model on the DTU dataset and exploit the ground truth, low-resolution depths to help the rendering process. The pretrained weights for this model is `pretrained_w_depth.ckpt`.
76 |
77 | ## Training From Scratch
78 | For training our model from scratch, first, prepare the following datasets:
79 |
80 | * The original `Rectified` images from [DTU](https://roboimagedata.compute.dtu.dk/?page_id=36). Set the corresponding path as `dtu_path` in the [config_general.txt](configs/config_general.txt) file.
81 |
82 | * The preprocessed [DTU training data](https://drive.google.com/file/d/1eDjh-_bxKKnEuz5h-HXS7EDJn59clx6V/view)
83 | with the replacement of its `Depths` directory with [Depth_raw](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/cascade-stereo/CasMVSNet/dtu_data/dtu_train_hr/Depths_raw.zip). Set the corresponding path as `dtu_pre_path` in the [config_general.txt](configs/config_general.txt) file.
84 |
85 | * LLFF released scenes. Download [real_iconic_noface.zip](https://drive.google.com/drive/folders/1M-_Fdn4ajDa0CS8-iqejv0fQQeuonpKF) and remove the test scenes with the following command:
86 | ```
87 | unzip real_iconic_noface.zip
88 | cd real_iconic_noface/
89 | rm -rf data2_fernvlsb data2_hugetrike data2_trexsanta data3_orchid data5_leafscene data5_lotr data5_redflower
90 | ```
91 | Then, set the corresponding path as `llff_path` in the [config_general.txt](configs/config_general.txt) file.
92 |
93 | * Collected scenes from [IBRNet](https://github.com/googleinterns/IBRNet) ([Subset1](https://drive.google.com/file/d/1rkzl3ecL3H0Xxf5WTyc2Swv30RIyr1R_/view?usp=sharing) and [Subset2](https://drive.google.com/file/d/1Uxw0neyiIn3Ve8mpRsO6A06KfbqNrWuq/view?usp=sharing)). Set the corresponding paths as `ibrnet1_path` and `ibrnet2_path` in the [config_general.txt](configs/config_general.txt) file.
94 |
95 | Also, download `nerf_llff_data.zip` and `nerf_synthetic.zip` from [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) for validation and testing and set their corresponding paths as `llff_test_path` and `nerf_path` in the [config_general.txt](configs/config_general.txt) file.
96 |
97 | Once all the datasets are available, train the network from scratch with the following command:
98 | ```
99 | python run_geo_nerf.py --config configs/config_general.txt
100 | ```
101 | ### Contact
102 | You can contact the author through email: mohammad.johari At idiap.ch.
103 |
104 | ## Citing
105 | If you find our work useful, please consider citing:
106 | ```BibTeX
107 | @inproceedings{johari-et-al-2022,
108 | author = {Johari, M. and Lepoittevin, Y. and Fleuret, F.},
109 | title = {GeoNeRF: Generalizing NeRF with Geometry Priors},
110 | booktitle = {Proceedings of the IEEE International Conference on Computer Vision and Pattern Recognition (CVPR)},
111 | year = {2022}
112 | }
113 | ```
114 |
115 | ### Acknowledgement
116 | This work was supported by ams OSRAM.
--------------------------------------------------------------------------------
/configs/config_dtu.txt:
--------------------------------------------------------------------------------
1 | ### INPUT
2 | expname = scan21_test
3 | logdir = ./logs
4 | nb_views = 9 #### use 9 for both evaluation and fine-tuning
5 |
6 | ## dataset
7 | dataset_name = dtu
8 | dtu_path = Path to DTU MVS
9 | dtu_pre_path = Path to preprocessed DTU MVS
10 | scene = scan21
11 |
12 | ### TESTING
13 | chunk = 4096 ### Reduce it to save memory
14 |
15 | ### TRAINING
16 | num_steps = 10000
17 | lrate = 0.0002
--------------------------------------------------------------------------------
/configs/config_general.txt:
--------------------------------------------------------------------------------
1 | ### INPUT
2 | expname = Generalizable
3 | logdir = ./logs
4 | nb_views = 6
5 |
6 | ## dataset
7 | dataset_name = llff
8 | dtu_path = Path to DTU MVS
9 | dtu_pre_path = Path to preprocessed DTU MVS
10 | llff_path = Path to LLFF training scenes (real_iconic_noface)
11 | ibrnet1_path = Path to IBRNet dataset 1 (ibrnet_collected_1)
12 | ibrnet2_path = Path to IBRNet dataset 1 (ibrnet_collected_2)
13 | nerf_path = Path to NeRF dataset (nerf_synthetic)
14 | llff_test_path = Path to LLFF test scenes (nerf_llff_data)
15 | scene = None
16 |
17 | ### TESTING
18 | chunk = 4096 ### Reduce it to save memory
19 |
20 | ### TRAINING
21 | num_steps = 250000
22 | lrate = 0.0005
--------------------------------------------------------------------------------
/configs/config_llff.txt:
--------------------------------------------------------------------------------
1 | ### INPUT
2 | expname = fern_test
3 | logdir = ./logs
4 | nb_views = 9 #### Set to 7 for fine-tuning
5 |
6 | ## dataset
7 | dataset_name = llff
8 | llff_path = Path to LLFF test scenes (nerf_llff_data)
9 | scene = fern
10 |
11 | ### TESTING
12 | chunk = 4096 ### Reduce it to save memory
13 |
14 | ### TRAINING
15 | num_steps = 10000
16 | lrate = 0.0002
--------------------------------------------------------------------------------
/configs/config_nerf.txt:
--------------------------------------------------------------------------------
1 | ### INPUT
2 | expname = lego_test
3 | logdir = ./logs
4 | nb_views = 9 #### Set to 7 for fine-tuning
5 |
6 | ## dataset
7 | dataset_name = nerf
8 | nerf_path = Path to NeRF dataset (nerf_synthetic)
9 | scene = lego
10 |
11 | ### TESTING
12 | chunk = 4096 ### Reduce it to save memory
13 |
14 | ### TRAINING
15 | num_steps = 10000
16 | lrate = 0.0002
--------------------------------------------------------------------------------
/configs/lists/dtu_pairs.txt:
--------------------------------------------------------------------------------
1 | 33
2 | 1
3 | 10 9 2850.87 10 2583.94 2 2105.59 0 2052.84 8 1868.24 13 1184.23 14 1017.51 12 961.966 7 670.208 15 657.218
4 | 2
5 | 10 8 2501.24 1 2106.88 7 1856.5 9 1782.34 3 1141.77 15 1061.76 14 815.457 16 762.153 6 709.789 10 699.921
6 | 8
7 | 10 15 3124.01 9 3099.92 14 2756.29 2 2501.22 7 2449.32 1 1875.94 16 1726.04 13 1325.76 23 1177.09 24 1108.82
8 | 9
9 | 10 13 3355.62 14 3226.07 8 3098.8 10 3097.07 1 2861.42 12 1873.63 2 1785.98 15 1753.32 25 1365.45 0 1261.59
10 | 10
11 | 10 12 3750.7 9 3085.87 13 3028.39 1 2590.55 0 2369.79 11 2266.67 14 1524.16 26 1448.15 27 1293.6 8 1041.84
12 | 11
13 | 10 12 3543.76 27 3056.05 10 2248.07 26 1524.28 28 1273.33 13 1265.9 29 1129.55 0 998.164 9 591.176 30 572.919
14 | 12
15 | 10 27 3889.87 10 3754.54 13 3745.21 11 3584.26 26 3574.56 25 1877.11 9 1866.34 29 1482.72 30 1418.51 14 1341.86
16 | 13
17 | 10 12 3773.14 26 3699.28 25 3657.17 14 3652.04 9 3356.29 10 3049.27 24 2098.91 27 1900.96 31 1460.96 30 1349.62
18 | 14
19 | 10 13 3663.52 24 3610.69 9 3232.55 25 3216.4 15 3128.84 8 2758.04 23 2219.91 26 1567.45 10 1536.6 32 1419.33
20 | 15
21 | 10 23 3194.92 14 3126 8 3120.43 16 2897.02 24 2562.49 7 2084.05 22 2041.63 9 1752.08 33 1232.29 13 1137.55
22 | 22
23 | 10 23 3232.68 34 3175.15 35 2831.09 16 2712.51 21 2632.19 15 2033.39 33 1712.67 17 1393.86 36 1290.96 24 1195.33
24 | 23
25 | 10 24 3710.9 33 3603.07 22 3244.2 15 3190.62 34 3086.49 14 2220.11 32 2100 16 1917.1 35 1359.79 25 1356.71
26 | 24
27 | 10 25 3844.6 32 3750.75 23 3710.6 14 3609.09 33 3091.04 15 2559.24 31 2423.71 13 2109.36 26 1440.58 34 1410.03
28 | 25
29 | 10 26 3951.74 31 3888.57 24 3833.07 13 3667.35 14 3208.21 32 2993.46 30 2681.52 12 1900.23 45 1484.03 27 1462.88
30 | 26
31 | 10 30 4033.35 27 3970.47 25 3925.25 13 3686.34 12 3595.59 29 2943.87 31 2917 14 1556.34 11 1554.75 46 1503.84
32 | 27
33 | 10 29 4027.84 26 3929.94 12 3875.58 11 3085.03 28 2908.6 30 2792.67 13 1878.42 25 1438.55 47 1425.2 10 1290.25
34 | 28
35 | 10 29 3687.02 48 3209.13 27 2872.86 47 2014.53 30 1361.95 11 1273.6 26 1062.85 12 840.841 46 672.985 31 271.952
36 | 29
37 | 10 27 4029.43 30 3909.55 28 3739.93 47 3695.23 48 3135.87 26 2910.97 46 2229.55 12 1479.16 31 1430.26 11 1144.56
38 | 30
39 | 10 26 4029.86 29 3953.72 31 3811.12 46 3630.46 47 3105.96 27 2824.43 25 2657.89 45 2347.75 32 1459.11 12 1429.62
40 | 31
41 | 10 25 3882.21 30 3841.88 32 3808.5 45 3649.82 46 3000.67 26 2939.94 24 2409.93 44 2381.3 13 1467.59 29 1459.56
42 | 32
43 | 10 31 3826.5 24 3744.14 33 3613.24 44 3552.04 25 3004.6 45 2884.59 43 2393.34 23 2095.27 30 1478.6 14 1420.78
44 | 33
45 | 10 32 3618.11 23 3598.1 34 3530.53 43 3462.37 24 3091.53 44 2608.08 42 2426 22 1717.94 31 1407.65 25 1324.78
46 | 34
47 | 10 33 3523.37 42 3356.55 35 3210.34 22 3178.85 23 3079.03 43 2396.45 41 2386.86 24 1408.02 32 1301.34 21 1256.45
48 | 35
49 | 10 34 3187.88 41 3106.44 36 2866.04 22 2817.74 21 2654.87 40 2416.98 42 2137.81 23 1346.86 33 1150.33 16 1044.66
50 | 40
51 | 10 36 2918.14 41 2852.62 39 2782.6 35 2392.96 37 1641.45 21 1124.3 42 1056.48 34 877.946 38 853.944 20 788.701
52 | 41
53 | 10 35 3111.05 42 3049.71 40 2885.36 34 2371.02 36 1813.69 43 1164.71 22 1126.9 39 1011.26 21 906.536 33 903.238
54 | 42
55 | 10 34 3356.98 43 3183 41 3070.54 33 2421.77 35 2155.08 44 1278.41 23 1183.52 22 1147.07 40 1077.08 32 899.646
56 | 43
57 | 10 33 3461.24 44 3380.74 42 3188.7 34 2400.6 32 2399.09 45 1359.37 23 1314.08 41 1176.12 24 1159.62 31 901.556
58 | 44
59 | 10 32 3550.81 45 3510.16 43 3373.11 33 2602.33 31 2395.93 24 1410.43 46 1386.31 42 1279 25 1095.24 34 968.44
60 | 45
61 | 10 31 3650.09 46 3555.09 44 3491.15 32 2868.39 30 2373.59 25 1485.37 47 1405.28 43 1349.54 33 1104.77 26 1046.81
62 | 46
63 | 10 30 3635.64 47 3562.17 45 3524.17 31 2976.82 29 2264.04 26 1508.87 44 1367.41 48 1352.1 32 1211.24 25 1102.17
64 | 47
65 | 10 29 3705.31 46 3519.76 48 3450.48 30 3074.77 28 2054.63 27 1434.57 45 1377.34 31 1268.23 26 1223.83 25 471.111
66 | 48
67 | 10 47 3401.95 28 3224.84 29 3101.16 46 1317.1 30 1306.7 27 1235.07 26 537.731 31 291.919 45 276.869 11 258.856
--------------------------------------------------------------------------------
/configs/lists/dtu_pairs_ft.txt:
--------------------------------------------------------------------------------
1 | 29
2 | 1
3 | 10 9 2850.87 10 2583.94 2 2105.59 0 2052.84 8 1868.24 13 1184.23 14 1017.51 12 961.966 7 670.208 15 657.218
4 | 2
5 | 10 8 2501.24 1 2106.88 7 1856.5 9 1782.34 3 1141.77 15 1061.76 14 815.457 16 762.153 6 709.789 10 699.921
6 | 8
7 | 10 15 3124.01 9 3099.92 14 2756.29 2 2501.22 7 2449.32 1 1875.94 16 1726.04 13 1325.76 23 1177.09 24 1108.82
8 | 9
9 | 10 13 3355.62 14 3226.07 8 3098.8 10 3097.07 1 2861.42 12 1873.63 2 1785.98 15 1753.32 25 1365.45 0 1261.59
10 | 10
11 | 10 12 3750.7 9 3085.87 13 3028.39 1 2590.55 0 2369.79 11 2266.67 14 1524.16 26 1448.15 27 1293.6 8 1041.84
12 | 11
13 | 10 12 3543.76 27 3056.05 10 2248.07 26 1524.28 28 1273.33 13 1265.9 29 1129.55 0 998.164 9 591.176 30 572.919
14 | 12
15 | 10 27 3889.87 10 3754.54 13 3745.21 11 3584.26 26 3574.56 25 1877.11 9 1866.34 29 1482.72 30 1418.51 14 1341.86
16 | 13
17 | 10 12 3773.14 26 3699.28 25 3657.17 14 3652.04 9 3356.29 10 3049.27 24 2098.91 27 1900.96 31 1460.96 30 1349.62
18 | 14
19 | 10 13 3663.52 24 3610.69 9 3232.55 25 3216.4 15 3128.84 8 2758.04 23 2219.91 26 1567.45 10 1536.6 32 1419.33
20 | 15
21 | 10 23 3194.92 14 3126 8 3120.43 16 2897.02 24 2562.49 7 2084.05 22 2041.63 9 1752.08 33 1232.29 13 1137.55
22 | 22
23 | 10 23 3232.68 34 3175.15 35 2831.09 16 2712.51 21 2632.19 15 2033.39 33 1712.67 17 1393.86 36 1290.96 24 1195.33
24 | 26
25 | 10 30 4033.35 27 3970.47 25 3925.25 13 3686.34 12 3595.59 29 2943.87 31 2917 14 1556.34 11 1554.75 46 1503.84
26 | 27
27 | 10 29 4027.84 26 3929.94 12 3875.58 11 3085.03 28 2908.6 30 2792.67 13 1878.42 25 1438.55 47 1425.2 10 1290.25
28 | 28
29 | 10 29 3687.02 48 3209.13 27 2872.86 47 2014.53 30 1361.95 11 1273.6 26 1062.85 12 840.841 46 672.985 31 271.952
30 | 29
31 | 10 27 4029.43 30 3909.55 28 3739.93 47 3695.23 48 3135.87 26 2910.97 46 2229.55 12 1479.16 31 1430.26 11 1144.56
32 | 30
33 | 10 26 4029.86 29 3953.72 31 3811.12 46 3630.46 47 3105.96 27 2824.43 25 2657.89 45 2347.75 32 1459.11 12 1429.62
34 | 31
35 | 10 25 3882.21 30 3841.88 32 3808.5 45 3649.82 46 3000.67 26 2939.94 24 2409.93 44 2381.3 13 1467.59 29 1459.56
36 | 33
37 | 10 32 3618.11 23 3598.1 34 3530.53 43 3462.37 24 3091.53 44 2608.08 42 2426 22 1717.94 31 1407.65 25 1324.78
38 | 34
39 | 10 33 3523.37 42 3356.55 35 3210.34 22 3178.85 23 3079.03 43 2396.45 41 2386.86 24 1408.02 32 1301.34 21 1256.45
40 | 35
41 | 10 34 3187.88 41 3106.44 36 2866.04 22 2817.74 21 2654.87 40 2416.98 42 2137.81 23 1346.86 33 1150.33 16 1044.66
42 | 40
43 | 10 36 2918.14 41 2852.62 39 2782.6 35 2392.96 37 1641.45 21 1124.3 42 1056.48 34 877.946 38 853.944 20 788.701
44 | 41
45 | 10 35 3111.05 42 3049.71 40 2885.36 34 2371.02 36 1813.69 43 1164.71 22 1126.9 39 1011.26 21 906.536 33 903.238
46 | 42
47 | 10 34 3356.98 43 3183 41 3070.54 33 2421.77 35 2155.08 44 1278.41 23 1183.52 22 1147.07 40 1077.08 32 899.646
48 | 43
49 | 10 33 3461.24 44 3380.74 42 3188.7 34 2400.6 32 2399.09 45 1359.37 23 1314.08 41 1176.12 24 1159.62 31 901.556
50 | 44
51 | 10 32 3550.81 45 3510.16 43 3373.11 33 2602.33 31 2395.93 24 1410.43 46 1386.31 42 1279 25 1095.24 34 968.44
52 | 45
53 | 10 31 3650.09 46 3555.09 44 3491.15 32 2868.39 30 2373.59 25 1485.37 47 1405.28 43 1349.54 33 1104.77 26 1046.81
54 | 46
55 | 10 30 3635.64 47 3562.17 45 3524.17 31 2976.82 29 2264.04 26 1508.87 44 1367.41 48 1352.1 32 1211.24 25 1102.17
56 | 47
57 | 10 29 3705.31 46 3519.76 48 3450.48 30 3074.77 28 2054.63 27 1434.57 45 1377.34 31 1268.23 26 1223.83 25 471.111
58 | 48
59 | 10 47 3401.95 28 3224.84 29 3101.16 46 1317.1 30 1306.7 27 1235.07 26 537.731 31 291.919 45 276.869 11 258.856
--------------------------------------------------------------------------------
/configs/lists/dtu_pairs_val.txt:
--------------------------------------------------------------------------------
1 | 4
2 | 23
3 | 10 24 3710.9 33 3603.07 22 3244.2 15 3190.62 34 3086.49 14 2220.11 32 2100 16 1917.1 35 1359.79 25 1356.71
4 | 24
5 | 10 25 3844.6 32 3750.75 23 3710.6 14 3609.09 33 3091.04 15 2559.24 31 2423.71 13 2109.36 26 1440.58 34 1410.03
6 | 25
7 | 10 26 3951.74 31 3888.57 24 3833.07 13 3667.35 14 3208.21 32 2993.46 30 2681.52 12 1900.23 45 1484.03 27 1462.88
8 | 32
9 | 10 31 3826.5 24 3744.14 33 3613.24 44 3552.04 25 3004.6 45 2884.59 43 2393.34 23 2095.27 30 1478.6 14 1420.78
--------------------------------------------------------------------------------
/configs/lists/dtu_train_all.txt:
--------------------------------------------------------------------------------
1 | scan3
2 | scan4
3 | scan5
4 | scan6
5 | scan9
6 | scan10
7 | scan11
8 | scan12
9 | scan13
10 | scan14
11 | scan15
12 | scan16
13 | scan17
14 | scan18
15 | scan19
16 | scan20
17 | scan22
18 | scan23
19 | scan24
20 | scan28
21 | scan32
22 | scan33
23 | scan35
24 | scan36
25 | scan37
26 | scan42
27 | scan43
28 | scan44
29 | scan46
30 | scan47
31 | scan48
32 | scan49
33 | scan50
34 | scan52
35 | scan53
36 | scan59
37 | scan60
38 | scan61
39 | scan62
40 | scan64
41 | scan65
42 | scan66
43 | scan67
44 | scan68
45 | scan69
46 | scan70
47 | scan71
48 | scan72
49 | scan74
50 | scan75
51 | scan76
52 | scan77
53 | scan84
54 | scan85
55 | scan86
56 | scan87
57 | scan88
58 | scan89
59 | scan90
60 | scan91
61 | scan92
62 | scan93
63 | scan94
64 | scan95
65 | scan96
66 | scan97
67 | scan98
68 | scan99
69 | scan100
70 | scan101
71 | scan102
72 | scan104
73 | scan105
74 | scan106
75 | scan107
76 | scan108
77 | scan109
78 | scan118
79 | scan119
80 | scan120
81 | scan121
82 | scan122
83 | scan123
84 | scan124
85 | scan125
86 | scan126
87 | scan127
88 | scan128
--------------------------------------------------------------------------------
/configs/lists/dtu_val_all.txt:
--------------------------------------------------------------------------------
1 | scan8
2 | scan21
3 | scan30
4 | scan31
5 | scan34
6 | scan38
7 | scan40
8 | scan41
9 | scan45
10 | scan55
11 | scan63
12 | scan82
13 | scan103
14 | scan110
15 | scan114
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/idiap/GeoNeRF/e6249fdae5672853c6bbbd4ba380c4c166d02c95/data/__init__.py
--------------------------------------------------------------------------------
/data/dtu.py:
--------------------------------------------------------------------------------
1 | # GeoNeRF is a generalizable NeRF model that renders novel views
2 | # without requiring per-scene optimization. This software is the
3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with
4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
5 | # and Francois Fleuret.
6 |
7 | # Copyright (c) 2022 ams International AG
8 |
9 | # This file is part of GeoNeRF.
10 | # GeoNeRF is free software: you can redistribute it and/or modify
11 | # it under the terms of the GNU General Public License version 3 as
12 | # published by the Free Software Foundation.
13 |
14 | # GeoNeRF is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 |
19 | # You should have received a copy of the GNU General Public License
20 | # along with GeoNeRF. If not, see .
21 |
22 | # This file incorporates work covered by the following copyright and
23 | # permission notice:
24 |
25 | # MIT License
26 |
27 | # Copyright (c) 2021 apchenstu
28 |
29 | # Permission is hereby granted, free of charge, to any person obtaining a copy
30 | # of this software and associated documentation files (the "Software"), to deal
31 | # in the Software without restriction, including without limitation the rights
32 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
33 | # copies of the Software, and to permit persons to whom the Software is
34 | # furnished to do so, subject to the following conditions:
35 |
36 | # The above copyright notice and this permission notice shall be included in all
37 | # copies or substantial portions of the Software.
38 |
39 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
40 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
41 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
42 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
43 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
44 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
45 | # SOFTWARE.
46 |
47 | from torch.utils.data import Dataset
48 | from torchvision import transforms as T
49 |
50 | import os
51 | import cv2
52 | import numpy as np
53 | from PIL import Image
54 |
55 | from utils.utils import read_pfm, get_nearest_pose_ids
56 |
57 | class DTU_Dataset(Dataset):
58 | def __init__(
59 | self,
60 | original_root_dir,
61 | preprocessed_root_dir,
62 | split,
63 | nb_views,
64 | downSample=1.0,
65 | max_len=-1,
66 | scene="None",
67 | ):
68 | self.original_root_dir = original_root_dir
69 | self.preprocessed_root_dir = preprocessed_root_dir
70 | self.split = split
71 | self.scene = scene
72 |
73 | self.downSample = downSample
74 | self.scale_factor = 1.0 / 200
75 | self.interval_scale = 1.06
76 | self.max_len = max_len
77 | self.nb_views = nb_views
78 |
79 | self.build_metas()
80 | self.build_proj_mats()
81 | self.define_transforms()
82 |
83 | def define_transforms(self):
84 | self.transform = T.Compose(
85 | [
86 | T.ToTensor(),
87 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
88 | ]
89 | )
90 |
91 | def build_metas(self):
92 | self.metas = []
93 | with open(f"configs/lists/dtu_{self.split}_all.txt") as f:
94 | self.scans = [line.rstrip() for line in f.readlines()]
95 | if self.scene != "None":
96 | self.scans = [self.scene]
97 |
98 | # light conditions 2-5 for training
99 | # light condition 3 for testing (the brightest?)
100 | light_idxs = (
101 | [3] if "train" != self.split or self.scene != "None" else range(2, 5)
102 | )
103 |
104 | self.id_list = []
105 |
106 | if self.split == "train":
107 | if self.scene == "None":
108 | pair_file = f"configs/lists/dtu_pairs.txt"
109 | else:
110 | pair_file = f"configs/lists/dtu_pairs_ft.txt"
111 | else:
112 | pair_file = f"configs/lists/dtu_pairs_val.txt"
113 |
114 | for scan in self.scans:
115 | with open(pair_file) as f:
116 | num_viewpoint = int(f.readline())
117 | for _ in range(num_viewpoint):
118 | ref_view = int(f.readline().rstrip())
119 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
120 | for light_idx in light_idxs:
121 | self.metas += [(scan, light_idx, ref_view, src_views)]
122 | self.id_list.append([ref_view] + src_views)
123 |
124 | self.id_list = np.unique(self.id_list)
125 | self.build_remap()
126 |
127 | def build_proj_mats(self):
128 | near_fars, intrinsics, world2cams, cam2worlds = [], [], [], []
129 | for vid in self.id_list:
130 | proj_mat_filename = os.path.join(
131 | self.preprocessed_root_dir, f"Cameras/train/{vid:08d}_cam.txt"
132 | )
133 | intrinsic, extrinsic, near_far = self.read_cam_file(proj_mat_filename)
134 | intrinsic[:2] *= 4
135 | extrinsic[:3, 3] *= self.scale_factor
136 |
137 | intrinsic[:2] = intrinsic[:2] * self.downSample
138 | intrinsics += [intrinsic.copy()]
139 |
140 | near_fars += [near_far]
141 | world2cams += [extrinsic]
142 | cam2worlds += [np.linalg.inv(extrinsic)]
143 |
144 | self.near_fars, self.intrinsics = np.stack(near_fars), np.stack(intrinsics)
145 | self.world2cams, self.cam2worlds = np.stack(world2cams), np.stack(cam2worlds)
146 |
147 | def read_cam_file(self, filename):
148 | with open(filename) as f:
149 | lines = [line.rstrip() for line in f.readlines()]
150 | # extrinsics: line [1,5), 4x4 matrix
151 | extrinsics = np.fromstring(" ".join(lines[1:5]), dtype=np.float32, sep=" ")
152 | extrinsics = extrinsics.reshape((4, 4))
153 | # intrinsics: line [7-10), 3x3 matrix
154 | intrinsics = np.fromstring(" ".join(lines[7:10]), dtype=np.float32, sep=" ")
155 | intrinsics = intrinsics.reshape((3, 3))
156 | # depth_min & depth_interval: line 11
157 | depth_min, depth_interval = lines[11].split()
158 | depth_min = float(depth_min) * self.scale_factor
159 | depth_max = depth_min + float(depth_interval) * 192 * self.interval_scale * self.scale_factor
160 |
161 | intrinsics[0, 2] = intrinsics[0, 2] + 80.0 / 4.0
162 | intrinsics[1, 2] = intrinsics[1, 2] + 44.0 / 4.0
163 | intrinsics[:2] = intrinsics[:2]
164 |
165 | return intrinsics, extrinsics, [depth_min, depth_max]
166 |
167 | def read_depth(self, filename, far_bound, noisy_factor=1.0):
168 | depth_h = self.scale_factor * np.array(
169 | read_pfm(filename)[0], dtype=np.float32
170 | )
171 | depth_h = cv2.resize(
172 | depth_h, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_NEAREST
173 | )
174 |
175 | depth_h = cv2.resize(
176 | depth_h,
177 | None,
178 | fx=self.downSample * noisy_factor,
179 | fy=self.downSample * noisy_factor,
180 | interpolation=cv2.INTER_NEAREST,
181 | )
182 |
183 | ## Exclude points beyond the bounds
184 | depth_h[depth_h > far_bound * 0.95] = 0.0
185 |
186 | depth = {}
187 | for l in range(3):
188 | depth[f"level_{l}"] = cv2.resize(
189 | depth_h,
190 | None,
191 | fx=1.0 / (2**l),
192 | fy=1.0 / (2**l),
193 | interpolation=cv2.INTER_NEAREST,
194 | )
195 |
196 | if self.split == "train":
197 | cutout = np.ones_like(depth[f"level_2"])
198 | h0 = int(np.random.randint(0, high=cutout.shape[0] // 5, size=1))
199 | h1 = int(
200 | np.random.randint(
201 | 4 * cutout.shape[0] // 5, high=cutout.shape[0], size=1
202 | )
203 | )
204 | w0 = int(np.random.randint(0, high=cutout.shape[1] // 5, size=1))
205 | w1 = int(
206 | np.random.randint(
207 | 4 * cutout.shape[1] // 5, high=cutout.shape[1], size=1
208 | )
209 | )
210 | cutout[h0:h1, w0:w1] = 0
211 | depth_aug = depth[f"level_2"] * cutout
212 | else:
213 | depth_aug = depth[f"level_2"].copy()
214 |
215 | return depth, depth_h, depth_aug
216 |
217 | def build_remap(self):
218 | self.remap = np.zeros(np.max(self.id_list) + 1).astype("int")
219 | for i, item in enumerate(self.id_list):
220 | self.remap[item] = i
221 |
222 | def __len__(self):
223 | return len(self.metas) if self.max_len <= 0 else self.max_len
224 |
225 | def __getitem__(self, idx):
226 | if self.split == "train" and self.scene == "None":
227 | noisy_factor = float(np.random.choice([1.0, 0.5], 1))
228 | close_views = int(np.random.choice([3, 4, 5], 1))
229 | else:
230 | noisy_factor = 1.0
231 | close_views = 5
232 |
233 | scan, light_idx, target_view, src_views = self.metas[idx]
234 | view_ids = src_views[:self.nb_views] + [target_view]
235 |
236 | affine_mats, affine_mats_inv = [], []
237 | imgs, depths_h, depths_aug = [], [], []
238 | depths = {"level_0": [], "level_1": [], "level_2": []}
239 | intrinsics, w2cs, c2ws, near_fars = [], [], [], []
240 |
241 | for vid in view_ids:
242 | # Note that the id in image file names is from 1 to 49 (not 0~48)
243 | img_filename = os.path.join(
244 | self.original_root_dir,
245 | f"Rectified/{scan}/rect_{vid + 1:03d}_{light_idx}_r5000.png",
246 | )
247 | depth_filename = os.path.join(
248 | self.preprocessed_root_dir, f"Depths/{scan}/depth_map_{vid:04d}.pfm"
249 | )
250 | img = Image.open(img_filename)
251 | img_wh = np.round(
252 | np.array(img.size) / 2.0 * self.downSample * noisy_factor
253 | ).astype("int")
254 | img = img.resize(img_wh, Image.BICUBIC)
255 | img = self.transform(img)
256 | imgs += [img]
257 |
258 | index_mat = self.remap[vid]
259 |
260 | intrinsic = self.intrinsics[index_mat].copy()
261 | intrinsic[:2] = intrinsic[:2] * noisy_factor
262 | intrinsics.append(intrinsic)
263 |
264 | w2c = self.world2cams[index_mat]
265 | w2cs.append(w2c)
266 | c2ws.append(self.cam2worlds[index_mat])
267 |
268 | aff = []
269 | aff_inv = []
270 | for l in range(3):
271 | proj_mat_l = np.eye(4)
272 | intrinsic_temp = intrinsic.copy()
273 | intrinsic_temp[:2] = intrinsic_temp[:2] / (2**l)
274 | proj_mat_l[:3, :4] = intrinsic_temp @ w2c[:3, :4]
275 | aff.append(proj_mat_l.copy())
276 | aff_inv.append(np.linalg.inv(proj_mat_l))
277 | aff = np.stack(aff, axis=-1)
278 | aff_inv = np.stack(aff_inv, axis=-1)
279 |
280 | affine_mats.append(aff)
281 | affine_mats_inv.append(aff_inv)
282 |
283 | near_far = self.near_fars[index_mat]
284 |
285 | depth, depth_h, depth_aug = self.read_depth(
286 | depth_filename, near_far[1], noisy_factor
287 | )
288 |
289 | depths["level_0"].append(depth["level_0"])
290 | depths["level_1"].append(depth["level_1"])
291 | depths["level_2"].append(depth["level_2"])
292 | depths_h.append(depth_h)
293 | depths_aug.append(depth_aug)
294 |
295 | near_fars.append(near_far)
296 |
297 | imgs = np.stack(imgs)
298 | depths_h, depths_aug = np.stack(depths_h), np.stack(depths_aug)
299 | depths["level_0"] = np.stack(depths["level_0"])
300 | depths["level_1"] = np.stack(depths["level_1"])
301 | depths["level_2"] = np.stack(depths["level_2"])
302 | affine_mats, affine_mats_inv = np.stack(affine_mats), np.stack(affine_mats_inv)
303 | intrinsics = np.stack(intrinsics)
304 | w2cs = np.stack(w2cs)
305 | c2ws = np.stack(c2ws)
306 | near_fars = np.stack(near_fars)
307 |
308 | closest_idxs = []
309 | for pose in c2ws[:-1]:
310 | closest_idxs.append(
311 | get_nearest_pose_ids(
312 | pose,
313 | ref_poses=c2ws[:-1],
314 | num_select=close_views,
315 | angular_dist_method="dist",
316 | )
317 | )
318 | closest_idxs = np.stack(closest_idxs, axis=0)
319 |
320 | sample = {}
321 | sample["images"] = imgs
322 | sample["depths"] = depths
323 | sample["depths_h"] = depths_h
324 | sample["depths_aug"] = depths_aug
325 | sample["w2cs"] = w2cs
326 | sample["c2ws"] = c2ws
327 | sample["near_fars"] = near_fars
328 | sample["intrinsics"] = intrinsics
329 | sample["affine_mats"] = affine_mats
330 | sample["affine_mats_inv"] = affine_mats_inv
331 | sample["closest_idxs"] = closest_idxs
332 |
333 | return sample
334 |
--------------------------------------------------------------------------------
/data/get_datasets.py:
--------------------------------------------------------------------------------
1 | # GeoNeRF is a generalizable NeRF model that renders novel views
2 | # without requiring per-scene optimization. This software is the
3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with
4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
5 | # and Francois Fleuret.
6 |
7 | # Copyright (c) 2022 ams International AG
8 |
9 | # This file is part of GeoNeRF.
10 | # GeoNeRF is free software: you can redistribute it and/or modify
11 | # it under the terms of the GNU General Public License version 3 as
12 | # published by the Free Software Foundation.
13 |
14 | # GeoNeRF is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 |
19 | # You should have received a copy of the GNU General Public License
20 | # along with GeoNeRF. If not, see .
21 |
22 | import torch
23 | from torch.utils.data import ConcatDataset, WeightedRandomSampler
24 | import numpy as np
25 |
26 | from data.llff import LLFF_Dataset
27 | from data.dtu import DTU_Dataset
28 | from data.nerf import NeRF_Dataset
29 |
30 | def get_training_dataset(args, downsample=1.0):
31 | train_datasets = [
32 | DTU_Dataset(
33 | original_root_dir=args.dtu_path,
34 | preprocessed_root_dir=args.dtu_pre_path,
35 | split="train",
36 | max_len=-1,
37 | downSample=downsample,
38 | nb_views=args.nb_views,
39 | ),
40 | LLFF_Dataset(
41 | root_dir=args.ibrnet1_path,
42 | split="train",
43 | max_len=-1,
44 | downSample=downsample,
45 | nb_views=args.nb_views,
46 | imgs_folder_name="images",
47 | ),
48 | LLFF_Dataset(
49 | root_dir=args.ibrnet2_path,
50 | split="train",
51 | max_len=-1,
52 | downSample=downsample,
53 | nb_views=args.nb_views,
54 | imgs_folder_name="images",
55 | ),
56 | LLFF_Dataset(
57 | root_dir=args.llff_path,
58 | split="train",
59 | max_len=-1,
60 | downSample=downsample,
61 | nb_views=args.nb_views,
62 | imgs_folder_name="images_4",
63 | ),
64 | ]
65 | weights = [0.5, 0.22, 0.12, 0.16]
66 |
67 | train_weights_samples = []
68 | for dataset, weight in zip(train_datasets, weights):
69 | num_samples = len(dataset)
70 | weight_each_sample = weight / num_samples
71 | train_weights_samples.extend([weight_each_sample] * num_samples)
72 |
73 | train_dataset = ConcatDataset(train_datasets)
74 | train_weights = torch.from_numpy(np.array(train_weights_samples))
75 | train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
76 |
77 | return train_dataset, train_sampler
78 |
79 |
80 | def get_finetuning_dataset(args, downsample=1.0):
81 | if args.dataset_name == "dtu":
82 | train_dataset = DTU_Dataset(
83 | original_root_dir=args.dtu_path,
84 | preprocessed_root_dir=args.dtu_pre_path,
85 | split="train",
86 | max_len=-1,
87 | downSample=downsample,
88 | nb_views=args.nb_views,
89 | scene=args.scene,
90 | )
91 | elif args.dataset_name == "llff":
92 | train_dataset = LLFF_Dataset(
93 | root_dir=args.llff_path,
94 | split="train",
95 | max_len=-1,
96 | downSample=downsample,
97 | nb_views=args.nb_views,
98 | scene=args.scene,
99 | imgs_folder_name="images_4",
100 | )
101 | elif args.dataset_name == "nerf":
102 | train_dataset = NeRF_Dataset(
103 | root_dir=args.nerf_path,
104 | split="train",
105 | max_len=-1,
106 | downSample=downsample,
107 | nb_views=args.nb_views,
108 | scene=args.scene,
109 | )
110 |
111 | train_sampler = None
112 |
113 | return train_dataset, train_sampler
114 |
115 |
116 | def get_validation_dataset(args, downsample=1.0):
117 | if args.scene == "None":
118 | max_len = 2
119 | else:
120 | max_len = -1
121 |
122 | if args.dataset_name == "dtu":
123 | val_dataset = DTU_Dataset(
124 | original_root_dir=args.dtu_path,
125 | preprocessed_root_dir=args.dtu_pre_path,
126 | split="val",
127 | max_len=max_len,
128 | downSample=downsample,
129 | nb_views=args.nb_views,
130 | scene=args.scene,
131 | )
132 | elif args.dataset_name == "llff":
133 | val_dataset = LLFF_Dataset(
134 | root_dir=args.llff_test_path if not args.llff_test_path is None else args.llff_path,
135 | split="val",
136 | max_len=max_len,
137 | downSample=downsample,
138 | nb_views=args.nb_views,
139 | scene=args.scene,
140 | imgs_folder_name="images_4",
141 | )
142 | elif args.dataset_name == "nerf":
143 | val_dataset = NeRF_Dataset(
144 | root_dir=args.nerf_path,
145 | split="val",
146 | max_len=max_len,
147 | downSample=downsample,
148 | nb_views=args.nb_views,
149 | scene=args.scene,
150 | )
151 |
152 | return val_dataset
153 |
--------------------------------------------------------------------------------
/data/llff.py:
--------------------------------------------------------------------------------
1 | # GeoNeRF is a generalizable NeRF model that renders novel views
2 | # without requiring per-scene optimization. This software is the
3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with
4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
5 | # and Francois Fleuret.
6 |
7 | # Copyright (c) 2022 ams International AG
8 |
9 | # This file is part of GeoNeRF.
10 | # GeoNeRF is free software: you can redistribute it and/or modify
11 | # it under the terms of the GNU General Public License version 3 as
12 | # published by the Free Software Foundation.
13 |
14 | # GeoNeRF is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 |
19 | # You should have received a copy of the GNU General Public License
20 | # along with GeoNeRF. If not, see .
21 |
22 | # This file incorporates work covered by the following copyright and
23 | # permission notice:
24 |
25 | # MIT License
26 |
27 | # Copyright (c) 2021 apchenstu
28 |
29 | # Permission is hereby granted, free of charge, to any person obtaining a copy
30 | # of this software and associated documentation files (the "Software"), to deal
31 | # in the Software without restriction, including without limitation the rights
32 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
33 | # copies of the Software, and to permit persons to whom the Software is
34 | # furnished to do so, subject to the following conditions:
35 |
36 | # The above copyright notice and this permission notice shall be included in all
37 | # copies or substantial portions of the Software.
38 |
39 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
40 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
41 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
42 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
43 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
44 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
45 | # SOFTWARE.
46 |
47 | from torch.utils.data import Dataset
48 | from torchvision import transforms as T
49 |
50 | import os
51 | import glob
52 | import numpy as np
53 | from PIL import Image
54 |
55 | from utils.utils import get_nearest_pose_ids
56 |
57 | def normalize(v):
58 | return v / np.linalg.norm(v)
59 |
60 |
61 | def average_poses(poses):
62 | # 1. Compute the center
63 | center = poses[..., 3].mean(0) # (3)
64 |
65 | # 2. Compute the z axis
66 | z = normalize(poses[..., 2].mean(0)) # (3)
67 |
68 | # 3. Compute axis y' (no need to normalize as it's not the final output)
69 | y_ = poses[..., 1].mean(0) # (3)
70 |
71 | # 4. Compute the x axis
72 | x = normalize(np.cross(y_, z)) # (3)
73 |
74 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
75 | y = np.cross(z, x) # (3)
76 |
77 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4)
78 |
79 | return pose_avg
80 |
81 |
82 | def center_poses(poses, blender2opencv):
83 | pose_avg = average_poses(poses) # (3, 4)
84 | pose_avg_homo = np.eye(4)
85 |
86 | # convert to homogeneous coordinate for faster computation
87 | # by simply adding 0, 0, 0, 1 as the last row
88 | pose_avg_homo[:3] = pose_avg
89 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4)
90 |
91 | # (N_images, 4, 4) homogeneous coordinate
92 | poses_homo = np.concatenate([poses, last_row], 1)
93 |
94 | poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4)
95 | poses_centered = poses_centered @ blender2opencv
96 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4)
97 |
98 | return poses_centered, np.linalg.inv(pose_avg_homo) @ blender2opencv
99 |
100 |
101 | class LLFF_Dataset(Dataset):
102 | def __init__(
103 | self,
104 | root_dir,
105 | split,
106 | nb_views,
107 | downSample=1.0,
108 | max_len=-1,
109 | scene="None",
110 | imgs_folder_name="images",
111 | ):
112 | self.root_dir = root_dir
113 | self.split = split
114 | self.nb_views = nb_views
115 | self.scene = scene
116 | self.imgs_folder_name = imgs_folder_name
117 |
118 | self.downsample = downSample
119 | self.max_len = max_len
120 | self.img_wh = (int(960 * self.downsample), int(720 * self.downsample))
121 |
122 | self.define_transforms()
123 | self.blender2opencv = np.array(
124 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
125 | )
126 |
127 | self.build_metas()
128 |
129 | def define_transforms(self):
130 | self.transform = T.Compose(
131 | [
132 | T.ToTensor(),
133 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
134 | ]
135 | )
136 |
137 | def build_metas(self):
138 | if self.scene != "None":
139 | self.scans = [
140 | os.path.basename(scan_dir)
141 | for scan_dir in sorted(
142 | glob.glob(os.path.join(self.root_dir, self.scene))
143 | )
144 | ]
145 | else:
146 | self.scans = [
147 | os.path.basename(scan_dir)
148 | for scan_dir in sorted(glob.glob(os.path.join(self.root_dir, "*")))
149 | ]
150 |
151 | self.meta = []
152 | self.image_paths = {}
153 | self.near_far = {}
154 | self.id_list = {}
155 | self.closest_idxs = {}
156 | self.c2ws = {}
157 | self.w2cs = {}
158 | self.intrinsics = {}
159 | self.affine_mats = {}
160 | self.affine_mats_inv = {}
161 | for scan in self.scans:
162 | self.image_paths[scan] = sorted(
163 | glob.glob(os.path.join(self.root_dir, scan, self.imgs_folder_name, "*"))
164 | )
165 | poses_bounds = np.load(
166 | os.path.join(self.root_dir, scan, "poses_bounds.npy")
167 | ) # (N_images, 17)
168 | poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5)
169 | bounds = poses_bounds[:, -2:] # (N_images, 2)
170 |
171 | # Step 1: rescale focal length according to training resolution
172 | H, W, focal = poses[0, :, -1] # original intrinsics, same for all images
173 |
174 | focal = [focal * self.img_wh[0] / W, focal * self.img_wh[1] / H]
175 |
176 | # Step 2: correct poses
177 | poses = np.concatenate(
178 | [poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1
179 | )
180 | poses, _ = center_poses(poses, self.blender2opencv)
181 | # poses = poses @ self.blender2opencv
182 |
183 | # Step 3: correct scale so that the nearest depth is at a little more than 1.0
184 | near_original = bounds.min()
185 | scale_factor = near_original * 0.75 # 0.75 is the default parameter
186 | bounds /= scale_factor
187 | poses[..., 3] /= scale_factor
188 |
189 | self.near_far[scan] = bounds.astype('float32')
190 |
191 | num_viewpoint = len(self.image_paths[scan])
192 | val_ids = [idx for idx in range(0, num_viewpoint, 8)]
193 | w, h = self.img_wh
194 |
195 | self.id_list[scan] = []
196 | self.closest_idxs[scan] = []
197 | self.c2ws[scan] = []
198 | self.w2cs[scan] = []
199 | self.intrinsics[scan] = []
200 | self.affine_mats[scan] = []
201 | self.affine_mats_inv[scan] = []
202 | for idx in range(num_viewpoint):
203 | if (
204 | (self.split == "val" and idx in val_ids)
205 | or (
206 | self.split == "train"
207 | and self.scene != "None"
208 | and idx not in val_ids
209 | )
210 | or (self.split == "train" and self.scene == "None")
211 | ):
212 | self.meta.append({"scan": scan, "target_idx": idx})
213 |
214 | view_ids = get_nearest_pose_ids(
215 | poses[idx, :, :],
216 | ref_poses=poses[..., :],
217 | num_select=self.nb_views + 1,
218 | angular_dist_method="dist",
219 | )
220 |
221 | self.id_list[scan].append(view_ids)
222 |
223 | closest_idxs = []
224 | source_views = view_ids[1:]
225 | for vid in source_views:
226 | closest_idxs.append(
227 | get_nearest_pose_ids(
228 | poses[vid, :, :],
229 | ref_poses=poses[source_views],
230 | num_select=5,
231 | angular_dist_method="dist",
232 | )
233 | )
234 | self.closest_idxs[scan].append(np.stack(closest_idxs, axis=0))
235 |
236 | c2w = np.eye(4).astype('float32')
237 | c2w[:3] = poses[idx]
238 | w2c = np.linalg.inv(c2w)
239 | self.c2ws[scan].append(c2w)
240 | self.w2cs[scan].append(w2c)
241 |
242 | intrinsic = np.array([[focal[0], 0, w / 2], [0, focal[1], h / 2], [0, 0, 1]]).astype('float32')
243 | self.intrinsics[scan].append(intrinsic)
244 |
245 | def __len__(self):
246 | return len(self.meta) if self.max_len <= 0 else self.max_len
247 |
248 | def __getitem__(self, idx):
249 | if self.split == "train" and self.scene == "None":
250 | noisy_factor = float(np.random.choice([1.0, 0.75, 0.5], 1))
251 | close_views = int(np.random.choice([3, 4, 5], 1))
252 | else:
253 | noisy_factor = 1.0
254 | close_views = 5
255 |
256 | scan = self.meta[idx]["scan"]
257 | target_idx = self.meta[idx]["target_idx"]
258 |
259 | view_ids = self.id_list[scan][target_idx]
260 | target_view = view_ids[0]
261 | src_views = view_ids[1:]
262 | view_ids = [vid for vid in src_views] + [target_view]
263 |
264 | closest_idxs = self.closest_idxs[scan][target_idx][:, :close_views]
265 |
266 | imgs, depths, depths_h, depths_aug = [], [], [], []
267 | intrinsics, w2cs, c2ws, near_fars = [], [], [], []
268 | affine_mats, affine_mats_inv = [], []
269 |
270 | w, h = self.img_wh
271 | w, h = int(w * noisy_factor), int(h * noisy_factor)
272 |
273 | for vid in view_ids:
274 | img_filename = self.image_paths[scan][vid]
275 | img = Image.open(img_filename).convert("RGB")
276 | if img.size != (w, h):
277 | img = img.resize((w, h), Image.BICUBIC)
278 | img = self.transform(img)
279 | imgs.append(img)
280 |
281 | intrinsic = self.intrinsics[scan][vid].copy()
282 | intrinsic[:2] = intrinsic[:2] * noisy_factor
283 | intrinsics.append(intrinsic)
284 |
285 | w2c = self.w2cs[scan][vid]
286 | w2cs.append(w2c)
287 | c2ws.append(self.c2ws[scan][vid])
288 |
289 | aff = []
290 | aff_inv = []
291 | for l in range(3):
292 | proj_mat_l = np.eye(4)
293 | intrinsic_temp = intrinsic.copy()
294 | intrinsic_temp[:2] = intrinsic_temp[:2] / (2**l)
295 | proj_mat_l[:3, :4] = intrinsic_temp @ w2c[:3, :4]
296 | aff.append(proj_mat_l.copy())
297 | aff_inv.append(np.linalg.inv(proj_mat_l))
298 | aff = np.stack(aff, axis=-1)
299 | aff_inv = np.stack(aff_inv, axis=-1)
300 |
301 | affine_mats.append(aff)
302 | affine_mats_inv.append(aff_inv)
303 |
304 | near_fars.append(self.near_far[scan][vid])
305 |
306 | depths_h.append(np.zeros([h, w]))
307 | depths.append(np.zeros([h // 4, w // 4]))
308 | depths_aug.append(np.zeros([h // 4, w // 4]))
309 |
310 | imgs = np.stack(imgs)
311 | depths = np.stack(depths)
312 | depths_h = np.stack(depths_h)
313 | depths_aug = np.stack(depths_aug)
314 | affine_mats = np.stack(affine_mats)
315 | affine_mats_inv = np.stack(affine_mats_inv)
316 | intrinsics = np.stack(intrinsics)
317 | w2cs = np.stack(w2cs)
318 | c2ws = np.stack(c2ws)
319 | near_fars = np.stack(near_fars)
320 |
321 | sample = {}
322 | sample["images"] = imgs
323 | sample["depths"] = depths
324 | sample["depths_h"] = depths_h
325 | sample["depths_aug"] = depths_aug
326 | sample["w2cs"] = w2cs
327 | sample["c2ws"] = c2ws
328 | sample["near_fars"] = near_fars
329 | sample["affine_mats"] = affine_mats
330 | sample["affine_mats_inv"] = affine_mats_inv
331 | sample["intrinsics"] = intrinsics
332 | sample["closest_idxs"] = closest_idxs
333 |
334 | return sample
335 |
--------------------------------------------------------------------------------
/data/nerf.py:
--------------------------------------------------------------------------------
1 | # GeoNeRF is a generalizable NeRF model that renders novel views
2 | # without requiring per-scene optimization. This software is the
3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with
4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
5 | # and Francois Fleuret.
6 |
7 | # Copyright (c) 2022 ams International AG
8 |
9 | # This file is part of GeoNeRF.
10 | # GeoNeRF is free software: you can redistribute it and/or modify
11 | # it under the terms of the GNU General Public License version 3 as
12 | # published by the Free Software Foundation.
13 |
14 | # GeoNeRF is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 |
19 | # You should have received a copy of the GNU General Public License
20 | # along with GeoNeRF. If not, see .
21 |
22 | # This file incorporates work covered by the following copyright and
23 | # permission notice:
24 |
25 | # MIT License
26 |
27 | # Copyright (c) 2021 apchenstu
28 |
29 | # Permission is hereby granted, free of charge, to any person obtaining a copy
30 | # of this software and associated documentation files (the "Software"), to deal
31 | # in the Software without restriction, including without limitation the rights
32 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
33 | # copies of the Software, and to permit persons to whom the Software is
34 | # furnished to do so, subject to the following conditions:
35 |
36 | # The above copyright notice and this permission notice shall be included in all
37 | # copies or substantial portions of the Software.
38 |
39 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
40 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
41 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
42 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
43 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
44 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
45 | # SOFTWARE.
46 |
47 | from torch.utils.data import Dataset
48 | from torchvision import transforms as T
49 |
50 | import os
51 | import json
52 | import numpy as np
53 | from PIL import Image
54 |
55 | from utils.utils import get_nearest_pose_ids
56 |
57 | class NeRF_Dataset(Dataset):
58 | def __init__(
59 | self,
60 | root_dir,
61 | split,
62 | nb_views,
63 | downSample=1.0,
64 | max_len=-1,
65 | scene="None",
66 | ):
67 | self.root_dir = root_dir
68 | self.split = split
69 | self.nb_views = nb_views
70 | self.scene = scene
71 |
72 | self.downsample = downSample
73 | self.max_len = max_len
74 |
75 | self.img_wh = (int(800 * self.downsample), int(800 * self.downsample))
76 |
77 | self.define_transforms()
78 | self.blender2opencv = np.array(
79 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
80 | )
81 |
82 | self.build_metas()
83 |
84 | def define_transforms(self):
85 | self.transform = T.ToTensor()
86 |
87 | self.src_transform = T.Compose(
88 | [
89 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
90 | ]
91 | )
92 |
93 | def build_metas(self):
94 | self.meta = {}
95 | with open(
96 | os.path.join(self.root_dir, self.scene, "transforms_train.json"), "r"
97 | ) as f:
98 | self.meta["train"] = json.load(f)
99 |
100 | with open(
101 | os.path.join(self.root_dir, self.scene, "transforms_test.json"), "r"
102 | ) as f:
103 | self.meta["val"] = json.load(f)
104 |
105 | w, h = self.img_wh
106 |
107 | # original focal length
108 | focal = 0.5 * 800 / np.tan(0.5 * self.meta["train"]["camera_angle_x"])
109 |
110 | # modify focal length to match size self.img_wh
111 | focal *= self.img_wh[0] / 800
112 |
113 | self.near_far = np.array([2.0, 6.0])
114 |
115 | self.image_paths = {"train": [], "val": []}
116 | self.c2ws = {"train": [], "val": []}
117 | self.w2cs = {"train": [], "val": []}
118 | self.intrinsics = {"train": [], "val": []}
119 |
120 | for frame in self.meta["train"]["frames"]:
121 | self.image_paths["train"].append(
122 | os.path.join(self.root_dir, self.scene, f"{frame['file_path']}.png")
123 | )
124 |
125 | c2w = np.array(frame["transform_matrix"]) @ self.blender2opencv
126 | w2c = np.linalg.inv(c2w)
127 | self.c2ws["train"].append(c2w)
128 | self.w2cs["train"].append(w2c)
129 |
130 | intrinsic = np.array([[focal, 0, w / 2], [0, focal, h / 2], [0, 0, 1]])
131 | self.intrinsics["train"].append(intrinsic.copy())
132 |
133 | self.c2ws["train"] = np.stack(self.c2ws["train"], axis=0)
134 | self.w2cs["train"] = np.stack(self.w2cs["train"], axis=0)
135 | self.intrinsics["train"] = np.stack(self.intrinsics["train"], axis=0)
136 |
137 | for frame in self.meta["val"]["frames"]:
138 | self.image_paths["val"].append(
139 | os.path.join(self.root_dir, self.scene, f"{frame['file_path']}.png")
140 | )
141 |
142 | c2w = np.array(frame["transform_matrix"]) @ self.blender2opencv
143 | w2c = np.linalg.inv(c2w)
144 | self.c2ws["val"].append(c2w)
145 | self.w2cs["val"].append(w2c)
146 |
147 | intrinsic = np.array([[focal, 0, w / 2], [0, focal, h / 2], [0, 0, 1]])
148 | self.intrinsics["val"].append(intrinsic.copy())
149 |
150 | self.c2ws["val"] = np.stack(self.c2ws["val"], axis=0)
151 | self.w2cs["val"] = np.stack(self.w2cs["val"], axis=0)
152 | self.intrinsics["val"] = np.stack(self.intrinsics["val"], axis=0)
153 |
154 | def __len__(self):
155 | return len(self.image_paths[self.split]) if self.max_len <= 0 else self.max_len
156 |
157 | def __getitem__(self, idx):
158 | target_frame = self.meta[self.split]["frames"][idx]
159 | c2w = np.array(target_frame["transform_matrix"]) @ self.blender2opencv
160 | w2c = np.linalg.inv(c2w)
161 |
162 | if self.split == "train":
163 | src_views = get_nearest_pose_ids(
164 | c2w,
165 | ref_poses=self.c2ws["train"],
166 | num_select=self.nb_views + 1,
167 | angular_dist_method="dist",
168 | )[1:]
169 | else:
170 | src_views = get_nearest_pose_ids(
171 | c2w,
172 | ref_poses=self.c2ws["train"],
173 | num_select=self.nb_views,
174 | angular_dist_method="dist",
175 | )
176 |
177 | imgs, depths, depths_h, depths_aug = [], [], [], []
178 | intrinsics, w2cs, c2ws, near_fars = [], [], [], []
179 | affine_mats, affine_mats_inv = [], []
180 |
181 | w, h = self.img_wh
182 |
183 | for vid in src_views:
184 | img_filename = self.image_paths["train"][vid]
185 | img = Image.open(img_filename)
186 | if img.size != (w, h):
187 | img = img.resize((w, h), Image.BICUBIC)
188 |
189 | img = self.transform(img)
190 | img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
191 | imgs.append(self.src_transform(img))
192 |
193 | intrinsic = self.intrinsics["train"][vid]
194 | intrinsics.append(intrinsic)
195 |
196 | w2c = self.w2cs["train"][vid]
197 | w2cs.append(w2c)
198 | c2ws.append(self.c2ws["train"][vid])
199 |
200 | aff = []
201 | aff_inv = []
202 | for l in range(3):
203 | proj_mat_l = np.eye(4)
204 | intrinsic_temp = intrinsic.copy()
205 | intrinsic_temp[:2] = intrinsic_temp[:2] / (2**l)
206 | proj_mat_l[:3, :4] = intrinsic_temp @ w2c[:3, :4]
207 | aff.append(proj_mat_l.copy())
208 | aff_inv.append(np.linalg.inv(proj_mat_l))
209 | aff = np.stack(aff, axis=-1)
210 | aff_inv = np.stack(aff_inv, axis=-1)
211 |
212 | affine_mats.append(aff)
213 | affine_mats_inv.append(aff_inv)
214 |
215 | near_fars.append(self.near_far)
216 |
217 | depths_h.append(np.zeros([h, w]))
218 | depths.append(np.zeros([h // 4, w // 4]))
219 | depths_aug.append(np.zeros([h // 4, w // 4]))
220 |
221 | ## Adding target data
222 | img_filename = self.image_paths[self.split][idx]
223 | img = Image.open(img_filename)
224 | if img.size != (w, h):
225 | img = img.resize((w, h), Image.BICUBIC)
226 |
227 | img = self.transform(img) # (4, h, w)
228 | img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
229 | imgs.append(self.src_transform(img))
230 |
231 | intrinsic = self.intrinsics[self.split][idx]
232 | intrinsics.append(intrinsic)
233 |
234 | w2c = self.w2cs[self.split][idx]
235 | w2cs.append(w2c)
236 | c2ws.append(self.c2ws[self.split][idx])
237 |
238 | near_fars.append(self.near_far)
239 |
240 | depths_h.append(np.zeros([h, w]))
241 | depths.append(np.zeros([h // 4, w // 4]))
242 | depths_aug.append(np.zeros([h // 4, w // 4]))
243 |
244 | ## Stacking
245 | imgs = np.stack(imgs)
246 | depths = np.stack(depths)
247 | depths_h = np.stack(depths_h)
248 | depths_aug = np.stack(depths_aug)
249 | affine_mats = np.stack(affine_mats)
250 | affine_mats_inv = np.stack(affine_mats_inv)
251 | intrinsics = np.stack(intrinsics)
252 | w2cs = np.stack(w2cs)
253 | c2ws = np.stack(c2ws)
254 | near_fars = np.stack(near_fars)
255 |
256 | closest_idxs = []
257 | for pose in c2ws[:-1]:
258 | closest_idxs.append(
259 | get_nearest_pose_ids(
260 | pose, ref_poses=c2ws[:-1], num_select=5, angular_dist_method="dist"
261 | )
262 | )
263 | closest_idxs = np.stack(closest_idxs, axis=0)
264 |
265 | sample = {}
266 | sample["images"] = imgs
267 | sample["depths"] = depths
268 | sample["depths_h"] = depths_h
269 | sample["depths_aug"] = depths_aug
270 | sample["w2cs"] = w2cs.astype("float32")
271 | sample["c2ws"] = c2ws.astype("float32")
272 | sample["near_fars"] = near_fars
273 | sample["affine_mats"] = affine_mats
274 | sample["affine_mats_inv"] = affine_mats_inv
275 | sample["intrinsics"] = intrinsics.astype("float32")
276 | sample["closest_idxs"] = closest_idxs
277 |
278 | return sample
279 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/idiap/GeoNeRF/e6249fdae5672853c6bbbd4ba380c4c166d02c95/model/__init__.py
--------------------------------------------------------------------------------
/model/geo_reasoner.py:
--------------------------------------------------------------------------------
1 | # GeoNeRF is a generalizable NeRF model that renders novel views
2 | # without requiring per-scene optimization. This software is the
3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with
4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
5 | # and Francois Fleuret.
6 |
7 | # Copyright (c) 2022 ams International AG
8 |
9 | # This file is part of GeoNeRF.
10 | # GeoNeRF is free software: you can redistribute it and/or modify
11 | # it under the terms of the GNU General Public License version 3 as
12 | # published by the Free Software Foundation.
13 |
14 | # GeoNeRF is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 |
19 | # You should have received a copy of the GNU General Public License
20 | # along with GeoNeRF. If not, see .
21 |
22 | # This file incorporates work covered by the following copyright and
23 | # permission notice:
24 |
25 | # Copyright (c) 2020 AI葵
26 |
27 | # This file is part of CasMVSNet_pl.
28 | # CasMVSNet_pl is free software: you can redistribute it and/or modify
29 | # it under the terms of the GNU General Public License version 3 as
30 | # published by the Free Software Foundation.
31 |
32 | # CasMVSNet_pl is distributed in the hope that it will be useful,
33 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
34 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
35 | # GNU General Public License for more details.
36 |
37 | # You should have received a copy of the GNU General Public License
38 | # along with CasMVSNet_pl. If not, see .
39 |
40 | import torch
41 | import torch.nn as nn
42 | import torch.nn.functional as F
43 | from torch.utils.checkpoint import checkpoint
44 |
45 | from utils.utils import homo_warp
46 | from inplace_abn import InPlaceABN
47 |
48 |
49 | def get_depth_values(current_depth, n_depths, depth_interval):
50 | depth_min = torch.clamp_min(current_depth - n_depths / 2 * depth_interval, 1e-7)
51 | depth_values = (
52 | depth_min
53 | + depth_interval
54 | * torch.arange(
55 | 0, n_depths, device=current_depth.device, dtype=current_depth.dtype
56 | )[None, :, None, None]
57 | )
58 | return depth_values
59 |
60 |
61 | class ConvBnReLU(nn.Module):
62 | def __init__(
63 | self,
64 | in_channels,
65 | out_channels,
66 | kernel_size=3,
67 | stride=1,
68 | pad=1,
69 | norm_act=InPlaceABN,
70 | ):
71 | super(ConvBnReLU, self).__init__()
72 | self.conv = nn.Conv2d(
73 | in_channels,
74 | out_channels,
75 | kernel_size,
76 | stride=stride,
77 | padding=pad,
78 | bias=False,
79 | )
80 | self.bn = norm_act(out_channels)
81 |
82 | def forward(self, x):
83 | return self.bn(self.conv(x))
84 |
85 |
86 | class ConvBnReLU3D(nn.Module):
87 | def __init__(
88 | self,
89 | in_channels,
90 | out_channels,
91 | kernel_size=3,
92 | stride=1,
93 | pad=1,
94 | norm_act=InPlaceABN,
95 | ):
96 | super(ConvBnReLU3D, self).__init__()
97 | self.conv = nn.Conv3d(
98 | in_channels,
99 | out_channels,
100 | kernel_size,
101 | stride=stride,
102 | padding=pad,
103 | bias=False,
104 | )
105 | self.bn = norm_act(out_channels)
106 |
107 | def forward(self, x):
108 | return self.bn(self.conv(x))
109 |
110 |
111 | class FeatureNet(nn.Module):
112 | def __init__(self, norm_act=InPlaceABN):
113 | super(FeatureNet, self).__init__()
114 |
115 | self.conv0 = nn.Sequential(
116 | ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act),
117 | ConvBnReLU(8, 8, 3, 1, 1, norm_act=norm_act),
118 | )
119 |
120 | self.conv1 = nn.Sequential(
121 | ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act),
122 | ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act),
123 | ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act),
124 | )
125 |
126 | self.conv2 = nn.Sequential(
127 | ConvBnReLU(16, 32, 5, 2, 2, norm_act=norm_act),
128 | ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act),
129 | ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act),
130 | )
131 |
132 | self.toplayer = nn.Conv2d(32, 32, 1)
133 | self.lat1 = nn.Conv2d(16, 32, 1)
134 | self.lat0 = nn.Conv2d(8, 32, 1)
135 |
136 | # to reduce channel size of the outputs from FPN
137 | self.smooth1 = nn.Conv2d(32, 16, 3, padding=1)
138 | self.smooth0 = nn.Conv2d(32, 8, 3, padding=1)
139 |
140 | def _upsample_add(self, x, y):
141 | return F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + y
142 |
143 | def forward(self, x, dummy=None):
144 | # x: (B, 3, H, W)
145 | conv0 = self.conv0(x) # (B, 8, H, W)
146 | conv1 = self.conv1(conv0) # (B, 16, H//2, W//2)
147 | conv2 = self.conv2(conv1) # (B, 32, H//4, W//4)
148 | feat2 = self.toplayer(conv2) # (B, 32, H//4, W//4)
149 | feat1 = self._upsample_add(feat2, self.lat1(conv1)) # (B, 32, H//2, W//2)
150 | feat0 = self._upsample_add(feat1, self.lat0(conv0)) # (B, 32, H, W)
151 |
152 | # reduce output channels
153 | feat1 = self.smooth1(feat1) # (B, 16, H//2, W//2)
154 | feat0 = self.smooth0(feat0) # (B, 8, H, W)
155 |
156 | feats = {"level_0": feat0, "level_1": feat1, "level_2": feat2}
157 |
158 | return feats
159 |
160 |
161 | class CostRegNet(nn.Module):
162 | def __init__(self, in_channels, norm_act=InPlaceABN):
163 | super(CostRegNet, self).__init__()
164 | self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act)
165 |
166 | self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act)
167 | self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act)
168 |
169 | self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act)
170 | self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act)
171 |
172 | self.conv5 = ConvBnReLU3D(32, 64, stride=2, norm_act=norm_act)
173 | self.conv6 = ConvBnReLU3D(64, 64, norm_act=norm_act)
174 |
175 | self.conv7 = nn.Sequential(
176 | nn.ConvTranspose3d(
177 | 64, 32, 3, padding=1, output_padding=1, stride=2, bias=False
178 | ),
179 | norm_act(32),
180 | )
181 |
182 | self.conv9 = nn.Sequential(
183 | nn.ConvTranspose3d(
184 | 32, 16, 3, padding=1, output_padding=1, stride=2, bias=False
185 | ),
186 | norm_act(16),
187 | )
188 |
189 | self.conv11 = nn.Sequential(
190 | nn.ConvTranspose3d(
191 | 16, 8, 3, padding=1, output_padding=1, stride=2, bias=False
192 | ),
193 | norm_act(8),
194 | )
195 |
196 | self.br1 = ConvBnReLU3D(8, 8, norm_act=norm_act)
197 | self.br2 = ConvBnReLU3D(8, 8, norm_act=norm_act)
198 |
199 | self.prob = nn.Conv3d(8, 1, 3, stride=1, padding=1)
200 |
201 | def forward(self, x):
202 | if x.shape[-2] % 8 != 0 or x.shape[-1] % 8 != 0:
203 | pad_h = 8 * (x.shape[-2] // 8 + 1) - x.shape[-2]
204 | pad_w = 8 * (x.shape[-1] // 8 + 1) - x.shape[-1]
205 | x = F.pad(x, (0, pad_w, 0, pad_h), mode="constant", value=0)
206 | else:
207 | pad_h = 0
208 | pad_w = 0
209 |
210 | conv0 = self.conv0(x)
211 | conv2 = self.conv2(self.conv1(conv0))
212 | conv4 = self.conv4(self.conv3(conv2))
213 |
214 | x = self.conv6(self.conv5(conv4))
215 | x = conv4 + self.conv7(x)
216 | del conv4
217 | x = conv2 + self.conv9(x)
218 | del conv2
219 | x = conv0 + self.conv11(x)
220 | del conv0
221 | ####################
222 | x1 = self.br1(x)
223 | with torch.enable_grad():
224 | x2 = self.br2(x)
225 | ####################
226 | p = self.prob(x1)
227 |
228 | if pad_h > 0 or pad_w > 0:
229 | x2 = x2[..., :-pad_h, :-pad_w]
230 | p = p[..., :-pad_h, :-pad_w]
231 |
232 | return x2, p
233 |
234 |
235 | class CasMVSNet(nn.Module):
236 | def __init__(self, num_groups=8, norm_act=InPlaceABN, levels=3, use_depth=False):
237 | super(CasMVSNet, self).__init__()
238 | self.levels = levels # 3 depth levels
239 | self.n_depths = [8, 32, 48]
240 | self.interval_ratios = [1, 2, 4]
241 | self.use_depth = use_depth
242 |
243 | self.G = num_groups # number of groups in groupwise correlation
244 | self.feature = FeatureNet()
245 |
246 | for l in range(self.levels):
247 | if l == self.levels - 1 and self.use_depth:
248 | cost_reg_l = CostRegNet(self.G + 1, norm_act)
249 | else:
250 | cost_reg_l = CostRegNet(self.G, norm_act)
251 |
252 | setattr(self, f"cost_reg_{l}", cost_reg_l)
253 |
254 | def build_cost_volumes(self, feats, affine_mats, affine_mats_inv, depth_values, idx, spikes):
255 | B, V, C, H, W = feats.shape
256 | D = depth_values.shape[1]
257 |
258 | ref_feats, src_feats = feats[:, idx[0]], feats[:, idx[1:]]
259 | src_feats = src_feats.permute(1, 0, 2, 3, 4) # (V-1, B, C, h, w)
260 |
261 | affine_mats_inv = affine_mats_inv[:, idx[0]]
262 | affine_mats = affine_mats[:, idx[1:]]
263 | affine_mats = affine_mats.permute(1, 0, 2, 3) # (V-1, B, 3, 4)
264 |
265 | ref_volume = ref_feats.unsqueeze(2).repeat(1, 1, D, 1, 1) # (B, C, D, h, w)
266 |
267 | ref_volume = ref_volume.view(B, self.G, C // self.G, *ref_volume.shape[-3:])
268 | volume_sum = 0
269 |
270 | for i in range(len(idx) - 1):
271 | proj_mat = (affine_mats[i].double() @ affine_mats_inv.double()).float()[
272 | :, :3
273 | ]
274 | warped_volume, grid = homo_warp(src_feats[i], proj_mat, depth_values)
275 |
276 | warped_volume = warped_volume.view_as(ref_volume)
277 | volume_sum = volume_sum + warped_volume # (B, G, C//G, D, h, w)
278 |
279 | volume = (volume_sum * ref_volume).mean(dim=2) / (V - 1)
280 |
281 | if spikes is None:
282 | output = volume
283 | else:
284 | output = torch.cat([volume, spikes], dim=1)
285 |
286 | return output
287 |
288 | def create_neural_volume(
289 | self,
290 | feats,
291 | affine_mats,
292 | affine_mats_inv,
293 | idx,
294 | init_depth_min,
295 | depth_interval,
296 | gt_depths,
297 | ):
298 | if feats["level_0"].shape[-1] >= 800:
299 | hres_input = True
300 | else:
301 | hres_input = False
302 |
303 | B, V = affine_mats.shape[:2]
304 |
305 | v_feat = {}
306 | depth_maps = {}
307 | depth_values = {}
308 | for l in reversed(range(self.levels)): # (2, 1, 0)
309 | feats_l = feats[f"level_{l}"] # (B*V, C, h, w)
310 | feats_l = feats_l.view(B, V, *feats_l.shape[1:]) # (B, V, C, h, w)
311 | h, w = feats_l.shape[-2:]
312 | depth_interval_l = depth_interval * self.interval_ratios[l]
313 | D = self.n_depths[l]
314 | if l == self.levels - 1: # coarsest level
315 | depth_values_l = init_depth_min + depth_interval_l * torch.arange(
316 | 0, D, device=feats_l.device, dtype=feats_l.dtype
317 | ) # (D)
318 | depth_values_l = depth_values_l[None, :, None, None].expand(
319 | -1, -1, h, w
320 | )
321 |
322 | if self.use_depth:
323 | gt_mask = gt_depths > 0
324 | sp_idx_float = (
325 | gt_mask * (gt_depths - init_depth_min) / (depth_interval_l)
326 | )[:, :, None]
327 | spikes = (
328 | torch.arange(D).view(1, 1, -1, 1, 1).cuda()
329 | == sp_idx_float.floor().long()
330 | ) * (1 - sp_idx_float.frac())
331 | spikes = spikes + (
332 | torch.arange(D).view(1, 1, -1, 1, 1).cuda()
333 | == sp_idx_float.ceil().long()
334 | ) * (sp_idx_float.frac())
335 | spikes = (spikes * gt_mask[:, :, None]).float()
336 | else:
337 | depth_lm1 = depth_l.detach() # the depth of previous level
338 | depth_lm1 = F.interpolate(
339 | depth_lm1, scale_factor=2, mode="bilinear", align_corners=True
340 | ) # (B, 1, h, w)
341 | depth_values_l = get_depth_values(depth_lm1, D, depth_interval_l)
342 |
343 | affine_mats_l = affine_mats[..., l]
344 | affine_mats_inv_l = affine_mats_inv[..., l]
345 |
346 | if l == self.levels - 1 and self.use_depth:
347 | spikes_ = spikes
348 | else:
349 | spikes_ = None
350 |
351 | if hres_input:
352 | v_feat_l = checkpoint(
353 | self.build_cost_volumes,
354 | feats_l,
355 | affine_mats_l,
356 | affine_mats_inv_l,
357 | depth_values_l,
358 | idx,
359 | spikes_,
360 | preserve_rng_state=False,
361 | )
362 | else:
363 | v_feat_l = self.build_cost_volumes(
364 | feats_l,
365 | affine_mats_l,
366 | affine_mats_inv_l,
367 | depth_values_l,
368 | idx,
369 | spikes_,
370 | )
371 |
372 | cost_reg_l = getattr(self, f"cost_reg_{l}")
373 | v_feat_l, depth_prob = cost_reg_l(v_feat_l) # (B, 1, D, h, w)
374 |
375 | depth_l = (F.softmax(depth_prob, dim=2) * depth_values_l[:, None]).sum(
376 | dim=2
377 | )
378 |
379 | v_feat[f"level_{l}"] = v_feat_l
380 | depth_maps[f"level_{l}"] = depth_l
381 | depth_values[f"level_{l}"] = depth_values_l
382 |
383 | return v_feat, depth_maps, depth_values
384 |
385 | def forward(
386 | self, imgs, affine_mats, affine_mats_inv, near_far, closest_idxs, gt_depths=None
387 | ):
388 | B, V, _, H, W = imgs.shape
389 |
390 | ## Feature Pyramid
391 | feats = self.feature(
392 | imgs.reshape(B * V, 3, H, W)
393 | ) # (B*V, 8, H, W), (B*V, 16, H//2, W//2), (B*V, 32, H//4, W//4)
394 | feats_fpn = feats[f"level_0"].reshape(B, V, *feats[f"level_0"].shape[1:])
395 |
396 | feats_vol = {"level_0": [], "level_1": [], "level_2": []}
397 | depth_map = {"level_0": [], "level_1": [], "level_2": []}
398 | depth_values = {"level_0": [], "level_1": [], "level_2": []}
399 | ## Create cost volumes for each view
400 | for i in range(0, V):
401 | permuted_idx = torch.tensor(closest_idxs[0, i]).cuda()
402 |
403 | init_depth_min = near_far[0, i, 0]
404 | depth_interval = (
405 | (near_far[0, i, 1] - near_far[0, i, 0])
406 | / self.n_depths[-1]
407 | / self.interval_ratios[-1]
408 | )
409 |
410 | v_feat, d_map, d_values = self.create_neural_volume(
411 | feats,
412 | affine_mats,
413 | affine_mats_inv,
414 | idx=permuted_idx,
415 | init_depth_min=init_depth_min,
416 | depth_interval=depth_interval,
417 | gt_depths=gt_depths[:, i : i + 1],
418 | )
419 |
420 | for l in range(3):
421 | feats_vol[f"level_{l}"].append(v_feat[f"level_{l}"])
422 | depth_map[f"level_{l}"].append(d_map[f"level_{l}"])
423 | depth_values[f"level_{l}"].append(d_values[f"level_{l}"])
424 |
425 | for l in range(3):
426 | feats_vol[f"level_{l}"] = torch.stack(feats_vol[f"level_{l}"], dim=1)
427 | depth_map[f"level_{l}"] = torch.cat(depth_map[f"level_{l}"], dim=1)
428 | depth_values[f"level_{l}"] = torch.stack(depth_values[f"level_{l}"], dim=1)
429 |
430 | return feats_vol, feats_fpn, depth_map, depth_values
431 |
--------------------------------------------------------------------------------
/model/self_attn_renderer.py:
--------------------------------------------------------------------------------
1 | # GeoNeRF is a generalizable NeRF model that renders novel views
2 | # without requiring per-scene optimization. This software is the
3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with
4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
5 | # and Francois Fleuret.
6 |
7 | # Copyright (c) 2022 ams International AG
8 |
9 | # This file is part of GeoNeRF.
10 | # GeoNeRF is free software: you can redistribute it and/or modify
11 | # it under the terms of the GNU General Public License version 3 as
12 | # published by the Free Software Foundation.
13 |
14 | # GeoNeRF is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 |
19 | # You should have received a copy of the GNU General Public License
20 | # along with GeoNeRF. If not, see .
21 |
22 | # This file incorporates work covered by the following copyright and
23 | # permission notice:
24 |
25 | # Copyright 2020 Google LLC
26 | #
27 | # Licensed under the Apache License, Version 2.0 (the "License");
28 | # you may not use this file except in compliance with the License.
29 | # You may obtain a copy of the License at
30 | #
31 | # https://www.apache.org/licenses/LICENSE-2.0
32 | #
33 | # Unless required by applicable law or agreed to in writing, software
34 | # distributed under the License is distributed on an "AS IS" BASIS,
35 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
36 | # See the License for the specific language governing permissions and
37 | # limitations under the License.
38 |
39 | import torch
40 | import torch.nn as nn
41 | import torch.nn.functional as F
42 |
43 | import math
44 |
45 | def weights_init(m):
46 | if isinstance(m, nn.Linear):
47 | stdv = 1.0 / math.sqrt(m.weight.size(1))
48 | m.weight.data.uniform_(-stdv, stdv)
49 | if m.bias is not None:
50 | m.bias.data.uniform_(stdv, stdv)
51 |
52 |
53 | def masked_softmax(x, mask, **kwargs):
54 | x_masked = x.masked_fill(mask == 0, -float("inf"))
55 |
56 | return torch.softmax(x_masked, **kwargs)
57 |
58 |
59 | ## Auto-encoder network
60 | class ConvAutoEncoder(nn.Module):
61 | def __init__(self, num_ch, S):
62 | super(ConvAutoEncoder, self).__init__()
63 |
64 | # Encoder
65 | self.conv1 = nn.Sequential(
66 | nn.Conv1d(num_ch, num_ch * 2, 3, stride=1, padding=1),
67 | nn.LayerNorm(S, elementwise_affine=False),
68 | nn.ELU(alpha=1.0, inplace=True),
69 | nn.MaxPool1d(2),
70 | )
71 | self.conv2 = nn.Sequential(
72 | nn.Conv1d(num_ch * 2, num_ch * 4, 3, stride=1, padding=1),
73 | nn.LayerNorm(S // 2, elementwise_affine=False),
74 | nn.ELU(alpha=1.0, inplace=True),
75 | nn.MaxPool1d(2),
76 | )
77 | self.conv3 = nn.Sequential(
78 | nn.Conv1d(num_ch * 4, num_ch * 4, 3, stride=1, padding=1),
79 | nn.LayerNorm(S // 4, elementwise_affine=False),
80 | nn.ELU(alpha=1.0, inplace=True),
81 | nn.MaxPool1d(2),
82 | )
83 |
84 | # Decoder
85 | self.t_conv1 = nn.Sequential(
86 | nn.ConvTranspose1d(num_ch * 4, num_ch * 4, 4, stride=2, padding=1),
87 | nn.LayerNorm(S // 4, elementwise_affine=False),
88 | nn.ELU(alpha=1.0, inplace=True),
89 | )
90 | self.t_conv2 = nn.Sequential(
91 | nn.ConvTranspose1d(num_ch * 8, num_ch * 2, 4, stride=2, padding=1),
92 | nn.LayerNorm(S // 2, elementwise_affine=False),
93 | nn.ELU(alpha=1.0, inplace=True),
94 | )
95 | self.t_conv3 = nn.Sequential(
96 | nn.ConvTranspose1d(num_ch * 4, num_ch, 4, stride=2, padding=1),
97 | nn.LayerNorm(S, elementwise_affine=False),
98 | nn.ELU(alpha=1.0, inplace=True),
99 | )
100 | # Output
101 | self.conv_out = nn.Sequential(
102 | nn.Conv1d(num_ch * 2, num_ch, 3, stride=1, padding=1),
103 | nn.LayerNorm(S, elementwise_affine=False),
104 | nn.ELU(alpha=1.0, inplace=True),
105 | )
106 |
107 | def forward(self, x):
108 | input = x
109 | x = self.conv1(x)
110 | conv1_out = x
111 | x = self.conv2(x)
112 | conv2_out = x
113 | x = self.conv3(x)
114 |
115 | x = self.t_conv1(x)
116 | x = self.t_conv2(torch.cat([x, conv2_out], dim=1))
117 | x = self.t_conv3(torch.cat([x, conv1_out], dim=1))
118 |
119 | x = self.conv_out(torch.cat([x, input], dim=1))
120 |
121 | return x
122 |
123 |
124 | class ScaledDotProductAttention(nn.Module):
125 | def __init__(self, temperature, attn_dropout=0.1):
126 | super().__init__()
127 | self.temperature = temperature
128 |
129 | def forward(self, q, k, v, mask=None):
130 |
131 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
132 |
133 | if mask is not None:
134 | attn = masked_softmax(attn, mask, dim=-1)
135 | else:
136 | attn = F.softmax(attn, dim=-1)
137 |
138 | output = torch.matmul(attn, v)
139 |
140 | return output, attn
141 |
142 |
143 | class PositionwiseFeedForward(nn.Module):
144 | def __init__(self, d_in, d_hid, dropout=0.1):
145 | super().__init__()
146 | self.w_1 = nn.Linear(d_in, d_hid) # position-wise
147 | self.w_2 = nn.Linear(d_hid, d_in) # position-wise
148 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
149 |
150 | def forward(self, x):
151 |
152 | residual = x
153 |
154 | x = self.w_2(F.relu(self.w_1(x)))
155 | x += residual
156 |
157 | x = self.layer_norm(x)
158 |
159 | return x
160 |
161 |
162 | class MultiHeadAttention(nn.Module):
163 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
164 | super().__init__()
165 |
166 | self.n_head = n_head
167 | self.d_k = d_k
168 | self.d_v = d_v
169 |
170 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
171 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
172 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
173 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
174 |
175 | self.attention = ScaledDotProductAttention(temperature=d_k**0.5)
176 |
177 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
178 |
179 | def forward(self, q, k, v, mask=None):
180 |
181 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
182 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
183 |
184 | residual = q
185 |
186 | # Pass through the pre-attention projection: b x lq x (n*dv)
187 | # Separate different heads: b x lq x n x dv
188 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
189 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
190 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
191 |
192 | # Transpose for attention dot product: b x n x lq x dv
193 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
194 |
195 | if mask is not None:
196 | mask = mask.transpose(1, 2).unsqueeze(1) # For head axis broadcasting.
197 |
198 | q, attn = self.attention(q, k, v, mask=mask)
199 |
200 | # Transpose to move the head dimension back: b x lq x n x dv
201 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
202 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
203 | q = self.fc(q)
204 | q += residual
205 |
206 | q = self.layer_norm(q)
207 |
208 | return q, attn
209 |
210 |
211 | class EncoderLayer(nn.Module):
212 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0):
213 | super(EncoderLayer, self).__init__()
214 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
215 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
216 |
217 | def forward(self, enc_input, slf_attn_mask=None):
218 | enc_output, enc_slf_attn = self.slf_attn(
219 | enc_input, enc_input, enc_input, mask=slf_attn_mask
220 | )
221 | enc_output = self.pos_ffn(enc_output)
222 | return enc_output, enc_slf_attn
223 |
224 |
225 | class Renderer(nn.Module):
226 | def __init__(self, nb_samples_per_ray):
227 | super(Renderer, self).__init__()
228 |
229 | self.dim = 32
230 | self.attn_token_gen = nn.Linear(24 + 1 + 8, self.dim)
231 |
232 | ## Self-Attention Settings
233 | d_inner = self.dim
234 | n_head = 4
235 | d_k = self.dim // n_head
236 | d_v = self.dim // n_head
237 | num_layers = 4
238 | self.attn_layers = nn.ModuleList(
239 | [
240 | EncoderLayer(self.dim, d_inner, n_head, d_k, d_v)
241 | for i in range(num_layers)
242 | ]
243 | )
244 |
245 | ## Processing the mean and variance of input features
246 | self.var_mean_fc1 = nn.Linear(16, self.dim)
247 | self.var_mean_fc2 = nn.Linear(self.dim, self.dim)
248 |
249 | ## Setting mask of var_mean always enabled
250 | self.var_mean_mask = torch.tensor([1]).cuda()
251 | self.var_mean_mask.requires_grad = False
252 |
253 | ## For aggregating data along ray samples
254 | self.auto_enc = ConvAutoEncoder(self.dim, nb_samples_per_ray)
255 |
256 | self.sigma_fc1 = nn.Linear(self.dim, self.dim)
257 | self.sigma_fc2 = nn.Linear(self.dim, self.dim // 2)
258 | self.sigma_fc3 = nn.Linear(self.dim // 2, 1)
259 |
260 | self.rgb_fc1 = nn.Linear(self.dim + 9, self.dim)
261 | self.rgb_fc2 = nn.Linear(self.dim, self.dim // 2)
262 | self.rgb_fc3 = nn.Linear(self.dim // 2, 1)
263 |
264 | ## Initialization
265 | self.sigma_fc3.apply(weights_init)
266 |
267 | def forward(self, viewdirs, feat, occ_masks):
268 | ## Viewing samples regardless of batch or ray
269 | N, S, V = feat.shape[:3]
270 | feat = feat.view(-1, *feat.shape[2:])
271 | v_feat = feat[..., :24]
272 | s_feat = feat[..., 24 : 24 + 8]
273 | colors = feat[..., 24 + 8 : -1]
274 | vis_mask = feat[..., -1:].detach()
275 |
276 | occ_masks = occ_masks.view(-1, *occ_masks.shape[2:])
277 | viewdirs = viewdirs.view(-1, *viewdirs.shape[2:])
278 |
279 | ## Mean and variance of 2D features provide view-independent tokens
280 | var_mean = torch.var_mean(s_feat, dim=1, unbiased=False, keepdim=True)
281 | var_mean = torch.cat(var_mean, dim=-1)
282 | var_mean = F.elu(self.var_mean_fc1(var_mean))
283 | var_mean = F.elu(self.var_mean_fc2(var_mean))
284 |
285 | ## Converting the input features to tokens (view-dependent) before self-attention
286 | tokens = F.elu(
287 | self.attn_token_gen(torch.cat([v_feat, vis_mask, s_feat], dim=-1))
288 | )
289 | tokens = torch.cat([tokens, var_mean], dim=1)
290 |
291 | ## Adding a new channel to mask for var_mean
292 | vis_mask = torch.cat(
293 | [vis_mask, self.var_mean_mask.view(1, 1, 1).expand(N * S, -1, -1)], dim=1
294 | )
295 | ## If a point is not visible by any source view, force its masks to enabled
296 | vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)
297 |
298 | ## Taking occ_masks into account, but remembering if there were any visibility before that
299 | mask_cloned = vis_mask.clone()
300 | vis_mask[:, :-1] *= occ_masks
301 | vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)
302 | masks = vis_mask * mask_cloned
303 |
304 | ## Performing self-attention
305 | for layer in self.attn_layers:
306 | tokens, _ = layer(tokens, masks)
307 |
308 | ## Predicting sigma with an Auto-Encoder and MLP
309 | sigma_tokens = tokens[:, -1:]
310 | sigma_tokens = sigma_tokens.view(N, S, self.dim).transpose(1, 2)
311 | sigma_tokens = self.auto_enc(sigma_tokens)
312 | sigma_tokens = sigma_tokens.transpose(1, 2).reshape(N * S, 1, self.dim)
313 |
314 | sigma_tokens = F.elu(self.sigma_fc1(sigma_tokens))
315 | sigma_tokens = F.elu(self.sigma_fc2(sigma_tokens))
316 | sigma = torch.relu(self.sigma_fc3(sigma_tokens[:, 0]))
317 |
318 | ## Concatenating positional encodings and predicting RGB weights
319 | rgb_tokens = torch.cat([tokens[:, :-1], viewdirs], dim=-1)
320 | rgb_tokens = F.elu(self.rgb_fc1(rgb_tokens))
321 | rgb_tokens = F.elu(self.rgb_fc2(rgb_tokens))
322 | rgb_w = self.rgb_fc3(rgb_tokens)
323 | rgb_w = masked_softmax(rgb_w, masks[:, :-1], dim=1)
324 |
325 | rgb = (colors * rgb_w).sum(1)
326 |
327 | outputs = torch.cat([rgb, sigma], -1)
328 | outputs = outputs.reshape(N, S, -1)
329 |
330 | return outputs
331 |
--------------------------------------------------------------------------------
/pretrained_weights/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch-lightning==1.3.7
2 | inplace_abn
3 | imageio
4 | pillow
5 | scikit-image
6 | opencv-python
7 | ConfigArgParse
8 | lpips
9 | kornia
10 | ipdb
--------------------------------------------------------------------------------
/run_geo_nerf.py:
--------------------------------------------------------------------------------
1 | # GeoNeRF is a generalizable NeRF model that renders novel views
2 | # without requiring per-scene optimization. This software is the
3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with
4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
5 | # and Francois Fleuret.
6 |
7 | # Copyright (c) 2022 ams International AG
8 |
9 | # This file is part of GeoNeRF.
10 | # GeoNeRF is free software: you can redistribute it and/or modify
11 | # it under the terms of the GNU General Public License version 3 as
12 | # published by the Free Software Foundation.
13 |
14 | # GeoNeRF is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 |
19 | # You should have received a copy of the GNU General Public License
20 | # along with GeoNeRF. If not, see .
21 |
22 | # This file incorporates work covered by the following copyright and
23 | # permission notice:
24 |
25 | # MIT License
26 |
27 | # Copyright (c) 2021 apchenstu
28 |
29 | # Permission is hereby granted, free of charge, to any person obtaining a copy
30 | # of this software and associated documentation files (the "Software"), to deal
31 | # in the Software without restriction, including without limitation the rights
32 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
33 | # copies of the Software, and to permit persons to whom the Software is
34 | # furnished to do so, subject to the following conditions:
35 |
36 | # The above copyright notice and this permission notice shall be included in all
37 | # copies or substantial portions of the Software.
38 |
39 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
40 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
41 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
42 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
43 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
44 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
45 | # SOFTWARE.
46 |
47 | import torch
48 | from torch.utils.data import DataLoader
49 | from torch.optim.lr_scheduler import CosineAnnealingLR
50 |
51 | from pytorch_lightning.callbacks import ModelCheckpoint
52 | from pytorch_lightning import LightningModule, Trainer, loggers
53 | from pytorch_lightning.loggers import WandbLogger
54 |
55 | import os
56 | import time
57 | import numpy as np
58 | import imageio
59 | import lpips
60 | from skimage.metrics import structural_similarity as ssim
61 |
62 | from model.geo_reasoner import CasMVSNet
63 | from model.self_attn_renderer import Renderer
64 | from utils.rendering import render_rays
65 | from utils.utils import (
66 | load_ckpt,
67 | init_log,
68 | get_rays_pts,
69 | SL1Loss,
70 | self_supervision_loss,
71 | img2mse,
72 | mse2psnr,
73 | acc_threshold,
74 | abs_error,
75 | visualize_depth,
76 | )
77 | from utils.options import config_parser
78 | from data.get_datasets import (
79 | get_training_dataset,
80 | get_finetuning_dataset,
81 | get_validation_dataset,
82 | )
83 |
84 | lpips_fn = lpips.LPIPS(net="vgg")
85 |
86 | class GeoNeRF(LightningModule):
87 | def __init__(self, hparams):
88 | super(GeoNeRF, self).__init__()
89 | self.hparams.update(vars(hparams))
90 | self.wr_cntr = 0
91 |
92 | self.depth_loss = SL1Loss()
93 | self.learning_rate = hparams.lrate
94 |
95 | # Create geometry_reasoner and renderer models
96 | self.geo_reasoner = CasMVSNet(use_depth=hparams.use_depth).cuda()
97 | self.renderer = Renderer(
98 | nb_samples_per_ray=hparams.nb_coarse + hparams.nb_fine
99 | ).cuda()
100 |
101 | self.eval_metric = [0.01, 0.05, 0.1]
102 |
103 | self.automatic_optimization = False
104 | self.save_hyperparameters()
105 |
106 | def unpreprocess(self, data, shape=(1, 1, 3, 1, 1)):
107 | # to unnormalize image for visualization
108 | device = data.device
109 | mean = (
110 | torch.tensor([-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225])
111 | .view(*shape)
112 | .to(device)
113 | )
114 | std = torch.tensor([1 / 0.229, 1 / 0.224, 1 / 0.225]).view(*shape).to(device)
115 |
116 | return (data - mean) / std
117 |
118 | def prepare_data(self):
119 | if self.hparams.scene == "None": ## Generalizable
120 | self.train_dataset, self.train_sampler = get_training_dataset(self.hparams)
121 | self.val_dataset = get_validation_dataset(self.hparams)
122 | else: ## Fine-tune
123 | self.train_dataset, self.train_sampler = get_finetuning_dataset(
124 | self.hparams
125 | )
126 | self.val_dataset = get_validation_dataset(self.hparams)
127 |
128 | def configure_optimizers(self):
129 | eps = 1e-5
130 |
131 | opt = torch.optim.Adam(
132 | list(self.geo_reasoner.parameters()) + list(self.renderer.parameters()),
133 | lr=self.learning_rate,
134 | betas=(0.9, 0.999),
135 | )
136 | sch = CosineAnnealingLR(opt, T_max=self.hparams.num_steps, eta_min=eps)
137 |
138 | return [opt], [sch]
139 |
140 | def train_dataloader(self):
141 | return DataLoader(
142 | self.train_dataset,
143 | sampler=self.train_sampler,
144 | shuffle=True if self.train_sampler is None else False,
145 | num_workers=8,
146 | batch_size=1,
147 | pin_memory=True,
148 | )
149 |
150 | def val_dataloader(self):
151 | return DataLoader(
152 | self.val_dataset,
153 | shuffle=False,
154 | num_workers=1,
155 | batch_size=1,
156 | pin_memory=True,
157 | )
158 |
159 | def training_step(self, batch, batch_nb):
160 | loss = 0
161 | nb_views = self.hparams.nb_views
162 | H, W = batch["images"].shape[-2:]
163 | H, W = int(H), int(W)
164 |
165 | ## Inferring Geometry Reasoner
166 | feats_vol, feats_fpn, depth_map, depth_values = self.geo_reasoner(
167 | imgs=batch["images"][:, :nb_views],
168 | affine_mats=batch["affine_mats"][:, :nb_views],
169 | affine_mats_inv=batch["affine_mats_inv"][:, :nb_views],
170 | near_far=batch["near_fars"][:, :nb_views],
171 | closest_idxs=batch["closest_idxs"][:, :nb_views],
172 | gt_depths=batch["depths_aug"][:, :nb_views],
173 | )
174 |
175 | ## Normalizing depth maps in NDC coordinate
176 | depth_map_norm = {}
177 | for l in range(3):
178 | depth_map_norm[f"level_{l}"] = (
179 | depth_map[f"level_{l}"].detach() - depth_values[f"level_{l}"][:, :, 0]
180 | ) / (
181 | depth_values[f"level_{l}"][:, :, -1]
182 | - depth_values[f"level_{l}"][:, :, 0]
183 | )
184 |
185 | unpre_imgs = self.unpreprocess(batch["images"])
186 |
187 | (
188 | pts_depth,
189 | rays_pts,
190 | rays_pts_ndc,
191 | rays_dir,
192 | rays_gt_rgb,
193 | rays_gt_depth,
194 | rays_pixs,
195 | ) = get_rays_pts(
196 | H,
197 | W,
198 | batch["c2ws"],
199 | batch["w2cs"],
200 | batch["intrinsics"],
201 | batch["near_fars"],
202 | depth_values,
203 | self.hparams.nb_coarse,
204 | self.hparams.nb_fine,
205 | nb_views=nb_views,
206 | train=True,
207 | train_batch_size=self.hparams.batch_size,
208 | target_img=unpre_imgs[0, -1],
209 | target_depth=batch["depths_h"][0, -1],
210 | )
211 |
212 | ## Rendering
213 | rendered_rgb, rendered_depth = render_rays(
214 | c2ws=batch["c2ws"][0, :nb_views],
215 | rays_pts=rays_pts,
216 | rays_pts_ndc=rays_pts_ndc,
217 | pts_depth=pts_depth,
218 | rays_dir=rays_dir,
219 | feats_vol=feats_vol,
220 | feats_fpn=feats_fpn[:, :nb_views],
221 | imgs=unpre_imgs[:, :nb_views],
222 | depth_map_norm=depth_map_norm,
223 | renderer_net=self.renderer,
224 | )
225 |
226 | # Supervising depth maps with either ground truth depth or self-supervision loss
227 | ## This loss is only used in the generalizable model
228 | if self.hparams.scene == "None":
229 | ## if ground truth is available
230 | if isinstance(batch["depths"], dict):
231 | loss = loss + 1 * self.depth_loss(depth_map, batch["depths"])
232 | if loss != 0:
233 | self.log("train/dlossgt", loss.item(), prog_bar=False)
234 | else:
235 | loss = loss + 0.1 * self_supervision_loss(
236 | self.depth_loss,
237 | rays_pixs,
238 | rendered_depth.detach(),
239 | depth_map,
240 | rays_gt_rgb,
241 | unpre_imgs,
242 | rendered_rgb.detach(),
243 | batch["intrinsics"],
244 | batch["c2ws"],
245 | batch["w2cs"],
246 | )
247 | if loss != 0:
248 | self.log("train/dlosspgt", loss.item(), prog_bar=False)
249 |
250 | mask = rays_gt_depth > 0
251 | depth_available = mask.sum() > 0
252 |
253 | ## Supervising ray depths
254 | if depth_available:
255 | ## This loss is only used in the generalizable model
256 | if self.hparams.scene == "None":
257 | loss = loss + 0.1 * self.depth_loss(rendered_depth, rays_gt_depth)
258 |
259 | self.log(
260 | f"train/acc_l_{self.eval_metric[0]}mm",
261 | acc_threshold(
262 | rendered_depth, rays_gt_depth, mask, self.eval_metric[0]
263 | ).mean(),
264 | prog_bar=False,
265 | )
266 | self.log(
267 | f"train/acc_l_{self.eval_metric[1]}mm",
268 | acc_threshold(
269 | rendered_depth, rays_gt_depth, mask, self.eval_metric[1]
270 | ).mean(),
271 | prog_bar=False,
272 | )
273 | self.log(
274 | f"train/acc_l_{self.eval_metric[2]}mm",
275 | acc_threshold(
276 | rendered_depth, rays_gt_depth, mask, self.eval_metric[2]
277 | ).mean(),
278 | prog_bar=False,
279 | )
280 |
281 | abs_err = abs_error(rendered_depth, rays_gt_depth, mask).mean()
282 | self.log("train/abs_err", abs_err, prog_bar=False)
283 |
284 | ## Reconstruction loss
285 | mse_loss = img2mse(rendered_rgb, rays_gt_rgb)
286 | loss = loss + mse_loss
287 |
288 | with torch.no_grad():
289 | self.log("train/loss", loss.item(), prog_bar=True)
290 | psnr = mse2psnr(mse_loss.detach())
291 | self.log("train/PSNR", psnr.item(), prog_bar=False)
292 | self.log("train/img_mse_loss", mse_loss.item(), prog_bar=False)
293 |
294 | # Manual Optimization
295 | self.manual_backward(loss)
296 |
297 | opt = self.optimizers()
298 | sch = self.lr_schedulers()
299 |
300 | # Warming up the learning rate
301 | if self.trainer.global_step < self.hparams.warmup_steps:
302 | lr_scale = min(
303 | 1.0, float(self.trainer.global_step + 1) / self.hparams.warmup_steps
304 | )
305 | for pg in opt.param_groups:
306 | pg["lr"] = lr_scale * self.learning_rate
307 |
308 | self.log("train/lr", opt.param_groups[0]["lr"], prog_bar=False)
309 |
310 | opt.step()
311 | opt.zero_grad()
312 | sch.step()
313 |
314 | return {"loss": loss}
315 |
316 | def validation_step(self, batch, batch_nb):
317 | ## This makes Batchnorm to behave like InstanceNorm
318 | self.geo_reasoner.train()
319 |
320 | log_keys = [
321 | "val_psnr",
322 | "val_ssim",
323 | "val_lpips",
324 | "val_depth_loss_r",
325 | "val_abs_err",
326 | "mask_sum",
327 | ] + [f"val_acc_{i}mm" for i in self.eval_metric]
328 | log = {}
329 | log = init_log(log, log_keys)
330 |
331 | H, W = batch["images"].shape[-2:]
332 | H, W = int(H), int(W)
333 |
334 | nb_views = self.hparams.nb_views
335 |
336 | with torch.no_grad():
337 | ## Inferring Geometry Reasoner
338 | feats_vol, feats_fpn, depth_map, depth_values = self.geo_reasoner(
339 | imgs=batch["images"][:, :nb_views],
340 | affine_mats=batch["affine_mats"][:, :nb_views],
341 | affine_mats_inv=batch["affine_mats_inv"][:, :nb_views],
342 | near_far=batch["near_fars"][:, :nb_views],
343 | closest_idxs=batch["closest_idxs"][:, :nb_views],
344 | gt_depths=batch["depths_aug"][:, :nb_views],
345 | )
346 |
347 | ## Normalizing depth maps in NDC coordinate
348 | depth_map_norm = {}
349 | for l in range(3):
350 | depth_map_norm[f"level_{l}"] = (
351 | depth_map[f"level_{l}"] - depth_values[f"level_{l}"][:, :, 0]
352 | ) / (
353 | depth_values[f"level_{l}"][:, :, -1]
354 | - depth_values[f"level_{l}"][:, :, 0]
355 | )
356 |
357 | unpre_imgs = self.unpreprocess(batch["images"])
358 |
359 | rendered_rgb, rendered_depth = [], []
360 | for chunk_idx in range(
361 | H * W // self.hparams.chunk + int(H * W % self.hparams.chunk > 0)
362 | ):
363 | pts_depth, rays_pts, rays_pts_ndc, rays_dir, _, _, _ = get_rays_pts(
364 | H,
365 | W,
366 | batch["c2ws"],
367 | batch["w2cs"],
368 | batch["intrinsics"],
369 | batch["near_fars"],
370 | depth_values,
371 | self.hparams.nb_coarse,
372 | self.hparams.nb_fine,
373 | nb_views=nb_views,
374 | chunk=self.hparams.chunk,
375 | chunk_idx=chunk_idx,
376 | )
377 |
378 | ## Rendering
379 | rend_rgb, rend_depth = render_rays(
380 | c2ws=batch["c2ws"][0, :nb_views],
381 | rays_pts=rays_pts,
382 | rays_pts_ndc=rays_pts_ndc,
383 | pts_depth=pts_depth,
384 | rays_dir=rays_dir,
385 | feats_vol=feats_vol,
386 | feats_fpn=feats_fpn[:, :nb_views],
387 | imgs=unpre_imgs[:, :nb_views],
388 | depth_map_norm=depth_map_norm,
389 | renderer_net=self.renderer,
390 | )
391 | rendered_rgb.append(rend_rgb)
392 | rendered_depth.append(rend_depth)
393 | rendered_rgb = torch.clamp(
394 | torch.cat(rendered_rgb).reshape(H, W, 3).permute(2, 0, 1), 0, 1
395 | )
396 | rendered_depth = torch.cat(rendered_depth).reshape(H, W)
397 |
398 | ## Check if there is any ground truth depth information for the dataset
399 | depth_available = batch["depths_h"].sum() > 0
400 |
401 | ## Evaluate only on pixels with meaningful ground truth depths
402 | if depth_available:
403 | mask = batch["depths_h"] > 0
404 | img_gt_masked = (unpre_imgs[0, -1] * mask[0, -1][None]).cpu()
405 | rendered_rgb_masked = (rendered_rgb * mask[0, -1][None]).cpu()
406 | else:
407 | img_gt_masked = unpre_imgs[0, -1].cpu()
408 | rendered_rgb_masked = rendered_rgb.cpu()
409 |
410 | unpre_imgs = unpre_imgs.cpu()
411 | rendered_rgb, rendered_depth = rendered_rgb.cpu(), rendered_depth.cpu()
412 | img_err_abs = (rendered_rgb_masked - img_gt_masked).abs()
413 |
414 | depth_target = batch["depths_h"][0, -1].cpu()
415 | mask_target = depth_target > 0
416 |
417 | if depth_available:
418 | log["val_psnr"] = mse2psnr(torch.mean(img_err_abs[:, mask_target] ** 2))
419 | else:
420 | log["val_psnr"] = mse2psnr(torch.mean(img_err_abs**2))
421 | log["val_ssim"] = ssim(
422 | rendered_rgb_masked.permute(1, 2, 0).numpy(),
423 | img_gt_masked.permute(1, 2, 0).numpy(),
424 | data_range=1,
425 | multichannel=True,
426 | )
427 | log["val_lpips"] = lpips_fn(
428 | rendered_rgb_masked[None] * 2 - 1, img_gt_masked[None] * 2 - 1
429 | ).item() # Normalize to [-1,1]
430 |
431 | depth_minmax = [
432 | 0.9 * batch["near_fars"].min().detach().cpu().numpy(),
433 | 1.1 * batch["near_fars"].max().detach().cpu().numpy(),
434 | ]
435 | rendered_depth_vis, _ = visualize_depth(rendered_depth, depth_minmax)
436 |
437 | if depth_available:
438 | log["val_abs_err"] = abs_error(
439 | rendered_depth, depth_target, mask_target
440 | ).sum()
441 | log[f"val_acc_{self.eval_metric[0]}mm"] = acc_threshold(
442 | rendered_depth, depth_target, mask_target, self.eval_metric[0]
443 | ).sum()
444 | log[f"val_acc_{self.eval_metric[1]}mm"] = acc_threshold(
445 | rendered_depth, depth_target, mask_target, self.eval_metric[1]
446 | ).sum()
447 | log[f"val_acc_{self.eval_metric[2]}mm"] = acc_threshold(
448 | rendered_depth, depth_target, mask_target, self.eval_metric[2]
449 | ).sum()
450 | log["mask_sum"] = mask_target.float().sum()
451 |
452 | img_vis = (
453 | torch.cat(
454 | (
455 | unpre_imgs[:, -1],
456 | torch.stack([rendered_rgb, img_err_abs * 5]),
457 | rendered_depth_vis[None],
458 | ),
459 | dim=0,
460 | )
461 | .clip(0, 1)
462 | .permute(2, 0, 3, 1)
463 | .reshape(H, -1, 3)
464 | .numpy()
465 | )
466 |
467 | os.makedirs(
468 | f"{self.hparams.logdir}/{self.hparams.dataset_name}/{self.hparams.expname}/rendered_results/",
469 | exist_ok=True,
470 | )
471 | imageio.imwrite(
472 | f"{self.hparams.logdir}/{self.hparams.dataset_name}/{self.hparams.expname}/rendered_results/{self.wr_cntr:03d}.png",
473 | (
474 | rendered_rgb.detach().permute(1, 2, 0).clip(0.0, 1.0).cpu().numpy()
475 | * 255
476 | ).astype("uint8"),
477 | )
478 |
479 | os.makedirs(
480 | f"{self.hparams.logdir}/{self.hparams.dataset_name}/{self.hparams.expname}/evaluation/",
481 | exist_ok=True,
482 | )
483 | imageio.imwrite(
484 | f"{self.hparams.logdir}/{self.hparams.dataset_name}/{self.hparams.expname}/evaluation/{self.global_step:08d}_{self.wr_cntr:02d}.png",
485 | (img_vis * 255).astype("uint8"),
486 | )
487 |
488 | print(f"Image {self.wr_cntr:02d} rendered.")
489 | self.wr_cntr += 1
490 |
491 | return log
492 |
493 | def validation_epoch_end(self, outputs):
494 | mean_psnr = torch.stack([x["val_psnr"] for x in outputs]).mean()
495 | mean_ssim = np.stack([x["val_ssim"] for x in outputs]).mean()
496 | mean_lpips = np.stack([x["val_lpips"] for x in outputs]).mean()
497 | mask_sum = torch.stack([x["mask_sum"] for x in outputs]).sum()
498 | mean_d_loss_r = torch.stack([x["val_depth_loss_r"] for x in outputs]).mean()
499 | mean_abs_err = torch.stack([x["val_abs_err"] for x in outputs]).sum() / mask_sum
500 | mean_acc_1mm = (
501 | torch.stack([x[f"val_acc_{self.eval_metric[0]}mm"] for x in outputs]).sum()
502 | / mask_sum
503 | )
504 | mean_acc_2mm = (
505 | torch.stack([x[f"val_acc_{self.eval_metric[1]}mm"] for x in outputs]).sum()
506 | / mask_sum
507 | )
508 | mean_acc_4mm = (
509 | torch.stack([x[f"val_acc_{self.eval_metric[2]}mm"] for x in outputs]).sum()
510 | / mask_sum
511 | )
512 |
513 | self.log("val/PSNR", mean_psnr, prog_bar=False)
514 | self.log("val/SSIM", mean_ssim, prog_bar=False)
515 | self.log("val/LPIPS", mean_lpips, prog_bar=False)
516 | if mask_sum > 0:
517 | self.log("val/d_loss_r", mean_d_loss_r, prog_bar=False)
518 | self.log("val/abs_err", mean_abs_err, prog_bar=False)
519 | self.log(f"val/acc_{self.eval_metric[0]}mm", mean_acc_1mm, prog_bar=False)
520 | self.log(f"val/acc_{self.eval_metric[1]}mm", mean_acc_2mm, prog_bar=False)
521 | self.log(f"val/acc_{self.eval_metric[2]}mm", mean_acc_4mm, prog_bar=False)
522 |
523 | with open(
524 | f"{self.hparams.logdir}/{self.hparams.dataset_name}/{self.hparams.expname}/{self.hparams.expname}_metrics.txt",
525 | "w",
526 | ) as metric_file:
527 | metric_file.write(f"PSNR: {mean_psnr}\n")
528 | metric_file.write(f"SSIM: {mean_ssim}\n")
529 | metric_file.write(f"LPIPS: {mean_lpips}")
530 |
531 | return
532 |
533 |
534 | if __name__ == "__main__":
535 | torch.set_default_dtype(torch.float32)
536 | args = config_parser()
537 | geonerf = GeoNeRF(args)
538 |
539 | ## Checking to logdir to see if there is any checkpoint file to continue with
540 | ckpt_path = f"{args.logdir}/{args.dataset_name}/{args.expname}/ckpts"
541 | if os.path.isdir(ckpt_path) and len(os.listdir(ckpt_path)) > 0:
542 | ckpt_file = os.path.join(ckpt_path, os.listdir(ckpt_path)[-1])
543 | else:
544 | ckpt_file = None
545 |
546 | ## Setting a callback to automatically save checkpoints
547 | checkpoint_callback = ModelCheckpoint(
548 | f"{args.logdir}/{args.dataset_name}/{args.expname}/ckpts",
549 | filename="ckpt_step-{step:06d}",
550 | auto_insert_metric_name=False,
551 | save_top_k=-1,
552 | )
553 |
554 | ## Setting up a logger
555 | if args.logger == "wandb":
556 | logger = WandbLogger(
557 | name=args.expname,
558 | project="GeoNeRF",
559 | save_dir=f"{args.logdir}",
560 | resume="allow",
561 | id=args.expname,
562 | )
563 | elif args.logger == "tensorboard":
564 | logger = loggers.TestTubeLogger(
565 | save_dir=f"{args.logdir}/{args.dataset_name}/{args.expname}",
566 | name=args.expname + "_logs",
567 | debug=False,
568 | create_git_tag=False,
569 | )
570 | else:
571 | logger = None
572 |
573 | args.use_amp = False if args.eval else True
574 | trainer = Trainer(
575 | max_steps=args.num_steps,
576 | callbacks=checkpoint_callback,
577 | checkpoint_callback=True,
578 | resume_from_checkpoint=ckpt_file,
579 | logger=logger,
580 | progress_bar_refresh_rate=1,
581 | gpus=1,
582 | num_sanity_val_steps=0,
583 | val_check_interval=2000 if args.scene == "None" else 1.0,
584 | check_val_every_n_epoch=1000 if args.scene != 'None' else 1,
585 | benchmark=True,
586 | precision=16 if args.use_amp else 32,
587 | amp_level="O1",
588 | )
589 |
590 | if not args.eval: ## Train
591 | if args.scene != "None": ## Fine-tune
592 | if args.use_depth:
593 | ckpt_file = "pretrained_weights/pretrained_w_depth.ckpt"
594 | else:
595 | ckpt_file = "pretrained_weights/pretrained.ckpt"
596 | load_ckpt(geonerf.geo_reasoner, ckpt_file, "geo_reasoner")
597 | load_ckpt(geonerf.renderer, ckpt_file, "renderer")
598 | elif not args.use_depth: ## Generalizable
599 | ## Loading the pretrained weights from Cascade MVSNet
600 | torch.utils.model_zoo.load_url(
601 | "https://github.com/kwea123/CasMVSNet_pl/releases/download/1.5/epoch.15.ckpt",
602 | model_dir="pretrained_weights",
603 | )
604 | ckpt_file = "pretrained_weights/epoch.15.ckpt"
605 | load_ckpt(geonerf.geo_reasoner, ckpt_file, "model", strict=False)
606 |
607 | trainer.fit(geonerf)
608 | else: ## Eval
609 | geonerf = GeoNeRF(args)
610 |
611 | if ckpt_file is None:
612 | if args.use_depth:
613 | ckpt_file = "pretrained_weights/pretrained_w_depth.ckpt"
614 | else:
615 | ckpt_file = "pretrained_weights/pretrained.ckpt"
616 | load_ckpt(geonerf.geo_reasoner, ckpt_file, "geo_reasoner")
617 | load_ckpt(geonerf.renderer, ckpt_file, "renderer")
618 |
619 | trainer.validate(geonerf)
620 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/idiap/GeoNeRF/e6249fdae5672853c6bbbd4ba380c4c166d02c95/utils/__init__.py
--------------------------------------------------------------------------------
/utils/options.py:
--------------------------------------------------------------------------------
1 | # GeoNeRF is a generalizable NeRF model that renders novel views
2 | # without requiring per-scene optimization. This software is the
3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with
4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
5 | # and Francois Fleuret.
6 |
7 | # Copyright (c) 2022 ams International AG
8 |
9 | # This file is part of GeoNeRF.
10 | # GeoNeRF is free software: you can redistribute it and/or modify
11 | # it under the terms of the GNU General Public License version 3 as
12 | # published by the Free Software Foundation.
13 |
14 | # GeoNeRF is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 |
19 | # You should have received a copy of the GNU General Public License
20 | # along with GeoNeRF. If not, see .
21 |
22 | import configargparse
23 |
24 | def config_parser():
25 | parser = configargparse.ArgumentParser()
26 | parser.add_argument("--config", is_config_file=True, help="Config file path")
27 |
28 | # Datasets options
29 | parser.add_argument("--dataset_name", type=str, default="llff", choices=["llff", "nerf", "dtu"],)
30 | parser.add_argument("--llff_path", type=str, help="Path to llff dataset")
31 | parser.add_argument("--llff_test_path", type=str, help="Path to llff dataset")
32 | parser.add_argument("--dtu_path", type=str, help="Path to dtu dataset")
33 | parser.add_argument("--dtu_pre_path", type=str, help="Path to preprocessed dtu dataset")
34 | parser.add_argument("--nerf_path", type=str, help="Path to nerf dataset")
35 | parser.add_argument("--ams_path", type=str, help="Path to ams dataset")
36 | parser.add_argument("--ibrnet1_path", type=str, help="Path to ibrnet1 dataset")
37 | parser.add_argument("--ibrnet2_path", type=str, help="Path to ibrnet2 dataset")
38 |
39 | # Training options
40 | parser.add_argument("--batch_size", type=int, default=512)
41 | parser.add_argument("--num_steps", type=int, default=200000)
42 | parser.add_argument("--nb_views", type=int, default=3)
43 | parser.add_argument("--lrate", type=float, default=5e-4, help="Learning rate")
44 | parser.add_argument("--warmup_steps", type=int, default=500, help="Gradually warm-up learning rate in optimizer")
45 | parser.add_argument("--scene", type=str, default="None", help="Scene for fine-tuning")
46 |
47 | # Rendering options
48 | parser.add_argument("--chunk", type=int, default=4096, help="Number of rays rendered in parallel")
49 | parser.add_argument("--nb_coarse", type=int, default=96, help="Number of coarse samples per ray")
50 | parser.add_argument("--nb_fine", type=int, default=32, help="Number of additional fine samples per ray",)
51 |
52 | # Other options
53 | parser.add_argument("--expname", type=str, help="Experiment name")
54 | parser.add_argument("--logger", type=str, default="tensorboard", choices=["wandb", "tensorboard", "none"])
55 | parser.add_argument("--logdir", type=str, default="./logs/", help="Where to store ckpts and logs")
56 | parser.add_argument("--eval", action="store_true", help="Render and evaluate the test set")
57 | parser.add_argument("--use_depth", action="store_true", help="Use ground truth low-res depth maps in rendering process")
58 |
59 | return parser.parse_args()
60 |
--------------------------------------------------------------------------------
/utils/rendering.py:
--------------------------------------------------------------------------------
1 | # GeoNeRF is a generalizable NeRF model that renders novel views
2 | # without requiring per-scene optimization. This software is the
3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with
4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
5 | # and Francois Fleuret.
6 |
7 | # Copyright (c) 2022 ams International AG
8 |
9 | # This file is part of GeoNeRF.
10 | # GeoNeRF is free software: you can redistribute it and/or modify
11 | # it under the terms of the GNU General Public License version 3 as
12 | # published by the Free Software Foundation.
13 |
14 | # GeoNeRF is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 |
19 | # You should have received a copy of the GNU General Public License
20 | # along with GeoNeRF. If not, see .
21 |
22 | # This file incorporates work covered by the following copyright and
23 | # permission notice:
24 |
25 | # MIT License
26 |
27 | # Copyright (c) 2021 apchenstu
28 |
29 | # Permission is hereby granted, free of charge, to any person obtaining a copy
30 | # of this software and associated documentation files (the "Software"), to deal
31 | # in the Software without restriction, including without limitation the rights
32 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
33 | # copies of the Software, and to permit persons to whom the Software is
34 | # furnished to do so, subject to the following conditions:
35 |
36 | # The above copyright notice and this permission notice shall be included in all
37 | # copies or substantial portions of the Software.
38 |
39 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
40 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
41 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
42 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
43 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
44 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
45 | # SOFTWARE.
46 |
47 | import torch
48 | import torch.nn.functional as F
49 |
50 | from utils.utils import normal_vect, interpolate_3D, interpolate_2D
51 |
52 |
53 | class Embedder:
54 | def __init__(self, **kwargs):
55 | self.kwargs = kwargs
56 | self.create_embedding_fn()
57 |
58 | def create_embedding_fn(self):
59 | embed_fns = []
60 |
61 | if self.kwargs["include_input"]:
62 | embed_fns.append(lambda x: x)
63 |
64 | max_freq = self.kwargs["max_freq_log2"]
65 | N_freqs = self.kwargs["num_freqs"]
66 |
67 | if self.kwargs["log_sampling"]:
68 | freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs)
69 | else:
70 | freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs)
71 | self.freq_bands = freq_bands.reshape(1, -1, 1).cuda()
72 |
73 | for freq in freq_bands:
74 | for p_fn in self.kwargs["periodic_fns"]:
75 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
76 |
77 | self.embed_fns = embed_fns
78 |
79 | def embed(self, inputs):
80 | repeat = inputs.dim() - 1
81 | inputs_scaled = (
82 | inputs.unsqueeze(-2) * self.freq_bands.view(*[1] * repeat, -1, 1)
83 | ).reshape(*inputs.shape[:-1], -1)
84 | inputs_scaled = torch.cat(
85 | (inputs, torch.sin(inputs_scaled), torch.cos(inputs_scaled)), dim=-1
86 | )
87 | return inputs_scaled
88 |
89 |
90 | def get_embedder(multires=4):
91 |
92 | embed_kwargs = {
93 | "include_input": True,
94 | "max_freq_log2": multires - 1,
95 | "num_freqs": multires,
96 | "log_sampling": True,
97 | "periodic_fns": [torch.sin, torch.cos],
98 | }
99 |
100 | embedder_obj = Embedder(**embed_kwargs)
101 | embed = lambda x, eo=embedder_obj: eo.embed(x)
102 | return embed
103 |
104 |
105 | def sigma2weights(sigma):
106 | alpha = 1.0 - torch.exp(-sigma)
107 | T = torch.cumprod(
108 | torch.cat(
109 | [torch.ones(alpha.shape[0], 1).to(alpha.device), 1.0 - alpha + 1e-10], -1
110 | ),
111 | -1,
112 | )[:, :-1]
113 | weights = alpha * T
114 |
115 | return weights
116 |
117 |
118 | def volume_rendering(rgb_sigma, pts_depth):
119 | rgb = rgb_sigma[..., :3]
120 | weights = sigma2weights(rgb_sigma[..., 3])
121 |
122 | rendered_rgb = torch.sum(weights[..., None] * rgb, -2)
123 | rendered_depth = torch.sum(weights * pts_depth, -1)
124 |
125 | return rendered_rgb, rendered_depth
126 |
127 |
128 | def get_angle_wrt_src_cams(c2ws, rays_pts, rays_dir_unit):
129 | nb_rays = rays_pts.shape[0]
130 | ## Unit vectors from source cameras to the points on the ray
131 | dirs = normal_vect(rays_pts.unsqueeze(2) - c2ws[:, :3, 3][None, None])
132 | ## Cosine of the angle between two directions
133 | angle_cos = torch.sum(
134 | dirs * rays_dir_unit.reshape(nb_rays, 1, 1, 3), dim=-1, keepdim=True
135 | )
136 | # Cosine to Sine and approximating it as the angle (angle << 1 => sin(angle) = angle)
137 | angle = (1 - (angle_cos**2)).abs().sqrt()
138 |
139 | return angle
140 |
141 |
142 | def interpolate_pts_feats(imgs, feats_fpn, feats_vol, rays_pts_ndc):
143 | nb_views = feats_fpn.shape[1]
144 | interpolated_feats = []
145 |
146 | for i in range(nb_views):
147 | ray_feats_0 = interpolate_3D(
148 | feats_vol[f"level_0"][:, i], rays_pts_ndc[f"level_0"][:, :, i]
149 | )
150 | ray_feats_1 = interpolate_3D(
151 | feats_vol[f"level_1"][:, i], rays_pts_ndc[f"level_1"][:, :, i]
152 | )
153 | ray_feats_2 = interpolate_3D(
154 | feats_vol[f"level_2"][:, i], rays_pts_ndc[f"level_2"][:, :, i]
155 | )
156 |
157 | ray_feats_fpn, ray_colors, ray_masks = interpolate_2D(
158 | feats_fpn[:, i], imgs[:, i], rays_pts_ndc[f"level_0"][:, :, i]
159 | )
160 |
161 | interpolated_feats.append(
162 | torch.cat(
163 | [
164 | ray_feats_0,
165 | ray_feats_1,
166 | ray_feats_2,
167 | ray_feats_fpn,
168 | ray_colors,
169 | ray_masks,
170 | ],
171 | dim=-1,
172 | )
173 | )
174 | interpolated_feats = torch.stack(interpolated_feats, dim=2)
175 |
176 | return interpolated_feats
177 |
178 |
179 | def get_occ_masks(depth_map_norm, rays_pts_ndc, visibility_thr=0.2):
180 | nb_views = depth_map_norm["level_0"].shape[1]
181 | z_diff = []
182 | for i in range(nb_views):
183 | ## Interpolate depth maps corresponding to each sample point
184 | # [1 H W 3] (x,y,z)
185 | grid = rays_pts_ndc[f"level_0"][None, :, :, i, :2] * 2 - 1.0
186 | rays_depths = F.grid_sample(
187 | depth_map_norm["level_0"][:, i : i + 1],
188 | grid,
189 | align_corners=True,
190 | mode="bilinear",
191 | padding_mode="border",
192 | )[0, 0]
193 | z_diff.append(rays_pts_ndc["level_0"][:, :, i, 2] - rays_depths)
194 | z_diff = torch.stack(z_diff, dim=2)
195 |
196 | occ_masks = z_diff.unsqueeze(-1) < visibility_thr
197 |
198 | return occ_masks
199 |
200 |
201 | def render_rays(
202 | c2ws,
203 | rays_pts,
204 | rays_pts_ndc,
205 | pts_depth,
206 | rays_dir,
207 | feats_vol,
208 | feats_fpn,
209 | imgs,
210 | depth_map_norm,
211 | renderer_net,
212 | ):
213 | ## The angles between the ray and source camera vectors
214 | rays_dir_unit = rays_dir / torch.norm(rays_dir, dim=-1, keepdim=True)
215 | angles = get_angle_wrt_src_cams(c2ws, rays_pts, rays_dir_unit)
216 |
217 | ## Positional encoding
218 | embedded_angles = get_embedder()(angles)
219 |
220 | ## Interpolate all features for sample points
221 | pts_feat = interpolate_pts_feats(imgs, feats_fpn, feats_vol, rays_pts_ndc)
222 |
223 | ## Getting Occlusion Masks based on predicted depths
224 | occ_masks = get_occ_masks(depth_map_norm, rays_pts_ndc)
225 |
226 | ## rendering sigma and RGB values
227 | rgb_sigma = renderer_net(embedded_angles, pts_feat, occ_masks)
228 |
229 | rendered_rgb, rendered_depth = volume_rendering(rgb_sigma, pts_depth)
230 |
231 | return rendered_rgb, rendered_depth
232 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | # GeoNeRF is a generalizable NeRF model that renders novel views
2 | # without requiring per-scene optimization. This software is the
3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with
4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
5 | # and Francois Fleuret.
6 |
7 | # Copyright (c) 2022 ams International AG
8 |
9 | # This file is part of GeoNeRF.
10 | # GeoNeRF is free software: you can redistribute it and/or modify
11 | # it under the terms of the GNU General Public License version 3 as
12 | # published by the Free Software Foundation.
13 |
14 | # GeoNeRF is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 |
19 | # You should have received a copy of the GNU General Public License
20 | # along with GeoNeRF. If not, see .
21 |
22 | # This file incorporates work covered by the following copyright and
23 | # permission notice:
24 |
25 | # MIT License
26 |
27 | # Copyright (c) 2021 apchenstu
28 |
29 | # Permission is hereby granted, free of charge, to any person obtaining a copy
30 | # of this software and associated documentation files (the "Software"), to deal
31 | # in the Software without restriction, including without limitation the rights
32 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
33 | # copies of the Software, and to permit persons to whom the Software is
34 | # furnished to do so, subject to the following conditions:
35 |
36 | # The above copyright notice and this permission notice shall be included in all
37 | # copies or substantial portions of the Software.
38 |
39 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
40 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
41 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
42 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
43 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
44 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
45 | # SOFTWARE.
46 |
47 | import torch
48 | import torch.nn as nn
49 | import torch.nn.functional as F
50 | import torchvision.transforms as T
51 |
52 | import numpy as np
53 | import cv2
54 | import re
55 |
56 | from PIL import Image
57 | from kornia.utils import create_meshgrid
58 |
59 | img2mse = lambda x, y: torch.mean((x - y) ** 2)
60 | mse2psnr = lambda x: -10.0 * torch.log(x) / torch.log(torch.Tensor([10.0]).to(x.device))
61 |
62 |
63 | def load_ckpt(network, ckpt_file, key_prefix, strict=True):
64 | ckpt_dict = torch.load(ckpt_file)
65 |
66 | if "state_dict" in ckpt_dict.keys():
67 | ckpt_dict = ckpt_dict["state_dict"]
68 |
69 | state_dict = {}
70 | for key, val in ckpt_dict.items():
71 | if key_prefix in key:
72 | state_dict[key[len(key_prefix) + 1 :]] = val
73 | network.load_state_dict(state_dict, strict)
74 |
75 |
76 | def init_log(log, keys):
77 | for key in keys:
78 | log[key] = torch.tensor([0.0], dtype=float)
79 | return log
80 |
81 |
82 | class SL1Loss(nn.Module):
83 | def __init__(self, levels=3):
84 | super(SL1Loss, self).__init__()
85 | self.levels = levels
86 | self.loss = nn.SmoothL1Loss(reduction="mean")
87 | self.loss_ray = nn.SmoothL1Loss(reduction="none")
88 |
89 | def forward(self, inputs, targets):
90 | loss = 0
91 | if isinstance(inputs, dict):
92 | for l in range(self.levels):
93 | depth_pred_l = inputs[f"level_{l}"]
94 | V = depth_pred_l.shape[1]
95 |
96 | depth_gt_l = targets[f"level_{l}"]
97 | depth_gt_l = depth_gt_l[:, :V]
98 | mask_l = depth_gt_l > 0
99 |
100 | loss = loss + self.loss(
101 | depth_pred_l[mask_l], depth_gt_l[mask_l]
102 | ) * 2 ** (1 - l)
103 | else:
104 | mask = targets > 0
105 | loss = loss + (self.loss_ray(inputs, targets) * mask).sum() / len(mask)
106 |
107 | return loss
108 |
109 |
110 | def self_supervision_loss(
111 | loss_fn,
112 | rays_pixs,
113 | rendered_depth,
114 | depth_map,
115 | rays_gt_rgb,
116 | unpre_imgs,
117 | rendered_rgb,
118 | intrinsics,
119 | c2ws,
120 | w2cs,
121 | ):
122 | loss = 0
123 | target_points = torch.stack(
124 | [rays_pixs[1], rays_pixs[0], torch.ones(rays_pixs[0].shape[0]).cuda()], dim=-1
125 | )
126 | target_points = rendered_depth.view(-1, 1) * (
127 | target_points @ torch.inverse(intrinsics[0, -1]).t()
128 | )
129 | target_points = target_points @ c2ws[0, -1][:3, :3].t() + c2ws[0, -1][:3, 3]
130 |
131 | rgb_mask = (rendered_rgb - rays_gt_rgb).abs().mean(dim=-1) < 0.02
132 |
133 | for v in range(len(w2cs[0]) - 1):
134 | points_v = target_points @ w2cs[0, v][:3, :3].t() + w2cs[0, v][:3, 3]
135 | points_v = points_v @ intrinsics[0, v].t()
136 | z_pred = points_v[:, -1].clone()
137 | points_v = points_v[:, :2] / points_v[:, -1:]
138 |
139 | points_unit = points_v.clone()
140 | H, W = depth_map["level_0"].shape[-2:]
141 | points_unit[:, 0] = points_unit[:, 0] / W
142 | points_unit[:, 1] = points_unit[:, 1] / H
143 | grid = 2 * points_unit - 1
144 |
145 | warped_rgbs = F.grid_sample(
146 | unpre_imgs[:, v],
147 | grid.view(1, -1, 1, 2),
148 | align_corners=True,
149 | mode="bilinear",
150 | padding_mode="zeros",
151 | ).squeeze()
152 | photo_mask = (warped_rgbs.t() - rays_gt_rgb).abs().mean(dim=-1) < 0.02
153 |
154 | pixel_coor = points_v.round().long()
155 | k = 5
156 | pixel_coor[:, 0] = pixel_coor[:, 0].clip(k // 2, W - (k // 2) - 1)
157 | pixel_coor[:, 1] = pixel_coor[:, 1].clip(2, H - (k // 2) - 1)
158 | lower_b = pixel_coor - (k // 2)
159 | higher_b = pixel_coor + (k // 2)
160 |
161 | ind_h = (
162 | lower_b[:, 1:] * torch.arange(k - 1, -1, -1).view(1, -1).cuda()
163 | + higher_b[:, 1:] * torch.arange(0, k).view(1, -1).cuda()
164 | ) // (k - 1)
165 | ind_w = (
166 | lower_b[:, 0:1] * torch.arange(k - 1, -1, -1).view(1, -1).cuda()
167 | + higher_b[:, 0:1] * torch.arange(0, k).view(1, -1).cuda()
168 | ) // (k - 1)
169 |
170 | patches_h = torch.gather(
171 | unpre_imgs[:, v].mean(dim=1).expand(ind_h.shape[0], -1, -1),
172 | 1,
173 | ind_h.unsqueeze(-1).expand(-1, -1, W),
174 | )
175 | patches = torch.gather(patches_h, 2, ind_w.unsqueeze(1).expand(-1, k, -1))
176 | ent_mask = patches.view(-1, k * k).std(dim=-1) > 0.05
177 |
178 | for l in range(3):
179 | depth = F.grid_sample(
180 | depth_map[f"level_{l}"][:, v : v + 1],
181 | grid.view(1, -1, 1, 2),
182 | align_corners=True,
183 | mode="bilinear",
184 | padding_mode="zeros",
185 | ).squeeze()
186 | in_mask = (grid > -1.0) * (grid < 1.0)
187 | in_mask = (in_mask[..., 0] * in_mask[..., 1]).float()
188 | loss = loss + loss_fn(
189 | depth, z_pred * in_mask * photo_mask * ent_mask * rgb_mask
190 | ) * 2 ** (1 - l)
191 | loss = loss / (len(w2cs[0]) - 1)
192 |
193 | return loss
194 |
195 |
196 | def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET):
197 | if type(depth) is not np.ndarray:
198 | depth = depth.cpu().numpy()
199 |
200 | x = np.nan_to_num(depth) # change nan to 0
201 | if minmax is None:
202 | mi = np.min(x[x > 0]) # get minimum positive depth (ignore background)
203 | ma = np.max(x)
204 | else:
205 | mi, ma = minmax
206 |
207 | x = (x - mi) / (ma - mi + 1e-8) # normalize to 0~1
208 | x = (255 * x).astype(np.uint8)
209 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap))
210 | x_ = T.ToTensor()(x_) # (3, H, W)
211 | return x_, [mi, ma]
212 |
213 |
214 | def abs_error(depth_pred, depth_gt, mask):
215 | depth_pred, depth_gt = depth_pred[mask], depth_gt[mask]
216 | err = depth_pred - depth_gt
217 | return np.abs(err) if type(depth_pred) is np.ndarray else err.abs()
218 |
219 |
220 | def acc_threshold(depth_pred, depth_gt, mask, threshold):
221 | errors = abs_error(depth_pred, depth_gt, mask)
222 | acc_mask = errors < threshold
223 | return (
224 | acc_mask.astype("float") if type(depth_pred) is np.ndarray else acc_mask.float()
225 | )
226 |
227 |
228 | # Ray helpers
229 | def get_rays(
230 | H,
231 | W,
232 | intrinsics_target,
233 | c2w_target,
234 | chunk=-1,
235 | chunk_id=-1,
236 | train=True,
237 | train_batch_size=-1,
238 | mask=None,
239 | ):
240 | if train:
241 | if mask is None:
242 | xs, ys = (
243 | torch.randint(0, W, (train_batch_size,)).float().cuda(),
244 | torch.randint(0, H, (train_batch_size,)).float().cuda(),
245 | )
246 | else: # Sample 8 times more points to get mask points as much as possible
247 | xs, ys = (
248 | torch.randint(0, W, (8 * train_batch_size,)).float().cuda(),
249 | torch.randint(0, H, (8 * train_batch_size,)).float().cuda(),
250 | )
251 | masked_points = mask[ys.long(), xs.long()]
252 | xs_, ys_ = xs[~masked_points], ys[~masked_points]
253 | xs, ys = xs[masked_points], ys[masked_points]
254 | xs, ys = torch.cat([xs, xs_]), torch.cat([ys, ys_])
255 | xs, ys = xs[:train_batch_size], ys[:train_batch_size]
256 | else:
257 | ys, xs = torch.meshgrid(
258 | torch.linspace(0, H - 1, H), torch.linspace(0, W - 1, W)
259 | ) # pytorch's meshgrid has indexing='ij'
260 | ys, xs = ys.cuda().reshape(-1), xs.cuda().reshape(-1)
261 | if chunk > 0:
262 | ys, xs = (
263 | ys[chunk_id * chunk : (chunk_id + 1) * chunk],
264 | xs[chunk_id * chunk : (chunk_id + 1) * chunk],
265 | )
266 |
267 | dirs = torch.stack(
268 | [
269 | (xs - intrinsics_target[0, 2]) / intrinsics_target[0, 0],
270 | (ys - intrinsics_target[1, 2]) / intrinsics_target[1, 1],
271 | torch.ones_like(xs),
272 | ],
273 | -1,
274 | ) # use 1 instead of -1
275 |
276 | # Translate camera frame's origin to the world frame. It is the origin of all rays.
277 | rays_dir = (
278 | dirs @ c2w_target[:3, :3].t()
279 | ) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
280 | rays_orig = c2w_target[:3, -1].clone().reshape(1, 3).expand(rays_dir.shape[0], -1)
281 |
282 | rays_pixs = torch.stack((ys, xs)) # row col
283 |
284 | return rays_orig, rays_dir, rays_pixs
285 |
286 |
287 | def conver_to_ndc(ray_pts, w2c_ref, intrinsics_ref, W_H, depth_values):
288 | nb_rays, nb_samples = ray_pts.shape[:2]
289 | ray_pts = ray_pts.reshape(-1, 3)
290 |
291 | R = w2c_ref[:3, :3] # (3, 3)
292 | T = w2c_ref[:3, 3:] # (3, 1)
293 | ray_pts = torch.matmul(ray_pts, R.t()) + T.reshape(1, 3)
294 |
295 | ray_pts_ndc = ray_pts @ intrinsics_ref.t()
296 | ray_pts_ndc[:, :2] = ray_pts_ndc[:, :2] / (
297 | ray_pts_ndc[:, -1:] * W_H.reshape(1, 2)
298 | ) # normalize x,y to 0~1
299 |
300 | grid = ray_pts_ndc[None, None, :, :2] * 2 - 1
301 | near = F.grid_sample(
302 | depth_values[:, :1],
303 | grid,
304 | align_corners=True,
305 | mode="bilinear",
306 | padding_mode="border",
307 | ).squeeze()
308 | far = F.grid_sample(
309 | depth_values[:, -1:],
310 | grid,
311 | align_corners=True,
312 | mode="bilinear",
313 | padding_mode="border",
314 | ).squeeze()
315 | ray_pts_ndc[:, 2] = (ray_pts_ndc[:, 2] - near) / (far - near) # normalize z to 0~1
316 |
317 | ray_pts_ndc = ray_pts_ndc.view(nb_rays, nb_samples, 3)
318 |
319 | return ray_pts_ndc
320 |
321 |
322 | def get_sample_points(
323 | nb_coarse,
324 | nb_fine,
325 | near,
326 | far,
327 | rays_o,
328 | rays_d,
329 | nb_views,
330 | w2cs,
331 | intrinsics,
332 | depth_values,
333 | W_H,
334 | with_noise=False,
335 | ):
336 | device = rays_o.device
337 | nb_rays = rays_o.shape[0]
338 |
339 | with torch.no_grad():
340 | t_vals = torch.linspace(0.0, 1.0, steps=nb_coarse).view(1, nb_coarse).to(device)
341 | pts_depth = near * (1.0 - t_vals) + far * (t_vals)
342 | pts_depth = pts_depth.expand([nb_rays, nb_coarse])
343 | ray_pts = rays_o.unsqueeze(1) + pts_depth.unsqueeze(-1) * rays_d.unsqueeze(1)
344 |
345 | ## Counting the number of source views for which the points are valid
346 | valid_points = torch.zeros([nb_rays, nb_coarse]).to(device)
347 | for idx in range(nb_views):
348 | w2c_ref, intrinsic_ref = w2cs[0, idx], intrinsics[0, idx]
349 | ray_pts_ndc = conver_to_ndc(
350 | ray_pts,
351 | w2c_ref,
352 | intrinsic_ref,
353 | W_H,
354 | depth_values=depth_values[f"level_0"][:, idx],
355 | )
356 | valid_points += (
357 | ((ray_pts_ndc >= 0) & (ray_pts_ndc <= 1)).sum(dim=-1) == 3
358 | ).float()
359 |
360 | ## Creating a distribution based on the counted values and sample more points
361 | if nb_fine > 0:
362 | point_distr = torch.distributions.categorical.Categorical(
363 | logits=valid_points
364 | )
365 | t_vals = (
366 | point_distr.sample([nb_fine]).t()
367 | - torch.rand([nb_rays, nb_fine]).cuda()
368 | ) / (nb_coarse - 1)
369 | pts_depth_fine = near * (1.0 - t_vals) + far * (t_vals)
370 |
371 | pts_depth = torch.cat([pts_depth, pts_depth_fine], dim=-1)
372 | pts_depth, _ = torch.sort(pts_depth)
373 |
374 | if with_noise: ## Add noise to sample points during training
375 | # get intervals between samples
376 | mids = 0.5 * (pts_depth[..., 1:] + pts_depth[..., :-1])
377 | upper = torch.cat([mids, pts_depth[..., -1:]], -1)
378 | lower = torch.cat([pts_depth[..., :1], mids], -1)
379 | # stratified samples in those intervals
380 | t_rand = torch.rand(pts_depth.shape, device=device)
381 | pts_depth = lower + (upper - lower) * t_rand
382 |
383 | ray_pts = rays_o.unsqueeze(1) + pts_depth.unsqueeze(-1) * rays_d.unsqueeze(1)
384 |
385 | ray_pts_ndc = {"level_0": [], "level_1": [], "level_2": []}
386 | for idx in range(nb_views):
387 | w2c_ref, intrinsic_ref = w2cs[0, idx], intrinsics[0, idx]
388 | for l in range(3):
389 | ray_pts_ndc[f"level_{l}"].append(
390 | conver_to_ndc(
391 | ray_pts,
392 | w2c_ref,
393 | intrinsic_ref,
394 | W_H,
395 | depth_values=depth_values[f"level_{l}"][:, idx],
396 | )
397 | )
398 | for l in range(3):
399 | ray_pts_ndc[f"level_{l}"] = torch.stack(ray_pts_ndc[f"level_{l}"], dim=2)
400 |
401 | return pts_depth, ray_pts, ray_pts_ndc
402 |
403 |
404 | def get_rays_pts(
405 | H,
406 | W,
407 | c2ws,
408 | w2cs,
409 | intrinsics,
410 | near_fars,
411 | depth_values,
412 | nb_coarse,
413 | nb_fine,
414 | nb_views,
415 | chunk=-1,
416 | chunk_idx=-1,
417 | train=False,
418 | train_batch_size=-1,
419 | target_img=None,
420 | target_depth=None,
421 | ):
422 | if train:
423 | if target_depth.sum() > 0:
424 | depth_mask = target_depth > 0
425 | else:
426 | depth_mask = None
427 | else:
428 | depth_mask = None
429 |
430 | rays_orig, rays_dir, rays_pixs = get_rays(
431 | H,
432 | W,
433 | intrinsics[0, -1],
434 | c2ws[0, -1],
435 | chunk=chunk,
436 | chunk_id=chunk_idx,
437 | train=train,
438 | train_batch_size=train_batch_size,
439 | mask=depth_mask,
440 | )
441 |
442 | ## Extracting ground truth color and depth of target view
443 | if train:
444 | rays_pixs_int = rays_pixs.long()
445 | rays_gt_rgb = target_img[:, rays_pixs_int[0], rays_pixs_int[1]].permute(1, 0)
446 | rays_gt_depth = target_depth[rays_pixs_int[0], rays_pixs_int[1]]
447 | else:
448 | rays_gt_rgb = None
449 | rays_gt_depth = None
450 |
451 | # travel along the rays
452 | near, far = near_fars[0, -1, 0], near_fars[0, -1, 1] ## near/far of the target view
453 | W_H = torch.tensor([W - 1, H - 1]).cuda()
454 | pts_depth, ray_pts, ray_pts_ndc = get_sample_points(
455 | nb_coarse,
456 | nb_fine,
457 | near,
458 | far,
459 | rays_orig,
460 | rays_dir,
461 | nb_views,
462 | w2cs,
463 | intrinsics,
464 | depth_values,
465 | W_H,
466 | with_noise=train,
467 | )
468 |
469 | return (
470 | pts_depth,
471 | ray_pts,
472 | ray_pts_ndc,
473 | rays_dir,
474 | rays_gt_rgb,
475 | rays_gt_depth,
476 | rays_pixs,
477 | )
478 |
479 |
480 | def normal_vect(vect, dim=-1):
481 | return vect / (torch.sqrt(torch.sum(vect**2, dim=dim, keepdim=True)) + 1e-7)
482 |
483 |
484 | def interpolate_3D(feats, pts_ndc):
485 | H, W = pts_ndc.shape[-3:-1]
486 | grid = pts_ndc.view(-1, 1, H, W, 3) * 2 - 1.0 # [1 1 H W 3] (x,y,z)
487 | features = (
488 | F.grid_sample(
489 | feats, grid, align_corners=True, mode="bilinear", padding_mode="border"
490 | )[:, :, 0]
491 | .permute(2, 3, 0, 1)
492 | .squeeze()
493 | )
494 |
495 | return features
496 |
497 |
498 | def interpolate_2D(feats, imgs, pts_ndc):
499 | H, W = pts_ndc.shape[-3:-1]
500 | grid = pts_ndc[..., :2].view(-1, H, W, 2) * 2 - 1.0 # [1 H W 2] (x,y)
501 | features = (
502 | F.grid_sample(
503 | feats, grid, align_corners=True, mode="bilinear", padding_mode="border"
504 | )
505 | .permute(2, 3, 1, 0)
506 | .squeeze()
507 | )
508 | images = (
509 | F.grid_sample(
510 | imgs, grid, align_corners=True, mode="bilinear", padding_mode="border"
511 | )
512 | .permute(2, 3, 1, 0)
513 | .squeeze()
514 | )
515 | with torch.no_grad():
516 | in_mask = (grid > -1.0) * (grid < 1.0)
517 | in_mask = (in_mask[..., 0] * in_mask[..., 1]).float().permute(1, 2, 0)
518 |
519 | return features, images, in_mask
520 |
521 |
522 | def read_pfm(filename):
523 | file = open(filename, "rb")
524 | color = None
525 | width = None
526 | height = None
527 | scale = None
528 | endian = None
529 |
530 | header = file.readline().decode("utf-8").rstrip()
531 | if header == "PF":
532 | color = True
533 | elif header == "Pf":
534 | color = False
535 | else:
536 | raise Exception("Not a PFM file.")
537 |
538 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("utf-8"))
539 | if dim_match:
540 | width, height = map(int, dim_match.groups())
541 | else:
542 | raise Exception("Malformed PFM header.")
543 |
544 | scale = float(file.readline().rstrip())
545 | if scale < 0: # little-endian
546 | endian = "<"
547 | scale = -scale
548 | else:
549 | endian = ">" # big-endian
550 |
551 | data = np.fromfile(file, endian + "f")
552 | shape = (height, width, 3) if color else (height, width)
553 |
554 | data = np.reshape(data, shape)
555 | data = np.flipud(data)
556 | file.close()
557 | return data, scale
558 |
559 |
560 | def homo_warp(src_feat, proj_mat, depth_values, src_grid=None, pad=0):
561 | if src_grid == None:
562 | B, C, H, W = src_feat.shape
563 | device = src_feat.device
564 |
565 | if pad > 0:
566 | H_pad, W_pad = H + pad * 2, W + pad * 2
567 | else:
568 | H_pad, W_pad = H, W
569 |
570 | if depth_values.dim() != 4:
571 | depth_values = depth_values[..., None, None].repeat(1, 1, H_pad, W_pad)
572 | D = depth_values.shape[1]
573 |
574 | R = proj_mat[:, :, :3] # (B, 3, 3)
575 | T = proj_mat[:, :, 3:] # (B, 3, 1)
576 | # create grid from the ref frame
577 | ref_grid = create_meshgrid(
578 | H_pad, W_pad, normalized_coordinates=False, device=device
579 | ) # (1, H, W, 2)
580 | if pad > 0:
581 | ref_grid -= pad
582 |
583 | ref_grid = ref_grid.permute(0, 3, 1, 2) # (1, 2, H, W)
584 | ref_grid = ref_grid.reshape(1, 2, W_pad * H_pad) # (1, 2, H*W)
585 | ref_grid = ref_grid.expand(B, -1, -1) # (B, 2, H*W)
586 | ref_grid = torch.cat(
587 | (ref_grid, torch.ones_like(ref_grid[:, :1])), 1
588 | ) # (B, 3, H*W)
589 | ref_grid_d = ref_grid.repeat(1, 1, D) # (B, 3, D*H*W)
590 | src_grid_d = R @ ref_grid_d + T / depth_values.reshape(B, 1, D * W_pad * H_pad)
591 | del ref_grid_d, ref_grid, proj_mat, R, T, depth_values # release (GPU) memory
592 |
593 | src_grid = (
594 | src_grid_d[:, :2] / src_grid_d[:, 2:]
595 | ) # divide by depth (B, 2, D*H*W)
596 | del src_grid_d
597 | src_grid[:, 0] = src_grid[:, 0] / ((W - 1) / 2) - 1 # scale to -1~1
598 | src_grid[:, 1] = src_grid[:, 1] / ((H - 1) / 2) - 1 # scale to -1~1
599 | src_grid = src_grid.permute(0, 2, 1) # (B, D*H*W, 2)
600 | src_grid = src_grid.view(B, D, W_pad, H_pad, 2)
601 |
602 | B, D, W_pad, H_pad = src_grid.shape[:4]
603 | warped_src_feat = F.grid_sample(
604 | src_feat,
605 | src_grid.view(B, D, W_pad * H_pad, 2),
606 | mode="bilinear",
607 | padding_mode="zeros",
608 | align_corners=True,
609 | ) # (B, C, D, H*W)
610 | warped_src_feat = warped_src_feat.view(B, -1, D, H_pad, W_pad)
611 | # src_grid = src_grid.view(B, 1, D, H_pad, W_pad, 2)
612 | return warped_src_feat, src_grid
613 |
614 | ##### Functions for view selection
615 | TINY_NUMBER = 1e-5 # float32 only has 7 decimal digits precision
616 |
617 | def angular_dist_between_2_vectors(vec1, vec2):
618 | vec1_unit = vec1 / (np.linalg.norm(vec1, axis=1, keepdims=True) + TINY_NUMBER)
619 | vec2_unit = vec2 / (np.linalg.norm(vec2, axis=1, keepdims=True) + TINY_NUMBER)
620 | angular_dists = np.arccos(
621 | np.clip(np.sum(vec1_unit * vec2_unit, axis=-1), -1.0, 1.0)
622 | )
623 | return angular_dists
624 |
625 |
626 | def batched_angular_dist_rot_matrix(R1, R2):
627 | assert (
628 | R1.shape[-1] == 3
629 | and R2.shape[-1] == 3
630 | and R1.shape[-2] == 3
631 | and R2.shape[-2] == 3
632 | )
633 | return np.arccos(
634 | np.clip(
635 | (np.trace(np.matmul(R2.transpose(0, 2, 1), R1), axis1=1, axis2=2) - 1)
636 | / 2.0,
637 | a_min=-1 + TINY_NUMBER,
638 | a_max=1 - TINY_NUMBER,
639 | )
640 | )
641 |
642 |
643 | def get_nearest_pose_ids(
644 | tar_pose,
645 | ref_poses,
646 | num_select,
647 | tar_id=-1,
648 | angular_dist_method="dist",
649 | scene_center=(0, 0, 0),
650 | ):
651 | num_cams = len(ref_poses)
652 | num_select = min(num_select, num_cams - 1)
653 | batched_tar_pose = tar_pose[None, ...].repeat(num_cams, 0)
654 |
655 | if angular_dist_method == "matrix":
656 | dists = batched_angular_dist_rot_matrix(
657 | batched_tar_pose[:, :3, :3], ref_poses[:, :3, :3]
658 | )
659 | elif angular_dist_method == "vector":
660 | tar_cam_locs = batched_tar_pose[:, :3, 3]
661 | ref_cam_locs = ref_poses[:, :3, 3]
662 | scene_center = np.array(scene_center)[None, ...]
663 | tar_vectors = tar_cam_locs - scene_center
664 | ref_vectors = ref_cam_locs - scene_center
665 | dists = angular_dist_between_2_vectors(tar_vectors, ref_vectors)
666 | elif angular_dist_method == "dist":
667 | tar_cam_locs = batched_tar_pose[:, :3, 3]
668 | ref_cam_locs = ref_poses[:, :3, 3]
669 | dists = np.linalg.norm(tar_cam_locs - ref_cam_locs, axis=1)
670 | else:
671 | raise Exception("unknown angular distance calculation method!")
672 |
673 | if tar_id >= 0:
674 | assert tar_id < num_cams
675 | dists[tar_id] = 1e3 # make sure not to select the target id itself
676 |
677 | sorted_ids = np.argsort(dists)
678 | selected_ids = sorted_ids[:num_select]
679 |
680 | return selected_ids
--------------------------------------------------------------------------------