├── .gitignore ├── LICENSE.txt ├── README.md ├── __init__.py ├── config.yaml ├── logo-dark.svg ├── logo.svg ├── requirements.txt ├── tsr ├── models │ ├── isosurface.py │ ├── nerf_renderer.py │ ├── network_utils.py │ ├── tokenizers │ │ ├── image.py │ │ └── triplane.py │ └── transformer │ │ ├── attention.py │ │ ├── basic_transformer_block.py │ │ └── transformer_1d.py ├── system.py └── utils.py ├── web ├── html │ └── threeVisualizer.html ├── js │ └── threeVisualizer.js ├── style │ ├── progressStyle.css │ └── threeStyle.css └── visualization.js ├── workflow-sample.png ├── workflow_rembg.json └── workflow_simple.json /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .DS_Store -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-Flowty-TripoSR 2 | 3 | This is a custom node that lets you use TripoSR right from ComfyUI. 4 | 5 | [TripoSR](https://github.com/VAST-AI-Research/TripoSR) is a state-of-the-art open-source model for fast feedforward 3D reconstruction from a single image, collaboratively developed by Tripo AI and Stability AI. (TL;DR it creates a 3d model from an image.) 6 | 7 | ![example](workflow-sample.png) 8 | 9 | I've created this node for experimentation, feel free to submit PRs for performance improvements etc. 10 | 11 | ### Installation: 12 | * Install ComfyUI 13 | * Clone this repo into ```custom_nodes```: 14 | ```shell 15 | $ cd ComfyUI/custom_nodes 16 | $ git clone https://github.com/flowtyone/ComfyUI-Flowty-TripoSR.git 17 | ``` 18 | * Install dependencies: 19 | ```shell 20 | $ cd ComfyUI-Flowty-TripoSR 21 | $ pip install -r requirements.txt 22 | ``` 23 | * [Download TripoSR](https://huggingface.co/stabilityai/TripoSR/blob/main/model.ckpt) and place it in ```ComfyUI/models/checkpoints``` 24 | * Start ComfyUI (or restart) 25 | 26 | Special thanks to MrForExample for creating [ComfyUI-3D-Pack](https://github.com/MrForExample/ComfyUI-3D-Pack). Code from that node pack was used to display 3d models in comfyui. 27 | 28 | This is a community project from [flowt.ai](https://flowt.ai). If you like it, check us out! 29 | 30 | 31 | 32 | 33 | flowt.ai logo 34 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from os import path 3 | 4 | sys.path.insert(0, path.dirname(__file__)) 5 | from folder_paths import get_filename_list, get_full_path, get_save_image_path, get_output_directory 6 | from comfy.model_management import get_torch_device 7 | from tsr.system import TSR 8 | from PIL import Image 9 | import numpy as np 10 | import torch 11 | 12 | 13 | def fill_background(image): 14 | image = np.array(image).astype(np.float32) / 255.0 15 | image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5 16 | image = Image.fromarray((image * 255.0).astype(np.uint8)) 17 | return image 18 | 19 | 20 | class TripoSRModelLoader: 21 | def __init__(self): 22 | self.initialized_model = None 23 | 24 | @classmethod 25 | def INPUT_TYPES(s): 26 | return { 27 | "required": { 28 | "model": (get_filename_list("checkpoints"),), 29 | "chunk_size": ("INT", {"default": 8192, "min": 1, "max": 10000}) 30 | } 31 | } 32 | 33 | RETURN_TYPES = ("TRIPOSR_MODEL",) 34 | FUNCTION = "load" 35 | CATEGORY = "Flowty TripoSR" 36 | 37 | def load(self, model, chunk_size): 38 | device = get_torch_device() 39 | 40 | if not torch.cuda.is_available(): 41 | device = "cpu" 42 | 43 | if not self.initialized_model: 44 | print("Loading TripoSR model") 45 | self.initialized_model = TSR.from_pretrained_custom( 46 | weight_path=get_full_path("checkpoints", model), 47 | config_path=path.join(path.dirname(__file__), "config.yaml") 48 | ) 49 | self.initialized_model.renderer.set_chunk_size(chunk_size) 50 | self.initialized_model.to(device) 51 | 52 | return (self.initialized_model,) 53 | 54 | 55 | class TripoSRSampler: 56 | 57 | def __init__(self): 58 | self.initialized_model = None 59 | 60 | @classmethod 61 | def INPUT_TYPES(s): 62 | return { 63 | "required": { 64 | "model": ("TRIPOSR_MODEL",), 65 | "reference_image": ("IMAGE",), 66 | "geometry_resolution": ("INT", {"default": 256, "min": 128, "max": 12288}), 67 | "threshold": ("FLOAT", {"default": 25.0, "min": 0.0, "step": 0.01}), 68 | }, 69 | "optional": { 70 | "reference_mask": ("MASK",) 71 | } 72 | } 73 | 74 | RETURN_TYPES = ("MESH",) 75 | FUNCTION = "sample" 76 | CATEGORY = "Flowty TripoSR" 77 | 78 | def sample(self, model, reference_image, geometry_resolution, threshold, reference_mask=None): 79 | device = get_torch_device() 80 | 81 | if not torch.cuda.is_available(): 82 | device = "cpu" 83 | 84 | image = reference_image[0] 85 | 86 | if reference_mask is not None: 87 | mask = reference_mask[0].unsqueeze(2) 88 | image = torch.cat((image, mask), dim=2).detach().cpu().numpy() 89 | else: 90 | image = image.detach().cpu().numpy() 91 | 92 | image = Image.fromarray(np.clip(255. * image, 0, 255).astype(np.uint8)) 93 | if reference_mask is not None: 94 | image = fill_background(image) 95 | image = image.convert('RGB') 96 | scene_codes = model([image], device) 97 | meshes = model.extract_mesh(scene_codes, resolution=geometry_resolution, threshold=threshold) 98 | return ([meshes[0]],) 99 | 100 | 101 | class TripoSRViewer: 102 | @classmethod 103 | def INPUT_TYPES(s): 104 | return { 105 | "required": { 106 | "mesh": ("MESH",) 107 | } 108 | } 109 | 110 | RETURN_TYPES = () 111 | OUTPUT_NODE = True 112 | FUNCTION = "display" 113 | CATEGORY = "Flowty TripoSR" 114 | 115 | def display(self, mesh): 116 | saved = list() 117 | full_output_folder, filename, counter, subfolder, filename_prefix = get_save_image_path("meshsave", 118 | get_output_directory()) 119 | 120 | for (batch_number, single_mesh) in enumerate(mesh): 121 | filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) 122 | file = f"{filename_with_batch_num}_{counter:05}_.obj" 123 | single_mesh.apply_transform(np.array([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]])) 124 | single_mesh.export(path.join(full_output_folder, file)) 125 | saved.append({ 126 | "filename": file, 127 | "type": "output", 128 | "subfolder": subfolder 129 | }) 130 | 131 | return {"ui": {"mesh": saved}} 132 | 133 | 134 | NODE_CLASS_MAPPINGS = { 135 | "TripoSRModelLoader": TripoSRModelLoader, 136 | "TripoSRSampler": TripoSRSampler, 137 | "TripoSRViewer": TripoSRViewer 138 | } 139 | 140 | NODE_DISPLAY_NAME_MAPPINGS = { 141 | "TripoSRModelLoader": "TripoSR Model Loader", 142 | "TripoSRSampler": "TripoSR Sampler", 143 | "TripoSRViewer": "TripoSR Viewer" 144 | } 145 | 146 | WEB_DIRECTORY = "./web" 147 | 148 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS', 'WEB_DIRECTORY'] 149 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | cond_image_size: 512 2 | 3 | image_tokenizer_cls: tsr.models.tokenizers.image.DINOSingleImageTokenizer 4 | image_tokenizer: 5 | pretrained_model_name_or_path: "facebook/dino-vitb16" 6 | 7 | tokenizer_cls: tsr.models.tokenizers.triplane.Triplane1DTokenizer 8 | tokenizer: 9 | plane_size: 32 10 | num_channels: 1024 11 | 12 | backbone_cls: tsr.models.transformer.transformer_1d.Transformer1D 13 | backbone: 14 | in_channels: ${tokenizer.num_channels} 15 | num_attention_heads: 16 16 | attention_head_dim: 64 17 | num_layers: 16 18 | cross_attention_dim: 768 19 | 20 | post_processor_cls: tsr.models.network_utils.TriplaneUpsampleNetwork 21 | post_processor: 22 | in_channels: 1024 23 | out_channels: 40 24 | 25 | decoder_cls: tsr.models.network_utils.NeRFMLP 26 | decoder: 27 | in_channels: 120 # 3 * 40 28 | n_neurons: 64 29 | n_hidden_layers: 9 30 | activation: silu 31 | 32 | renderer_cls: tsr.models.nerf_renderer.TriplaneNeRFRenderer 33 | renderer: 34 | radius: 0.87 # slightly larger than 0.5 * sqrt(3) 35 | feature_reduction: concat 36 | density_activation: exp 37 | density_bias: -1.0 38 | num_samples_per_ray: 128 -------------------------------------------------------------------------------- /logo-dark.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | 9 | 10 | 12 | 14 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | 9 | 10 | 12 | 14 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | omegaconf==2.3.0 2 | Pillow==10.1.0 3 | einops==0.7.0 4 | transformers==4.35.0 5 | trimesh==4.0.5 6 | huggingface-hub 7 | imageio[ffmpeg] 8 | scikit-image -------------------------------------------------------------------------------- /tsr/models/isosurface.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from skimage import measure 7 | 8 | 9 | class IsosurfaceHelper(nn.Module): 10 | points_range: Tuple[float, float] = (0, 1) 11 | 12 | @property 13 | def grid_vertices(self) -> torch.FloatTensor: 14 | raise NotImplementedError 15 | 16 | 17 | class MarchingCubeHelper(IsosurfaceHelper): 18 | def __init__(self, resolution: int) -> None: 19 | super().__init__() 20 | self.resolution = resolution 21 | #self.mc_func: Callable = marching_cubes 22 | self._grid_vertices: Optional[torch.FloatTensor] = None 23 | 24 | @property 25 | def grid_vertices(self) -> torch.FloatTensor: 26 | if self._grid_vertices is None: 27 | # keep the vertices on CPU so that we can support very large resolution 28 | x, y, z = ( 29 | torch.linspace(*self.points_range, self.resolution), 30 | torch.linspace(*self.points_range, self.resolution), 31 | torch.linspace(*self.points_range, self.resolution), 32 | ) 33 | x, y, z = torch.meshgrid(x, y, z, indexing="ij") 34 | verts = torch.cat( 35 | [x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1 36 | ).reshape(-1, 3) 37 | self._grid_vertices = verts 38 | return self._grid_vertices 39 | 40 | def forward( 41 | self, 42 | level: torch.FloatTensor, 43 | ) -> Tuple[torch.FloatTensor, torch.LongTensor]: 44 | level = -level.view(self.resolution, self.resolution, self.resolution) 45 | v_pos, t_pos_idx, _, __ = measure.marching_cubes((level.detach().cpu() if level.is_cuda else level.detach()).numpy(), 0.0) #self.mc_func(level.detach(), 0.0) 46 | v_pos = torch.from_numpy(v_pos.copy()).type(torch.FloatTensor).to(level.device) 47 | t_pos_idx = torch.from_numpy(t_pos_idx.copy()).type(torch.LongTensor).to(level.device) 48 | v_pos = v_pos[..., [0, 1, 2]] 49 | t_pos_idx = t_pos_idx[..., [1, 0, 2]] 50 | v_pos = v_pos / (self.resolution - 1.0) 51 | return v_pos, t_pos_idx 52 | -------------------------------------------------------------------------------- /tsr/models/nerf_renderer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from einops import rearrange, reduce 7 | 8 | from ..utils import ( 9 | BaseModule, 10 | chunk_batch, 11 | get_activation, 12 | rays_intersect_bbox, 13 | scale_tensor, 14 | ) 15 | 16 | 17 | class TriplaneNeRFRenderer(BaseModule): 18 | @dataclass 19 | class Config(BaseModule.Config): 20 | radius: float 21 | 22 | feature_reduction: str = "concat" 23 | density_activation: str = "trunc_exp" 24 | density_bias: float = -1.0 25 | color_activation: str = "sigmoid" 26 | num_samples_per_ray: int = 128 27 | randomized: bool = False 28 | 29 | cfg: Config 30 | 31 | def configure(self) -> None: 32 | assert self.cfg.feature_reduction in ["concat", "mean"] 33 | self.chunk_size = 0 34 | 35 | def set_chunk_size(self, chunk_size: int): 36 | assert ( 37 | chunk_size >= 0 38 | ), "chunk_size must be a non-negative integer (0 for no chunking)." 39 | self.chunk_size = chunk_size 40 | 41 | def query_triplane( 42 | self, 43 | decoder: torch.nn.Module, 44 | positions: torch.Tensor, 45 | triplane: torch.Tensor, 46 | ) -> Dict[str, torch.Tensor]: 47 | input_shape = positions.shape[:-1] 48 | positions = positions.view(-1, 3) 49 | 50 | # positions in (-radius, radius) 51 | # normalized to (-1, 1) for grid sample 52 | positions = scale_tensor( 53 | positions, (-self.cfg.radius, self.cfg.radius), (-1, 1) 54 | ) 55 | 56 | def _query_chunk(x): 57 | indices2D: torch.Tensor = torch.stack( 58 | (x[..., [0, 1]], x[..., [0, 2]], x[..., [1, 2]]), 59 | dim=-3, 60 | ) 61 | out: torch.Tensor = F.grid_sample( 62 | rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3), 63 | rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3), 64 | align_corners=False, 65 | mode="bilinear", 66 | ) 67 | if self.cfg.feature_reduction == "concat": 68 | out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3) 69 | elif self.cfg.feature_reduction == "mean": 70 | out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean") 71 | else: 72 | raise NotImplementedError 73 | 74 | net_out: Dict[str, torch.Tensor] = decoder(out) 75 | return net_out 76 | 77 | if self.chunk_size > 0: 78 | net_out = chunk_batch(_query_chunk, self.chunk_size, positions) 79 | else: 80 | net_out = _query_chunk(positions) 81 | 82 | net_out["density_act"] = get_activation(self.cfg.density_activation)( 83 | net_out["density"] + self.cfg.density_bias 84 | ) 85 | net_out["color"] = get_activation(self.cfg.color_activation)( 86 | net_out["features"] 87 | ) 88 | 89 | net_out = {k: v.view(*input_shape, -1) for k, v in net_out.items()} 90 | 91 | return net_out 92 | 93 | def _forward( 94 | self, 95 | decoder: torch.nn.Module, 96 | triplane: torch.Tensor, 97 | rays_o: torch.Tensor, 98 | rays_d: torch.Tensor, 99 | **kwargs, 100 | ): 101 | rays_shape = rays_o.shape[:-1] 102 | rays_o = rays_o.view(-1, 3) 103 | rays_d = rays_d.view(-1, 3) 104 | n_rays = rays_o.shape[0] 105 | 106 | t_near, t_far, rays_valid = rays_intersect_bbox(rays_o, rays_d, self.cfg.radius) 107 | t_near, t_far = t_near[rays_valid], t_far[rays_valid] 108 | 109 | t_vals = torch.linspace( 110 | 0, 1, self.cfg.num_samples_per_ray + 1, device=triplane.device 111 | ) 112 | t_mid = (t_vals[:-1] + t_vals[1:]) / 2.0 113 | z_vals = t_near * (1 - t_mid[None]) + t_far * t_mid[None] # (N_rays, N_samples) 114 | 115 | xyz = ( 116 | rays_o[:, None, :] + z_vals[..., None] * rays_d[..., None, :] 117 | ) # (N_rays, N_sample, 3) 118 | 119 | mlp_out = self.query_triplane( 120 | decoder=decoder, 121 | positions=xyz, 122 | triplane=triplane, 123 | ) 124 | 125 | eps = 1e-10 126 | # deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples) 127 | deltas = t_vals[1:] - t_vals[:-1] # (N_rays, N_samples) 128 | alpha = 1 - torch.exp( 129 | -deltas * mlp_out["density_act"][..., 0] 130 | ) # (N_rays, N_samples) 131 | accum_prod = torch.cat( 132 | [ 133 | torch.ones_like(alpha[:, :1]), 134 | torch.cumprod(1 - alpha[:, :-1] + eps, dim=-1), 135 | ], 136 | dim=-1, 137 | ) 138 | weights = alpha * accum_prod # (N_rays, N_samples) 139 | comp_rgb_ = (weights[..., None] * mlp_out["color"]).sum(dim=-2) # (N_rays, 3) 140 | opacity_ = weights.sum(dim=-1) # (N_rays) 141 | 142 | comp_rgb = torch.zeros( 143 | n_rays, 3, dtype=comp_rgb_.dtype, device=comp_rgb_.device 144 | ) 145 | opacity = torch.zeros(n_rays, dtype=opacity_.dtype, device=opacity_.device) 146 | comp_rgb[rays_valid] = comp_rgb_ 147 | opacity[rays_valid] = opacity_ 148 | 149 | comp_rgb += 1 - opacity[..., None] 150 | comp_rgb = comp_rgb.view(*rays_shape, 3) 151 | 152 | return comp_rgb 153 | 154 | def forward( 155 | self, 156 | decoder: torch.nn.Module, 157 | triplane: torch.Tensor, 158 | rays_o: torch.Tensor, 159 | rays_d: torch.Tensor, 160 | ) -> Dict[str, torch.Tensor]: 161 | if triplane.ndim == 4: 162 | comp_rgb = self._forward(decoder, triplane, rays_o, rays_d) 163 | else: 164 | comp_rgb = torch.stack( 165 | [ 166 | self._forward(decoder, triplane[i], rays_o[i], rays_d[i]) 167 | for i in range(triplane.shape[0]) 168 | ], 169 | dim=0, 170 | ) 171 | 172 | return comp_rgb 173 | 174 | def train(self, mode=True): 175 | self.randomized = mode and self.cfg.randomized 176 | return super().train(mode=mode) 177 | 178 | def eval(self): 179 | self.randomized = False 180 | return super().eval() 181 | -------------------------------------------------------------------------------- /tsr/models/network_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | from einops import rearrange 7 | 8 | from ..utils import BaseModule 9 | 10 | 11 | class TriplaneUpsampleNetwork(BaseModule): 12 | @dataclass 13 | class Config(BaseModule.Config): 14 | in_channels: int 15 | out_channels: int 16 | 17 | cfg: Config 18 | 19 | def configure(self) -> None: 20 | self.upsample = nn.ConvTranspose2d( 21 | self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2 22 | ) 23 | 24 | def forward(self, triplanes: torch.Tensor) -> torch.Tensor: 25 | triplanes_up = rearrange( 26 | self.upsample( 27 | rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3) 28 | ), 29 | "(B Np) Co Hp Wp -> B Np Co Hp Wp", 30 | Np=3, 31 | ) 32 | return triplanes_up 33 | 34 | 35 | class NeRFMLP(BaseModule): 36 | @dataclass 37 | class Config(BaseModule.Config): 38 | in_channels: int 39 | n_neurons: int 40 | n_hidden_layers: int 41 | activation: str = "relu" 42 | bias: bool = True 43 | weight_init: Optional[str] = "kaiming_uniform" 44 | bias_init: Optional[str] = None 45 | 46 | cfg: Config 47 | 48 | def configure(self) -> None: 49 | layers = [ 50 | self.make_linear( 51 | self.cfg.in_channels, 52 | self.cfg.n_neurons, 53 | bias=self.cfg.bias, 54 | weight_init=self.cfg.weight_init, 55 | bias_init=self.cfg.bias_init, 56 | ), 57 | self.make_activation(self.cfg.activation), 58 | ] 59 | for i in range(self.cfg.n_hidden_layers - 1): 60 | layers += [ 61 | self.make_linear( 62 | self.cfg.n_neurons, 63 | self.cfg.n_neurons, 64 | bias=self.cfg.bias, 65 | weight_init=self.cfg.weight_init, 66 | bias_init=self.cfg.bias_init, 67 | ), 68 | self.make_activation(self.cfg.activation), 69 | ] 70 | layers += [ 71 | self.make_linear( 72 | self.cfg.n_neurons, 73 | 4, # density 1 + features 3 74 | bias=self.cfg.bias, 75 | weight_init=self.cfg.weight_init, 76 | bias_init=self.cfg.bias_init, 77 | ) 78 | ] 79 | self.layers = nn.Sequential(*layers) 80 | 81 | def make_linear( 82 | self, 83 | dim_in, 84 | dim_out, 85 | bias=True, 86 | weight_init=None, 87 | bias_init=None, 88 | ): 89 | layer = nn.Linear(dim_in, dim_out, bias=bias) 90 | 91 | if weight_init is None: 92 | pass 93 | elif weight_init == "kaiming_uniform": 94 | torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu") 95 | else: 96 | raise NotImplementedError 97 | 98 | if bias: 99 | if bias_init is None: 100 | pass 101 | elif bias_init == "zero": 102 | torch.nn.init.zeros_(layer.bias) 103 | else: 104 | raise NotImplementedError 105 | 106 | return layer 107 | 108 | def make_activation(self, activation): 109 | if activation == "relu": 110 | return nn.ReLU(inplace=True) 111 | elif activation == "silu": 112 | return nn.SiLU(inplace=True) 113 | else: 114 | raise NotImplementedError 115 | 116 | def forward(self, x): 117 | inp_shape = x.shape[:-1] 118 | x = x.reshape(-1, x.shape[-1]) 119 | 120 | features = self.layers(x) 121 | features = features.reshape(*inp_shape, -1) 122 | out = {"density": features[..., 0:1], "features": features[..., 1:4]} 123 | 124 | return out 125 | -------------------------------------------------------------------------------- /tsr/models/tokenizers/image.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn as nn 5 | from einops import rearrange 6 | from huggingface_hub import hf_hub_download 7 | from transformers.models.vit.modeling_vit import ViTModel 8 | 9 | from ...utils import BaseModule 10 | 11 | 12 | class DINOSingleImageTokenizer(BaseModule): 13 | @dataclass 14 | class Config(BaseModule.Config): 15 | pretrained_model_name_or_path: str = "facebook/dino-vitb16" 16 | enable_gradient_checkpointing: bool = False 17 | 18 | cfg: Config 19 | 20 | def configure(self) -> None: 21 | self.model: ViTModel = ViTModel( 22 | ViTModel.config_class.from_pretrained( 23 | hf_hub_download( 24 | repo_id=self.cfg.pretrained_model_name_or_path, 25 | filename="config.json", 26 | ) 27 | ) 28 | ) 29 | 30 | if self.cfg.enable_gradient_checkpointing: 31 | self.model.encoder.gradient_checkpointing = True 32 | 33 | self.register_buffer( 34 | "image_mean", 35 | torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1), 36 | persistent=False, 37 | ) 38 | self.register_buffer( 39 | "image_std", 40 | torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1), 41 | persistent=False, 42 | ) 43 | 44 | def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor: 45 | packed = False 46 | if images.ndim == 4: 47 | packed = True 48 | images = images.unsqueeze(1) 49 | 50 | batch_size, n_input_views = images.shape[:2] 51 | images = (images - self.image_mean) / self.image_std 52 | out = self.model( 53 | rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True 54 | ) 55 | local_features, global_features = out.last_hidden_state, out.pooler_output 56 | local_features = local_features.permute(0, 2, 1) 57 | local_features = rearrange( 58 | local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size 59 | ) 60 | if packed: 61 | local_features = local_features.squeeze(1) 62 | 63 | return local_features 64 | 65 | def detokenize(self, *args, **kwargs): 66 | raise NotImplementedError 67 | -------------------------------------------------------------------------------- /tsr/models/tokenizers/triplane.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | 4 | import torch 5 | import torch.nn as nn 6 | from einops import rearrange, repeat 7 | 8 | from ...utils import BaseModule 9 | 10 | 11 | class Triplane1DTokenizer(BaseModule): 12 | @dataclass 13 | class Config(BaseModule.Config): 14 | plane_size: int 15 | num_channels: int 16 | 17 | cfg: Config 18 | 19 | def configure(self) -> None: 20 | self.embeddings = nn.Parameter( 21 | torch.randn( 22 | (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size), 23 | dtype=torch.float32, 24 | ) 25 | * 1 26 | / math.sqrt(self.cfg.num_channels) 27 | ) 28 | 29 | def forward(self, batch_size: int) -> torch.Tensor: 30 | return rearrange( 31 | repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size), 32 | "B Np Ct Hp Wp -> B Ct (Np Hp Wp)", 33 | ) 34 | 35 | def detokenize(self, tokens: torch.Tensor) -> torch.Tensor: 36 | batch_size, Ct, Nt = tokens.shape 37 | assert Nt == self.cfg.plane_size**2 * 3 38 | assert Ct == self.cfg.num_channels 39 | return rearrange( 40 | tokens, 41 | "B Ct (Np Hp Wp) -> B Np Ct Hp Wp", 42 | Np=3, 43 | Hp=self.cfg.plane_size, 44 | Wp=self.cfg.plane_size, 45 | ) 46 | -------------------------------------------------------------------------------- /tsr/models/transformer/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # -------- 16 | # 17 | # Modified 2024 by the Tripo AI and Stability AI Team. 18 | # 19 | # Copyright (c) 2024 Tripo AI & Stability AI 20 | # 21 | # Permission is hereby granted, free of charge, to any person obtaining a copy 22 | # of this software and associated documentation files (the "Software"), to deal 23 | # in the Software without restriction, including without limitation the rights 24 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 25 | # copies of the Software, and to permit persons to whom the Software is 26 | # furnished to do so, subject to the following conditions: 27 | # 28 | # The above copyright notice and this permission notice shall be included in all 29 | # copies or substantial portions of the Software. 30 | # 31 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 32 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 33 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 34 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 35 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 36 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 37 | # SOFTWARE. 38 | 39 | from typing import Optional 40 | 41 | import torch 42 | import torch.nn.functional as F 43 | from torch import nn 44 | 45 | 46 | class Attention(nn.Module): 47 | r""" 48 | A cross attention layer. 49 | 50 | Parameters: 51 | query_dim (`int`): 52 | The number of channels in the query. 53 | cross_attention_dim (`int`, *optional*): 54 | The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. 55 | heads (`int`, *optional*, defaults to 8): 56 | The number of heads to use for multi-head attention. 57 | dim_head (`int`, *optional*, defaults to 64): 58 | The number of channels in each head. 59 | dropout (`float`, *optional*, defaults to 0.0): 60 | The dropout probability to use. 61 | bias (`bool`, *optional*, defaults to False): 62 | Set to `True` for the query, key, and value linear layers to contain a bias parameter. 63 | upcast_attention (`bool`, *optional*, defaults to False): 64 | Set to `True` to upcast the attention computation to `float32`. 65 | upcast_softmax (`bool`, *optional*, defaults to False): 66 | Set to `True` to upcast the softmax computation to `float32`. 67 | cross_attention_norm (`str`, *optional*, defaults to `None`): 68 | The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. 69 | cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): 70 | The number of groups to use for the group norm in the cross attention. 71 | added_kv_proj_dim (`int`, *optional*, defaults to `None`): 72 | The number of channels to use for the added key and value projections. If `None`, no projection is used. 73 | norm_num_groups (`int`, *optional*, defaults to `None`): 74 | The number of groups to use for the group norm in the attention. 75 | spatial_norm_dim (`int`, *optional*, defaults to `None`): 76 | The number of channels to use for the spatial normalization. 77 | out_bias (`bool`, *optional*, defaults to `True`): 78 | Set to `True` to use a bias in the output linear layer. 79 | scale_qk (`bool`, *optional*, defaults to `True`): 80 | Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. 81 | only_cross_attention (`bool`, *optional*, defaults to `False`): 82 | Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if 83 | `added_kv_proj_dim` is not `None`. 84 | eps (`float`, *optional*, defaults to 1e-5): 85 | An additional value added to the denominator in group normalization that is used for numerical stability. 86 | rescale_output_factor (`float`, *optional*, defaults to 1.0): 87 | A factor to rescale the output by dividing it with this value. 88 | residual_connection (`bool`, *optional*, defaults to `False`): 89 | Set to `True` to add the residual connection to the output. 90 | _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): 91 | Set to `True` if the attention block is loaded from a deprecated state dict. 92 | processor (`AttnProcessor`, *optional*, defaults to `None`): 93 | The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and 94 | `AttnProcessor` otherwise. 95 | """ 96 | 97 | def __init__( 98 | self, 99 | query_dim: int, 100 | cross_attention_dim: Optional[int] = None, 101 | heads: int = 8, 102 | dim_head: int = 64, 103 | dropout: float = 0.0, 104 | bias: bool = False, 105 | upcast_attention: bool = False, 106 | upcast_softmax: bool = False, 107 | cross_attention_norm: Optional[str] = None, 108 | cross_attention_norm_num_groups: int = 32, 109 | added_kv_proj_dim: Optional[int] = None, 110 | norm_num_groups: Optional[int] = None, 111 | out_bias: bool = True, 112 | scale_qk: bool = True, 113 | only_cross_attention: bool = False, 114 | eps: float = 1e-5, 115 | rescale_output_factor: float = 1.0, 116 | residual_connection: bool = False, 117 | _from_deprecated_attn_block: bool = False, 118 | processor: Optional["AttnProcessor"] = None, 119 | out_dim: int = None, 120 | ): 121 | super().__init__() 122 | self.inner_dim = out_dim if out_dim is not None else dim_head * heads 123 | self.query_dim = query_dim 124 | self.cross_attention_dim = ( 125 | cross_attention_dim if cross_attention_dim is not None else query_dim 126 | ) 127 | self.upcast_attention = upcast_attention 128 | self.upcast_softmax = upcast_softmax 129 | self.rescale_output_factor = rescale_output_factor 130 | self.residual_connection = residual_connection 131 | self.dropout = dropout 132 | self.fused_projections = False 133 | self.out_dim = out_dim if out_dim is not None else query_dim 134 | 135 | # we make use of this private variable to know whether this class is loaded 136 | # with an deprecated state dict so that we can convert it on the fly 137 | self._from_deprecated_attn_block = _from_deprecated_attn_block 138 | 139 | self.scale_qk = scale_qk 140 | self.scale = dim_head**-0.5 if self.scale_qk else 1.0 141 | 142 | self.heads = out_dim // dim_head if out_dim is not None else heads 143 | # for slice_size > 0 the attention score computation 144 | # is split across the batch axis to save memory 145 | # You can set slice_size with `set_attention_slice` 146 | self.sliceable_head_dim = heads 147 | 148 | self.added_kv_proj_dim = added_kv_proj_dim 149 | self.only_cross_attention = only_cross_attention 150 | 151 | if self.added_kv_proj_dim is None and self.only_cross_attention: 152 | raise ValueError( 153 | "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." 154 | ) 155 | 156 | if norm_num_groups is not None: 157 | self.group_norm = nn.GroupNorm( 158 | num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True 159 | ) 160 | else: 161 | self.group_norm = None 162 | 163 | self.spatial_norm = None 164 | 165 | if cross_attention_norm is None: 166 | self.norm_cross = None 167 | elif cross_attention_norm == "layer_norm": 168 | self.norm_cross = nn.LayerNorm(self.cross_attention_dim) 169 | elif cross_attention_norm == "group_norm": 170 | if self.added_kv_proj_dim is not None: 171 | # The given `encoder_hidden_states` are initially of shape 172 | # (batch_size, seq_len, added_kv_proj_dim) before being projected 173 | # to (batch_size, seq_len, cross_attention_dim). The norm is applied 174 | # before the projection, so we need to use `added_kv_proj_dim` as 175 | # the number of channels for the group norm. 176 | norm_cross_num_channels = added_kv_proj_dim 177 | else: 178 | norm_cross_num_channels = self.cross_attention_dim 179 | 180 | self.norm_cross = nn.GroupNorm( 181 | num_channels=norm_cross_num_channels, 182 | num_groups=cross_attention_norm_num_groups, 183 | eps=1e-5, 184 | affine=True, 185 | ) 186 | else: 187 | raise ValueError( 188 | f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" 189 | ) 190 | 191 | linear_cls = nn.Linear 192 | 193 | self.linear_cls = linear_cls 194 | self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) 195 | 196 | if not self.only_cross_attention: 197 | # only relevant for the `AddedKVProcessor` classes 198 | self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) 199 | self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) 200 | else: 201 | self.to_k = None 202 | self.to_v = None 203 | 204 | if self.added_kv_proj_dim is not None: 205 | self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) 206 | self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) 207 | 208 | self.to_out = nn.ModuleList([]) 209 | self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias)) 210 | self.to_out.append(nn.Dropout(dropout)) 211 | 212 | # set attention processor 213 | # We use the AttnProcessor2_0 by default when torch 2.x is used which uses 214 | # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention 215 | # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 216 | if processor is None: 217 | processor = ( 218 | AttnProcessor2_0() 219 | if hasattr(F, "scaled_dot_product_attention") and self.scale_qk 220 | else AttnProcessor() 221 | ) 222 | self.set_processor(processor) 223 | 224 | def set_processor(self, processor: "AttnProcessor") -> None: 225 | self.processor = processor 226 | 227 | def forward( 228 | self, 229 | hidden_states: torch.FloatTensor, 230 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 231 | attention_mask: Optional[torch.FloatTensor] = None, 232 | **cross_attention_kwargs, 233 | ) -> torch.Tensor: 234 | r""" 235 | The forward method of the `Attention` class. 236 | 237 | Args: 238 | hidden_states (`torch.Tensor`): 239 | The hidden states of the query. 240 | encoder_hidden_states (`torch.Tensor`, *optional*): 241 | The hidden states of the encoder. 242 | attention_mask (`torch.Tensor`, *optional*): 243 | The attention mask to use. If `None`, no mask is applied. 244 | **cross_attention_kwargs: 245 | Additional keyword arguments to pass along to the cross attention. 246 | 247 | Returns: 248 | `torch.Tensor`: The output of the attention layer. 249 | """ 250 | # The `Attention` class can call different attention processors / attention functions 251 | # here we simply pass along all tensors to the selected processor class 252 | # For standard processors that are defined here, `**cross_attention_kwargs` is empty 253 | return self.processor( 254 | self, 255 | hidden_states, 256 | encoder_hidden_states=encoder_hidden_states, 257 | attention_mask=attention_mask, 258 | **cross_attention_kwargs, 259 | ) 260 | 261 | def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: 262 | r""" 263 | Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` 264 | is the number of heads initialized while constructing the `Attention` class. 265 | 266 | Args: 267 | tensor (`torch.Tensor`): The tensor to reshape. 268 | 269 | Returns: 270 | `torch.Tensor`: The reshaped tensor. 271 | """ 272 | head_size = self.heads 273 | batch_size, seq_len, dim = tensor.shape 274 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) 275 | tensor = tensor.permute(0, 2, 1, 3).reshape( 276 | batch_size // head_size, seq_len, dim * head_size 277 | ) 278 | return tensor 279 | 280 | def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: 281 | r""" 282 | Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is 283 | the number of heads initialized while constructing the `Attention` class. 284 | 285 | Args: 286 | tensor (`torch.Tensor`): The tensor to reshape. 287 | out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is 288 | reshaped to `[batch_size * heads, seq_len, dim // heads]`. 289 | 290 | Returns: 291 | `torch.Tensor`: The reshaped tensor. 292 | """ 293 | head_size = self.heads 294 | batch_size, seq_len, dim = tensor.shape 295 | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) 296 | tensor = tensor.permute(0, 2, 1, 3) 297 | 298 | if out_dim == 3: 299 | tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) 300 | 301 | return tensor 302 | 303 | def get_attention_scores( 304 | self, 305 | query: torch.Tensor, 306 | key: torch.Tensor, 307 | attention_mask: torch.Tensor = None, 308 | ) -> torch.Tensor: 309 | r""" 310 | Compute the attention scores. 311 | 312 | Args: 313 | query (`torch.Tensor`): The query tensor. 314 | key (`torch.Tensor`): The key tensor. 315 | attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. 316 | 317 | Returns: 318 | `torch.Tensor`: The attention probabilities/scores. 319 | """ 320 | dtype = query.dtype 321 | if self.upcast_attention: 322 | query = query.float() 323 | key = key.float() 324 | 325 | if attention_mask is None: 326 | baddbmm_input = torch.empty( 327 | query.shape[0], 328 | query.shape[1], 329 | key.shape[1], 330 | dtype=query.dtype, 331 | device=query.device, 332 | ) 333 | beta = 0 334 | else: 335 | baddbmm_input = attention_mask 336 | beta = 1 337 | 338 | attention_scores = torch.baddbmm( 339 | baddbmm_input, 340 | query, 341 | key.transpose(-1, -2), 342 | beta=beta, 343 | alpha=self.scale, 344 | ) 345 | del baddbmm_input 346 | 347 | if self.upcast_softmax: 348 | attention_scores = attention_scores.float() 349 | 350 | attention_probs = attention_scores.softmax(dim=-1) 351 | del attention_scores 352 | 353 | attention_probs = attention_probs.to(dtype) 354 | 355 | return attention_probs 356 | 357 | def prepare_attention_mask( 358 | self, 359 | attention_mask: torch.Tensor, 360 | target_length: int, 361 | batch_size: int, 362 | out_dim: int = 3, 363 | ) -> torch.Tensor: 364 | r""" 365 | Prepare the attention mask for the attention computation. 366 | 367 | Args: 368 | attention_mask (`torch.Tensor`): 369 | The attention mask to prepare. 370 | target_length (`int`): 371 | The target length of the attention mask. This is the length of the attention mask after padding. 372 | batch_size (`int`): 373 | The batch size, which is used to repeat the attention mask. 374 | out_dim (`int`, *optional*, defaults to `3`): 375 | The output dimension of the attention mask. Can be either `3` or `4`. 376 | 377 | Returns: 378 | `torch.Tensor`: The prepared attention mask. 379 | """ 380 | head_size = self.heads 381 | if attention_mask is None: 382 | return attention_mask 383 | 384 | current_length: int = attention_mask.shape[-1] 385 | if current_length != target_length: 386 | if attention_mask.device.type == "mps": 387 | # HACK: MPS: Does not support padding by greater than dimension of input tensor. 388 | # Instead, we can manually construct the padding tensor. 389 | padding_shape = ( 390 | attention_mask.shape[0], 391 | attention_mask.shape[1], 392 | target_length, 393 | ) 394 | padding = torch.zeros( 395 | padding_shape, 396 | dtype=attention_mask.dtype, 397 | device=attention_mask.device, 398 | ) 399 | attention_mask = torch.cat([attention_mask, padding], dim=2) 400 | else: 401 | # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: 402 | # we want to instead pad by (0, remaining_length), where remaining_length is: 403 | # remaining_length: int = target_length - current_length 404 | # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding 405 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 406 | 407 | if out_dim == 3: 408 | if attention_mask.shape[0] < batch_size * head_size: 409 | attention_mask = attention_mask.repeat_interleave(head_size, dim=0) 410 | elif out_dim == 4: 411 | attention_mask = attention_mask.unsqueeze(1) 412 | attention_mask = attention_mask.repeat_interleave(head_size, dim=1) 413 | 414 | return attention_mask 415 | 416 | def norm_encoder_hidden_states( 417 | self, encoder_hidden_states: torch.Tensor 418 | ) -> torch.Tensor: 419 | r""" 420 | Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the 421 | `Attention` class. 422 | 423 | Args: 424 | encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. 425 | 426 | Returns: 427 | `torch.Tensor`: The normalized encoder hidden states. 428 | """ 429 | assert ( 430 | self.norm_cross is not None 431 | ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states" 432 | 433 | if isinstance(self.norm_cross, nn.LayerNorm): 434 | encoder_hidden_states = self.norm_cross(encoder_hidden_states) 435 | elif isinstance(self.norm_cross, nn.GroupNorm): 436 | # Group norm norms along the channels dimension and expects 437 | # input to be in the shape of (N, C, *). In this case, we want 438 | # to norm along the hidden dimension, so we need to move 439 | # (batch_size, sequence_length, hidden_size) -> 440 | # (batch_size, hidden_size, sequence_length) 441 | encoder_hidden_states = encoder_hidden_states.transpose(1, 2) 442 | encoder_hidden_states = self.norm_cross(encoder_hidden_states) 443 | encoder_hidden_states = encoder_hidden_states.transpose(1, 2) 444 | else: 445 | assert False 446 | 447 | return encoder_hidden_states 448 | 449 | @torch.no_grad() 450 | def fuse_projections(self, fuse=True): 451 | is_cross_attention = self.cross_attention_dim != self.query_dim 452 | device = self.to_q.weight.data.device 453 | dtype = self.to_q.weight.data.dtype 454 | 455 | if not is_cross_attention: 456 | # fetch weight matrices. 457 | concatenated_weights = torch.cat( 458 | [self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data] 459 | ) 460 | in_features = concatenated_weights.shape[1] 461 | out_features = concatenated_weights.shape[0] 462 | 463 | # create a new single projection layer and copy over the weights. 464 | self.to_qkv = self.linear_cls( 465 | in_features, out_features, bias=False, device=device, dtype=dtype 466 | ) 467 | self.to_qkv.weight.copy_(concatenated_weights) 468 | 469 | else: 470 | concatenated_weights = torch.cat( 471 | [self.to_k.weight.data, self.to_v.weight.data] 472 | ) 473 | in_features = concatenated_weights.shape[1] 474 | out_features = concatenated_weights.shape[0] 475 | 476 | self.to_kv = self.linear_cls( 477 | in_features, out_features, bias=False, device=device, dtype=dtype 478 | ) 479 | self.to_kv.weight.copy_(concatenated_weights) 480 | 481 | self.fused_projections = fuse 482 | 483 | 484 | class AttnProcessor: 485 | r""" 486 | Default processor for performing attention-related computations. 487 | """ 488 | 489 | def __call__( 490 | self, 491 | attn: Attention, 492 | hidden_states: torch.FloatTensor, 493 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 494 | attention_mask: Optional[torch.FloatTensor] = None, 495 | ) -> torch.Tensor: 496 | residual = hidden_states 497 | 498 | input_ndim = hidden_states.ndim 499 | 500 | if input_ndim == 4: 501 | batch_size, channel, height, width = hidden_states.shape 502 | hidden_states = hidden_states.view( 503 | batch_size, channel, height * width 504 | ).transpose(1, 2) 505 | 506 | batch_size, sequence_length, _ = ( 507 | hidden_states.shape 508 | if encoder_hidden_states is None 509 | else encoder_hidden_states.shape 510 | ) 511 | attention_mask = attn.prepare_attention_mask( 512 | attention_mask, sequence_length, batch_size 513 | ) 514 | 515 | if attn.group_norm is not None: 516 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( 517 | 1, 2 518 | ) 519 | 520 | query = attn.to_q(hidden_states) 521 | 522 | if encoder_hidden_states is None: 523 | encoder_hidden_states = hidden_states 524 | elif attn.norm_cross: 525 | encoder_hidden_states = attn.norm_encoder_hidden_states( 526 | encoder_hidden_states 527 | ) 528 | 529 | key = attn.to_k(encoder_hidden_states) 530 | value = attn.to_v(encoder_hidden_states) 531 | 532 | query = attn.head_to_batch_dim(query) 533 | key = attn.head_to_batch_dim(key) 534 | value = attn.head_to_batch_dim(value) 535 | 536 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 537 | hidden_states = torch.bmm(attention_probs, value) 538 | hidden_states = attn.batch_to_head_dim(hidden_states) 539 | 540 | # linear proj 541 | hidden_states = attn.to_out[0](hidden_states) 542 | # dropout 543 | hidden_states = attn.to_out[1](hidden_states) 544 | 545 | if input_ndim == 4: 546 | hidden_states = hidden_states.transpose(-1, -2).reshape( 547 | batch_size, channel, height, width 548 | ) 549 | 550 | if attn.residual_connection: 551 | hidden_states = hidden_states + residual 552 | 553 | hidden_states = hidden_states / attn.rescale_output_factor 554 | 555 | return hidden_states 556 | 557 | 558 | class AttnProcessor2_0: 559 | r""" 560 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 561 | """ 562 | 563 | def __init__(self): 564 | if not hasattr(F, "scaled_dot_product_attention"): 565 | raise ImportError( 566 | "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." 567 | ) 568 | 569 | def __call__( 570 | self, 571 | attn: Attention, 572 | hidden_states: torch.FloatTensor, 573 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 574 | attention_mask: Optional[torch.FloatTensor] = None, 575 | ) -> torch.FloatTensor: 576 | residual = hidden_states 577 | 578 | input_ndim = hidden_states.ndim 579 | 580 | if input_ndim == 4: 581 | batch_size, channel, height, width = hidden_states.shape 582 | hidden_states = hidden_states.view( 583 | batch_size, channel, height * width 584 | ).transpose(1, 2) 585 | 586 | batch_size, sequence_length, _ = ( 587 | hidden_states.shape 588 | if encoder_hidden_states is None 589 | else encoder_hidden_states.shape 590 | ) 591 | 592 | if attention_mask is not None: 593 | attention_mask = attn.prepare_attention_mask( 594 | attention_mask, sequence_length, batch_size 595 | ) 596 | # scaled_dot_product_attention expects attention_mask shape to be 597 | # (batch, heads, source_length, target_length) 598 | attention_mask = attention_mask.view( 599 | batch_size, attn.heads, -1, attention_mask.shape[-1] 600 | ) 601 | 602 | if attn.group_norm is not None: 603 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( 604 | 1, 2 605 | ) 606 | 607 | query = attn.to_q(hidden_states) 608 | 609 | if encoder_hidden_states is None: 610 | encoder_hidden_states = hidden_states 611 | elif attn.norm_cross: 612 | encoder_hidden_states = attn.norm_encoder_hidden_states( 613 | encoder_hidden_states 614 | ) 615 | 616 | key = attn.to_k(encoder_hidden_states) 617 | value = attn.to_v(encoder_hidden_states) 618 | 619 | inner_dim = key.shape[-1] 620 | head_dim = inner_dim // attn.heads 621 | 622 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 623 | 624 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 625 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 626 | 627 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 628 | # TODO: add support for attn.scale when we move to Torch 2.1 629 | hidden_states = F.scaled_dot_product_attention( 630 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 631 | ) 632 | 633 | hidden_states = hidden_states.transpose(1, 2).reshape( 634 | batch_size, -1, attn.heads * head_dim 635 | ) 636 | hidden_states = hidden_states.to(query.dtype) 637 | 638 | # linear proj 639 | hidden_states = attn.to_out[0](hidden_states) 640 | # dropout 641 | hidden_states = attn.to_out[1](hidden_states) 642 | 643 | if input_ndim == 4: 644 | hidden_states = hidden_states.transpose(-1, -2).reshape( 645 | batch_size, channel, height, width 646 | ) 647 | 648 | if attn.residual_connection: 649 | hidden_states = hidden_states + residual 650 | 651 | hidden_states = hidden_states / attn.rescale_output_factor 652 | 653 | return hidden_states 654 | -------------------------------------------------------------------------------- /tsr/models/transformer/basic_transformer_block.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # -------- 16 | # 17 | # Modified 2024 by the Tripo AI and Stability AI Team. 18 | # 19 | # Copyright (c) 2024 Tripo AI & Stability AI 20 | # 21 | # Permission is hereby granted, free of charge, to any person obtaining a copy 22 | # of this software and associated documentation files (the "Software"), to deal 23 | # in the Software without restriction, including without limitation the rights 24 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 25 | # copies of the Software, and to permit persons to whom the Software is 26 | # furnished to do so, subject to the following conditions: 27 | # 28 | # The above copyright notice and this permission notice shall be included in all 29 | # copies or substantial portions of the Software. 30 | # 31 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 32 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 33 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 34 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 35 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 36 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 37 | # SOFTWARE. 38 | 39 | from typing import Optional 40 | 41 | import torch 42 | import torch.nn.functional as F 43 | from torch import nn 44 | 45 | from .attention import Attention 46 | 47 | 48 | class BasicTransformerBlock(nn.Module): 49 | r""" 50 | A basic Transformer block. 51 | 52 | Parameters: 53 | dim (`int`): The number of channels in the input and output. 54 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 55 | attention_head_dim (`int`): The number of channels in each head. 56 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 57 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 58 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 59 | attention_bias (: 60 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. 61 | only_cross_attention (`bool`, *optional*): 62 | Whether to use only cross-attention layers. In this case two cross attention layers are used. 63 | double_self_attention (`bool`, *optional*): 64 | Whether to use two self-attention layers. In this case no cross attention layers are used. 65 | upcast_attention (`bool`, *optional*): 66 | Whether to upcast the attention computation to float32. This is useful for mixed precision training. 67 | norm_elementwise_affine (`bool`, *optional*, defaults to `True`): 68 | Whether to use learnable elementwise affine parameters for normalization. 69 | norm_type (`str`, *optional*, defaults to `"layer_norm"`): 70 | The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. 71 | final_dropout (`bool` *optional*, defaults to False): 72 | Whether to apply a final dropout after the last feed-forward layer. 73 | """ 74 | 75 | def __init__( 76 | self, 77 | dim: int, 78 | num_attention_heads: int, 79 | attention_head_dim: int, 80 | dropout=0.0, 81 | cross_attention_dim: Optional[int] = None, 82 | activation_fn: str = "geglu", 83 | attention_bias: bool = False, 84 | only_cross_attention: bool = False, 85 | double_self_attention: bool = False, 86 | upcast_attention: bool = False, 87 | norm_elementwise_affine: bool = True, 88 | norm_type: str = "layer_norm", 89 | final_dropout: bool = False, 90 | ): 91 | super().__init__() 92 | self.only_cross_attention = only_cross_attention 93 | 94 | assert norm_type == "layer_norm" 95 | 96 | # Define 3 blocks. Each block has its own normalization layer. 97 | # 1. Self-Attn 98 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 99 | self.attn1 = Attention( 100 | query_dim=dim, 101 | heads=num_attention_heads, 102 | dim_head=attention_head_dim, 103 | dropout=dropout, 104 | bias=attention_bias, 105 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 106 | upcast_attention=upcast_attention, 107 | ) 108 | 109 | # 2. Cross-Attn 110 | if cross_attention_dim is not None or double_self_attention: 111 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 112 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 113 | # the second cross attention block. 114 | self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 115 | 116 | self.attn2 = Attention( 117 | query_dim=dim, 118 | cross_attention_dim=( 119 | cross_attention_dim if not double_self_attention else None 120 | ), 121 | heads=num_attention_heads, 122 | dim_head=attention_head_dim, 123 | dropout=dropout, 124 | bias=attention_bias, 125 | upcast_attention=upcast_attention, 126 | ) # is self-attn if encoder_hidden_states is none 127 | else: 128 | self.norm2 = None 129 | self.attn2 = None 130 | 131 | # 3. Feed-forward 132 | self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 133 | self.ff = FeedForward( 134 | dim, 135 | dropout=dropout, 136 | activation_fn=activation_fn, 137 | final_dropout=final_dropout, 138 | ) 139 | 140 | # let chunk size default to None 141 | self._chunk_size = None 142 | self._chunk_dim = 0 143 | 144 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): 145 | # Sets chunk feed-forward 146 | self._chunk_size = chunk_size 147 | self._chunk_dim = dim 148 | 149 | def forward( 150 | self, 151 | hidden_states: torch.FloatTensor, 152 | attention_mask: Optional[torch.FloatTensor] = None, 153 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 154 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 155 | ) -> torch.FloatTensor: 156 | # Notice that normalization is always applied before the real computation in the following blocks. 157 | # 0. Self-Attention 158 | norm_hidden_states = self.norm1(hidden_states) 159 | 160 | attn_output = self.attn1( 161 | norm_hidden_states, 162 | encoder_hidden_states=( 163 | encoder_hidden_states if self.only_cross_attention else None 164 | ), 165 | attention_mask=attention_mask, 166 | ) 167 | 168 | hidden_states = attn_output + hidden_states 169 | 170 | # 3. Cross-Attention 171 | if self.attn2 is not None: 172 | norm_hidden_states = self.norm2(hidden_states) 173 | 174 | attn_output = self.attn2( 175 | norm_hidden_states, 176 | encoder_hidden_states=encoder_hidden_states, 177 | attention_mask=encoder_attention_mask, 178 | ) 179 | hidden_states = attn_output + hidden_states 180 | 181 | # 4. Feed-forward 182 | norm_hidden_states = self.norm3(hidden_states) 183 | 184 | if self._chunk_size is not None: 185 | # "feed_forward_chunk_size" can be used to save memory 186 | if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: 187 | raise ValueError( 188 | f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." 189 | ) 190 | 191 | num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size 192 | ff_output = torch.cat( 193 | [ 194 | self.ff(hid_slice) 195 | for hid_slice in norm_hidden_states.chunk( 196 | num_chunks, dim=self._chunk_dim 197 | ) 198 | ], 199 | dim=self._chunk_dim, 200 | ) 201 | else: 202 | ff_output = self.ff(norm_hidden_states) 203 | 204 | hidden_states = ff_output + hidden_states 205 | 206 | return hidden_states 207 | 208 | 209 | class FeedForward(nn.Module): 210 | r""" 211 | A feed-forward layer. 212 | 213 | Parameters: 214 | dim (`int`): The number of channels in the input. 215 | dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. 216 | mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. 217 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 218 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 219 | final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. 220 | """ 221 | 222 | def __init__( 223 | self, 224 | dim: int, 225 | dim_out: Optional[int] = None, 226 | mult: int = 4, 227 | dropout: float = 0.0, 228 | activation_fn: str = "geglu", 229 | final_dropout: bool = False, 230 | ): 231 | super().__init__() 232 | inner_dim = int(dim * mult) 233 | dim_out = dim_out if dim_out is not None else dim 234 | linear_cls = nn.Linear 235 | 236 | if activation_fn == "gelu": 237 | act_fn = GELU(dim, inner_dim) 238 | if activation_fn == "gelu-approximate": 239 | act_fn = GELU(dim, inner_dim, approximate="tanh") 240 | elif activation_fn == "geglu": 241 | act_fn = GEGLU(dim, inner_dim) 242 | elif activation_fn == "geglu-approximate": 243 | act_fn = ApproximateGELU(dim, inner_dim) 244 | 245 | self.net = nn.ModuleList([]) 246 | # project in 247 | self.net.append(act_fn) 248 | # project dropout 249 | self.net.append(nn.Dropout(dropout)) 250 | # project out 251 | self.net.append(linear_cls(inner_dim, dim_out)) 252 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout 253 | if final_dropout: 254 | self.net.append(nn.Dropout(dropout)) 255 | 256 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 257 | for module in self.net: 258 | hidden_states = module(hidden_states) 259 | return hidden_states 260 | 261 | 262 | class GELU(nn.Module): 263 | r""" 264 | GELU activation function with tanh approximation support with `approximate="tanh"`. 265 | 266 | Parameters: 267 | dim_in (`int`): The number of channels in the input. 268 | dim_out (`int`): The number of channels in the output. 269 | approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. 270 | """ 271 | 272 | def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): 273 | super().__init__() 274 | self.proj = nn.Linear(dim_in, dim_out) 275 | self.approximate = approximate 276 | 277 | def gelu(self, gate: torch.Tensor) -> torch.Tensor: 278 | if gate.device.type != "mps": 279 | return F.gelu(gate, approximate=self.approximate) 280 | # mps: gelu is not implemented for float16 281 | return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to( 282 | dtype=gate.dtype 283 | ) 284 | 285 | def forward(self, hidden_states): 286 | hidden_states = self.proj(hidden_states) 287 | hidden_states = self.gelu(hidden_states) 288 | return hidden_states 289 | 290 | 291 | class GEGLU(nn.Module): 292 | r""" 293 | A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. 294 | 295 | Parameters: 296 | dim_in (`int`): The number of channels in the input. 297 | dim_out (`int`): The number of channels in the output. 298 | """ 299 | 300 | def __init__(self, dim_in: int, dim_out: int): 301 | super().__init__() 302 | linear_cls = nn.Linear 303 | 304 | self.proj = linear_cls(dim_in, dim_out * 2) 305 | 306 | def gelu(self, gate: torch.Tensor) -> torch.Tensor: 307 | if gate.device.type != "mps": 308 | return F.gelu(gate) 309 | # mps: gelu is not implemented for float16 310 | return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) 311 | 312 | def forward(self, hidden_states, scale: float = 1.0): 313 | args = () 314 | hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1) 315 | return hidden_states * self.gelu(gate) 316 | 317 | 318 | class ApproximateGELU(nn.Module): 319 | r""" 320 | The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2: 321 | https://arxiv.org/abs/1606.08415. 322 | 323 | Parameters: 324 | dim_in (`int`): The number of channels in the input. 325 | dim_out (`int`): The number of channels in the output. 326 | """ 327 | 328 | def __init__(self, dim_in: int, dim_out: int): 329 | super().__init__() 330 | self.proj = nn.Linear(dim_in, dim_out) 331 | 332 | def forward(self, x: torch.Tensor) -> torch.Tensor: 333 | x = self.proj(x) 334 | return x * torch.sigmoid(1.702 * x) 335 | -------------------------------------------------------------------------------- /tsr/models/transformer/transformer_1d.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # -------- 16 | # 17 | # Modified 2024 by the Tripo AI and Stability AI Team. 18 | # 19 | # Copyright (c) 2024 Tripo AI & Stability AI 20 | # 21 | # Permission is hereby granted, free of charge, to any person obtaining a copy 22 | # of this software and associated documentation files (the "Software"), to deal 23 | # in the Software without restriction, including without limitation the rights 24 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 25 | # copies of the Software, and to permit persons to whom the Software is 26 | # furnished to do so, subject to the following conditions: 27 | # 28 | # The above copyright notice and this permission notice shall be included in all 29 | # copies or substantial portions of the Software. 30 | # 31 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 32 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 33 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 34 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 35 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 36 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 37 | # SOFTWARE. 38 | 39 | from dataclasses import dataclass 40 | from typing import Optional 41 | 42 | import torch 43 | import torch.nn.functional as F 44 | from torch import nn 45 | 46 | from ...utils import BaseModule 47 | from .basic_transformer_block import BasicTransformerBlock 48 | 49 | 50 | class Transformer1D(BaseModule): 51 | @dataclass 52 | class Config(BaseModule.Config): 53 | num_attention_heads: int = 16 54 | attention_head_dim: int = 88 55 | in_channels: Optional[int] = None 56 | out_channels: Optional[int] = None 57 | num_layers: int = 1 58 | dropout: float = 0.0 59 | norm_num_groups: int = 32 60 | cross_attention_dim: Optional[int] = None 61 | attention_bias: bool = False 62 | activation_fn: str = "geglu" 63 | only_cross_attention: bool = False 64 | double_self_attention: bool = False 65 | upcast_attention: bool = False 66 | norm_type: str = "layer_norm" 67 | norm_elementwise_affine: bool = True 68 | gradient_checkpointing: bool = False 69 | 70 | cfg: Config 71 | 72 | def configure(self) -> None: 73 | self.num_attention_heads = self.cfg.num_attention_heads 74 | self.attention_head_dim = self.cfg.attention_head_dim 75 | inner_dim = self.num_attention_heads * self.attention_head_dim 76 | 77 | linear_cls = nn.Linear 78 | 79 | # 2. Define input layers 80 | self.in_channels = self.cfg.in_channels 81 | 82 | self.norm = torch.nn.GroupNorm( 83 | num_groups=self.cfg.norm_num_groups, 84 | num_channels=self.cfg.in_channels, 85 | eps=1e-6, 86 | affine=True, 87 | ) 88 | self.proj_in = linear_cls(self.cfg.in_channels, inner_dim) 89 | 90 | # 3. Define transformers blocks 91 | self.transformer_blocks = nn.ModuleList( 92 | [ 93 | BasicTransformerBlock( 94 | inner_dim, 95 | self.num_attention_heads, 96 | self.attention_head_dim, 97 | dropout=self.cfg.dropout, 98 | cross_attention_dim=self.cfg.cross_attention_dim, 99 | activation_fn=self.cfg.activation_fn, 100 | attention_bias=self.cfg.attention_bias, 101 | only_cross_attention=self.cfg.only_cross_attention, 102 | double_self_attention=self.cfg.double_self_attention, 103 | upcast_attention=self.cfg.upcast_attention, 104 | norm_type=self.cfg.norm_type, 105 | norm_elementwise_affine=self.cfg.norm_elementwise_affine, 106 | ) 107 | for d in range(self.cfg.num_layers) 108 | ] 109 | ) 110 | 111 | # 4. Define output layers 112 | self.out_channels = ( 113 | self.cfg.in_channels 114 | if self.cfg.out_channels is None 115 | else self.cfg.out_channels 116 | ) 117 | 118 | self.proj_out = linear_cls(inner_dim, self.cfg.in_channels) 119 | 120 | self.gradient_checkpointing = self.cfg.gradient_checkpointing 121 | 122 | def forward( 123 | self, 124 | hidden_states: torch.Tensor, 125 | encoder_hidden_states: Optional[torch.Tensor] = None, 126 | attention_mask: Optional[torch.Tensor] = None, 127 | encoder_attention_mask: Optional[torch.Tensor] = None, 128 | ): 129 | """ 130 | The [`Transformer1DModel`] forward method. 131 | 132 | Args: 133 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): 134 | Input `hidden_states`. 135 | encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): 136 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 137 | self-attention. 138 | attention_mask ( `torch.Tensor`, *optional*): 139 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask 140 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large 141 | negative values to the attention scores corresponding to "discard" tokens. 142 | encoder_attention_mask ( `torch.Tensor`, *optional*): 143 | Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: 144 | 145 | * Mask `(batch, sequence_length)` True = keep, False = discard. 146 | * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. 147 | 148 | If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format 149 | above. This bias will be added to the cross-attention scores. 150 | 151 | Returns: 152 | torch.FloatTensor 153 | """ 154 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. 155 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. 156 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. 157 | # expects mask of shape: 158 | # [batch, key_tokens] 159 | # adds singleton query_tokens dimension: 160 | # [batch, 1, key_tokens] 161 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 162 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 163 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 164 | if attention_mask is not None and attention_mask.ndim == 2: 165 | # assume that mask is expressed as: 166 | # (1 = keep, 0 = discard) 167 | # convert mask into a bias that can be added to attention scores: 168 | # (keep = +0, discard = -10000.0) 169 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 170 | attention_mask = attention_mask.unsqueeze(1) 171 | 172 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 173 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: 174 | encoder_attention_mask = ( 175 | 1 - encoder_attention_mask.to(hidden_states.dtype) 176 | ) * -10000.0 177 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 178 | 179 | # 1. Input 180 | batch, _, seq_len = hidden_states.shape 181 | residual = hidden_states 182 | 183 | hidden_states = self.norm(hidden_states) 184 | inner_dim = hidden_states.shape[1] 185 | hidden_states = hidden_states.permute(0, 2, 1).reshape( 186 | batch, seq_len, inner_dim 187 | ) 188 | hidden_states = self.proj_in(hidden_states) 189 | 190 | # 2. Blocks 191 | for block in self.transformer_blocks: 192 | if self.training and self.gradient_checkpointing: 193 | hidden_states = torch.utils.checkpoint.checkpoint( 194 | block, 195 | hidden_states, 196 | attention_mask, 197 | encoder_hidden_states, 198 | encoder_attention_mask, 199 | use_reentrant=False, 200 | ) 201 | else: 202 | hidden_states = block( 203 | hidden_states, 204 | attention_mask=attention_mask, 205 | encoder_hidden_states=encoder_hidden_states, 206 | encoder_attention_mask=encoder_attention_mask, 207 | ) 208 | 209 | # 3. Output 210 | hidden_states = self.proj_out(hidden_states) 211 | hidden_states = ( 212 | hidden_states.reshape(batch, seq_len, inner_dim) 213 | .permute(0, 2, 1) 214 | .contiguous() 215 | ) 216 | 217 | output = hidden_states + residual 218 | 219 | return output 220 | -------------------------------------------------------------------------------- /tsr/system.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from dataclasses import dataclass, field 4 | from typing import List, Union 5 | 6 | import numpy as np 7 | import PIL.Image 8 | import torch 9 | import torch.nn.functional as F 10 | import trimesh 11 | from einops import rearrange 12 | from huggingface_hub import hf_hub_download 13 | from omegaconf import OmegaConf 14 | from PIL import Image 15 | 16 | from .models.isosurface import MarchingCubeHelper 17 | from .utils import ( 18 | BaseModule, 19 | ImagePreprocessor, 20 | find_class, 21 | get_spherical_cameras, 22 | scale_tensor, 23 | ) 24 | 25 | 26 | class TSR(BaseModule): 27 | @dataclass 28 | class Config(BaseModule.Config): 29 | cond_image_size: int 30 | 31 | image_tokenizer_cls: str 32 | image_tokenizer: dict 33 | 34 | tokenizer_cls: str 35 | tokenizer: dict 36 | 37 | backbone_cls: str 38 | backbone: dict 39 | 40 | post_processor_cls: str 41 | post_processor: dict 42 | 43 | decoder_cls: str 44 | decoder: dict 45 | 46 | renderer_cls: str 47 | renderer: dict 48 | 49 | cfg: Config 50 | 51 | @classmethod 52 | def from_pretrained( 53 | cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str 54 | ): 55 | if os.path.isdir(pretrained_model_name_or_path): 56 | config_path = os.path.join(pretrained_model_name_or_path, config_name) 57 | weight_path = os.path.join(pretrained_model_name_or_path, weight_name) 58 | else: 59 | config_path = hf_hub_download( 60 | repo_id=pretrained_model_name_or_path, filename=config_name 61 | ) 62 | weight_path = hf_hub_download( 63 | repo_id=pretrained_model_name_or_path, filename=weight_name 64 | ) 65 | 66 | cfg = OmegaConf.load(config_path) 67 | OmegaConf.resolve(cfg) 68 | model = cls(cfg) 69 | ckpt = torch.load(weight_path, map_location="cpu") 70 | model.load_state_dict(ckpt) 71 | return model 72 | 73 | @classmethod 74 | def from_pretrained_custom( 75 | cls, weight_path: str, config_path: str 76 | ): 77 | cfg = OmegaConf.load(config_path) 78 | OmegaConf.resolve(cfg) 79 | model = cls(cfg) 80 | ckpt = torch.load(weight_path, map_location="cpu") 81 | model.load_state_dict(ckpt) 82 | return model 83 | 84 | def configure(self): 85 | self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)( 86 | self.cfg.image_tokenizer 87 | ) 88 | self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer) 89 | self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone) 90 | self.post_processor = find_class(self.cfg.post_processor_cls)( 91 | self.cfg.post_processor 92 | ) 93 | self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder) 94 | self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer) 95 | self.image_processor = ImagePreprocessor() 96 | self.isosurface_helper = None 97 | 98 | def forward( 99 | self, 100 | image: Union[ 101 | PIL.Image.Image, 102 | np.ndarray, 103 | torch.FloatTensor, 104 | List[PIL.Image.Image], 105 | List[np.ndarray], 106 | List[torch.FloatTensor], 107 | ], 108 | device: str, 109 | ) -> torch.FloatTensor: 110 | rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to( 111 | device 112 | ) 113 | batch_size = rgb_cond.shape[0] 114 | 115 | input_image_tokens: torch.Tensor = self.image_tokenizer( 116 | rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1), 117 | ) 118 | 119 | input_image_tokens = rearrange( 120 | input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1 121 | ) 122 | 123 | tokens: torch.Tensor = self.tokenizer(batch_size) 124 | 125 | tokens = self.backbone( 126 | tokens, 127 | encoder_hidden_states=input_image_tokens, 128 | ) 129 | 130 | scene_codes = self.post_processor(self.tokenizer.detokenize(tokens)) 131 | return scene_codes 132 | 133 | def render( 134 | self, 135 | scene_codes, 136 | n_views: int, 137 | elevation_deg: float = 0.0, 138 | camera_distance: float = 1.9, 139 | fovy_deg: float = 40.0, 140 | height: int = 256, 141 | width: int = 256, 142 | return_type: str = "pil", 143 | ): 144 | rays_o, rays_d = get_spherical_cameras( 145 | n_views, elevation_deg, camera_distance, fovy_deg, height, width 146 | ) 147 | rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device) 148 | 149 | def process_output(image: torch.FloatTensor): 150 | if return_type == "pt": 151 | return image 152 | elif return_type == "np": 153 | return image.detach().cpu().numpy() 154 | elif return_type == "pil": 155 | return Image.fromarray( 156 | (image.detach().cpu().numpy() * 255.0).astype(np.uint8) 157 | ) 158 | else: 159 | raise NotImplementedError 160 | 161 | images = [] 162 | for scene_code in scene_codes: 163 | images_ = [] 164 | for i in range(n_views): 165 | with torch.no_grad(): 166 | image = self.renderer( 167 | self.decoder, scene_code, rays_o[i], rays_d[i] 168 | ) 169 | images_.append(process_output(image)) 170 | images.append(images_) 171 | 172 | return images 173 | 174 | def set_marching_cubes_resolution(self, resolution: int): 175 | if ( 176 | self.isosurface_helper is not None 177 | and self.isosurface_helper.resolution == resolution 178 | ): 179 | return 180 | self.isosurface_helper = MarchingCubeHelper(resolution) 181 | 182 | def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 25.0): 183 | self.set_marching_cubes_resolution(resolution) 184 | meshes = [] 185 | for scene_code in scene_codes: 186 | with torch.no_grad(): 187 | density = self.renderer.query_triplane( 188 | self.decoder, 189 | scale_tensor( 190 | self.isosurface_helper.grid_vertices.to(scene_codes.device), 191 | self.isosurface_helper.points_range, 192 | (-self.renderer.cfg.radius, self.renderer.cfg.radius), 193 | ), 194 | scene_code, 195 | )["density_act"] 196 | v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold)) 197 | v_pos = scale_tensor( 198 | v_pos, 199 | self.isosurface_helper.points_range, 200 | (-self.renderer.cfg.radius, self.renderer.cfg.radius), 201 | ) 202 | with torch.no_grad(): 203 | color = self.renderer.query_triplane( 204 | self.decoder, 205 | v_pos, 206 | scene_code, 207 | )["color"] 208 | mesh = trimesh.Trimesh( 209 | vertices=v_pos.cpu().numpy(), 210 | faces=t_pos_idx.cpu().numpy(), 211 | vertex_colors=color.cpu().numpy(), 212 | ) 213 | meshes.append(mesh) 214 | return meshes 215 | -------------------------------------------------------------------------------- /tsr/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import math 3 | from collections import defaultdict 4 | from dataclasses import dataclass 5 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 6 | 7 | import imageio 8 | import numpy as np 9 | import PIL.Image 10 | #import rembg 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import trimesh 15 | from omegaconf import DictConfig, OmegaConf 16 | #from PIL import Image 17 | 18 | 19 | def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: 20 | scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg) 21 | return scfg 22 | 23 | 24 | def find_class(cls_string): 25 | module_string = ".".join(cls_string.split(".")[:-1]) 26 | cls_name = cls_string.split(".")[-1] 27 | module = importlib.import_module(module_string, package=None) 28 | cls = getattr(module, cls_name) 29 | return cls 30 | 31 | 32 | def get_intrinsic_from_fov(fov, H, W, bs=-1): 33 | focal_length = 0.5 * H / np.tan(0.5 * fov) 34 | intrinsic = np.identity(3, dtype=np.float32) 35 | intrinsic[0, 0] = focal_length 36 | intrinsic[1, 1] = focal_length 37 | intrinsic[0, 2] = W / 2.0 38 | intrinsic[1, 2] = H / 2.0 39 | 40 | if bs > 0: 41 | intrinsic = intrinsic[None].repeat(bs, axis=0) 42 | 43 | return torch.from_numpy(intrinsic) 44 | 45 | 46 | class BaseModule(nn.Module): 47 | @dataclass 48 | class Config: 49 | pass 50 | 51 | cfg: Config # add this to every subclass of BaseModule to enable static type checking 52 | 53 | def __init__( 54 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs 55 | ) -> None: 56 | super().__init__() 57 | self.cfg = parse_structured(self.Config, cfg) 58 | self.configure(*args, **kwargs) 59 | 60 | def configure(self, *args, **kwargs) -> None: 61 | raise NotImplementedError 62 | 63 | 64 | class ImagePreprocessor: 65 | def convert_and_resize( 66 | self, 67 | image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], 68 | size: int, 69 | ): 70 | if isinstance(image, PIL.Image.Image): 71 | image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0) 72 | elif isinstance(image, np.ndarray): 73 | if image.dtype == np.uint8: 74 | image = torch.from_numpy(image.astype(np.float32) / 255.0) 75 | else: 76 | image = torch.from_numpy(image) 77 | elif isinstance(image, torch.Tensor): 78 | pass 79 | 80 | batched = image.ndim == 4 81 | 82 | if not batched: 83 | image = image[None, ...] 84 | image = F.interpolate( 85 | image.permute(0, 3, 1, 2), 86 | (size, size), 87 | mode="bilinear", 88 | align_corners=False, 89 | antialias=True, 90 | ).permute(0, 2, 3, 1) 91 | if not batched: 92 | image = image[0] 93 | return image 94 | 95 | def __call__( 96 | self, 97 | image: Union[ 98 | PIL.Image.Image, 99 | np.ndarray, 100 | torch.FloatTensor, 101 | List[PIL.Image.Image], 102 | List[np.ndarray], 103 | List[torch.FloatTensor], 104 | ], 105 | size: int, 106 | ) -> Any: 107 | if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4: 108 | image = self.convert_and_resize(image, size) 109 | else: 110 | if not isinstance(image, list): 111 | image = [image] 112 | image = [self.convert_and_resize(im, size) for im in image] 113 | image = torch.stack(image, dim=0) 114 | return image 115 | 116 | 117 | def rays_intersect_bbox( 118 | rays_o: torch.Tensor, 119 | rays_d: torch.Tensor, 120 | radius: float, 121 | near: float = 0.0, 122 | valid_thresh: float = 0.01, 123 | ): 124 | input_shape = rays_o.shape[:-1] 125 | rays_o, rays_d = rays_o.view(-1, 3), rays_d.view(-1, 3) 126 | rays_d_valid = torch.where( 127 | rays_d.abs() < 1e-6, torch.full_like(rays_d, 1e-6), rays_d 128 | ) 129 | if type(radius) in [int, float]: 130 | radius = torch.FloatTensor( 131 | [[-radius, radius], [-radius, radius], [-radius, radius]] 132 | ).to(rays_o.device) 133 | radius = ( 134 | 1.0 - 1.0e-3 135 | ) * radius # tighten the radius to make sure the intersection point lies in the bounding box 136 | interx0 = (radius[..., 1] - rays_o) / rays_d_valid 137 | interx1 = (radius[..., 0] - rays_o) / rays_d_valid 138 | t_near = torch.minimum(interx0, interx1).amax(dim=-1).clamp_min(near) 139 | t_far = torch.maximum(interx0, interx1).amin(dim=-1) 140 | 141 | # check wheter a ray intersects the bbox or not 142 | rays_valid = t_far - t_near > valid_thresh 143 | 144 | t_near[torch.where(~rays_valid)] = 0.0 145 | t_far[torch.where(~rays_valid)] = 0.0 146 | 147 | t_near = t_near.view(*input_shape, 1) 148 | t_far = t_far.view(*input_shape, 1) 149 | rays_valid = rays_valid.view(*input_shape) 150 | 151 | return t_near, t_far, rays_valid 152 | 153 | 154 | def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any: 155 | if chunk_size <= 0: 156 | return func(*args, **kwargs) 157 | B = None 158 | for arg in list(args) + list(kwargs.values()): 159 | if isinstance(arg, torch.Tensor): 160 | B = arg.shape[0] 161 | break 162 | assert ( 163 | B is not None 164 | ), "No tensor found in args or kwargs, cannot determine batch size." 165 | out = defaultdict(list) 166 | out_type = None 167 | # max(1, B) to support B == 0 168 | for i in range(0, max(1, B), chunk_size): 169 | out_chunk = func( 170 | *[ 171 | arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg 172 | for arg in args 173 | ], 174 | **{ 175 | k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg 176 | for k, arg in kwargs.items() 177 | }, 178 | ) 179 | if out_chunk is None: 180 | continue 181 | out_type = type(out_chunk) 182 | if isinstance(out_chunk, torch.Tensor): 183 | out_chunk = {0: out_chunk} 184 | elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): 185 | chunk_length = len(out_chunk) 186 | out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} 187 | elif isinstance(out_chunk, dict): 188 | pass 189 | else: 190 | print( 191 | f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}." 192 | ) 193 | exit(1) 194 | for k, v in out_chunk.items(): 195 | v = v if torch.is_grad_enabled() else v.detach() 196 | out[k].append(v) 197 | 198 | if out_type is None: 199 | return None 200 | 201 | out_merged: Dict[Any, Optional[torch.Tensor]] = {} 202 | for k, v in out.items(): 203 | if all([vv is None for vv in v]): 204 | # allow None in return value 205 | out_merged[k] = None 206 | elif all([isinstance(vv, torch.Tensor) for vv in v]): 207 | out_merged[k] = torch.cat(v, dim=0) 208 | else: 209 | raise TypeError( 210 | f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}" 211 | ) 212 | 213 | if out_type is torch.Tensor: 214 | return out_merged[0] 215 | elif out_type in [tuple, list]: 216 | return out_type([out_merged[i] for i in range(chunk_length)]) 217 | elif out_type is dict: 218 | return out_merged 219 | 220 | 221 | ValidScale = Union[Tuple[float, float], torch.FloatTensor] 222 | 223 | 224 | def scale_tensor(dat: torch.FloatTensor, inp_scale: ValidScale, tgt_scale: ValidScale): 225 | if inp_scale is None: 226 | inp_scale = (0, 1) 227 | if tgt_scale is None: 228 | tgt_scale = (0, 1) 229 | if isinstance(tgt_scale, torch.FloatTensor): 230 | assert dat.shape[-1] == tgt_scale.shape[-1] 231 | dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) 232 | dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] 233 | return dat 234 | 235 | 236 | def get_activation(name) -> Callable: 237 | if name is None: 238 | return lambda x: x 239 | name = name.lower() 240 | if name == "none": 241 | return lambda x: x 242 | elif name == "exp": 243 | return lambda x: torch.exp(x) 244 | elif name == "sigmoid": 245 | return lambda x: torch.sigmoid(x) 246 | elif name == "tanh": 247 | return lambda x: torch.tanh(x) 248 | elif name == "softplus": 249 | return lambda x: F.softplus(x) 250 | else: 251 | try: 252 | return getattr(F, name) 253 | except AttributeError: 254 | raise ValueError(f"Unknown activation function: {name}") 255 | 256 | 257 | def get_ray_directions( 258 | H: int, 259 | W: int, 260 | focal: Union[float, Tuple[float, float]], 261 | principal: Optional[Tuple[float, float]] = None, 262 | use_pixel_centers: bool = True, 263 | normalize: bool = True, 264 | ) -> torch.FloatTensor: 265 | """ 266 | Get ray directions for all pixels in camera coordinate. 267 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 268 | ray-tracing-generating-camera-rays/standard-coordinate-systems 269 | 270 | Inputs: 271 | H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers 272 | Outputs: 273 | directions: (H, W, 3), the direction of the rays in camera coordinate 274 | """ 275 | pixel_center = 0.5 if use_pixel_centers else 0 276 | 277 | if isinstance(focal, float): 278 | fx, fy = focal, focal 279 | cx, cy = W / 2, H / 2 280 | else: 281 | fx, fy = focal 282 | assert principal is not None 283 | cx, cy = principal 284 | 285 | i, j = torch.meshgrid( 286 | torch.arange(W, dtype=torch.float32) + pixel_center, 287 | torch.arange(H, dtype=torch.float32) + pixel_center, 288 | indexing="xy", 289 | ) 290 | 291 | directions = torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1) 292 | 293 | if normalize: 294 | directions = F.normalize(directions, dim=-1) 295 | 296 | return directions 297 | 298 | 299 | def get_rays( 300 | directions, 301 | c2w, 302 | keepdim=False, 303 | normalize=False, 304 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 305 | # Rotate ray directions from camera coordinate to the world coordinate 306 | assert directions.shape[-1] == 3 307 | 308 | if directions.ndim == 2: # (N_rays, 3) 309 | if c2w.ndim == 2: # (4, 4) 310 | c2w = c2w[None, :, :] 311 | assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4) 312 | rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3) 313 | rays_o = c2w[:, :3, 3].expand(rays_d.shape) 314 | elif directions.ndim == 3: # (H, W, 3) 315 | assert c2w.ndim in [2, 3] 316 | if c2w.ndim == 2: # (4, 4) 317 | rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum( 318 | -1 319 | ) # (H, W, 3) 320 | rays_o = c2w[None, None, :3, 3].expand(rays_d.shape) 321 | elif c2w.ndim == 3: # (B, 4, 4) 322 | rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum( 323 | -1 324 | ) # (B, H, W, 3) 325 | rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape) 326 | elif directions.ndim == 4: # (B, H, W, 3) 327 | assert c2w.ndim == 3 # (B, 4, 4) 328 | rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum( 329 | -1 330 | ) # (B, H, W, 3) 331 | rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape) 332 | 333 | if normalize: 334 | rays_d = F.normalize(rays_d, dim=-1) 335 | if not keepdim: 336 | rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) 337 | 338 | return rays_o, rays_d 339 | 340 | 341 | def get_spherical_cameras( 342 | n_views: int, 343 | elevation_deg: float, 344 | camera_distance: float, 345 | fovy_deg: float, 346 | height: int, 347 | width: int, 348 | ): 349 | azimuth_deg = torch.linspace(0, 360.0, n_views + 1)[:n_views] 350 | elevation_deg = torch.full_like(azimuth_deg, elevation_deg) 351 | camera_distances = torch.full_like(elevation_deg, camera_distance) 352 | 353 | elevation = elevation_deg * math.pi / 180 354 | azimuth = azimuth_deg * math.pi / 180 355 | 356 | # convert spherical coordinates to cartesian coordinates 357 | # right hand coordinate system, x back, y right, z up 358 | # elevation in (-90, 90), azimuth from +x to +y in (-180, 180) 359 | camera_positions = torch.stack( 360 | [ 361 | camera_distances * torch.cos(elevation) * torch.cos(azimuth), 362 | camera_distances * torch.cos(elevation) * torch.sin(azimuth), 363 | camera_distances * torch.sin(elevation), 364 | ], 365 | dim=-1, 366 | ) 367 | 368 | # default scene center at origin 369 | center = torch.zeros_like(camera_positions) 370 | # default camera up direction as +z 371 | up = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None, :].repeat(n_views, 1) 372 | 373 | fovy = torch.full_like(elevation_deg, fovy_deg) * math.pi / 180 374 | 375 | lookat = F.normalize(center - camera_positions, dim=-1) 376 | right = F.normalize(torch.cross(lookat, up), dim=-1) 377 | up = F.normalize(torch.cross(right, lookat), dim=-1) 378 | c2w3x4 = torch.cat( 379 | [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], 380 | dim=-1, 381 | ) 382 | c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1) 383 | c2w[:, 3, 3] = 1.0 384 | 385 | # get directions by dividing directions_unit_focal by focal length 386 | focal_length = 0.5 * height / torch.tan(0.5 * fovy) 387 | directions_unit_focal = get_ray_directions( 388 | H=height, 389 | W=width, 390 | focal=1.0, 391 | ) 392 | directions = directions_unit_focal[None, :, :, :].repeat(n_views, 1, 1, 1) 393 | directions[:, :, :, :2] = ( 394 | directions[:, :, :, :2] / focal_length[:, None, None, None] 395 | ) 396 | # must use normalize=True to normalize directions here 397 | rays_o, rays_d = get_rays(directions, c2w, keepdim=True, normalize=True) 398 | 399 | return rays_o, rays_d 400 | 401 | 402 | # def remove_background( 403 | # image: PIL.Image.Image, 404 | # rembg_session: Any = None, 405 | # force: bool = False, 406 | # **rembg_kwargs, 407 | # ) -> PIL.Image.Image: 408 | # do_remove = True 409 | # if image.mode == "RGBA" and image.getextrema()[3][0] < 255: 410 | # do_remove = False 411 | # do_remove = do_remove or force 412 | # if do_remove: 413 | # image = rembg.remove(image, session=rembg_session, **rembg_kwargs) 414 | # return image 415 | 416 | 417 | def resize_foreground( 418 | image: PIL.Image.Image, 419 | ratio: float, 420 | ) -> PIL.Image.Image: 421 | image = np.array(image) 422 | assert image.shape[-1] == 4 423 | alpha = np.where(image[..., 3] > 0) 424 | y1, y2, x1, x2 = ( 425 | alpha[0].min(), 426 | alpha[0].max(), 427 | alpha[1].min(), 428 | alpha[1].max(), 429 | ) 430 | # crop the foreground 431 | fg = image[y1:y2, x1:x2] 432 | # pad to square 433 | size = max(fg.shape[0], fg.shape[1]) 434 | ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 435 | ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 436 | new_image = np.pad( 437 | fg, 438 | ((ph0, ph1), (pw0, pw1), (0, 0)), 439 | mode="constant", 440 | constant_values=((0, 0), (0, 0), (0, 0)), 441 | ) 442 | 443 | # compute padding according to the ratio 444 | new_size = int(new_image.shape[0] / ratio) 445 | # pad to size, double side 446 | ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 447 | ph1, pw1 = new_size - size - ph0, new_size - size - pw0 448 | new_image = np.pad( 449 | new_image, 450 | ((ph0, ph1), (pw0, pw1), (0, 0)), 451 | mode="constant", 452 | constant_values=((0, 0), (0, 0), (0, 0)), 453 | ) 454 | new_image = PIL.Image.fromarray(new_image) 455 | return new_image 456 | 457 | 458 | def save_video( 459 | frames: List[PIL.Image.Image], 460 | output_path: str, 461 | fps: int = 30, 462 | ): 463 | # use imageio to save video 464 | frames = [np.array(frame) for frame in frames] 465 | writer = imageio.get_writer(output_path, fps=fps) 466 | for frame in frames: 467 | writer.append_data(frame) 468 | writer.close() 469 | 470 | 471 | def to_gradio_3d_orientation(mesh): 472 | mesh.apply_transform(trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0])) 473 | mesh.apply_scale([1, 1, -1]) 474 | mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi/2, [0, 1, 0])) 475 | return mesh 476 | -------------------------------------------------------------------------------- /web/html/threeVisualizer.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 |

14 | 15 |

16 | 17 |
18 |
19 | 20 |
21 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /web/js/threeVisualizer.js: -------------------------------------------------------------------------------- 1 | import * as THREE from 'three'; 2 | import { api } from '../../../scripts/api.js' 3 | 4 | import { OrbitControls } from 'three/addons/controls/OrbitControls.js'; 5 | import { RoomEnvironment } from 'three/addons/environments/RoomEnvironment.js'; 6 | 7 | import { OBJLoader } from 'three/addons/loaders/OBJLoader.js'; 8 | 9 | const visualizer = document.getElementById("visualizer"); 10 | const container = document.getElementById( 'container' ); 11 | const progressDialog = document.getElementById("progress-dialog"); 12 | const progressIndicator = document.getElementById("progress-indicator"); 13 | 14 | const renderer = new THREE.WebGLRenderer( { antialias: true } ); 15 | renderer.setPixelRatio( window.devicePixelRatio ); 16 | renderer.setSize( window.innerWidth, window.innerHeight ); 17 | container.appendChild( renderer.domElement ); 18 | 19 | const pmremGenerator = new THREE.PMREMGenerator( renderer ); 20 | 21 | // scene 22 | const scene = new THREE.Scene(); 23 | scene.background = new THREE.Color( 0x000000 ); 24 | scene.environment = pmremGenerator.fromScene( new RoomEnvironment( renderer ), 0.04 ).texture; 25 | 26 | const ambientLight = new THREE.AmbientLight( 0xffffff ); 27 | 28 | const camera = new THREE.PerspectiveCamera( 40, window.innerWidth / window.innerHeight, 1, 100 ); 29 | camera.position.set( 5, 2, 8 ); 30 | const pointLight = new THREE.PointLight( 0xffffff, 15 ); 31 | camera.add( pointLight ); 32 | 33 | const controls = new OrbitControls( camera, renderer.domElement ); 34 | controls.target.set( 0, 0.5, 0 ); 35 | controls.update(); 36 | controls.enablePan = true; 37 | controls.enableDamping = true; 38 | 39 | // Handle window reseize event 40 | window.onresize = function () { 41 | 42 | camera.aspect = window.innerWidth / window.innerHeight; 43 | camera.updateProjectionMatrix(); 44 | 45 | renderer.setSize( window.innerWidth, window.innerHeight ); 46 | 47 | }; 48 | 49 | 50 | var lastFilepath = ""; 51 | var needUpdate = false; 52 | 53 | function frameUpdate() { 54 | 55 | var filepath = visualizer.getAttribute("filepath"); 56 | if (filepath == lastFilepath){ 57 | if (needUpdate){ 58 | controls.update(); 59 | renderer.render( scene, camera ); 60 | } 61 | requestAnimationFrame( frameUpdate ); 62 | } else { 63 | needUpdate = false; 64 | scene.clear(); 65 | progressDialog.open = true; 66 | lastFilepath = filepath; 67 | main(JSON.parse(lastFilepath)); 68 | } 69 | } 70 | 71 | const onProgress = function ( xhr ) { 72 | if ( xhr.lengthComputable ) { 73 | progressIndicator.value = xhr.loaded / xhr.total * 100; 74 | } 75 | }; 76 | const onError = function ( e ) { 77 | console.error( e ); 78 | }; 79 | 80 | async function main(params) { 81 | if(params?.filename){ 82 | const url = api.apiURL('/view?' + new URLSearchParams(params)).replace(/extensions.*\//,""); 83 | const fileExt = params.filename.slice(params.filename.lastIndexOf(".")+1) 84 | 85 | if (fileExt == "obj"){ 86 | const loader = new OBJLoader(); 87 | 88 | loader.load( url, function ( obj ) { 89 | obj.scale.setScalar( 5 ); 90 | console.log(obj) 91 | scene.add( obj ); 92 | obj.traverse(node => { 93 | if (node.material && node.material.map == null) { 94 | node.material.vertexColors = true; 95 | } 96 | }); 97 | 98 | }, onProgress, onError ); 99 | } 100 | 101 | needUpdate = true; 102 | } 103 | 104 | scene.add( ambientLight ); 105 | scene.add( camera ); 106 | 107 | progressDialog.close(); 108 | 109 | frameUpdate(); 110 | } 111 | 112 | main(); -------------------------------------------------------------------------------- /web/style/progressStyle.css: -------------------------------------------------------------------------------- 1 | dialog { 2 | width: 100%; 3 | text-align: center; 4 | max-width: 20em; 5 | color: white; 6 | background-color: #000; 7 | border: none; 8 | position: relative; 9 | transform: translate(-50%, -50%); 10 | } 11 | 12 | #progress-container { 13 | position: absolute; 14 | top: 50%; 15 | left: 50%; 16 | } 17 | 18 | progress { 19 | width: 100%; 20 | height: 1em; 21 | border: none; 22 | background-color: #fff; 23 | color: #eee; 24 | } 25 | 26 | progress::-webkit-progress-bar { 27 | background-color: #333; 28 | } 29 | 30 | progress::-webkit-progress-value { 31 | background-color: #eee; 32 | } 33 | 34 | progress::-moz-progress-bar { 35 | background-color: #eee; 36 | } -------------------------------------------------------------------------------- /web/style/threeStyle.css: -------------------------------------------------------------------------------- 1 | body { 2 | margin: 0; 3 | background-color: #000; 4 | color: #fff; 5 | font-family: Monospace; 6 | font-size: 13px; 7 | line-height: 24px; 8 | overscroll-behavior: none; 9 | } 10 | 11 | a { 12 | color: #ff0; 13 | text-decoration: none; 14 | } 15 | 16 | a:hover { 17 | text-decoration: underline; 18 | } 19 | 20 | button { 21 | cursor: pointer; 22 | text-transform: uppercase; 23 | } 24 | 25 | #info { 26 | position: absolute; 27 | top: 0px; 28 | width: 100%; 29 | padding: 10px; 30 | box-sizing: border-box; 31 | text-align: center; 32 | -moz-user-select: none; 33 | -webkit-user-select: none; 34 | -ms-user-select: none; 35 | user-select: none; 36 | pointer-events: none; 37 | z-index: 1; /* TODO Solve this in HTML */ 38 | } 39 | 40 | a, button, input, select { 41 | pointer-events: auto; 42 | } 43 | 44 | .lil-gui { 45 | z-index: 2 !important; /* TODO Solve this in HTML */ 46 | } 47 | 48 | @media all and ( max-width: 640px ) { 49 | .lil-gui.root { 50 | right: auto; 51 | top: auto; 52 | max-height: 50%; 53 | max-width: 80%; 54 | bottom: 0; 55 | left: 0; 56 | } 57 | } 58 | 59 | #overlay { 60 | position: absolute; 61 | font-size: 16px; 62 | z-index: 2; 63 | top: 0; 64 | left: 0; 65 | width: 100%; 66 | height: 100%; 67 | display: flex; 68 | align-items: center; 69 | justify-content: center; 70 | flex-direction: column; 71 | background: rgba(0,0,0,0.7); 72 | } 73 | 74 | #overlay button { 75 | background: transparent; 76 | border: 0; 77 | border: 1px solid rgb(255, 255, 255); 78 | border-radius: 4px; 79 | color: #ffffff; 80 | padding: 12px 18px; 81 | text-transform: uppercase; 82 | cursor: pointer; 83 | } 84 | 85 | #notSupported { 86 | width: 50%; 87 | margin: auto; 88 | background-color: #f00; 89 | margin-top: 20px; 90 | padding: 10px; 91 | } -------------------------------------------------------------------------------- /web/visualization.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../scripts/app.js" 2 | 3 | class Visualizer { 4 | constructor(node, container, visualSrc) { 5 | this.node = node 6 | 7 | this.iframe = document.createElement('iframe') 8 | Object.assign(this.iframe, { 9 | scrolling: "no", 10 | overflow: "hidden", 11 | }) 12 | this.iframe.src = "/extensions/ComfyUI-Flowty-TripoSR/html/" + visualSrc + ".html" 13 | container.appendChild(this.iframe) 14 | } 15 | 16 | updateVisual(params) { 17 | const iframeDocument = this.iframe.contentWindow.document 18 | const previewScript = iframeDocument.getElementById('visualizer') 19 | previewScript.setAttribute("filepath", JSON.stringify(params)) 20 | } 21 | 22 | remove() { 23 | this.container.remove() 24 | } 25 | } 26 | 27 | function createVisualizer(node, inputName, typeName, inputData, app) { 28 | node.name = inputName 29 | 30 | const widget = { 31 | type: typeName, 32 | name: "preview3d", 33 | callback: () => {}, 34 | draw : function(ctx, node, widgetWidth, widgetY, widgetHeight) { 35 | const margin = 10 36 | const top_offset = 5 37 | const visible = app.canvas.ds.scale > 0.5 && this.type === typeName 38 | const w = widgetWidth - margin * 4 39 | const clientRectBound = ctx.canvas.getBoundingClientRect() 40 | const transform = new DOMMatrix() 41 | .scaleSelf( 42 | clientRectBound.width / ctx.canvas.width, 43 | clientRectBound.height / ctx.canvas.height 44 | ) 45 | .multiplySelf(ctx.getTransform()) 46 | .translateSelf(margin, margin + widgetY) 47 | 48 | Object.assign(this.visualizer.style, { 49 | left: `${transform.a * margin + transform.e}px`, 50 | top: `${transform.d + transform.f + top_offset}px`, 51 | width: `${(w * transform.a)}px`, 52 | height: `${(w * transform.d - widgetHeight - (margin * 15) * transform.d)}px`, 53 | position: "absolute", 54 | overflow: "hidden", 55 | zIndex: app.graph._nodes.indexOf(node), 56 | }) 57 | 58 | Object.assign(this.visualizer.children[0].style, { 59 | transformOrigin: "50% 50%", 60 | width: '100%', 61 | height: '100%', 62 | border: '0 none', 63 | }) 64 | 65 | this.visualizer.hidden = !visible 66 | }, 67 | } 68 | 69 | const container = document.createElement('div') 70 | container.id = `Comfy3D_${inputName}` 71 | 72 | node.visualizer = new Visualizer(node, container, typeName) 73 | widget.visualizer = container 74 | widget.parent = node 75 | 76 | document.body.appendChild(widget.visualizer) 77 | 78 | node.addCustomWidget(widget) 79 | 80 | node.updateParameters = (params) => { 81 | node.visualizer.updateVisual(params) 82 | } 83 | 84 | // Events for drawing backgound 85 | node.onDrawBackground = function (ctx) { 86 | if (!this.flags.collapsed) { 87 | node.visualizer.iframe.hidden = false 88 | } else { 89 | node.visualizer.iframe.hidden = true 90 | } 91 | } 92 | 93 | // Make sure visualization iframe is always inside the node when resize the node 94 | node.onResize = function () { 95 | let [w, h] = this.size 96 | if (w <= 600) w = 600 97 | if (h <= 500) h = 500 98 | 99 | if (w > 600) { 100 | h = w - 100 101 | } 102 | 103 | this.size = [w, h] 104 | } 105 | 106 | // Events for remove nodes 107 | node.onRemoved = () => { 108 | for (let w in node.widgets) { 109 | if (node.widgets[w].visualizer) { 110 | node.widgets[w].visualizer.remove() 111 | } 112 | } 113 | } 114 | 115 | 116 | return { 117 | widget: widget, 118 | } 119 | } 120 | 121 | function registerVisualizer(nodeType, nodeData, nodeClassName, typeName){ 122 | if (nodeData.name == nodeClassName) { 123 | console.log("[3D Visualizer] Registering node: " + nodeData.name) 124 | 125 | const onNodeCreated = nodeType.prototype.onNodeCreated 126 | 127 | nodeType.prototype.onNodeCreated = async function() { 128 | const r = onNodeCreated 129 | ? onNodeCreated.apply(this, arguments) 130 | : undefined 131 | 132 | let Preview3DNode = app.graph._nodes.filter( 133 | (wi) => wi.type == nodeClassName 134 | ) 135 | let nodeName = `Preview3DNode_${Preview3DNode.length}` 136 | 137 | console.log(`[Comfy3D] Create: ${nodeName}`) 138 | 139 | const result = await createVisualizer.apply(this, [this, nodeName, typeName, {}, app]) 140 | 141 | this.setSize([600, 500]) 142 | 143 | return r 144 | } 145 | 146 | nodeType.prototype.onExecuted = async function(message) { 147 | if (message?.mesh) { 148 | this.updateParameters(message.mesh[0]) 149 | } 150 | } 151 | } 152 | } 153 | 154 | app.registerExtension({ 155 | name: "Mr.ForExample.Visualizer.GS", 156 | 157 | async init (app) { 158 | 159 | }, 160 | 161 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 162 | registerVisualizer(nodeType, nodeData, "TripoSRViewer", "threeVisualizer") 163 | }, 164 | }) -------------------------------------------------------------------------------- /workflow-sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flowtyone/ComfyUI-Flowty-TripoSR/a0c94ac60a7cc062604f61aeeea6d0d493521de3/workflow-sample.png -------------------------------------------------------------------------------- /workflow_rembg.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 19, 3 | "last_link_id": 26, 4 | "nodes": [ 5 | { 6 | "id": 17, 7 | "type": "ImageRemoveBackground+", 8 | "pos": [ 9 | -457, 10 | 656 11 | ], 12 | "size": { 13 | "0": 241.79998779296875, 14 | "1": 46 15 | }, 16 | "flags": {}, 17 | "order": 3, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "rembg_session", 22 | "type": "REMBG_SESSION", 23 | "link": 22, 24 | "slot_index": 0 25 | }, 26 | { 27 | "name": "image", 28 | "type": "IMAGE", 29 | "link": 21 30 | } 31 | ], 32 | "outputs": [ 33 | { 34 | "name": "IMAGE", 35 | "type": "IMAGE", 36 | "links": [ 37 | 23 38 | ], 39 | "shape": 3, 40 | "slot_index": 0 41 | }, 42 | { 43 | "name": "MASK", 44 | "type": "MASK", 45 | "links": [ 46 | 26 47 | ], 48 | "shape": 3, 49 | "slot_index": 1 50 | } 51 | ], 52 | "properties": { 53 | "Node name for S&R": "ImageRemoveBackground+" 54 | } 55 | }, 56 | { 57 | "id": 14, 58 | "type": "TripoSRModelLoader", 59 | "pos": [ 60 | -870, 61 | 421 62 | ], 63 | "size": { 64 | "0": 315, 65 | "1": 82 66 | }, 67 | "flags": {}, 68 | "order": 0, 69 | "mode": 0, 70 | "outputs": [ 71 | { 72 | "name": "TRIPOSR_MODEL", 73 | "type": "TRIPOSR_MODEL", 74 | "links": [ 75 | 16 76 | ], 77 | "shape": 3 78 | } 79 | ], 80 | "properties": { 81 | "Node name for S&R": "TripoSRModelLoader" 82 | }, 83 | "widgets_values": [ 84 | "model.ckpt", 85 | 8192 86 | ] 87 | }, 88 | { 89 | "id": 18, 90 | "type": "RemBGSession+", 91 | "pos": [ 92 | -868, 93 | 559 94 | ], 95 | "size": { 96 | "0": 315, 97 | "1": 82 98 | }, 99 | "flags": {}, 100 | "order": 1, 101 | "mode": 0, 102 | "outputs": [ 103 | { 104 | "name": "REMBG_SESSION", 105 | "type": "REMBG_SESSION", 106 | "links": [ 107 | 22 108 | ], 109 | "shape": 3 110 | } 111 | ], 112 | "properties": { 113 | "Node name for S&R": "RemBGSession+" 114 | }, 115 | "widgets_values": [ 116 | "u2net: general purpose", 117 | "CPU" 118 | ] 119 | }, 120 | { 121 | "id": 12, 122 | "type": "TripoSRSampler", 123 | "pos": [ 124 | -129, 125 | 598 126 | ], 127 | "size": { 128 | "0": 315, 129 | "1": 122 130 | }, 131 | "flags": {}, 132 | "order": 4, 133 | "mode": 0, 134 | "inputs": [ 135 | { 136 | "name": "model", 137 | "type": "TRIPOSR_MODEL", 138 | "link": 16, 139 | "slot_index": 0 140 | }, 141 | { 142 | "name": "reference_image", 143 | "type": "IMAGE", 144 | "link": 23, 145 | "slot_index": 1 146 | }, 147 | { 148 | "name": "reference_mask", 149 | "type": "MASK", 150 | "link": 26 151 | } 152 | ], 153 | "outputs": [ 154 | { 155 | "name": "MESH", 156 | "type": "MESH", 157 | "links": [ 158 | 15 159 | ], 160 | "shape": 3 161 | } 162 | ], 163 | "properties": { 164 | "Node name for S&R": "TripoSRSampler" 165 | }, 166 | "widgets_values": [ 167 | 256, 168 | 25 169 | ] 170 | }, 171 | { 172 | "id": 13, 173 | "type": "TripoSRViewer", 174 | "pos": [ 175 | -128, 176 | 772 177 | ], 178 | "size": [ 179 | 600, 180 | 500 181 | ], 182 | "flags": {}, 183 | "order": 5, 184 | "mode": 0, 185 | "inputs": [ 186 | { 187 | "name": "mesh", 188 | "type": "MESH", 189 | "link": 15, 190 | "slot_index": 0 191 | } 192 | ], 193 | "properties": { 194 | "Node name for S&R": "TripoSRViewer" 195 | }, 196 | "widgets_values": [ 197 | null 198 | ] 199 | }, 200 | { 201 | "id": 15, 202 | "type": "LoadImage", 203 | "pos": [ 204 | -869, 205 | 696 206 | ], 207 | "size": { 208 | "0": 315, 209 | "1": 314 210 | }, 211 | "flags": {}, 212 | "order": 2, 213 | "mode": 0, 214 | "outputs": [ 215 | { 216 | "name": "IMAGE", 217 | "type": "IMAGE", 218 | "links": [ 219 | 21 220 | ], 221 | "shape": 3, 222 | "slot_index": 0 223 | }, 224 | { 225 | "name": "MASK", 226 | "type": "MASK", 227 | "links": [], 228 | "shape": 3, 229 | "slot_index": 1 230 | } 231 | ], 232 | "properties": { 233 | "Node name for S&R": "LoadImage" 234 | }, 235 | "widgets_values": [ 236 | "marble (1).png", 237 | "image" 238 | ] 239 | } 240 | ], 241 | "links": [ 242 | [ 243 | 15, 244 | 12, 245 | 0, 246 | 13, 247 | 0, 248 | "MESH" 249 | ], 250 | [ 251 | 16, 252 | 14, 253 | 0, 254 | 12, 255 | 0, 256 | "TRIPOSR_MODEL" 257 | ], 258 | [ 259 | 21, 260 | 15, 261 | 0, 262 | 17, 263 | 1, 264 | "IMAGE" 265 | ], 266 | [ 267 | 22, 268 | 18, 269 | 0, 270 | 17, 271 | 0, 272 | "REMBG_SESSION" 273 | ], 274 | [ 275 | 23, 276 | 17, 277 | 0, 278 | 12, 279 | 1, 280 | "IMAGE" 281 | ], 282 | [ 283 | 26, 284 | 17, 285 | 1, 286 | 12, 287 | 2, 288 | "MASK" 289 | ] 290 | ], 291 | "groups": [], 292 | "config": {}, 293 | "extra": {}, 294 | "version": 0.4 295 | } -------------------------------------------------------------------------------- /workflow_simple.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 16, 3 | "last_link_id": 20, 4 | "nodes": [ 5 | { 6 | "id": 13, 7 | "type": "TripoSRViewer", 8 | "pos": [ 9 | 103, 10 | 796 11 | ], 12 | "size": [ 13 | 600, 14 | 500 15 | ], 16 | "flags": {}, 17 | "order": 3, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "mesh", 22 | "type": "MESH", 23 | "link": 15, 24 | "slot_index": 0 25 | } 26 | ], 27 | "properties": { 28 | "Node name for S&R": "TripoSRViewer" 29 | }, 30 | "widgets_values": [ 31 | null 32 | ] 33 | }, 34 | { 35 | "id": 14, 36 | "type": "TripoSRModelLoader", 37 | "pos": [ 38 | -596, 39 | 464 40 | ], 41 | "size": { 42 | "0": 315, 43 | "1": 82 44 | }, 45 | "flags": {}, 46 | "order": 0, 47 | "mode": 0, 48 | "outputs": [ 49 | { 50 | "name": "TRIPOSR_MODEL", 51 | "type": "TRIPOSR_MODEL", 52 | "links": [ 53 | 16 54 | ], 55 | "shape": 3 56 | } 57 | ], 58 | "properties": { 59 | "Node name for S&R": "TripoSRModelLoader" 60 | }, 61 | "widgets_values": [ 62 | "model.ckpt", 63 | 8192 64 | ] 65 | }, 66 | { 67 | "id": 15, 68 | "type": "LoadImage", 69 | "pos": [ 70 | -869, 71 | 696 72 | ], 73 | "size": { 74 | "0": 315, 75 | "1": 314 76 | }, 77 | "flags": {}, 78 | "order": 1, 79 | "mode": 0, 80 | "outputs": [ 81 | { 82 | "name": "IMAGE", 83 | "type": "IMAGE", 84 | "links": [ 85 | 17 86 | ], 87 | "shape": 3 88 | }, 89 | { 90 | "name": "MASK", 91 | "type": "MASK", 92 | "links": [], 93 | "shape": 3, 94 | "slot_index": 1 95 | } 96 | ], 97 | "properties": { 98 | "Node name for S&R": "LoadImage" 99 | }, 100 | "widgets_values": [ 101 | "robot.png", 102 | "image" 103 | ] 104 | }, 105 | { 106 | "id": 12, 107 | "type": "TripoSRSampler", 108 | "pos": [ 109 | -103, 110 | 592 111 | ], 112 | "size": { 113 | "0": 315, 114 | "1": 122 115 | }, 116 | "flags": {}, 117 | "order": 2, 118 | "mode": 0, 119 | "inputs": [ 120 | { 121 | "name": "model", 122 | "type": "TRIPOSR_MODEL", 123 | "link": 16, 124 | "slot_index": 0 125 | }, 126 | { 127 | "name": "reference_image", 128 | "type": "IMAGE", 129 | "link": 17, 130 | "slot_index": 1 131 | }, 132 | { 133 | "name": "reference_mask", 134 | "type": "MASK", 135 | "link": null 136 | } 137 | ], 138 | "outputs": [ 139 | { 140 | "name": "MESH", 141 | "type": "MESH", 142 | "links": [ 143 | 15 144 | ], 145 | "shape": 3 146 | } 147 | ], 148 | "properties": { 149 | "Node name for S&R": "TripoSRSampler" 150 | }, 151 | "widgets_values": [ 152 | 256, 153 | 25 154 | ] 155 | } 156 | ], 157 | "links": [ 158 | [ 159 | 15, 160 | 12, 161 | 0, 162 | 13, 163 | 0, 164 | "MESH" 165 | ], 166 | [ 167 | 16, 168 | 14, 169 | 0, 170 | 12, 171 | 0, 172 | "TRIPOSR_MODEL" 173 | ], 174 | [ 175 | 17, 176 | 15, 177 | 0, 178 | 12, 179 | 1, 180 | "IMAGE" 181 | ] 182 | ], 183 | "groups": [], 184 | "config": {}, 185 | "extra": {}, 186 | "version": 0.4 187 | } --------------------------------------------------------------------------------