├── .gitignore
├── LICENSE.txt
├── README.md
├── __init__.py
├── config.yaml
├── logo-dark.svg
├── logo.svg
├── requirements.txt
├── tsr
├── models
│ ├── isosurface.py
│ ├── nerf_renderer.py
│ ├── network_utils.py
│ ├── tokenizers
│ │ ├── image.py
│ │ └── triplane.py
│ └── transformer
│ │ ├── attention.py
│ │ ├── basic_transformer_block.py
│ │ └── transformer_1d.py
├── system.py
└── utils.py
├── web
├── html
│ └── threeVisualizer.html
├── js
│ └── threeVisualizer.js
├── style
│ ├── progressStyle.css
│ └── threeStyle.css
└── visualization.js
├── workflow-sample.png
├── workflow_rembg.json
└── workflow_simple.json
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .DS_Store
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ComfyUI-Flowty-TripoSR
2 |
3 | This is a custom node that lets you use TripoSR right from ComfyUI.
4 |
5 | [TripoSR](https://github.com/VAST-AI-Research/TripoSR) is a state-of-the-art open-source model for fast feedforward 3D reconstruction from a single image, collaboratively developed by Tripo AI and Stability AI. (TL;DR it creates a 3d model from an image.)
6 |
7 | 
8 |
9 | I've created this node for experimentation, feel free to submit PRs for performance improvements etc.
10 |
11 | ### Installation:
12 | * Install ComfyUI
13 | * Clone this repo into ```custom_nodes```:
14 | ```shell
15 | $ cd ComfyUI/custom_nodes
16 | $ git clone https://github.com/flowtyone/ComfyUI-Flowty-TripoSR.git
17 | ```
18 | * Install dependencies:
19 | ```shell
20 | $ cd ComfyUI-Flowty-TripoSR
21 | $ pip install -r requirements.txt
22 | ```
23 | * [Download TripoSR](https://huggingface.co/stabilityai/TripoSR/blob/main/model.ckpt) and place it in ```ComfyUI/models/checkpoints```
24 | * Start ComfyUI (or restart)
25 |
26 | Special thanks to MrForExample for creating [ComfyUI-3D-Pack](https://github.com/MrForExample/ComfyUI-3D-Pack). Code from that node pack was used to display 3d models in comfyui.
27 |
28 | This is a community project from [flowt.ai](https://flowt.ai). If you like it, check us out!
29 |
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from os import path
3 |
4 | sys.path.insert(0, path.dirname(__file__))
5 | from folder_paths import get_filename_list, get_full_path, get_save_image_path, get_output_directory
6 | from comfy.model_management import get_torch_device
7 | from tsr.system import TSR
8 | from PIL import Image
9 | import numpy as np
10 | import torch
11 |
12 |
13 | def fill_background(image):
14 | image = np.array(image).astype(np.float32) / 255.0
15 | image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
16 | image = Image.fromarray((image * 255.0).astype(np.uint8))
17 | return image
18 |
19 |
20 | class TripoSRModelLoader:
21 | def __init__(self):
22 | self.initialized_model = None
23 |
24 | @classmethod
25 | def INPUT_TYPES(s):
26 | return {
27 | "required": {
28 | "model": (get_filename_list("checkpoints"),),
29 | "chunk_size": ("INT", {"default": 8192, "min": 1, "max": 10000})
30 | }
31 | }
32 |
33 | RETURN_TYPES = ("TRIPOSR_MODEL",)
34 | FUNCTION = "load"
35 | CATEGORY = "Flowty TripoSR"
36 |
37 | def load(self, model, chunk_size):
38 | device = get_torch_device()
39 |
40 | if not torch.cuda.is_available():
41 | device = "cpu"
42 |
43 | if not self.initialized_model:
44 | print("Loading TripoSR model")
45 | self.initialized_model = TSR.from_pretrained_custom(
46 | weight_path=get_full_path("checkpoints", model),
47 | config_path=path.join(path.dirname(__file__), "config.yaml")
48 | )
49 | self.initialized_model.renderer.set_chunk_size(chunk_size)
50 | self.initialized_model.to(device)
51 |
52 | return (self.initialized_model,)
53 |
54 |
55 | class TripoSRSampler:
56 |
57 | def __init__(self):
58 | self.initialized_model = None
59 |
60 | @classmethod
61 | def INPUT_TYPES(s):
62 | return {
63 | "required": {
64 | "model": ("TRIPOSR_MODEL",),
65 | "reference_image": ("IMAGE",),
66 | "geometry_resolution": ("INT", {"default": 256, "min": 128, "max": 12288}),
67 | "threshold": ("FLOAT", {"default": 25.0, "min": 0.0, "step": 0.01}),
68 | },
69 | "optional": {
70 | "reference_mask": ("MASK",)
71 | }
72 | }
73 |
74 | RETURN_TYPES = ("MESH",)
75 | FUNCTION = "sample"
76 | CATEGORY = "Flowty TripoSR"
77 |
78 | def sample(self, model, reference_image, geometry_resolution, threshold, reference_mask=None):
79 | device = get_torch_device()
80 |
81 | if not torch.cuda.is_available():
82 | device = "cpu"
83 |
84 | image = reference_image[0]
85 |
86 | if reference_mask is not None:
87 | mask = reference_mask[0].unsqueeze(2)
88 | image = torch.cat((image, mask), dim=2).detach().cpu().numpy()
89 | else:
90 | image = image.detach().cpu().numpy()
91 |
92 | image = Image.fromarray(np.clip(255. * image, 0, 255).astype(np.uint8))
93 | if reference_mask is not None:
94 | image = fill_background(image)
95 | image = image.convert('RGB')
96 | scene_codes = model([image], device)
97 | meshes = model.extract_mesh(scene_codes, resolution=geometry_resolution, threshold=threshold)
98 | return ([meshes[0]],)
99 |
100 |
101 | class TripoSRViewer:
102 | @classmethod
103 | def INPUT_TYPES(s):
104 | return {
105 | "required": {
106 | "mesh": ("MESH",)
107 | }
108 | }
109 |
110 | RETURN_TYPES = ()
111 | OUTPUT_NODE = True
112 | FUNCTION = "display"
113 | CATEGORY = "Flowty TripoSR"
114 |
115 | def display(self, mesh):
116 | saved = list()
117 | full_output_folder, filename, counter, subfolder, filename_prefix = get_save_image_path("meshsave",
118 | get_output_directory())
119 |
120 | for (batch_number, single_mesh) in enumerate(mesh):
121 | filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
122 | file = f"{filename_with_batch_num}_{counter:05}_.obj"
123 | single_mesh.apply_transform(np.array([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]))
124 | single_mesh.export(path.join(full_output_folder, file))
125 | saved.append({
126 | "filename": file,
127 | "type": "output",
128 | "subfolder": subfolder
129 | })
130 |
131 | return {"ui": {"mesh": saved}}
132 |
133 |
134 | NODE_CLASS_MAPPINGS = {
135 | "TripoSRModelLoader": TripoSRModelLoader,
136 | "TripoSRSampler": TripoSRSampler,
137 | "TripoSRViewer": TripoSRViewer
138 | }
139 |
140 | NODE_DISPLAY_NAME_MAPPINGS = {
141 | "TripoSRModelLoader": "TripoSR Model Loader",
142 | "TripoSRSampler": "TripoSR Sampler",
143 | "TripoSRViewer": "TripoSR Viewer"
144 | }
145 |
146 | WEB_DIRECTORY = "./web"
147 |
148 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS', 'WEB_DIRECTORY']
149 |
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | cond_image_size: 512
2 |
3 | image_tokenizer_cls: tsr.models.tokenizers.image.DINOSingleImageTokenizer
4 | image_tokenizer:
5 | pretrained_model_name_or_path: "facebook/dino-vitb16"
6 |
7 | tokenizer_cls: tsr.models.tokenizers.triplane.Triplane1DTokenizer
8 | tokenizer:
9 | plane_size: 32
10 | num_channels: 1024
11 |
12 | backbone_cls: tsr.models.transformer.transformer_1d.Transformer1D
13 | backbone:
14 | in_channels: ${tokenizer.num_channels}
15 | num_attention_heads: 16
16 | attention_head_dim: 64
17 | num_layers: 16
18 | cross_attention_dim: 768
19 |
20 | post_processor_cls: tsr.models.network_utils.TriplaneUpsampleNetwork
21 | post_processor:
22 | in_channels: 1024
23 | out_channels: 40
24 |
25 | decoder_cls: tsr.models.network_utils.NeRFMLP
26 | decoder:
27 | in_channels: 120 # 3 * 40
28 | n_neurons: 64
29 | n_hidden_layers: 9
30 | activation: silu
31 |
32 | renderer_cls: tsr.models.nerf_renderer.TriplaneNeRFRenderer
33 | renderer:
34 | radius: 0.87 # slightly larger than 0.5 * sqrt(3)
35 | feature_reduction: concat
36 | density_activation: exp
37 | density_bias: -1.0
38 | num_samples_per_ray: 128
--------------------------------------------------------------------------------
/logo-dark.svg:
--------------------------------------------------------------------------------
1 |
26 |
--------------------------------------------------------------------------------
/logo.svg:
--------------------------------------------------------------------------------
1 |
26 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | omegaconf==2.3.0
2 | Pillow==10.1.0
3 | einops==0.7.0
4 | transformers==4.35.0
5 | trimesh==4.0.5
6 | huggingface-hub
7 | imageio[ffmpeg]
8 | scikit-image
--------------------------------------------------------------------------------
/tsr/models/isosurface.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Optional, Tuple
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from skimage import measure
7 |
8 |
9 | class IsosurfaceHelper(nn.Module):
10 | points_range: Tuple[float, float] = (0, 1)
11 |
12 | @property
13 | def grid_vertices(self) -> torch.FloatTensor:
14 | raise NotImplementedError
15 |
16 |
17 | class MarchingCubeHelper(IsosurfaceHelper):
18 | def __init__(self, resolution: int) -> None:
19 | super().__init__()
20 | self.resolution = resolution
21 | #self.mc_func: Callable = marching_cubes
22 | self._grid_vertices: Optional[torch.FloatTensor] = None
23 |
24 | @property
25 | def grid_vertices(self) -> torch.FloatTensor:
26 | if self._grid_vertices is None:
27 | # keep the vertices on CPU so that we can support very large resolution
28 | x, y, z = (
29 | torch.linspace(*self.points_range, self.resolution),
30 | torch.linspace(*self.points_range, self.resolution),
31 | torch.linspace(*self.points_range, self.resolution),
32 | )
33 | x, y, z = torch.meshgrid(x, y, z, indexing="ij")
34 | verts = torch.cat(
35 | [x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
36 | ).reshape(-1, 3)
37 | self._grid_vertices = verts
38 | return self._grid_vertices
39 |
40 | def forward(
41 | self,
42 | level: torch.FloatTensor,
43 | ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
44 | level = -level.view(self.resolution, self.resolution, self.resolution)
45 | v_pos, t_pos_idx, _, __ = measure.marching_cubes((level.detach().cpu() if level.is_cuda else level.detach()).numpy(), 0.0) #self.mc_func(level.detach(), 0.0)
46 | v_pos = torch.from_numpy(v_pos.copy()).type(torch.FloatTensor).to(level.device)
47 | t_pos_idx = torch.from_numpy(t_pos_idx.copy()).type(torch.LongTensor).to(level.device)
48 | v_pos = v_pos[..., [0, 1, 2]]
49 | t_pos_idx = t_pos_idx[..., [1, 0, 2]]
50 | v_pos = v_pos / (self.resolution - 1.0)
51 | return v_pos, t_pos_idx
52 |
--------------------------------------------------------------------------------
/tsr/models/nerf_renderer.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Dict
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from einops import rearrange, reduce
7 |
8 | from ..utils import (
9 | BaseModule,
10 | chunk_batch,
11 | get_activation,
12 | rays_intersect_bbox,
13 | scale_tensor,
14 | )
15 |
16 |
17 | class TriplaneNeRFRenderer(BaseModule):
18 | @dataclass
19 | class Config(BaseModule.Config):
20 | radius: float
21 |
22 | feature_reduction: str = "concat"
23 | density_activation: str = "trunc_exp"
24 | density_bias: float = -1.0
25 | color_activation: str = "sigmoid"
26 | num_samples_per_ray: int = 128
27 | randomized: bool = False
28 |
29 | cfg: Config
30 |
31 | def configure(self) -> None:
32 | assert self.cfg.feature_reduction in ["concat", "mean"]
33 | self.chunk_size = 0
34 |
35 | def set_chunk_size(self, chunk_size: int):
36 | assert (
37 | chunk_size >= 0
38 | ), "chunk_size must be a non-negative integer (0 for no chunking)."
39 | self.chunk_size = chunk_size
40 |
41 | def query_triplane(
42 | self,
43 | decoder: torch.nn.Module,
44 | positions: torch.Tensor,
45 | triplane: torch.Tensor,
46 | ) -> Dict[str, torch.Tensor]:
47 | input_shape = positions.shape[:-1]
48 | positions = positions.view(-1, 3)
49 |
50 | # positions in (-radius, radius)
51 | # normalized to (-1, 1) for grid sample
52 | positions = scale_tensor(
53 | positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
54 | )
55 |
56 | def _query_chunk(x):
57 | indices2D: torch.Tensor = torch.stack(
58 | (x[..., [0, 1]], x[..., [0, 2]], x[..., [1, 2]]),
59 | dim=-3,
60 | )
61 | out: torch.Tensor = F.grid_sample(
62 | rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3),
63 | rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3),
64 | align_corners=False,
65 | mode="bilinear",
66 | )
67 | if self.cfg.feature_reduction == "concat":
68 | out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3)
69 | elif self.cfg.feature_reduction == "mean":
70 | out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean")
71 | else:
72 | raise NotImplementedError
73 |
74 | net_out: Dict[str, torch.Tensor] = decoder(out)
75 | return net_out
76 |
77 | if self.chunk_size > 0:
78 | net_out = chunk_batch(_query_chunk, self.chunk_size, positions)
79 | else:
80 | net_out = _query_chunk(positions)
81 |
82 | net_out["density_act"] = get_activation(self.cfg.density_activation)(
83 | net_out["density"] + self.cfg.density_bias
84 | )
85 | net_out["color"] = get_activation(self.cfg.color_activation)(
86 | net_out["features"]
87 | )
88 |
89 | net_out = {k: v.view(*input_shape, -1) for k, v in net_out.items()}
90 |
91 | return net_out
92 |
93 | def _forward(
94 | self,
95 | decoder: torch.nn.Module,
96 | triplane: torch.Tensor,
97 | rays_o: torch.Tensor,
98 | rays_d: torch.Tensor,
99 | **kwargs,
100 | ):
101 | rays_shape = rays_o.shape[:-1]
102 | rays_o = rays_o.view(-1, 3)
103 | rays_d = rays_d.view(-1, 3)
104 | n_rays = rays_o.shape[0]
105 |
106 | t_near, t_far, rays_valid = rays_intersect_bbox(rays_o, rays_d, self.cfg.radius)
107 | t_near, t_far = t_near[rays_valid], t_far[rays_valid]
108 |
109 | t_vals = torch.linspace(
110 | 0, 1, self.cfg.num_samples_per_ray + 1, device=triplane.device
111 | )
112 | t_mid = (t_vals[:-1] + t_vals[1:]) / 2.0
113 | z_vals = t_near * (1 - t_mid[None]) + t_far * t_mid[None] # (N_rays, N_samples)
114 |
115 | xyz = (
116 | rays_o[:, None, :] + z_vals[..., None] * rays_d[..., None, :]
117 | ) # (N_rays, N_sample, 3)
118 |
119 | mlp_out = self.query_triplane(
120 | decoder=decoder,
121 | positions=xyz,
122 | triplane=triplane,
123 | )
124 |
125 | eps = 1e-10
126 | # deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples)
127 | deltas = t_vals[1:] - t_vals[:-1] # (N_rays, N_samples)
128 | alpha = 1 - torch.exp(
129 | -deltas * mlp_out["density_act"][..., 0]
130 | ) # (N_rays, N_samples)
131 | accum_prod = torch.cat(
132 | [
133 | torch.ones_like(alpha[:, :1]),
134 | torch.cumprod(1 - alpha[:, :-1] + eps, dim=-1),
135 | ],
136 | dim=-1,
137 | )
138 | weights = alpha * accum_prod # (N_rays, N_samples)
139 | comp_rgb_ = (weights[..., None] * mlp_out["color"]).sum(dim=-2) # (N_rays, 3)
140 | opacity_ = weights.sum(dim=-1) # (N_rays)
141 |
142 | comp_rgb = torch.zeros(
143 | n_rays, 3, dtype=comp_rgb_.dtype, device=comp_rgb_.device
144 | )
145 | opacity = torch.zeros(n_rays, dtype=opacity_.dtype, device=opacity_.device)
146 | comp_rgb[rays_valid] = comp_rgb_
147 | opacity[rays_valid] = opacity_
148 |
149 | comp_rgb += 1 - opacity[..., None]
150 | comp_rgb = comp_rgb.view(*rays_shape, 3)
151 |
152 | return comp_rgb
153 |
154 | def forward(
155 | self,
156 | decoder: torch.nn.Module,
157 | triplane: torch.Tensor,
158 | rays_o: torch.Tensor,
159 | rays_d: torch.Tensor,
160 | ) -> Dict[str, torch.Tensor]:
161 | if triplane.ndim == 4:
162 | comp_rgb = self._forward(decoder, triplane, rays_o, rays_d)
163 | else:
164 | comp_rgb = torch.stack(
165 | [
166 | self._forward(decoder, triplane[i], rays_o[i], rays_d[i])
167 | for i in range(triplane.shape[0])
168 | ],
169 | dim=0,
170 | )
171 |
172 | return comp_rgb
173 |
174 | def train(self, mode=True):
175 | self.randomized = mode and self.cfg.randomized
176 | return super().train(mode=mode)
177 |
178 | def eval(self):
179 | self.randomized = False
180 | return super().eval()
181 |
--------------------------------------------------------------------------------
/tsr/models/network_utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional
3 |
4 | import torch
5 | import torch.nn as nn
6 | from einops import rearrange
7 |
8 | from ..utils import BaseModule
9 |
10 |
11 | class TriplaneUpsampleNetwork(BaseModule):
12 | @dataclass
13 | class Config(BaseModule.Config):
14 | in_channels: int
15 | out_channels: int
16 |
17 | cfg: Config
18 |
19 | def configure(self) -> None:
20 | self.upsample = nn.ConvTranspose2d(
21 | self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2
22 | )
23 |
24 | def forward(self, triplanes: torch.Tensor) -> torch.Tensor:
25 | triplanes_up = rearrange(
26 | self.upsample(
27 | rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
28 | ),
29 | "(B Np) Co Hp Wp -> B Np Co Hp Wp",
30 | Np=3,
31 | )
32 | return triplanes_up
33 |
34 |
35 | class NeRFMLP(BaseModule):
36 | @dataclass
37 | class Config(BaseModule.Config):
38 | in_channels: int
39 | n_neurons: int
40 | n_hidden_layers: int
41 | activation: str = "relu"
42 | bias: bool = True
43 | weight_init: Optional[str] = "kaiming_uniform"
44 | bias_init: Optional[str] = None
45 |
46 | cfg: Config
47 |
48 | def configure(self) -> None:
49 | layers = [
50 | self.make_linear(
51 | self.cfg.in_channels,
52 | self.cfg.n_neurons,
53 | bias=self.cfg.bias,
54 | weight_init=self.cfg.weight_init,
55 | bias_init=self.cfg.bias_init,
56 | ),
57 | self.make_activation(self.cfg.activation),
58 | ]
59 | for i in range(self.cfg.n_hidden_layers - 1):
60 | layers += [
61 | self.make_linear(
62 | self.cfg.n_neurons,
63 | self.cfg.n_neurons,
64 | bias=self.cfg.bias,
65 | weight_init=self.cfg.weight_init,
66 | bias_init=self.cfg.bias_init,
67 | ),
68 | self.make_activation(self.cfg.activation),
69 | ]
70 | layers += [
71 | self.make_linear(
72 | self.cfg.n_neurons,
73 | 4, # density 1 + features 3
74 | bias=self.cfg.bias,
75 | weight_init=self.cfg.weight_init,
76 | bias_init=self.cfg.bias_init,
77 | )
78 | ]
79 | self.layers = nn.Sequential(*layers)
80 |
81 | def make_linear(
82 | self,
83 | dim_in,
84 | dim_out,
85 | bias=True,
86 | weight_init=None,
87 | bias_init=None,
88 | ):
89 | layer = nn.Linear(dim_in, dim_out, bias=bias)
90 |
91 | if weight_init is None:
92 | pass
93 | elif weight_init == "kaiming_uniform":
94 | torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
95 | else:
96 | raise NotImplementedError
97 |
98 | if bias:
99 | if bias_init is None:
100 | pass
101 | elif bias_init == "zero":
102 | torch.nn.init.zeros_(layer.bias)
103 | else:
104 | raise NotImplementedError
105 |
106 | return layer
107 |
108 | def make_activation(self, activation):
109 | if activation == "relu":
110 | return nn.ReLU(inplace=True)
111 | elif activation == "silu":
112 | return nn.SiLU(inplace=True)
113 | else:
114 | raise NotImplementedError
115 |
116 | def forward(self, x):
117 | inp_shape = x.shape[:-1]
118 | x = x.reshape(-1, x.shape[-1])
119 |
120 | features = self.layers(x)
121 | features = features.reshape(*inp_shape, -1)
122 | out = {"density": features[..., 0:1], "features": features[..., 1:4]}
123 |
124 | return out
125 |
--------------------------------------------------------------------------------
/tsr/models/tokenizers/image.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import torch
4 | import torch.nn as nn
5 | from einops import rearrange
6 | from huggingface_hub import hf_hub_download
7 | from transformers.models.vit.modeling_vit import ViTModel
8 |
9 | from ...utils import BaseModule
10 |
11 |
12 | class DINOSingleImageTokenizer(BaseModule):
13 | @dataclass
14 | class Config(BaseModule.Config):
15 | pretrained_model_name_or_path: str = "facebook/dino-vitb16"
16 | enable_gradient_checkpointing: bool = False
17 |
18 | cfg: Config
19 |
20 | def configure(self) -> None:
21 | self.model: ViTModel = ViTModel(
22 | ViTModel.config_class.from_pretrained(
23 | hf_hub_download(
24 | repo_id=self.cfg.pretrained_model_name_or_path,
25 | filename="config.json",
26 | )
27 | )
28 | )
29 |
30 | if self.cfg.enable_gradient_checkpointing:
31 | self.model.encoder.gradient_checkpointing = True
32 |
33 | self.register_buffer(
34 | "image_mean",
35 | torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
36 | persistent=False,
37 | )
38 | self.register_buffer(
39 | "image_std",
40 | torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
41 | persistent=False,
42 | )
43 |
44 | def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
45 | packed = False
46 | if images.ndim == 4:
47 | packed = True
48 | images = images.unsqueeze(1)
49 |
50 | batch_size, n_input_views = images.shape[:2]
51 | images = (images - self.image_mean) / self.image_std
52 | out = self.model(
53 | rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True
54 | )
55 | local_features, global_features = out.last_hidden_state, out.pooler_output
56 | local_features = local_features.permute(0, 2, 1)
57 | local_features = rearrange(
58 | local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
59 | )
60 | if packed:
61 | local_features = local_features.squeeze(1)
62 |
63 | return local_features
64 |
65 | def detokenize(self, *args, **kwargs):
66 | raise NotImplementedError
67 |
--------------------------------------------------------------------------------
/tsr/models/tokenizers/triplane.py:
--------------------------------------------------------------------------------
1 | import math
2 | from dataclasses import dataclass
3 |
4 | import torch
5 | import torch.nn as nn
6 | from einops import rearrange, repeat
7 |
8 | from ...utils import BaseModule
9 |
10 |
11 | class Triplane1DTokenizer(BaseModule):
12 | @dataclass
13 | class Config(BaseModule.Config):
14 | plane_size: int
15 | num_channels: int
16 |
17 | cfg: Config
18 |
19 | def configure(self) -> None:
20 | self.embeddings = nn.Parameter(
21 | torch.randn(
22 | (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
23 | dtype=torch.float32,
24 | )
25 | * 1
26 | / math.sqrt(self.cfg.num_channels)
27 | )
28 |
29 | def forward(self, batch_size: int) -> torch.Tensor:
30 | return rearrange(
31 | repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
32 | "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
33 | )
34 |
35 | def detokenize(self, tokens: torch.Tensor) -> torch.Tensor:
36 | batch_size, Ct, Nt = tokens.shape
37 | assert Nt == self.cfg.plane_size**2 * 3
38 | assert Ct == self.cfg.num_channels
39 | return rearrange(
40 | tokens,
41 | "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
42 | Np=3,
43 | Hp=self.cfg.plane_size,
44 | Wp=self.cfg.plane_size,
45 | )
46 |
--------------------------------------------------------------------------------
/tsr/models/transformer/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # --------
16 | #
17 | # Modified 2024 by the Tripo AI and Stability AI Team.
18 | #
19 | # Copyright (c) 2024 Tripo AI & Stability AI
20 | #
21 | # Permission is hereby granted, free of charge, to any person obtaining a copy
22 | # of this software and associated documentation files (the "Software"), to deal
23 | # in the Software without restriction, including without limitation the rights
24 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25 | # copies of the Software, and to permit persons to whom the Software is
26 | # furnished to do so, subject to the following conditions:
27 | #
28 | # The above copyright notice and this permission notice shall be included in all
29 | # copies or substantial portions of the Software.
30 | #
31 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37 | # SOFTWARE.
38 |
39 | from typing import Optional
40 |
41 | import torch
42 | import torch.nn.functional as F
43 | from torch import nn
44 |
45 |
46 | class Attention(nn.Module):
47 | r"""
48 | A cross attention layer.
49 |
50 | Parameters:
51 | query_dim (`int`):
52 | The number of channels in the query.
53 | cross_attention_dim (`int`, *optional*):
54 | The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
55 | heads (`int`, *optional*, defaults to 8):
56 | The number of heads to use for multi-head attention.
57 | dim_head (`int`, *optional*, defaults to 64):
58 | The number of channels in each head.
59 | dropout (`float`, *optional*, defaults to 0.0):
60 | The dropout probability to use.
61 | bias (`bool`, *optional*, defaults to False):
62 | Set to `True` for the query, key, and value linear layers to contain a bias parameter.
63 | upcast_attention (`bool`, *optional*, defaults to False):
64 | Set to `True` to upcast the attention computation to `float32`.
65 | upcast_softmax (`bool`, *optional*, defaults to False):
66 | Set to `True` to upcast the softmax computation to `float32`.
67 | cross_attention_norm (`str`, *optional*, defaults to `None`):
68 | The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
69 | cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
70 | The number of groups to use for the group norm in the cross attention.
71 | added_kv_proj_dim (`int`, *optional*, defaults to `None`):
72 | The number of channels to use for the added key and value projections. If `None`, no projection is used.
73 | norm_num_groups (`int`, *optional*, defaults to `None`):
74 | The number of groups to use for the group norm in the attention.
75 | spatial_norm_dim (`int`, *optional*, defaults to `None`):
76 | The number of channels to use for the spatial normalization.
77 | out_bias (`bool`, *optional*, defaults to `True`):
78 | Set to `True` to use a bias in the output linear layer.
79 | scale_qk (`bool`, *optional*, defaults to `True`):
80 | Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
81 | only_cross_attention (`bool`, *optional*, defaults to `False`):
82 | Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
83 | `added_kv_proj_dim` is not `None`.
84 | eps (`float`, *optional*, defaults to 1e-5):
85 | An additional value added to the denominator in group normalization that is used for numerical stability.
86 | rescale_output_factor (`float`, *optional*, defaults to 1.0):
87 | A factor to rescale the output by dividing it with this value.
88 | residual_connection (`bool`, *optional*, defaults to `False`):
89 | Set to `True` to add the residual connection to the output.
90 | _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
91 | Set to `True` if the attention block is loaded from a deprecated state dict.
92 | processor (`AttnProcessor`, *optional*, defaults to `None`):
93 | The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
94 | `AttnProcessor` otherwise.
95 | """
96 |
97 | def __init__(
98 | self,
99 | query_dim: int,
100 | cross_attention_dim: Optional[int] = None,
101 | heads: int = 8,
102 | dim_head: int = 64,
103 | dropout: float = 0.0,
104 | bias: bool = False,
105 | upcast_attention: bool = False,
106 | upcast_softmax: bool = False,
107 | cross_attention_norm: Optional[str] = None,
108 | cross_attention_norm_num_groups: int = 32,
109 | added_kv_proj_dim: Optional[int] = None,
110 | norm_num_groups: Optional[int] = None,
111 | out_bias: bool = True,
112 | scale_qk: bool = True,
113 | only_cross_attention: bool = False,
114 | eps: float = 1e-5,
115 | rescale_output_factor: float = 1.0,
116 | residual_connection: bool = False,
117 | _from_deprecated_attn_block: bool = False,
118 | processor: Optional["AttnProcessor"] = None,
119 | out_dim: int = None,
120 | ):
121 | super().__init__()
122 | self.inner_dim = out_dim if out_dim is not None else dim_head * heads
123 | self.query_dim = query_dim
124 | self.cross_attention_dim = (
125 | cross_attention_dim if cross_attention_dim is not None else query_dim
126 | )
127 | self.upcast_attention = upcast_attention
128 | self.upcast_softmax = upcast_softmax
129 | self.rescale_output_factor = rescale_output_factor
130 | self.residual_connection = residual_connection
131 | self.dropout = dropout
132 | self.fused_projections = False
133 | self.out_dim = out_dim if out_dim is not None else query_dim
134 |
135 | # we make use of this private variable to know whether this class is loaded
136 | # with an deprecated state dict so that we can convert it on the fly
137 | self._from_deprecated_attn_block = _from_deprecated_attn_block
138 |
139 | self.scale_qk = scale_qk
140 | self.scale = dim_head**-0.5 if self.scale_qk else 1.0
141 |
142 | self.heads = out_dim // dim_head if out_dim is not None else heads
143 | # for slice_size > 0 the attention score computation
144 | # is split across the batch axis to save memory
145 | # You can set slice_size with `set_attention_slice`
146 | self.sliceable_head_dim = heads
147 |
148 | self.added_kv_proj_dim = added_kv_proj_dim
149 | self.only_cross_attention = only_cross_attention
150 |
151 | if self.added_kv_proj_dim is None and self.only_cross_attention:
152 | raise ValueError(
153 | "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
154 | )
155 |
156 | if norm_num_groups is not None:
157 | self.group_norm = nn.GroupNorm(
158 | num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
159 | )
160 | else:
161 | self.group_norm = None
162 |
163 | self.spatial_norm = None
164 |
165 | if cross_attention_norm is None:
166 | self.norm_cross = None
167 | elif cross_attention_norm == "layer_norm":
168 | self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
169 | elif cross_attention_norm == "group_norm":
170 | if self.added_kv_proj_dim is not None:
171 | # The given `encoder_hidden_states` are initially of shape
172 | # (batch_size, seq_len, added_kv_proj_dim) before being projected
173 | # to (batch_size, seq_len, cross_attention_dim). The norm is applied
174 | # before the projection, so we need to use `added_kv_proj_dim` as
175 | # the number of channels for the group norm.
176 | norm_cross_num_channels = added_kv_proj_dim
177 | else:
178 | norm_cross_num_channels = self.cross_attention_dim
179 |
180 | self.norm_cross = nn.GroupNorm(
181 | num_channels=norm_cross_num_channels,
182 | num_groups=cross_attention_norm_num_groups,
183 | eps=1e-5,
184 | affine=True,
185 | )
186 | else:
187 | raise ValueError(
188 | f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
189 | )
190 |
191 | linear_cls = nn.Linear
192 |
193 | self.linear_cls = linear_cls
194 | self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
195 |
196 | if not self.only_cross_attention:
197 | # only relevant for the `AddedKVProcessor` classes
198 | self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
199 | self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
200 | else:
201 | self.to_k = None
202 | self.to_v = None
203 |
204 | if self.added_kv_proj_dim is not None:
205 | self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
206 | self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
207 |
208 | self.to_out = nn.ModuleList([])
209 | self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
210 | self.to_out.append(nn.Dropout(dropout))
211 |
212 | # set attention processor
213 | # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
214 | # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
215 | # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
216 | if processor is None:
217 | processor = (
218 | AttnProcessor2_0()
219 | if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
220 | else AttnProcessor()
221 | )
222 | self.set_processor(processor)
223 |
224 | def set_processor(self, processor: "AttnProcessor") -> None:
225 | self.processor = processor
226 |
227 | def forward(
228 | self,
229 | hidden_states: torch.FloatTensor,
230 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
231 | attention_mask: Optional[torch.FloatTensor] = None,
232 | **cross_attention_kwargs,
233 | ) -> torch.Tensor:
234 | r"""
235 | The forward method of the `Attention` class.
236 |
237 | Args:
238 | hidden_states (`torch.Tensor`):
239 | The hidden states of the query.
240 | encoder_hidden_states (`torch.Tensor`, *optional*):
241 | The hidden states of the encoder.
242 | attention_mask (`torch.Tensor`, *optional*):
243 | The attention mask to use. If `None`, no mask is applied.
244 | **cross_attention_kwargs:
245 | Additional keyword arguments to pass along to the cross attention.
246 |
247 | Returns:
248 | `torch.Tensor`: The output of the attention layer.
249 | """
250 | # The `Attention` class can call different attention processors / attention functions
251 | # here we simply pass along all tensors to the selected processor class
252 | # For standard processors that are defined here, `**cross_attention_kwargs` is empty
253 | return self.processor(
254 | self,
255 | hidden_states,
256 | encoder_hidden_states=encoder_hidden_states,
257 | attention_mask=attention_mask,
258 | **cross_attention_kwargs,
259 | )
260 |
261 | def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
262 | r"""
263 | Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
264 | is the number of heads initialized while constructing the `Attention` class.
265 |
266 | Args:
267 | tensor (`torch.Tensor`): The tensor to reshape.
268 |
269 | Returns:
270 | `torch.Tensor`: The reshaped tensor.
271 | """
272 | head_size = self.heads
273 | batch_size, seq_len, dim = tensor.shape
274 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
275 | tensor = tensor.permute(0, 2, 1, 3).reshape(
276 | batch_size // head_size, seq_len, dim * head_size
277 | )
278 | return tensor
279 |
280 | def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
281 | r"""
282 | Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
283 | the number of heads initialized while constructing the `Attention` class.
284 |
285 | Args:
286 | tensor (`torch.Tensor`): The tensor to reshape.
287 | out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
288 | reshaped to `[batch_size * heads, seq_len, dim // heads]`.
289 |
290 | Returns:
291 | `torch.Tensor`: The reshaped tensor.
292 | """
293 | head_size = self.heads
294 | batch_size, seq_len, dim = tensor.shape
295 | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
296 | tensor = tensor.permute(0, 2, 1, 3)
297 |
298 | if out_dim == 3:
299 | tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
300 |
301 | return tensor
302 |
303 | def get_attention_scores(
304 | self,
305 | query: torch.Tensor,
306 | key: torch.Tensor,
307 | attention_mask: torch.Tensor = None,
308 | ) -> torch.Tensor:
309 | r"""
310 | Compute the attention scores.
311 |
312 | Args:
313 | query (`torch.Tensor`): The query tensor.
314 | key (`torch.Tensor`): The key tensor.
315 | attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
316 |
317 | Returns:
318 | `torch.Tensor`: The attention probabilities/scores.
319 | """
320 | dtype = query.dtype
321 | if self.upcast_attention:
322 | query = query.float()
323 | key = key.float()
324 |
325 | if attention_mask is None:
326 | baddbmm_input = torch.empty(
327 | query.shape[0],
328 | query.shape[1],
329 | key.shape[1],
330 | dtype=query.dtype,
331 | device=query.device,
332 | )
333 | beta = 0
334 | else:
335 | baddbmm_input = attention_mask
336 | beta = 1
337 |
338 | attention_scores = torch.baddbmm(
339 | baddbmm_input,
340 | query,
341 | key.transpose(-1, -2),
342 | beta=beta,
343 | alpha=self.scale,
344 | )
345 | del baddbmm_input
346 |
347 | if self.upcast_softmax:
348 | attention_scores = attention_scores.float()
349 |
350 | attention_probs = attention_scores.softmax(dim=-1)
351 | del attention_scores
352 |
353 | attention_probs = attention_probs.to(dtype)
354 |
355 | return attention_probs
356 |
357 | def prepare_attention_mask(
358 | self,
359 | attention_mask: torch.Tensor,
360 | target_length: int,
361 | batch_size: int,
362 | out_dim: int = 3,
363 | ) -> torch.Tensor:
364 | r"""
365 | Prepare the attention mask for the attention computation.
366 |
367 | Args:
368 | attention_mask (`torch.Tensor`):
369 | The attention mask to prepare.
370 | target_length (`int`):
371 | The target length of the attention mask. This is the length of the attention mask after padding.
372 | batch_size (`int`):
373 | The batch size, which is used to repeat the attention mask.
374 | out_dim (`int`, *optional*, defaults to `3`):
375 | The output dimension of the attention mask. Can be either `3` or `4`.
376 |
377 | Returns:
378 | `torch.Tensor`: The prepared attention mask.
379 | """
380 | head_size = self.heads
381 | if attention_mask is None:
382 | return attention_mask
383 |
384 | current_length: int = attention_mask.shape[-1]
385 | if current_length != target_length:
386 | if attention_mask.device.type == "mps":
387 | # HACK: MPS: Does not support padding by greater than dimension of input tensor.
388 | # Instead, we can manually construct the padding tensor.
389 | padding_shape = (
390 | attention_mask.shape[0],
391 | attention_mask.shape[1],
392 | target_length,
393 | )
394 | padding = torch.zeros(
395 | padding_shape,
396 | dtype=attention_mask.dtype,
397 | device=attention_mask.device,
398 | )
399 | attention_mask = torch.cat([attention_mask, padding], dim=2)
400 | else:
401 | # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
402 | # we want to instead pad by (0, remaining_length), where remaining_length is:
403 | # remaining_length: int = target_length - current_length
404 | # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
405 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
406 |
407 | if out_dim == 3:
408 | if attention_mask.shape[0] < batch_size * head_size:
409 | attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
410 | elif out_dim == 4:
411 | attention_mask = attention_mask.unsqueeze(1)
412 | attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
413 |
414 | return attention_mask
415 |
416 | def norm_encoder_hidden_states(
417 | self, encoder_hidden_states: torch.Tensor
418 | ) -> torch.Tensor:
419 | r"""
420 | Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
421 | `Attention` class.
422 |
423 | Args:
424 | encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
425 |
426 | Returns:
427 | `torch.Tensor`: The normalized encoder hidden states.
428 | """
429 | assert (
430 | self.norm_cross is not None
431 | ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
432 |
433 | if isinstance(self.norm_cross, nn.LayerNorm):
434 | encoder_hidden_states = self.norm_cross(encoder_hidden_states)
435 | elif isinstance(self.norm_cross, nn.GroupNorm):
436 | # Group norm norms along the channels dimension and expects
437 | # input to be in the shape of (N, C, *). In this case, we want
438 | # to norm along the hidden dimension, so we need to move
439 | # (batch_size, sequence_length, hidden_size) ->
440 | # (batch_size, hidden_size, sequence_length)
441 | encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
442 | encoder_hidden_states = self.norm_cross(encoder_hidden_states)
443 | encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
444 | else:
445 | assert False
446 |
447 | return encoder_hidden_states
448 |
449 | @torch.no_grad()
450 | def fuse_projections(self, fuse=True):
451 | is_cross_attention = self.cross_attention_dim != self.query_dim
452 | device = self.to_q.weight.data.device
453 | dtype = self.to_q.weight.data.dtype
454 |
455 | if not is_cross_attention:
456 | # fetch weight matrices.
457 | concatenated_weights = torch.cat(
458 | [self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]
459 | )
460 | in_features = concatenated_weights.shape[1]
461 | out_features = concatenated_weights.shape[0]
462 |
463 | # create a new single projection layer and copy over the weights.
464 | self.to_qkv = self.linear_cls(
465 | in_features, out_features, bias=False, device=device, dtype=dtype
466 | )
467 | self.to_qkv.weight.copy_(concatenated_weights)
468 |
469 | else:
470 | concatenated_weights = torch.cat(
471 | [self.to_k.weight.data, self.to_v.weight.data]
472 | )
473 | in_features = concatenated_weights.shape[1]
474 | out_features = concatenated_weights.shape[0]
475 |
476 | self.to_kv = self.linear_cls(
477 | in_features, out_features, bias=False, device=device, dtype=dtype
478 | )
479 | self.to_kv.weight.copy_(concatenated_weights)
480 |
481 | self.fused_projections = fuse
482 |
483 |
484 | class AttnProcessor:
485 | r"""
486 | Default processor for performing attention-related computations.
487 | """
488 |
489 | def __call__(
490 | self,
491 | attn: Attention,
492 | hidden_states: torch.FloatTensor,
493 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
494 | attention_mask: Optional[torch.FloatTensor] = None,
495 | ) -> torch.Tensor:
496 | residual = hidden_states
497 |
498 | input_ndim = hidden_states.ndim
499 |
500 | if input_ndim == 4:
501 | batch_size, channel, height, width = hidden_states.shape
502 | hidden_states = hidden_states.view(
503 | batch_size, channel, height * width
504 | ).transpose(1, 2)
505 |
506 | batch_size, sequence_length, _ = (
507 | hidden_states.shape
508 | if encoder_hidden_states is None
509 | else encoder_hidden_states.shape
510 | )
511 | attention_mask = attn.prepare_attention_mask(
512 | attention_mask, sequence_length, batch_size
513 | )
514 |
515 | if attn.group_norm is not None:
516 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
517 | 1, 2
518 | )
519 |
520 | query = attn.to_q(hidden_states)
521 |
522 | if encoder_hidden_states is None:
523 | encoder_hidden_states = hidden_states
524 | elif attn.norm_cross:
525 | encoder_hidden_states = attn.norm_encoder_hidden_states(
526 | encoder_hidden_states
527 | )
528 |
529 | key = attn.to_k(encoder_hidden_states)
530 | value = attn.to_v(encoder_hidden_states)
531 |
532 | query = attn.head_to_batch_dim(query)
533 | key = attn.head_to_batch_dim(key)
534 | value = attn.head_to_batch_dim(value)
535 |
536 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
537 | hidden_states = torch.bmm(attention_probs, value)
538 | hidden_states = attn.batch_to_head_dim(hidden_states)
539 |
540 | # linear proj
541 | hidden_states = attn.to_out[0](hidden_states)
542 | # dropout
543 | hidden_states = attn.to_out[1](hidden_states)
544 |
545 | if input_ndim == 4:
546 | hidden_states = hidden_states.transpose(-1, -2).reshape(
547 | batch_size, channel, height, width
548 | )
549 |
550 | if attn.residual_connection:
551 | hidden_states = hidden_states + residual
552 |
553 | hidden_states = hidden_states / attn.rescale_output_factor
554 |
555 | return hidden_states
556 |
557 |
558 | class AttnProcessor2_0:
559 | r"""
560 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
561 | """
562 |
563 | def __init__(self):
564 | if not hasattr(F, "scaled_dot_product_attention"):
565 | raise ImportError(
566 | "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
567 | )
568 |
569 | def __call__(
570 | self,
571 | attn: Attention,
572 | hidden_states: torch.FloatTensor,
573 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
574 | attention_mask: Optional[torch.FloatTensor] = None,
575 | ) -> torch.FloatTensor:
576 | residual = hidden_states
577 |
578 | input_ndim = hidden_states.ndim
579 |
580 | if input_ndim == 4:
581 | batch_size, channel, height, width = hidden_states.shape
582 | hidden_states = hidden_states.view(
583 | batch_size, channel, height * width
584 | ).transpose(1, 2)
585 |
586 | batch_size, sequence_length, _ = (
587 | hidden_states.shape
588 | if encoder_hidden_states is None
589 | else encoder_hidden_states.shape
590 | )
591 |
592 | if attention_mask is not None:
593 | attention_mask = attn.prepare_attention_mask(
594 | attention_mask, sequence_length, batch_size
595 | )
596 | # scaled_dot_product_attention expects attention_mask shape to be
597 | # (batch, heads, source_length, target_length)
598 | attention_mask = attention_mask.view(
599 | batch_size, attn.heads, -1, attention_mask.shape[-1]
600 | )
601 |
602 | if attn.group_norm is not None:
603 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
604 | 1, 2
605 | )
606 |
607 | query = attn.to_q(hidden_states)
608 |
609 | if encoder_hidden_states is None:
610 | encoder_hidden_states = hidden_states
611 | elif attn.norm_cross:
612 | encoder_hidden_states = attn.norm_encoder_hidden_states(
613 | encoder_hidden_states
614 | )
615 |
616 | key = attn.to_k(encoder_hidden_states)
617 | value = attn.to_v(encoder_hidden_states)
618 |
619 | inner_dim = key.shape[-1]
620 | head_dim = inner_dim // attn.heads
621 |
622 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
623 |
624 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
625 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
626 |
627 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
628 | # TODO: add support for attn.scale when we move to Torch 2.1
629 | hidden_states = F.scaled_dot_product_attention(
630 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
631 | )
632 |
633 | hidden_states = hidden_states.transpose(1, 2).reshape(
634 | batch_size, -1, attn.heads * head_dim
635 | )
636 | hidden_states = hidden_states.to(query.dtype)
637 |
638 | # linear proj
639 | hidden_states = attn.to_out[0](hidden_states)
640 | # dropout
641 | hidden_states = attn.to_out[1](hidden_states)
642 |
643 | if input_ndim == 4:
644 | hidden_states = hidden_states.transpose(-1, -2).reshape(
645 | batch_size, channel, height, width
646 | )
647 |
648 | if attn.residual_connection:
649 | hidden_states = hidden_states + residual
650 |
651 | hidden_states = hidden_states / attn.rescale_output_factor
652 |
653 | return hidden_states
654 |
--------------------------------------------------------------------------------
/tsr/models/transformer/basic_transformer_block.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # --------
16 | #
17 | # Modified 2024 by the Tripo AI and Stability AI Team.
18 | #
19 | # Copyright (c) 2024 Tripo AI & Stability AI
20 | #
21 | # Permission is hereby granted, free of charge, to any person obtaining a copy
22 | # of this software and associated documentation files (the "Software"), to deal
23 | # in the Software without restriction, including without limitation the rights
24 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25 | # copies of the Software, and to permit persons to whom the Software is
26 | # furnished to do so, subject to the following conditions:
27 | #
28 | # The above copyright notice and this permission notice shall be included in all
29 | # copies or substantial portions of the Software.
30 | #
31 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37 | # SOFTWARE.
38 |
39 | from typing import Optional
40 |
41 | import torch
42 | import torch.nn.functional as F
43 | from torch import nn
44 |
45 | from .attention import Attention
46 |
47 |
48 | class BasicTransformerBlock(nn.Module):
49 | r"""
50 | A basic Transformer block.
51 |
52 | Parameters:
53 | dim (`int`): The number of channels in the input and output.
54 | num_attention_heads (`int`): The number of heads to use for multi-head attention.
55 | attention_head_dim (`int`): The number of channels in each head.
56 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
57 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
58 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
59 | attention_bias (:
60 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
61 | only_cross_attention (`bool`, *optional*):
62 | Whether to use only cross-attention layers. In this case two cross attention layers are used.
63 | double_self_attention (`bool`, *optional*):
64 | Whether to use two self-attention layers. In this case no cross attention layers are used.
65 | upcast_attention (`bool`, *optional*):
66 | Whether to upcast the attention computation to float32. This is useful for mixed precision training.
67 | norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
68 | Whether to use learnable elementwise affine parameters for normalization.
69 | norm_type (`str`, *optional*, defaults to `"layer_norm"`):
70 | The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
71 | final_dropout (`bool` *optional*, defaults to False):
72 | Whether to apply a final dropout after the last feed-forward layer.
73 | """
74 |
75 | def __init__(
76 | self,
77 | dim: int,
78 | num_attention_heads: int,
79 | attention_head_dim: int,
80 | dropout=0.0,
81 | cross_attention_dim: Optional[int] = None,
82 | activation_fn: str = "geglu",
83 | attention_bias: bool = False,
84 | only_cross_attention: bool = False,
85 | double_self_attention: bool = False,
86 | upcast_attention: bool = False,
87 | norm_elementwise_affine: bool = True,
88 | norm_type: str = "layer_norm",
89 | final_dropout: bool = False,
90 | ):
91 | super().__init__()
92 | self.only_cross_attention = only_cross_attention
93 |
94 | assert norm_type == "layer_norm"
95 |
96 | # Define 3 blocks. Each block has its own normalization layer.
97 | # 1. Self-Attn
98 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
99 | self.attn1 = Attention(
100 | query_dim=dim,
101 | heads=num_attention_heads,
102 | dim_head=attention_head_dim,
103 | dropout=dropout,
104 | bias=attention_bias,
105 | cross_attention_dim=cross_attention_dim if only_cross_attention else None,
106 | upcast_attention=upcast_attention,
107 | )
108 |
109 | # 2. Cross-Attn
110 | if cross_attention_dim is not None or double_self_attention:
111 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
112 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
113 | # the second cross attention block.
114 | self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
115 |
116 | self.attn2 = Attention(
117 | query_dim=dim,
118 | cross_attention_dim=(
119 | cross_attention_dim if not double_self_attention else None
120 | ),
121 | heads=num_attention_heads,
122 | dim_head=attention_head_dim,
123 | dropout=dropout,
124 | bias=attention_bias,
125 | upcast_attention=upcast_attention,
126 | ) # is self-attn if encoder_hidden_states is none
127 | else:
128 | self.norm2 = None
129 | self.attn2 = None
130 |
131 | # 3. Feed-forward
132 | self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
133 | self.ff = FeedForward(
134 | dim,
135 | dropout=dropout,
136 | activation_fn=activation_fn,
137 | final_dropout=final_dropout,
138 | )
139 |
140 | # let chunk size default to None
141 | self._chunk_size = None
142 | self._chunk_dim = 0
143 |
144 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
145 | # Sets chunk feed-forward
146 | self._chunk_size = chunk_size
147 | self._chunk_dim = dim
148 |
149 | def forward(
150 | self,
151 | hidden_states: torch.FloatTensor,
152 | attention_mask: Optional[torch.FloatTensor] = None,
153 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
154 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
155 | ) -> torch.FloatTensor:
156 | # Notice that normalization is always applied before the real computation in the following blocks.
157 | # 0. Self-Attention
158 | norm_hidden_states = self.norm1(hidden_states)
159 |
160 | attn_output = self.attn1(
161 | norm_hidden_states,
162 | encoder_hidden_states=(
163 | encoder_hidden_states if self.only_cross_attention else None
164 | ),
165 | attention_mask=attention_mask,
166 | )
167 |
168 | hidden_states = attn_output + hidden_states
169 |
170 | # 3. Cross-Attention
171 | if self.attn2 is not None:
172 | norm_hidden_states = self.norm2(hidden_states)
173 |
174 | attn_output = self.attn2(
175 | norm_hidden_states,
176 | encoder_hidden_states=encoder_hidden_states,
177 | attention_mask=encoder_attention_mask,
178 | )
179 | hidden_states = attn_output + hidden_states
180 |
181 | # 4. Feed-forward
182 | norm_hidden_states = self.norm3(hidden_states)
183 |
184 | if self._chunk_size is not None:
185 | # "feed_forward_chunk_size" can be used to save memory
186 | if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
187 | raise ValueError(
188 | f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
189 | )
190 |
191 | num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
192 | ff_output = torch.cat(
193 | [
194 | self.ff(hid_slice)
195 | for hid_slice in norm_hidden_states.chunk(
196 | num_chunks, dim=self._chunk_dim
197 | )
198 | ],
199 | dim=self._chunk_dim,
200 | )
201 | else:
202 | ff_output = self.ff(norm_hidden_states)
203 |
204 | hidden_states = ff_output + hidden_states
205 |
206 | return hidden_states
207 |
208 |
209 | class FeedForward(nn.Module):
210 | r"""
211 | A feed-forward layer.
212 |
213 | Parameters:
214 | dim (`int`): The number of channels in the input.
215 | dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
216 | mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
217 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
218 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
219 | final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
220 | """
221 |
222 | def __init__(
223 | self,
224 | dim: int,
225 | dim_out: Optional[int] = None,
226 | mult: int = 4,
227 | dropout: float = 0.0,
228 | activation_fn: str = "geglu",
229 | final_dropout: bool = False,
230 | ):
231 | super().__init__()
232 | inner_dim = int(dim * mult)
233 | dim_out = dim_out if dim_out is not None else dim
234 | linear_cls = nn.Linear
235 |
236 | if activation_fn == "gelu":
237 | act_fn = GELU(dim, inner_dim)
238 | if activation_fn == "gelu-approximate":
239 | act_fn = GELU(dim, inner_dim, approximate="tanh")
240 | elif activation_fn == "geglu":
241 | act_fn = GEGLU(dim, inner_dim)
242 | elif activation_fn == "geglu-approximate":
243 | act_fn = ApproximateGELU(dim, inner_dim)
244 |
245 | self.net = nn.ModuleList([])
246 | # project in
247 | self.net.append(act_fn)
248 | # project dropout
249 | self.net.append(nn.Dropout(dropout))
250 | # project out
251 | self.net.append(linear_cls(inner_dim, dim_out))
252 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
253 | if final_dropout:
254 | self.net.append(nn.Dropout(dropout))
255 |
256 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
257 | for module in self.net:
258 | hidden_states = module(hidden_states)
259 | return hidden_states
260 |
261 |
262 | class GELU(nn.Module):
263 | r"""
264 | GELU activation function with tanh approximation support with `approximate="tanh"`.
265 |
266 | Parameters:
267 | dim_in (`int`): The number of channels in the input.
268 | dim_out (`int`): The number of channels in the output.
269 | approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
270 | """
271 |
272 | def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
273 | super().__init__()
274 | self.proj = nn.Linear(dim_in, dim_out)
275 | self.approximate = approximate
276 |
277 | def gelu(self, gate: torch.Tensor) -> torch.Tensor:
278 | if gate.device.type != "mps":
279 | return F.gelu(gate, approximate=self.approximate)
280 | # mps: gelu is not implemented for float16
281 | return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(
282 | dtype=gate.dtype
283 | )
284 |
285 | def forward(self, hidden_states):
286 | hidden_states = self.proj(hidden_states)
287 | hidden_states = self.gelu(hidden_states)
288 | return hidden_states
289 |
290 |
291 | class GEGLU(nn.Module):
292 | r"""
293 | A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
294 |
295 | Parameters:
296 | dim_in (`int`): The number of channels in the input.
297 | dim_out (`int`): The number of channels in the output.
298 | """
299 |
300 | def __init__(self, dim_in: int, dim_out: int):
301 | super().__init__()
302 | linear_cls = nn.Linear
303 |
304 | self.proj = linear_cls(dim_in, dim_out * 2)
305 |
306 | def gelu(self, gate: torch.Tensor) -> torch.Tensor:
307 | if gate.device.type != "mps":
308 | return F.gelu(gate)
309 | # mps: gelu is not implemented for float16
310 | return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
311 |
312 | def forward(self, hidden_states, scale: float = 1.0):
313 | args = ()
314 | hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
315 | return hidden_states * self.gelu(gate)
316 |
317 |
318 | class ApproximateGELU(nn.Module):
319 | r"""
320 | The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
321 | https://arxiv.org/abs/1606.08415.
322 |
323 | Parameters:
324 | dim_in (`int`): The number of channels in the input.
325 | dim_out (`int`): The number of channels in the output.
326 | """
327 |
328 | def __init__(self, dim_in: int, dim_out: int):
329 | super().__init__()
330 | self.proj = nn.Linear(dim_in, dim_out)
331 |
332 | def forward(self, x: torch.Tensor) -> torch.Tensor:
333 | x = self.proj(x)
334 | return x * torch.sigmoid(1.702 * x)
335 |
--------------------------------------------------------------------------------
/tsr/models/transformer/transformer_1d.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # --------
16 | #
17 | # Modified 2024 by the Tripo AI and Stability AI Team.
18 | #
19 | # Copyright (c) 2024 Tripo AI & Stability AI
20 | #
21 | # Permission is hereby granted, free of charge, to any person obtaining a copy
22 | # of this software and associated documentation files (the "Software"), to deal
23 | # in the Software without restriction, including without limitation the rights
24 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25 | # copies of the Software, and to permit persons to whom the Software is
26 | # furnished to do so, subject to the following conditions:
27 | #
28 | # The above copyright notice and this permission notice shall be included in all
29 | # copies or substantial portions of the Software.
30 | #
31 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37 | # SOFTWARE.
38 |
39 | from dataclasses import dataclass
40 | from typing import Optional
41 |
42 | import torch
43 | import torch.nn.functional as F
44 | from torch import nn
45 |
46 | from ...utils import BaseModule
47 | from .basic_transformer_block import BasicTransformerBlock
48 |
49 |
50 | class Transformer1D(BaseModule):
51 | @dataclass
52 | class Config(BaseModule.Config):
53 | num_attention_heads: int = 16
54 | attention_head_dim: int = 88
55 | in_channels: Optional[int] = None
56 | out_channels: Optional[int] = None
57 | num_layers: int = 1
58 | dropout: float = 0.0
59 | norm_num_groups: int = 32
60 | cross_attention_dim: Optional[int] = None
61 | attention_bias: bool = False
62 | activation_fn: str = "geglu"
63 | only_cross_attention: bool = False
64 | double_self_attention: bool = False
65 | upcast_attention: bool = False
66 | norm_type: str = "layer_norm"
67 | norm_elementwise_affine: bool = True
68 | gradient_checkpointing: bool = False
69 |
70 | cfg: Config
71 |
72 | def configure(self) -> None:
73 | self.num_attention_heads = self.cfg.num_attention_heads
74 | self.attention_head_dim = self.cfg.attention_head_dim
75 | inner_dim = self.num_attention_heads * self.attention_head_dim
76 |
77 | linear_cls = nn.Linear
78 |
79 | # 2. Define input layers
80 | self.in_channels = self.cfg.in_channels
81 |
82 | self.norm = torch.nn.GroupNorm(
83 | num_groups=self.cfg.norm_num_groups,
84 | num_channels=self.cfg.in_channels,
85 | eps=1e-6,
86 | affine=True,
87 | )
88 | self.proj_in = linear_cls(self.cfg.in_channels, inner_dim)
89 |
90 | # 3. Define transformers blocks
91 | self.transformer_blocks = nn.ModuleList(
92 | [
93 | BasicTransformerBlock(
94 | inner_dim,
95 | self.num_attention_heads,
96 | self.attention_head_dim,
97 | dropout=self.cfg.dropout,
98 | cross_attention_dim=self.cfg.cross_attention_dim,
99 | activation_fn=self.cfg.activation_fn,
100 | attention_bias=self.cfg.attention_bias,
101 | only_cross_attention=self.cfg.only_cross_attention,
102 | double_self_attention=self.cfg.double_self_attention,
103 | upcast_attention=self.cfg.upcast_attention,
104 | norm_type=self.cfg.norm_type,
105 | norm_elementwise_affine=self.cfg.norm_elementwise_affine,
106 | )
107 | for d in range(self.cfg.num_layers)
108 | ]
109 | )
110 |
111 | # 4. Define output layers
112 | self.out_channels = (
113 | self.cfg.in_channels
114 | if self.cfg.out_channels is None
115 | else self.cfg.out_channels
116 | )
117 |
118 | self.proj_out = linear_cls(inner_dim, self.cfg.in_channels)
119 |
120 | self.gradient_checkpointing = self.cfg.gradient_checkpointing
121 |
122 | def forward(
123 | self,
124 | hidden_states: torch.Tensor,
125 | encoder_hidden_states: Optional[torch.Tensor] = None,
126 | attention_mask: Optional[torch.Tensor] = None,
127 | encoder_attention_mask: Optional[torch.Tensor] = None,
128 | ):
129 | """
130 | The [`Transformer1DModel`] forward method.
131 |
132 | Args:
133 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
134 | Input `hidden_states`.
135 | encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
136 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
137 | self-attention.
138 | attention_mask ( `torch.Tensor`, *optional*):
139 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
140 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
141 | negative values to the attention scores corresponding to "discard" tokens.
142 | encoder_attention_mask ( `torch.Tensor`, *optional*):
143 | Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
144 |
145 | * Mask `(batch, sequence_length)` True = keep, False = discard.
146 | * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
147 |
148 | If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
149 | above. This bias will be added to the cross-attention scores.
150 |
151 | Returns:
152 | torch.FloatTensor
153 | """
154 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
155 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
156 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
157 | # expects mask of shape:
158 | # [batch, key_tokens]
159 | # adds singleton query_tokens dimension:
160 | # [batch, 1, key_tokens]
161 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
162 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
163 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
164 | if attention_mask is not None and attention_mask.ndim == 2:
165 | # assume that mask is expressed as:
166 | # (1 = keep, 0 = discard)
167 | # convert mask into a bias that can be added to attention scores:
168 | # (keep = +0, discard = -10000.0)
169 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
170 | attention_mask = attention_mask.unsqueeze(1)
171 |
172 | # convert encoder_attention_mask to a bias the same way we do for attention_mask
173 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
174 | encoder_attention_mask = (
175 | 1 - encoder_attention_mask.to(hidden_states.dtype)
176 | ) * -10000.0
177 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
178 |
179 | # 1. Input
180 | batch, _, seq_len = hidden_states.shape
181 | residual = hidden_states
182 |
183 | hidden_states = self.norm(hidden_states)
184 | inner_dim = hidden_states.shape[1]
185 | hidden_states = hidden_states.permute(0, 2, 1).reshape(
186 | batch, seq_len, inner_dim
187 | )
188 | hidden_states = self.proj_in(hidden_states)
189 |
190 | # 2. Blocks
191 | for block in self.transformer_blocks:
192 | if self.training and self.gradient_checkpointing:
193 | hidden_states = torch.utils.checkpoint.checkpoint(
194 | block,
195 | hidden_states,
196 | attention_mask,
197 | encoder_hidden_states,
198 | encoder_attention_mask,
199 | use_reentrant=False,
200 | )
201 | else:
202 | hidden_states = block(
203 | hidden_states,
204 | attention_mask=attention_mask,
205 | encoder_hidden_states=encoder_hidden_states,
206 | encoder_attention_mask=encoder_attention_mask,
207 | )
208 |
209 | # 3. Output
210 | hidden_states = self.proj_out(hidden_states)
211 | hidden_states = (
212 | hidden_states.reshape(batch, seq_len, inner_dim)
213 | .permute(0, 2, 1)
214 | .contiguous()
215 | )
216 |
217 | output = hidden_states + residual
218 |
219 | return output
220 |
--------------------------------------------------------------------------------
/tsr/system.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from dataclasses import dataclass, field
4 | from typing import List, Union
5 |
6 | import numpy as np
7 | import PIL.Image
8 | import torch
9 | import torch.nn.functional as F
10 | import trimesh
11 | from einops import rearrange
12 | from huggingface_hub import hf_hub_download
13 | from omegaconf import OmegaConf
14 | from PIL import Image
15 |
16 | from .models.isosurface import MarchingCubeHelper
17 | from .utils import (
18 | BaseModule,
19 | ImagePreprocessor,
20 | find_class,
21 | get_spherical_cameras,
22 | scale_tensor,
23 | )
24 |
25 |
26 | class TSR(BaseModule):
27 | @dataclass
28 | class Config(BaseModule.Config):
29 | cond_image_size: int
30 |
31 | image_tokenizer_cls: str
32 | image_tokenizer: dict
33 |
34 | tokenizer_cls: str
35 | tokenizer: dict
36 |
37 | backbone_cls: str
38 | backbone: dict
39 |
40 | post_processor_cls: str
41 | post_processor: dict
42 |
43 | decoder_cls: str
44 | decoder: dict
45 |
46 | renderer_cls: str
47 | renderer: dict
48 |
49 | cfg: Config
50 |
51 | @classmethod
52 | def from_pretrained(
53 | cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
54 | ):
55 | if os.path.isdir(pretrained_model_name_or_path):
56 | config_path = os.path.join(pretrained_model_name_or_path, config_name)
57 | weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
58 | else:
59 | config_path = hf_hub_download(
60 | repo_id=pretrained_model_name_or_path, filename=config_name
61 | )
62 | weight_path = hf_hub_download(
63 | repo_id=pretrained_model_name_or_path, filename=weight_name
64 | )
65 |
66 | cfg = OmegaConf.load(config_path)
67 | OmegaConf.resolve(cfg)
68 | model = cls(cfg)
69 | ckpt = torch.load(weight_path, map_location="cpu")
70 | model.load_state_dict(ckpt)
71 | return model
72 |
73 | @classmethod
74 | def from_pretrained_custom(
75 | cls, weight_path: str, config_path: str
76 | ):
77 | cfg = OmegaConf.load(config_path)
78 | OmegaConf.resolve(cfg)
79 | model = cls(cfg)
80 | ckpt = torch.load(weight_path, map_location="cpu")
81 | model.load_state_dict(ckpt)
82 | return model
83 |
84 | def configure(self):
85 | self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
86 | self.cfg.image_tokenizer
87 | )
88 | self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
89 | self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
90 | self.post_processor = find_class(self.cfg.post_processor_cls)(
91 | self.cfg.post_processor
92 | )
93 | self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
94 | self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer)
95 | self.image_processor = ImagePreprocessor()
96 | self.isosurface_helper = None
97 |
98 | def forward(
99 | self,
100 | image: Union[
101 | PIL.Image.Image,
102 | np.ndarray,
103 | torch.FloatTensor,
104 | List[PIL.Image.Image],
105 | List[np.ndarray],
106 | List[torch.FloatTensor],
107 | ],
108 | device: str,
109 | ) -> torch.FloatTensor:
110 | rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to(
111 | device
112 | )
113 | batch_size = rgb_cond.shape[0]
114 |
115 | input_image_tokens: torch.Tensor = self.image_tokenizer(
116 | rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1),
117 | )
118 |
119 | input_image_tokens = rearrange(
120 | input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1
121 | )
122 |
123 | tokens: torch.Tensor = self.tokenizer(batch_size)
124 |
125 | tokens = self.backbone(
126 | tokens,
127 | encoder_hidden_states=input_image_tokens,
128 | )
129 |
130 | scene_codes = self.post_processor(self.tokenizer.detokenize(tokens))
131 | return scene_codes
132 |
133 | def render(
134 | self,
135 | scene_codes,
136 | n_views: int,
137 | elevation_deg: float = 0.0,
138 | camera_distance: float = 1.9,
139 | fovy_deg: float = 40.0,
140 | height: int = 256,
141 | width: int = 256,
142 | return_type: str = "pil",
143 | ):
144 | rays_o, rays_d = get_spherical_cameras(
145 | n_views, elevation_deg, camera_distance, fovy_deg, height, width
146 | )
147 | rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device)
148 |
149 | def process_output(image: torch.FloatTensor):
150 | if return_type == "pt":
151 | return image
152 | elif return_type == "np":
153 | return image.detach().cpu().numpy()
154 | elif return_type == "pil":
155 | return Image.fromarray(
156 | (image.detach().cpu().numpy() * 255.0).astype(np.uint8)
157 | )
158 | else:
159 | raise NotImplementedError
160 |
161 | images = []
162 | for scene_code in scene_codes:
163 | images_ = []
164 | for i in range(n_views):
165 | with torch.no_grad():
166 | image = self.renderer(
167 | self.decoder, scene_code, rays_o[i], rays_d[i]
168 | )
169 | images_.append(process_output(image))
170 | images.append(images_)
171 |
172 | return images
173 |
174 | def set_marching_cubes_resolution(self, resolution: int):
175 | if (
176 | self.isosurface_helper is not None
177 | and self.isosurface_helper.resolution == resolution
178 | ):
179 | return
180 | self.isosurface_helper = MarchingCubeHelper(resolution)
181 |
182 | def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 25.0):
183 | self.set_marching_cubes_resolution(resolution)
184 | meshes = []
185 | for scene_code in scene_codes:
186 | with torch.no_grad():
187 | density = self.renderer.query_triplane(
188 | self.decoder,
189 | scale_tensor(
190 | self.isosurface_helper.grid_vertices.to(scene_codes.device),
191 | self.isosurface_helper.points_range,
192 | (-self.renderer.cfg.radius, self.renderer.cfg.radius),
193 | ),
194 | scene_code,
195 | )["density_act"]
196 | v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold))
197 | v_pos = scale_tensor(
198 | v_pos,
199 | self.isosurface_helper.points_range,
200 | (-self.renderer.cfg.radius, self.renderer.cfg.radius),
201 | )
202 | with torch.no_grad():
203 | color = self.renderer.query_triplane(
204 | self.decoder,
205 | v_pos,
206 | scene_code,
207 | )["color"]
208 | mesh = trimesh.Trimesh(
209 | vertices=v_pos.cpu().numpy(),
210 | faces=t_pos_idx.cpu().numpy(),
211 | vertex_colors=color.cpu().numpy(),
212 | )
213 | meshes.append(mesh)
214 | return meshes
215 |
--------------------------------------------------------------------------------
/tsr/utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import math
3 | from collections import defaultdict
4 | from dataclasses import dataclass
5 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6 |
7 | import imageio
8 | import numpy as np
9 | import PIL.Image
10 | #import rembg
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import trimesh
15 | from omegaconf import DictConfig, OmegaConf
16 | #from PIL import Image
17 |
18 |
19 | def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
20 | scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg)
21 | return scfg
22 |
23 |
24 | def find_class(cls_string):
25 | module_string = ".".join(cls_string.split(".")[:-1])
26 | cls_name = cls_string.split(".")[-1]
27 | module = importlib.import_module(module_string, package=None)
28 | cls = getattr(module, cls_name)
29 | return cls
30 |
31 |
32 | def get_intrinsic_from_fov(fov, H, W, bs=-1):
33 | focal_length = 0.5 * H / np.tan(0.5 * fov)
34 | intrinsic = np.identity(3, dtype=np.float32)
35 | intrinsic[0, 0] = focal_length
36 | intrinsic[1, 1] = focal_length
37 | intrinsic[0, 2] = W / 2.0
38 | intrinsic[1, 2] = H / 2.0
39 |
40 | if bs > 0:
41 | intrinsic = intrinsic[None].repeat(bs, axis=0)
42 |
43 | return torch.from_numpy(intrinsic)
44 |
45 |
46 | class BaseModule(nn.Module):
47 | @dataclass
48 | class Config:
49 | pass
50 |
51 | cfg: Config # add this to every subclass of BaseModule to enable static type checking
52 |
53 | def __init__(
54 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
55 | ) -> None:
56 | super().__init__()
57 | self.cfg = parse_structured(self.Config, cfg)
58 | self.configure(*args, **kwargs)
59 |
60 | def configure(self, *args, **kwargs) -> None:
61 | raise NotImplementedError
62 |
63 |
64 | class ImagePreprocessor:
65 | def convert_and_resize(
66 | self,
67 | image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
68 | size: int,
69 | ):
70 | if isinstance(image, PIL.Image.Image):
71 | image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
72 | elif isinstance(image, np.ndarray):
73 | if image.dtype == np.uint8:
74 | image = torch.from_numpy(image.astype(np.float32) / 255.0)
75 | else:
76 | image = torch.from_numpy(image)
77 | elif isinstance(image, torch.Tensor):
78 | pass
79 |
80 | batched = image.ndim == 4
81 |
82 | if not batched:
83 | image = image[None, ...]
84 | image = F.interpolate(
85 | image.permute(0, 3, 1, 2),
86 | (size, size),
87 | mode="bilinear",
88 | align_corners=False,
89 | antialias=True,
90 | ).permute(0, 2, 3, 1)
91 | if not batched:
92 | image = image[0]
93 | return image
94 |
95 | def __call__(
96 | self,
97 | image: Union[
98 | PIL.Image.Image,
99 | np.ndarray,
100 | torch.FloatTensor,
101 | List[PIL.Image.Image],
102 | List[np.ndarray],
103 | List[torch.FloatTensor],
104 | ],
105 | size: int,
106 | ) -> Any:
107 | if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
108 | image = self.convert_and_resize(image, size)
109 | else:
110 | if not isinstance(image, list):
111 | image = [image]
112 | image = [self.convert_and_resize(im, size) for im in image]
113 | image = torch.stack(image, dim=0)
114 | return image
115 |
116 |
117 | def rays_intersect_bbox(
118 | rays_o: torch.Tensor,
119 | rays_d: torch.Tensor,
120 | radius: float,
121 | near: float = 0.0,
122 | valid_thresh: float = 0.01,
123 | ):
124 | input_shape = rays_o.shape[:-1]
125 | rays_o, rays_d = rays_o.view(-1, 3), rays_d.view(-1, 3)
126 | rays_d_valid = torch.where(
127 | rays_d.abs() < 1e-6, torch.full_like(rays_d, 1e-6), rays_d
128 | )
129 | if type(radius) in [int, float]:
130 | radius = torch.FloatTensor(
131 | [[-radius, radius], [-radius, radius], [-radius, radius]]
132 | ).to(rays_o.device)
133 | radius = (
134 | 1.0 - 1.0e-3
135 | ) * radius # tighten the radius to make sure the intersection point lies in the bounding box
136 | interx0 = (radius[..., 1] - rays_o) / rays_d_valid
137 | interx1 = (radius[..., 0] - rays_o) / rays_d_valid
138 | t_near = torch.minimum(interx0, interx1).amax(dim=-1).clamp_min(near)
139 | t_far = torch.maximum(interx0, interx1).amin(dim=-1)
140 |
141 | # check wheter a ray intersects the bbox or not
142 | rays_valid = t_far - t_near > valid_thresh
143 |
144 | t_near[torch.where(~rays_valid)] = 0.0
145 | t_far[torch.where(~rays_valid)] = 0.0
146 |
147 | t_near = t_near.view(*input_shape, 1)
148 | t_far = t_far.view(*input_shape, 1)
149 | rays_valid = rays_valid.view(*input_shape)
150 |
151 | return t_near, t_far, rays_valid
152 |
153 |
154 | def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any:
155 | if chunk_size <= 0:
156 | return func(*args, **kwargs)
157 | B = None
158 | for arg in list(args) + list(kwargs.values()):
159 | if isinstance(arg, torch.Tensor):
160 | B = arg.shape[0]
161 | break
162 | assert (
163 | B is not None
164 | ), "No tensor found in args or kwargs, cannot determine batch size."
165 | out = defaultdict(list)
166 | out_type = None
167 | # max(1, B) to support B == 0
168 | for i in range(0, max(1, B), chunk_size):
169 | out_chunk = func(
170 | *[
171 | arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
172 | for arg in args
173 | ],
174 | **{
175 | k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
176 | for k, arg in kwargs.items()
177 | },
178 | )
179 | if out_chunk is None:
180 | continue
181 | out_type = type(out_chunk)
182 | if isinstance(out_chunk, torch.Tensor):
183 | out_chunk = {0: out_chunk}
184 | elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list):
185 | chunk_length = len(out_chunk)
186 | out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)}
187 | elif isinstance(out_chunk, dict):
188 | pass
189 | else:
190 | print(
191 | f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}."
192 | )
193 | exit(1)
194 | for k, v in out_chunk.items():
195 | v = v if torch.is_grad_enabled() else v.detach()
196 | out[k].append(v)
197 |
198 | if out_type is None:
199 | return None
200 |
201 | out_merged: Dict[Any, Optional[torch.Tensor]] = {}
202 | for k, v in out.items():
203 | if all([vv is None for vv in v]):
204 | # allow None in return value
205 | out_merged[k] = None
206 | elif all([isinstance(vv, torch.Tensor) for vv in v]):
207 | out_merged[k] = torch.cat(v, dim=0)
208 | else:
209 | raise TypeError(
210 | f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}"
211 | )
212 |
213 | if out_type is torch.Tensor:
214 | return out_merged[0]
215 | elif out_type in [tuple, list]:
216 | return out_type([out_merged[i] for i in range(chunk_length)])
217 | elif out_type is dict:
218 | return out_merged
219 |
220 |
221 | ValidScale = Union[Tuple[float, float], torch.FloatTensor]
222 |
223 |
224 | def scale_tensor(dat: torch.FloatTensor, inp_scale: ValidScale, tgt_scale: ValidScale):
225 | if inp_scale is None:
226 | inp_scale = (0, 1)
227 | if tgt_scale is None:
228 | tgt_scale = (0, 1)
229 | if isinstance(tgt_scale, torch.FloatTensor):
230 | assert dat.shape[-1] == tgt_scale.shape[-1]
231 | dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
232 | dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
233 | return dat
234 |
235 |
236 | def get_activation(name) -> Callable:
237 | if name is None:
238 | return lambda x: x
239 | name = name.lower()
240 | if name == "none":
241 | return lambda x: x
242 | elif name == "exp":
243 | return lambda x: torch.exp(x)
244 | elif name == "sigmoid":
245 | return lambda x: torch.sigmoid(x)
246 | elif name == "tanh":
247 | return lambda x: torch.tanh(x)
248 | elif name == "softplus":
249 | return lambda x: F.softplus(x)
250 | else:
251 | try:
252 | return getattr(F, name)
253 | except AttributeError:
254 | raise ValueError(f"Unknown activation function: {name}")
255 |
256 |
257 | def get_ray_directions(
258 | H: int,
259 | W: int,
260 | focal: Union[float, Tuple[float, float]],
261 | principal: Optional[Tuple[float, float]] = None,
262 | use_pixel_centers: bool = True,
263 | normalize: bool = True,
264 | ) -> torch.FloatTensor:
265 | """
266 | Get ray directions for all pixels in camera coordinate.
267 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
268 | ray-tracing-generating-camera-rays/standard-coordinate-systems
269 |
270 | Inputs:
271 | H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers
272 | Outputs:
273 | directions: (H, W, 3), the direction of the rays in camera coordinate
274 | """
275 | pixel_center = 0.5 if use_pixel_centers else 0
276 |
277 | if isinstance(focal, float):
278 | fx, fy = focal, focal
279 | cx, cy = W / 2, H / 2
280 | else:
281 | fx, fy = focal
282 | assert principal is not None
283 | cx, cy = principal
284 |
285 | i, j = torch.meshgrid(
286 | torch.arange(W, dtype=torch.float32) + pixel_center,
287 | torch.arange(H, dtype=torch.float32) + pixel_center,
288 | indexing="xy",
289 | )
290 |
291 | directions = torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1)
292 |
293 | if normalize:
294 | directions = F.normalize(directions, dim=-1)
295 |
296 | return directions
297 |
298 |
299 | def get_rays(
300 | directions,
301 | c2w,
302 | keepdim=False,
303 | normalize=False,
304 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
305 | # Rotate ray directions from camera coordinate to the world coordinate
306 | assert directions.shape[-1] == 3
307 |
308 | if directions.ndim == 2: # (N_rays, 3)
309 | if c2w.ndim == 2: # (4, 4)
310 | c2w = c2w[None, :, :]
311 | assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4)
312 | rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3)
313 | rays_o = c2w[:, :3, 3].expand(rays_d.shape)
314 | elif directions.ndim == 3: # (H, W, 3)
315 | assert c2w.ndim in [2, 3]
316 | if c2w.ndim == 2: # (4, 4)
317 | rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum(
318 | -1
319 | ) # (H, W, 3)
320 | rays_o = c2w[None, None, :3, 3].expand(rays_d.shape)
321 | elif c2w.ndim == 3: # (B, 4, 4)
322 | rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
323 | -1
324 | ) # (B, H, W, 3)
325 | rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
326 | elif directions.ndim == 4: # (B, H, W, 3)
327 | assert c2w.ndim == 3 # (B, 4, 4)
328 | rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
329 | -1
330 | ) # (B, H, W, 3)
331 | rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
332 |
333 | if normalize:
334 | rays_d = F.normalize(rays_d, dim=-1)
335 | if not keepdim:
336 | rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
337 |
338 | return rays_o, rays_d
339 |
340 |
341 | def get_spherical_cameras(
342 | n_views: int,
343 | elevation_deg: float,
344 | camera_distance: float,
345 | fovy_deg: float,
346 | height: int,
347 | width: int,
348 | ):
349 | azimuth_deg = torch.linspace(0, 360.0, n_views + 1)[:n_views]
350 | elevation_deg = torch.full_like(azimuth_deg, elevation_deg)
351 | camera_distances = torch.full_like(elevation_deg, camera_distance)
352 |
353 | elevation = elevation_deg * math.pi / 180
354 | azimuth = azimuth_deg * math.pi / 180
355 |
356 | # convert spherical coordinates to cartesian coordinates
357 | # right hand coordinate system, x back, y right, z up
358 | # elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
359 | camera_positions = torch.stack(
360 | [
361 | camera_distances * torch.cos(elevation) * torch.cos(azimuth),
362 | camera_distances * torch.cos(elevation) * torch.sin(azimuth),
363 | camera_distances * torch.sin(elevation),
364 | ],
365 | dim=-1,
366 | )
367 |
368 | # default scene center at origin
369 | center = torch.zeros_like(camera_positions)
370 | # default camera up direction as +z
371 | up = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None, :].repeat(n_views, 1)
372 |
373 | fovy = torch.full_like(elevation_deg, fovy_deg) * math.pi / 180
374 |
375 | lookat = F.normalize(center - camera_positions, dim=-1)
376 | right = F.normalize(torch.cross(lookat, up), dim=-1)
377 | up = F.normalize(torch.cross(right, lookat), dim=-1)
378 | c2w3x4 = torch.cat(
379 | [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
380 | dim=-1,
381 | )
382 | c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
383 | c2w[:, 3, 3] = 1.0
384 |
385 | # get directions by dividing directions_unit_focal by focal length
386 | focal_length = 0.5 * height / torch.tan(0.5 * fovy)
387 | directions_unit_focal = get_ray_directions(
388 | H=height,
389 | W=width,
390 | focal=1.0,
391 | )
392 | directions = directions_unit_focal[None, :, :, :].repeat(n_views, 1, 1, 1)
393 | directions[:, :, :, :2] = (
394 | directions[:, :, :, :2] / focal_length[:, None, None, None]
395 | )
396 | # must use normalize=True to normalize directions here
397 | rays_o, rays_d = get_rays(directions, c2w, keepdim=True, normalize=True)
398 |
399 | return rays_o, rays_d
400 |
401 |
402 | # def remove_background(
403 | # image: PIL.Image.Image,
404 | # rembg_session: Any = None,
405 | # force: bool = False,
406 | # **rembg_kwargs,
407 | # ) -> PIL.Image.Image:
408 | # do_remove = True
409 | # if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
410 | # do_remove = False
411 | # do_remove = do_remove or force
412 | # if do_remove:
413 | # image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
414 | # return image
415 |
416 |
417 | def resize_foreground(
418 | image: PIL.Image.Image,
419 | ratio: float,
420 | ) -> PIL.Image.Image:
421 | image = np.array(image)
422 | assert image.shape[-1] == 4
423 | alpha = np.where(image[..., 3] > 0)
424 | y1, y2, x1, x2 = (
425 | alpha[0].min(),
426 | alpha[0].max(),
427 | alpha[1].min(),
428 | alpha[1].max(),
429 | )
430 | # crop the foreground
431 | fg = image[y1:y2, x1:x2]
432 | # pad to square
433 | size = max(fg.shape[0], fg.shape[1])
434 | ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
435 | ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
436 | new_image = np.pad(
437 | fg,
438 | ((ph0, ph1), (pw0, pw1), (0, 0)),
439 | mode="constant",
440 | constant_values=((0, 0), (0, 0), (0, 0)),
441 | )
442 |
443 | # compute padding according to the ratio
444 | new_size = int(new_image.shape[0] / ratio)
445 | # pad to size, double side
446 | ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
447 | ph1, pw1 = new_size - size - ph0, new_size - size - pw0
448 | new_image = np.pad(
449 | new_image,
450 | ((ph0, ph1), (pw0, pw1), (0, 0)),
451 | mode="constant",
452 | constant_values=((0, 0), (0, 0), (0, 0)),
453 | )
454 | new_image = PIL.Image.fromarray(new_image)
455 | return new_image
456 |
457 |
458 | def save_video(
459 | frames: List[PIL.Image.Image],
460 | output_path: str,
461 | fps: int = 30,
462 | ):
463 | # use imageio to save video
464 | frames = [np.array(frame) for frame in frames]
465 | writer = imageio.get_writer(output_path, fps=fps)
466 | for frame in frames:
467 | writer.append_data(frame)
468 | writer.close()
469 |
470 |
471 | def to_gradio_3d_orientation(mesh):
472 | mesh.apply_transform(trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0]))
473 | mesh.apply_scale([1, 1, -1])
474 | mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi/2, [0, 1, 0]))
475 | return mesh
476 |
--------------------------------------------------------------------------------
/web/html/threeVisualizer.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |