├── .github └── workflows │ └── publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── install.py ├── models └── .gitkeep ├── pyproject.toml ├── requirements.txt ├── src ├── __init__.py ├── idm_vton │ ├── __init__.py │ ├── attentionhacked_garmnet.py │ ├── attentionhacked_tryon.py │ ├── transformerhacked_garmnet.py │ ├── transformerhacked_tryon.py │ ├── tryon_pipeline.py │ ├── unet_block_hacked_garmnet.py │ ├── unet_block_hacked_tryon.py │ ├── unet_hacked_garmnet.py │ └── unet_hacked_tryon.py ├── ip_adapter │ ├── __init__.py │ ├── attention_processor.py │ ├── ip_adapter.py │ ├── resampler.py │ ├── test_resampler.py │ └── utils.py ├── logger.py ├── nodes │ ├── __init__.py │ ├── idm_vton.py │ └── pipeline_loader.py └── nodes_mappings.py └── workflow.png /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | jobs: 11 | publish-node: 12 | name: Publish Custom Node to registry 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Check out code 16 | uses: actions/checkout@v4 17 | - name: Publish Custom Node 18 | uses: Comfy-Org/publish-node-action@main 19 | with: 20 | ## Add your own personal access token to your Github Repository secrets and reference it here. 21 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | models/* 2 | !models/.gitkeep 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-IDM-VTON 2 | ComfyUI adaptation of [IDM-VTON](https://github.com/yisol/IDM-VTON). 3 | 4 | ![workflow](workflow.png) 5 | 6 | ## Installation 7 | 8 | :warning: Current implementation requires GPU with at least 16GB of VRAM :warning: 9 | 10 | ### Using ComfyUI Manager: 11 | 12 | - In [ComfyUI Manager](https://github.com/ltdrdata/ComfyUI-Manager), look for ```ComfyUI-IDM-VTON```, and be sure the author is ```TemryL```. Install it. 13 | 14 | ### Manually: 15 | - Clone this repo into `custom_nodes` folder in ComfyUI and install the dependencies. 16 | ```bash 17 | cd custom_nodes 18 | git clone https://github.com/TemryL/ComfyUI-IDM-VTON.git 19 | cd ComfyUI-IDM-VTON 20 | python install.py 21 | ``` 22 | 23 | Models weights from [yisol/IDM-VTON](https://huggingface.co/yisol/IDM-VTON) in [HuggingFace](https://huggingface.co) will be downloaded in [models](models/) folder of this repository. 24 | 25 | ## Mask Generation 26 | The workflow provided above uses [ComfyUI Segment Anything](https://github.com/storyicon/comfyui_segment_anything) to generate the image mask. 27 | 28 | ## DensePose Estimation 29 | DensePose estimation is performed using [ComfyUI's ControlNet Auxiliary Preprocessors](https://github.com/Fannovel16/comfyui_controlnet_aux). 30 | 31 | ## :star: Star Us! 32 | If you find this project useful, please consider giving it a star on GitHub. This helps the project to gain visibility and encourages more contributors to join in. Thank you for your support! 33 | 34 | ## Contribute 35 | Thanks for your interest in contributing to the source code! We welcome help from anyone and appreciate every contribution, no matter how small! 36 | 37 | If you're ready to contribute, please create a fork, make your changes, commit them, and then submit a pull request for review. We'll consider it for integration into the main code base. 38 | 39 | ## Credits 40 | - [ComfyUI](https://github.com/comfyanonymous/ComfyUI) 41 | - [IDM-VTON](https://github.com/yisol/IDM-VTON) 42 | - [ComfyUI Segment Anything](https://github.com/storyicon/comfyui_segment_anything) 43 | - [ComfyUI's ControlNet Auxiliary Preprocessors](https://github.com/Fannovel16/comfyui_controlnet_aux) 44 | 45 | ## License 46 | Original [IDM-VTON](https://github.com/yisol/IDM-VTON) source code under [CC BY-NC-SA 4.0 license](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 47 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .install import * 2 | from .src.nodes_mappings import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 3 | 4 | 5 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import subprocess 4 | from huggingface_hub import snapshot_download 5 | 6 | sys.path.append("../../") 7 | from folder_paths import models_dir 8 | 9 | CUSTOM_NODES_PATH = os.path.dirname(os.path.abspath(__file__)) 10 | WEIGHTS_PATH = os.path.join(models_dir, "IDM-VTON") 11 | HF_REPO_ID = "yisol/IDM-VTON" 12 | 13 | os.makedirs(WEIGHTS_PATH, exist_ok=True) 14 | 15 | def build_pip_install_cmds(args): 16 | if "python_embeded" in sys.executable or "python_embedded" in sys.executable: 17 | return [sys.executable, '-s', '-m', 'pip', 'install'] + args 18 | else: 19 | return [sys.executable, '-m', 'pip', 'install'] + args 20 | 21 | def ensure_package(): 22 | cmds = build_pip_install_cmds(['-r', 'requirements.txt']) 23 | subprocess.run(cmds, cwd=CUSTOM_NODES_PATH) 24 | 25 | 26 | if __name__ == "__main__": 27 | ensure_package() 28 | snapshot_download(repo_id=HF_REPO_ID, local_dir=WEIGHTS_PATH, local_dir_use_symlinks=False) 29 | -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TemryL/ComfyUI-IDM-VTON/5a8334c58c390381e31a8023cb7ba398ade40b39/models/.gitkeep -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-idm-vton" 3 | description = "ComfyUI adaptation of [a/IDM-VTON](https://github.com/yisol/IDM-VTON) for virtual try-on." 4 | version = "1.0.1" 5 | license = { file = "LICENSE" } 6 | dependencies = ["torch", "torchvision", "torchaudio", "accelerate==0.30.0", "torchmetrics==1.4.0", "tqdm==4.66.4", "transformers==4.40.2", "diffusers==0.27.2", "einops==0.8.0", "bitsandbytes==0.42", "scipy==1.13.0", "opencv-python"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/TemryL/ComfyUI-IDM-VTON" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "temryl" 14 | DisplayName = "ComfyUI-IDM-VTON" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchaudio 4 | accelerate==0.30.0 5 | torchmetrics==1.4.0 6 | tqdm==4.66.4 7 | transformers==4.40.2 8 | diffusers==0.27.2 9 | einops==0.8.0 10 | bitsandbytes==0.42 11 | scipy==1.13.0 12 | opencv-python -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TemryL/ComfyUI-IDM-VTON/5a8334c58c390381e31a8023cb7ba398ade40b39/src/__init__.py -------------------------------------------------------------------------------- /src/idm_vton/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TemryL/ComfyUI-IDM-VTON/5a8334c58c390381e31a8023cb7ba398ade40b39/src/idm_vton/__init__.py -------------------------------------------------------------------------------- /src/idm_vton/attentionhacked_garmnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Dict, Optional 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from torch import nn 19 | 20 | from diffusers.utils import USE_PEFT_BACKEND 21 | from diffusers.utils.torch_utils import maybe_allow_in_graph 22 | from diffusers.models.activations import GEGLU, GELU, ApproximateGELU 23 | from diffusers.models.attention_processor import Attention 24 | from diffusers.models.embeddings import SinusoidalPositionalEmbedding 25 | from diffusers.models.lora import LoRACompatibleLinear 26 | from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm 27 | 28 | 29 | def _chunked_feed_forward( 30 | ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None 31 | ): 32 | # "feed_forward_chunk_size" can be used to save memory 33 | if hidden_states.shape[chunk_dim] % chunk_size != 0: 34 | raise ValueError( 35 | f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." 36 | ) 37 | 38 | num_chunks = hidden_states.shape[chunk_dim] // chunk_size 39 | if lora_scale is None: 40 | ff_output = torch.cat( 41 | [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], 42 | dim=chunk_dim, 43 | ) 44 | else: 45 | # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete 46 | ff_output = torch.cat( 47 | [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], 48 | dim=chunk_dim, 49 | ) 50 | 51 | return ff_output 52 | 53 | 54 | @maybe_allow_in_graph 55 | class GatedSelfAttentionDense(nn.Module): 56 | r""" 57 | A gated self-attention dense layer that combines visual features and object features. 58 | 59 | Parameters: 60 | query_dim (`int`): The number of channels in the query. 61 | context_dim (`int`): The number of channels in the context. 62 | n_heads (`int`): The number of heads to use for attention. 63 | d_head (`int`): The number of channels in each head. 64 | """ 65 | 66 | def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): 67 | super().__init__() 68 | 69 | # we need a linear projection since we need cat visual feature and obj feature 70 | self.linear = nn.Linear(context_dim, query_dim) 71 | 72 | self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) 73 | self.ff = FeedForward(query_dim, activation_fn="geglu") 74 | 75 | self.norm1 = nn.LayerNorm(query_dim) 76 | self.norm2 = nn.LayerNorm(query_dim) 77 | 78 | self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) 79 | self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) 80 | 81 | self.enabled = True 82 | 83 | def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: 84 | if not self.enabled: 85 | return x 86 | 87 | n_visual = x.shape[1] 88 | objs = self.linear(objs) 89 | 90 | x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] 91 | x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) 92 | 93 | return x 94 | 95 | 96 | @maybe_allow_in_graph 97 | class BasicTransformerBlock(nn.Module): 98 | r""" 99 | A basic Transformer block. 100 | 101 | Parameters: 102 | dim (`int`): The number of channels in the input and output. 103 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 104 | attention_head_dim (`int`): The number of channels in each head. 105 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 106 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 107 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 108 | num_embeds_ada_norm (: 109 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. 110 | attention_bias (: 111 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. 112 | only_cross_attention (`bool`, *optional*): 113 | Whether to use only cross-attention layers. In this case two cross attention layers are used. 114 | double_self_attention (`bool`, *optional*): 115 | Whether to use two self-attention layers. In this case no cross attention layers are used. 116 | upcast_attention (`bool`, *optional*): 117 | Whether to upcast the attention computation to float32. This is useful for mixed precision training. 118 | norm_elementwise_affine (`bool`, *optional*, defaults to `True`): 119 | Whether to use learnable elementwise affine parameters for normalization. 120 | norm_type (`str`, *optional*, defaults to `"layer_norm"`): 121 | The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. 122 | final_dropout (`bool` *optional*, defaults to False): 123 | Whether to apply a final dropout after the last feed-forward layer. 124 | attention_type (`str`, *optional*, defaults to `"default"`): 125 | The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. 126 | positional_embeddings (`str`, *optional*, defaults to `None`): 127 | The type of positional embeddings to apply to. 128 | num_positional_embeddings (`int`, *optional*, defaults to `None`): 129 | The maximum number of positional embeddings to apply. 130 | """ 131 | 132 | def __init__( 133 | self, 134 | dim: int, 135 | num_attention_heads: int, 136 | attention_head_dim: int, 137 | dropout=0.0, 138 | cross_attention_dim: Optional[int] = None, 139 | activation_fn: str = "geglu", 140 | num_embeds_ada_norm: Optional[int] = None, 141 | attention_bias: bool = False, 142 | only_cross_attention: bool = False, 143 | double_self_attention: bool = False, 144 | upcast_attention: bool = False, 145 | norm_elementwise_affine: bool = True, 146 | norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' 147 | norm_eps: float = 1e-5, 148 | final_dropout: bool = False, 149 | attention_type: str = "default", 150 | positional_embeddings: Optional[str] = None, 151 | num_positional_embeddings: Optional[int] = None, 152 | ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, 153 | ada_norm_bias: Optional[int] = None, 154 | ff_inner_dim: Optional[int] = None, 155 | ff_bias: bool = True, 156 | attention_out_bias: bool = True, 157 | ): 158 | super().__init__() 159 | self.only_cross_attention = only_cross_attention 160 | 161 | self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" 162 | self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" 163 | self.use_ada_layer_norm_single = norm_type == "ada_norm_single" 164 | self.use_layer_norm = norm_type == "layer_norm" 165 | self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" 166 | 167 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: 168 | raise ValueError( 169 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" 170 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." 171 | ) 172 | 173 | if positional_embeddings and (num_positional_embeddings is None): 174 | raise ValueError( 175 | "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." 176 | ) 177 | 178 | if positional_embeddings == "sinusoidal": 179 | self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) 180 | else: 181 | self.pos_embed = None 182 | 183 | # Define 3 blocks. Each block has its own normalization layer. 184 | # 1. Self-Attn 185 | if self.use_ada_layer_norm: 186 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) 187 | elif self.use_ada_layer_norm_zero: 188 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) 189 | elif self.use_ada_layer_norm_continuous: 190 | self.norm1 = AdaLayerNormContinuous( 191 | dim, 192 | ada_norm_continous_conditioning_embedding_dim, 193 | norm_elementwise_affine, 194 | norm_eps, 195 | ada_norm_bias, 196 | "rms_norm", 197 | ) 198 | else: 199 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) 200 | 201 | self.attn1 = Attention( 202 | query_dim=dim, 203 | heads=num_attention_heads, 204 | dim_head=attention_head_dim, 205 | dropout=dropout, 206 | bias=attention_bias, 207 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 208 | upcast_attention=upcast_attention, 209 | out_bias=attention_out_bias, 210 | ) 211 | 212 | # 2. Cross-Attn 213 | if cross_attention_dim is not None or double_self_attention: 214 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 215 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 216 | # the second cross attention block. 217 | if self.use_ada_layer_norm: 218 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) 219 | elif self.use_ada_layer_norm_continuous: 220 | self.norm2 = AdaLayerNormContinuous( 221 | dim, 222 | ada_norm_continous_conditioning_embedding_dim, 223 | norm_elementwise_affine, 224 | norm_eps, 225 | ada_norm_bias, 226 | "rms_norm", 227 | ) 228 | else: 229 | self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) 230 | 231 | self.attn2 = Attention( 232 | query_dim=dim, 233 | cross_attention_dim=cross_attention_dim if not double_self_attention else None, 234 | heads=num_attention_heads, 235 | dim_head=attention_head_dim, 236 | dropout=dropout, 237 | bias=attention_bias, 238 | upcast_attention=upcast_attention, 239 | out_bias=attention_out_bias, 240 | ) # is self-attn if encoder_hidden_states is none 241 | else: 242 | self.norm2 = None 243 | self.attn2 = None 244 | 245 | # 3. Feed-forward 246 | if self.use_ada_layer_norm_continuous: 247 | self.norm3 = AdaLayerNormContinuous( 248 | dim, 249 | ada_norm_continous_conditioning_embedding_dim, 250 | norm_elementwise_affine, 251 | norm_eps, 252 | ada_norm_bias, 253 | "layer_norm", 254 | ) 255 | elif not self.use_ada_layer_norm_single: 256 | self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) 257 | 258 | self.ff = FeedForward( 259 | dim, 260 | dropout=dropout, 261 | activation_fn=activation_fn, 262 | final_dropout=final_dropout, 263 | inner_dim=ff_inner_dim, 264 | bias=ff_bias, 265 | ) 266 | 267 | # 4. Fuser 268 | if attention_type == "gated" or attention_type == "gated-text-image": 269 | self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) 270 | 271 | # 5. Scale-shift for PixArt-Alpha. 272 | if self.use_ada_layer_norm_single: 273 | self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) 274 | 275 | # let chunk size default to None 276 | self._chunk_size = None 277 | self._chunk_dim = 0 278 | 279 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): 280 | # Sets chunk feed-forward 281 | self._chunk_size = chunk_size 282 | self._chunk_dim = dim 283 | 284 | def forward( 285 | self, 286 | hidden_states: torch.FloatTensor, 287 | attention_mask: Optional[torch.FloatTensor] = None, 288 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 289 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 290 | timestep: Optional[torch.LongTensor] = None, 291 | cross_attention_kwargs: Dict[str, Any] = None, 292 | class_labels: Optional[torch.LongTensor] = None, 293 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 294 | ) -> torch.FloatTensor: 295 | # Notice that normalization is always applied before the real computation in the following blocks. 296 | # 0. Self-Attention 297 | batch_size = hidden_states.shape[0] 298 | if self.use_ada_layer_norm: 299 | norm_hidden_states = self.norm1(hidden_states, timestep) 300 | elif self.use_ada_layer_norm_zero: 301 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 302 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 303 | ) 304 | elif self.use_layer_norm: 305 | norm_hidden_states = self.norm1(hidden_states) 306 | elif self.use_ada_layer_norm_continuous: 307 | norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) 308 | elif self.use_ada_layer_norm_single: 309 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 310 | self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) 311 | ).chunk(6, dim=1) 312 | norm_hidden_states = self.norm1(hidden_states) 313 | norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa 314 | norm_hidden_states = norm_hidden_states.squeeze(1) 315 | else: 316 | raise ValueError("Incorrect norm used") 317 | 318 | if self.pos_embed is not None: 319 | norm_hidden_states = self.pos_embed(norm_hidden_states) 320 | 321 | garment_features = [] 322 | garment_features.append(norm_hidden_states) 323 | 324 | # 1. Retrieve lora scale. 325 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 326 | 327 | # 2. Prepare GLIGEN inputs 328 | cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 329 | gligen_kwargs = cross_attention_kwargs.pop("gligen", None) 330 | 331 | attn_output = self.attn1( 332 | norm_hidden_states, 333 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 334 | attention_mask=attention_mask, 335 | **cross_attention_kwargs, 336 | ) 337 | if self.use_ada_layer_norm_zero: 338 | attn_output = gate_msa.unsqueeze(1) * attn_output 339 | elif self.use_ada_layer_norm_single: 340 | attn_output = gate_msa * attn_output 341 | 342 | hidden_states = attn_output + hidden_states 343 | if hidden_states.ndim == 4: 344 | hidden_states = hidden_states.squeeze(1) 345 | 346 | # 2.5 GLIGEN Control 347 | if gligen_kwargs is not None: 348 | hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) 349 | 350 | # 3. Cross-Attention 351 | if self.attn2 is not None: 352 | if self.use_ada_layer_norm: 353 | norm_hidden_states = self.norm2(hidden_states, timestep) 354 | elif self.use_ada_layer_norm_zero or self.use_layer_norm: 355 | norm_hidden_states = self.norm2(hidden_states) 356 | elif self.use_ada_layer_norm_single: 357 | # For PixArt norm2 isn't applied here: 358 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 359 | norm_hidden_states = hidden_states 360 | elif self.use_ada_layer_norm_continuous: 361 | norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) 362 | else: 363 | raise ValueError("Incorrect norm") 364 | 365 | if self.pos_embed is not None and self.use_ada_layer_norm_single is False: 366 | norm_hidden_states = self.pos_embed(norm_hidden_states) 367 | 368 | attn_output = self.attn2( 369 | norm_hidden_states, 370 | encoder_hidden_states=encoder_hidden_states, 371 | attention_mask=encoder_attention_mask, 372 | **cross_attention_kwargs, 373 | ) 374 | hidden_states = attn_output + hidden_states 375 | 376 | # 4. Feed-forward 377 | if self.use_ada_layer_norm_continuous: 378 | norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) 379 | elif not self.use_ada_layer_norm_single: 380 | norm_hidden_states = self.norm3(hidden_states) 381 | 382 | if self.use_ada_layer_norm_zero: 383 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 384 | 385 | if self.use_ada_layer_norm_single: 386 | norm_hidden_states = self.norm2(hidden_states) 387 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp 388 | 389 | if self._chunk_size is not None: 390 | # "feed_forward_chunk_size" can be used to save memory 391 | ff_output = _chunked_feed_forward( 392 | self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale 393 | ) 394 | else: 395 | ff_output = self.ff(norm_hidden_states, scale=lora_scale) 396 | 397 | if self.use_ada_layer_norm_zero: 398 | ff_output = gate_mlp.unsqueeze(1) * ff_output 399 | elif self.use_ada_layer_norm_single: 400 | ff_output = gate_mlp * ff_output 401 | 402 | hidden_states = ff_output + hidden_states 403 | if hidden_states.ndim == 4: 404 | hidden_states = hidden_states.squeeze(1) 405 | 406 | return hidden_states, garment_features 407 | 408 | 409 | @maybe_allow_in_graph 410 | class TemporalBasicTransformerBlock(nn.Module): 411 | r""" 412 | A basic Transformer block for video like data. 413 | 414 | Parameters: 415 | dim (`int`): The number of channels in the input and output. 416 | time_mix_inner_dim (`int`): The number of channels for temporal attention. 417 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 418 | attention_head_dim (`int`): The number of channels in each head. 419 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 420 | """ 421 | 422 | def __init__( 423 | self, 424 | dim: int, 425 | time_mix_inner_dim: int, 426 | num_attention_heads: int, 427 | attention_head_dim: int, 428 | cross_attention_dim: Optional[int] = None, 429 | ): 430 | super().__init__() 431 | self.is_res = dim == time_mix_inner_dim 432 | 433 | self.norm_in = nn.LayerNorm(dim) 434 | 435 | # Define 3 blocks. Each block has its own normalization layer. 436 | # 1. Self-Attn 437 | self.norm_in = nn.LayerNorm(dim) 438 | self.ff_in = FeedForward( 439 | dim, 440 | dim_out=time_mix_inner_dim, 441 | activation_fn="geglu", 442 | ) 443 | 444 | self.norm1 = nn.LayerNorm(time_mix_inner_dim) 445 | self.attn1 = Attention( 446 | query_dim=time_mix_inner_dim, 447 | heads=num_attention_heads, 448 | dim_head=attention_head_dim, 449 | cross_attention_dim=None, 450 | ) 451 | 452 | # 2. Cross-Attn 453 | if cross_attention_dim is not None: 454 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 455 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 456 | # the second cross attention block. 457 | self.norm2 = nn.LayerNorm(time_mix_inner_dim) 458 | self.attn2 = Attention( 459 | query_dim=time_mix_inner_dim, 460 | cross_attention_dim=cross_attention_dim, 461 | heads=num_attention_heads, 462 | dim_head=attention_head_dim, 463 | ) # is self-attn if encoder_hidden_states is none 464 | else: 465 | self.norm2 = None 466 | self.attn2 = None 467 | 468 | # 3. Feed-forward 469 | self.norm3 = nn.LayerNorm(time_mix_inner_dim) 470 | self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu") 471 | 472 | # let chunk size default to None 473 | self._chunk_size = None 474 | self._chunk_dim = None 475 | 476 | def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs): 477 | # Sets chunk feed-forward 478 | self._chunk_size = chunk_size 479 | # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off 480 | self._chunk_dim = 1 481 | 482 | def forward( 483 | self, 484 | hidden_states: torch.FloatTensor, 485 | num_frames: int, 486 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 487 | ) -> torch.FloatTensor: 488 | # Notice that normalization is always applied before the real computation in the following blocks. 489 | # 0. Self-Attention 490 | batch_size = hidden_states.shape[0] 491 | 492 | batch_frames, seq_length, channels = hidden_states.shape 493 | batch_size = batch_frames // num_frames 494 | 495 | hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) 496 | hidden_states = hidden_states.permute(0, 2, 1, 3) 497 | hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) 498 | 499 | residual = hidden_states 500 | hidden_states = self.norm_in(hidden_states) 501 | 502 | if self._chunk_size is not None: 503 | hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size) 504 | else: 505 | hidden_states = self.ff_in(hidden_states) 506 | 507 | if self.is_res: 508 | hidden_states = hidden_states + residual 509 | 510 | norm_hidden_states = self.norm1(hidden_states) 511 | attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) 512 | hidden_states = attn_output + hidden_states 513 | 514 | # 3. Cross-Attention 515 | if self.attn2 is not None: 516 | norm_hidden_states = self.norm2(hidden_states) 517 | attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) 518 | hidden_states = attn_output + hidden_states 519 | 520 | # 4. Feed-forward 521 | norm_hidden_states = self.norm3(hidden_states) 522 | 523 | if self._chunk_size is not None: 524 | ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) 525 | else: 526 | ff_output = self.ff(norm_hidden_states) 527 | 528 | if self.is_res: 529 | hidden_states = ff_output + hidden_states 530 | else: 531 | hidden_states = ff_output 532 | 533 | hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) 534 | hidden_states = hidden_states.permute(0, 2, 1, 3) 535 | hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) 536 | 537 | return hidden_states 538 | 539 | 540 | class SkipFFTransformerBlock(nn.Module): 541 | def __init__( 542 | self, 543 | dim: int, 544 | num_attention_heads: int, 545 | attention_head_dim: int, 546 | kv_input_dim: int, 547 | kv_input_dim_proj_use_bias: bool, 548 | dropout=0.0, 549 | cross_attention_dim: Optional[int] = None, 550 | attention_bias: bool = False, 551 | attention_out_bias: bool = True, 552 | ): 553 | super().__init__() 554 | if kv_input_dim != dim: 555 | self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias) 556 | else: 557 | self.kv_mapper = None 558 | 559 | self.norm1 = RMSNorm(dim, 1e-06) 560 | 561 | self.attn1 = Attention( 562 | query_dim=dim, 563 | heads=num_attention_heads, 564 | dim_head=attention_head_dim, 565 | dropout=dropout, 566 | bias=attention_bias, 567 | cross_attention_dim=cross_attention_dim, 568 | out_bias=attention_out_bias, 569 | ) 570 | 571 | self.norm2 = RMSNorm(dim, 1e-06) 572 | 573 | self.attn2 = Attention( 574 | query_dim=dim, 575 | cross_attention_dim=cross_attention_dim, 576 | heads=num_attention_heads, 577 | dim_head=attention_head_dim, 578 | dropout=dropout, 579 | bias=attention_bias, 580 | out_bias=attention_out_bias, 581 | ) 582 | 583 | def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs): 584 | cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 585 | 586 | if self.kv_mapper is not None: 587 | encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) 588 | 589 | norm_hidden_states = self.norm1(hidden_states) 590 | 591 | attn_output = self.attn1( 592 | norm_hidden_states, 593 | encoder_hidden_states=encoder_hidden_states, 594 | **cross_attention_kwargs, 595 | ) 596 | 597 | hidden_states = attn_output + hidden_states 598 | 599 | norm_hidden_states = self.norm2(hidden_states) 600 | 601 | attn_output = self.attn2( 602 | norm_hidden_states, 603 | encoder_hidden_states=encoder_hidden_states, 604 | **cross_attention_kwargs, 605 | ) 606 | 607 | hidden_states = attn_output + hidden_states 608 | 609 | return hidden_states 610 | 611 | 612 | class FeedForward(nn.Module): 613 | r""" 614 | A feed-forward layer. 615 | 616 | Parameters: 617 | dim (`int`): The number of channels in the input. 618 | dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. 619 | mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. 620 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 621 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 622 | final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. 623 | bias (`bool`, defaults to True): Whether to use a bias in the linear layer. 624 | """ 625 | 626 | def __init__( 627 | self, 628 | dim: int, 629 | dim_out: Optional[int] = None, 630 | mult: int = 4, 631 | dropout: float = 0.0, 632 | activation_fn: str = "geglu", 633 | final_dropout: bool = False, 634 | inner_dim=None, 635 | bias: bool = True, 636 | ): 637 | super().__init__() 638 | if inner_dim is None: 639 | inner_dim = int(dim * mult) 640 | dim_out = dim_out if dim_out is not None else dim 641 | linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear 642 | 643 | if activation_fn == "gelu": 644 | act_fn = GELU(dim, inner_dim, bias=bias) 645 | if activation_fn == "gelu-approximate": 646 | act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) 647 | elif activation_fn == "geglu": 648 | act_fn = GEGLU(dim, inner_dim, bias=bias) 649 | elif activation_fn == "geglu-approximate": 650 | act_fn = ApproximateGELU(dim, inner_dim, bias=bias) 651 | 652 | self.net = nn.ModuleList([]) 653 | # project in 654 | self.net.append(act_fn) 655 | # project dropout 656 | self.net.append(nn.Dropout(dropout)) 657 | # project out 658 | self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) 659 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout 660 | if final_dropout: 661 | self.net.append(nn.Dropout(dropout)) 662 | 663 | def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: 664 | compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) 665 | for module in self.net: 666 | if isinstance(module, compatible_cls): 667 | hidden_states = module(hidden_states, scale) 668 | else: 669 | hidden_states = module(hidden_states) 670 | return hidden_states 671 | -------------------------------------------------------------------------------- /src/idm_vton/attentionhacked_tryon.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Dict, Optional 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from torch import nn 19 | 20 | from diffusers.utils import USE_PEFT_BACKEND 21 | from diffusers.utils.torch_utils import maybe_allow_in_graph 22 | from diffusers.models.activations import GEGLU, GELU, ApproximateGELU 23 | from diffusers.models.attention_processor import Attention 24 | from diffusers.models.embeddings import SinusoidalPositionalEmbedding 25 | from diffusers.models.lora import LoRACompatibleLinear 26 | from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm 27 | 28 | 29 | def _chunked_feed_forward( 30 | ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None 31 | ): 32 | # "feed_forward_chunk_size" can be used to save memory 33 | if hidden_states.shape[chunk_dim] % chunk_size != 0: 34 | raise ValueError( 35 | f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." 36 | ) 37 | 38 | num_chunks = hidden_states.shape[chunk_dim] // chunk_size 39 | if lora_scale is None: 40 | ff_output = torch.cat( 41 | [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], 42 | dim=chunk_dim, 43 | ) 44 | else: 45 | # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete 46 | ff_output = torch.cat( 47 | [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], 48 | dim=chunk_dim, 49 | ) 50 | 51 | return ff_output 52 | 53 | 54 | @maybe_allow_in_graph 55 | class GatedSelfAttentionDense(nn.Module): 56 | r""" 57 | A gated self-attention dense layer that combines visual features and object features. 58 | 59 | Parameters: 60 | query_dim (`int`): The number of channels in the query. 61 | context_dim (`int`): The number of channels in the context. 62 | n_heads (`int`): The number of heads to use for attention. 63 | d_head (`int`): The number of channels in each head. 64 | """ 65 | 66 | def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): 67 | super().__init__() 68 | 69 | # we need a linear projection since we need cat visual feature and obj feature 70 | self.linear = nn.Linear(context_dim, query_dim) 71 | 72 | self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) 73 | self.ff = FeedForward(query_dim, activation_fn="geglu") 74 | 75 | self.norm1 = nn.LayerNorm(query_dim) 76 | self.norm2 = nn.LayerNorm(query_dim) 77 | 78 | self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) 79 | self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) 80 | 81 | self.enabled = True 82 | 83 | def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: 84 | if not self.enabled: 85 | return x 86 | 87 | n_visual = x.shape[1] 88 | objs = self.linear(objs) 89 | 90 | x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] 91 | x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) 92 | 93 | return x 94 | 95 | 96 | @maybe_allow_in_graph 97 | class BasicTransformerBlock(nn.Module): 98 | r""" 99 | A basic Transformer block. 100 | 101 | Parameters: 102 | dim (`int`): The number of channels in the input and output. 103 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 104 | attention_head_dim (`int`): The number of channels in each head. 105 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 106 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 107 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 108 | num_embeds_ada_norm (: 109 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. 110 | attention_bias (: 111 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. 112 | only_cross_attention (`bool`, *optional*): 113 | Whether to use only cross-attention layers. In this case two cross attention layers are used. 114 | double_self_attention (`bool`, *optional*): 115 | Whether to use two self-attention layers. In this case no cross attention layers are used. 116 | upcast_attention (`bool`, *optional*): 117 | Whether to upcast the attention computation to float32. This is useful for mixed precision training. 118 | norm_elementwise_affine (`bool`, *optional*, defaults to `True`): 119 | Whether to use learnable elementwise affine parameters for normalization. 120 | norm_type (`str`, *optional*, defaults to `"layer_norm"`): 121 | The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. 122 | final_dropout (`bool` *optional*, defaults to False): 123 | Whether to apply a final dropout after the last feed-forward layer. 124 | attention_type (`str`, *optional*, defaults to `"default"`): 125 | The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. 126 | positional_embeddings (`str`, *optional*, defaults to `None`): 127 | The type of positional embeddings to apply to. 128 | num_positional_embeddings (`int`, *optional*, defaults to `None`): 129 | The maximum number of positional embeddings to apply. 130 | """ 131 | 132 | def __init__( 133 | self, 134 | dim: int, 135 | num_attention_heads: int, 136 | attention_head_dim: int, 137 | dropout=0.0, 138 | cross_attention_dim: Optional[int] = None, 139 | activation_fn: str = "geglu", 140 | num_embeds_ada_norm: Optional[int] = None, 141 | attention_bias: bool = False, 142 | only_cross_attention: bool = False, 143 | double_self_attention: bool = False, 144 | upcast_attention: bool = False, 145 | norm_elementwise_affine: bool = True, 146 | norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' 147 | norm_eps: float = 1e-5, 148 | final_dropout: bool = False, 149 | attention_type: str = "default", 150 | positional_embeddings: Optional[str] = None, 151 | num_positional_embeddings: Optional[int] = None, 152 | ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, 153 | ada_norm_bias: Optional[int] = None, 154 | ff_inner_dim: Optional[int] = None, 155 | ff_bias: bool = True, 156 | attention_out_bias: bool = True, 157 | ): 158 | super().__init__() 159 | self.only_cross_attention = only_cross_attention 160 | 161 | self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" 162 | self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" 163 | self.use_ada_layer_norm_single = norm_type == "ada_norm_single" 164 | self.use_layer_norm = norm_type == "layer_norm" 165 | self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" 166 | 167 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: 168 | raise ValueError( 169 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" 170 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." 171 | ) 172 | 173 | if positional_embeddings and (num_positional_embeddings is None): 174 | raise ValueError( 175 | "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." 176 | ) 177 | 178 | if positional_embeddings == "sinusoidal": 179 | self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) 180 | else: 181 | self.pos_embed = None 182 | 183 | # Define 3 blocks. Each block has its own normalization layer. 184 | # 1. Self-Attn 185 | if self.use_ada_layer_norm: 186 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) 187 | elif self.use_ada_layer_norm_zero: 188 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) 189 | elif self.use_ada_layer_norm_continuous: 190 | self.norm1 = AdaLayerNormContinuous( 191 | dim, 192 | ada_norm_continous_conditioning_embedding_dim, 193 | norm_elementwise_affine, 194 | norm_eps, 195 | ada_norm_bias, 196 | "rms_norm", 197 | ) 198 | else: 199 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) 200 | 201 | self.attn1 = Attention( 202 | query_dim=dim, 203 | heads=num_attention_heads, 204 | dim_head=attention_head_dim, 205 | dropout=dropout, 206 | bias=attention_bias, 207 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 208 | upcast_attention=upcast_attention, 209 | out_bias=attention_out_bias, 210 | ) 211 | 212 | # 2. Cross-Attn 213 | if cross_attention_dim is not None or double_self_attention: 214 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 215 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 216 | # the second cross attention block. 217 | if self.use_ada_layer_norm: 218 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) 219 | elif self.use_ada_layer_norm_continuous: 220 | self.norm2 = AdaLayerNormContinuous( 221 | dim, 222 | ada_norm_continous_conditioning_embedding_dim, 223 | norm_elementwise_affine, 224 | norm_eps, 225 | ada_norm_bias, 226 | "rms_norm", 227 | ) 228 | else: 229 | self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) 230 | 231 | self.attn2 = Attention( 232 | query_dim=dim, 233 | cross_attention_dim=cross_attention_dim if not double_self_attention else None, 234 | heads=num_attention_heads, 235 | dim_head=attention_head_dim, 236 | dropout=dropout, 237 | bias=attention_bias, 238 | upcast_attention=upcast_attention, 239 | out_bias=attention_out_bias, 240 | ) # is self-attn if encoder_hidden_states is none 241 | else: 242 | self.norm2 = None 243 | self.attn2 = None 244 | 245 | # 3. Feed-forward 246 | if self.use_ada_layer_norm_continuous: 247 | self.norm3 = AdaLayerNormContinuous( 248 | dim, 249 | ada_norm_continous_conditioning_embedding_dim, 250 | norm_elementwise_affine, 251 | norm_eps, 252 | ada_norm_bias, 253 | "layer_norm", 254 | ) 255 | elif not self.use_ada_layer_norm_single: 256 | self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) 257 | 258 | self.ff = FeedForward( 259 | dim, 260 | dropout=dropout, 261 | activation_fn=activation_fn, 262 | final_dropout=final_dropout, 263 | inner_dim=ff_inner_dim, 264 | bias=ff_bias, 265 | ) 266 | 267 | # 4. Fuser 268 | if attention_type == "gated" or attention_type == "gated-text-image": 269 | self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) 270 | 271 | # 5. Scale-shift for PixArt-Alpha. 272 | if self.use_ada_layer_norm_single: 273 | self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) 274 | 275 | # let chunk size default to None 276 | self._chunk_size = None 277 | self._chunk_dim = 0 278 | 279 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): 280 | # Sets chunk feed-forward 281 | self._chunk_size = chunk_size 282 | self._chunk_dim = dim 283 | 284 | def forward( 285 | self, 286 | hidden_states: torch.FloatTensor, 287 | attention_mask: Optional[torch.FloatTensor] = None, 288 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 289 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 290 | timestep: Optional[torch.LongTensor] = None, 291 | cross_attention_kwargs: Dict[str, Any] = None, 292 | class_labels: Optional[torch.LongTensor] = None, 293 | garment_features=None, 294 | curr_garment_feat_idx=0, 295 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 296 | ) -> torch.FloatTensor: 297 | # Notice that normalization is always applied before the real computation in the following blocks. 298 | # 0. Self-Attention 299 | batch_size = hidden_states.shape[0] 300 | 301 | 302 | 303 | if self.use_ada_layer_norm: 304 | norm_hidden_states = self.norm1(hidden_states, timestep) 305 | elif self.use_ada_layer_norm_zero: 306 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 307 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 308 | ) 309 | elif self.use_layer_norm: 310 | norm_hidden_states = self.norm1(hidden_states) 311 | elif self.use_ada_layer_norm_continuous: 312 | norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) 313 | elif self.use_ada_layer_norm_single: 314 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 315 | self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) 316 | ).chunk(6, dim=1) 317 | norm_hidden_states = self.norm1(hidden_states) 318 | norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa 319 | norm_hidden_states = norm_hidden_states.squeeze(1) 320 | else: 321 | raise ValueError("Incorrect norm used") 322 | 323 | if self.pos_embed is not None: 324 | norm_hidden_states = self.pos_embed(norm_hidden_states) 325 | 326 | # 1. Retrieve lora scale. 327 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 328 | 329 | # 2. Prepare GLIGEN inputs 330 | cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 331 | gligen_kwargs = cross_attention_kwargs.pop("gligen", None) 332 | 333 | 334 | modify_norm_hidden_states = torch.cat([norm_hidden_states,garment_features[curr_garment_feat_idx]], dim=1) 335 | curr_garment_feat_idx +=1 336 | attn_output = self.attn1( 337 | #norm_hidden_states, 338 | modify_norm_hidden_states, 339 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 340 | attention_mask=attention_mask, 341 | **cross_attention_kwargs, 342 | ) 343 | if self.use_ada_layer_norm_zero: 344 | attn_output = gate_msa.unsqueeze(1) * attn_output 345 | elif self.use_ada_layer_norm_single: 346 | attn_output = gate_msa * attn_output 347 | 348 | hidden_states = attn_output[:,:hidden_states.shape[-2],:] + hidden_states 349 | 350 | 351 | 352 | 353 | if hidden_states.ndim == 4: 354 | hidden_states = hidden_states.squeeze(1) 355 | 356 | # 2.5 GLIGEN Control 357 | if gligen_kwargs is not None: 358 | hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) 359 | 360 | # 3. Cross-Attention 361 | if self.attn2 is not None: 362 | if self.use_ada_layer_norm: 363 | norm_hidden_states = self.norm2(hidden_states, timestep) 364 | elif self.use_ada_layer_norm_zero or self.use_layer_norm: 365 | norm_hidden_states = self.norm2(hidden_states) 366 | elif self.use_ada_layer_norm_single: 367 | # For PixArt norm2 isn't applied here: 368 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 369 | norm_hidden_states = hidden_states 370 | elif self.use_ada_layer_norm_continuous: 371 | norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) 372 | else: 373 | raise ValueError("Incorrect norm") 374 | 375 | if self.pos_embed is not None and self.use_ada_layer_norm_single is False: 376 | norm_hidden_states = self.pos_embed(norm_hidden_states) 377 | 378 | attn_output = self.attn2( 379 | norm_hidden_states, 380 | encoder_hidden_states=encoder_hidden_states, 381 | attention_mask=encoder_attention_mask, 382 | **cross_attention_kwargs, 383 | ) 384 | hidden_states = attn_output + hidden_states 385 | 386 | # 4. Feed-forward 387 | if self.use_ada_layer_norm_continuous: 388 | norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) 389 | elif not self.use_ada_layer_norm_single: 390 | norm_hidden_states = self.norm3(hidden_states) 391 | 392 | if self.use_ada_layer_norm_zero: 393 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 394 | 395 | if self.use_ada_layer_norm_single: 396 | norm_hidden_states = self.norm2(hidden_states) 397 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp 398 | 399 | if self._chunk_size is not None: 400 | # "feed_forward_chunk_size" can be used to save memory 401 | ff_output = _chunked_feed_forward( 402 | self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale 403 | ) 404 | else: 405 | ff_output = self.ff(norm_hidden_states, scale=lora_scale) 406 | 407 | if self.use_ada_layer_norm_zero: 408 | ff_output = gate_mlp.unsqueeze(1) * ff_output 409 | elif self.use_ada_layer_norm_single: 410 | ff_output = gate_mlp * ff_output 411 | 412 | hidden_states = ff_output + hidden_states 413 | if hidden_states.ndim == 4: 414 | hidden_states = hidden_states.squeeze(1) 415 | return hidden_states,curr_garment_feat_idx 416 | 417 | 418 | @maybe_allow_in_graph 419 | class TemporalBasicTransformerBlock(nn.Module): 420 | r""" 421 | A basic Transformer block for video like data. 422 | 423 | Parameters: 424 | dim (`int`): The number of channels in the input and output. 425 | time_mix_inner_dim (`int`): The number of channels for temporal attention. 426 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 427 | attention_head_dim (`int`): The number of channels in each head. 428 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 429 | """ 430 | 431 | def __init__( 432 | self, 433 | dim: int, 434 | time_mix_inner_dim: int, 435 | num_attention_heads: int, 436 | attention_head_dim: int, 437 | cross_attention_dim: Optional[int] = None, 438 | ): 439 | super().__init__() 440 | self.is_res = dim == time_mix_inner_dim 441 | 442 | self.norm_in = nn.LayerNorm(dim) 443 | 444 | # Define 3 blocks. Each block has its own normalization layer. 445 | # 1. Self-Attn 446 | self.norm_in = nn.LayerNorm(dim) 447 | self.ff_in = FeedForward( 448 | dim, 449 | dim_out=time_mix_inner_dim, 450 | activation_fn="geglu", 451 | ) 452 | 453 | self.norm1 = nn.LayerNorm(time_mix_inner_dim) 454 | self.attn1 = Attention( 455 | query_dim=time_mix_inner_dim, 456 | heads=num_attention_heads, 457 | dim_head=attention_head_dim, 458 | cross_attention_dim=None, 459 | ) 460 | 461 | # 2. Cross-Attn 462 | if cross_attention_dim is not None: 463 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 464 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 465 | # the second cross attention block. 466 | self.norm2 = nn.LayerNorm(time_mix_inner_dim) 467 | self.attn2 = Attention( 468 | query_dim=time_mix_inner_dim, 469 | cross_attention_dim=cross_attention_dim, 470 | heads=num_attention_heads, 471 | dim_head=attention_head_dim, 472 | ) # is self-attn if encoder_hidden_states is none 473 | else: 474 | self.norm2 = None 475 | self.attn2 = None 476 | 477 | # 3. Feed-forward 478 | self.norm3 = nn.LayerNorm(time_mix_inner_dim) 479 | self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu") 480 | 481 | # let chunk size default to None 482 | self._chunk_size = None 483 | self._chunk_dim = None 484 | 485 | def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs): 486 | # Sets chunk feed-forward 487 | self._chunk_size = chunk_size 488 | # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off 489 | self._chunk_dim = 1 490 | 491 | def forward( 492 | self, 493 | hidden_states: torch.FloatTensor, 494 | num_frames: int, 495 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 496 | ) -> torch.FloatTensor: 497 | # Notice that normalization is always applied before the real computation in the following blocks. 498 | # 0. Self-Attention 499 | batch_size = hidden_states.shape[0] 500 | 501 | batch_frames, seq_length, channels = hidden_states.shape 502 | batch_size = batch_frames // num_frames 503 | 504 | hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) 505 | hidden_states = hidden_states.permute(0, 2, 1, 3) 506 | hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) 507 | 508 | residual = hidden_states 509 | hidden_states = self.norm_in(hidden_states) 510 | 511 | if self._chunk_size is not None: 512 | hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size) 513 | else: 514 | hidden_states = self.ff_in(hidden_states) 515 | 516 | if self.is_res: 517 | hidden_states = hidden_states + residual 518 | 519 | norm_hidden_states = self.norm1(hidden_states) 520 | attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) 521 | hidden_states = attn_output + hidden_states 522 | 523 | # 3. Cross-Attention 524 | if self.attn2 is not None: 525 | norm_hidden_states = self.norm2(hidden_states) 526 | attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) 527 | hidden_states = attn_output + hidden_states 528 | 529 | # 4. Feed-forward 530 | norm_hidden_states = self.norm3(hidden_states) 531 | 532 | if self._chunk_size is not None: 533 | ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) 534 | else: 535 | ff_output = self.ff(norm_hidden_states) 536 | 537 | if self.is_res: 538 | hidden_states = ff_output + hidden_states 539 | else: 540 | hidden_states = ff_output 541 | 542 | hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) 543 | hidden_states = hidden_states.permute(0, 2, 1, 3) 544 | hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) 545 | 546 | return hidden_states 547 | 548 | 549 | class SkipFFTransformerBlock(nn.Module): 550 | def __init__( 551 | self, 552 | dim: int, 553 | num_attention_heads: int, 554 | attention_head_dim: int, 555 | kv_input_dim: int, 556 | kv_input_dim_proj_use_bias: bool, 557 | dropout=0.0, 558 | cross_attention_dim: Optional[int] = None, 559 | attention_bias: bool = False, 560 | attention_out_bias: bool = True, 561 | ): 562 | super().__init__() 563 | if kv_input_dim != dim: 564 | self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias) 565 | else: 566 | self.kv_mapper = None 567 | 568 | self.norm1 = RMSNorm(dim, 1e-06) 569 | 570 | self.attn1 = Attention( 571 | query_dim=dim, 572 | heads=num_attention_heads, 573 | dim_head=attention_head_dim, 574 | dropout=dropout, 575 | bias=attention_bias, 576 | cross_attention_dim=cross_attention_dim, 577 | out_bias=attention_out_bias, 578 | ) 579 | 580 | self.norm2 = RMSNorm(dim, 1e-06) 581 | 582 | self.attn2 = Attention( 583 | query_dim=dim, 584 | cross_attention_dim=cross_attention_dim, 585 | heads=num_attention_heads, 586 | dim_head=attention_head_dim, 587 | dropout=dropout, 588 | bias=attention_bias, 589 | out_bias=attention_out_bias, 590 | ) 591 | 592 | def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs): 593 | cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 594 | 595 | if self.kv_mapper is not None: 596 | encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) 597 | 598 | norm_hidden_states = self.norm1(hidden_states) 599 | 600 | attn_output = self.attn1( 601 | norm_hidden_states, 602 | encoder_hidden_states=encoder_hidden_states, 603 | **cross_attention_kwargs, 604 | ) 605 | 606 | hidden_states = attn_output + hidden_states 607 | 608 | norm_hidden_states = self.norm2(hidden_states) 609 | 610 | attn_output = self.attn2( 611 | norm_hidden_states, 612 | encoder_hidden_states=encoder_hidden_states, 613 | **cross_attention_kwargs, 614 | ) 615 | 616 | hidden_states = attn_output + hidden_states 617 | 618 | return hidden_states 619 | 620 | 621 | class FeedForward(nn.Module): 622 | r""" 623 | A feed-forward layer. 624 | 625 | Parameters: 626 | dim (`int`): The number of channels in the input. 627 | dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. 628 | mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. 629 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 630 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 631 | final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. 632 | bias (`bool`, defaults to True): Whether to use a bias in the linear layer. 633 | """ 634 | 635 | def __init__( 636 | self, 637 | dim: int, 638 | dim_out: Optional[int] = None, 639 | mult: int = 4, 640 | dropout: float = 0.0, 641 | activation_fn: str = "geglu", 642 | final_dropout: bool = False, 643 | inner_dim=None, 644 | bias: bool = True, 645 | ): 646 | super().__init__() 647 | if inner_dim is None: 648 | inner_dim = int(dim * mult) 649 | dim_out = dim_out if dim_out is not None else dim 650 | linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear 651 | 652 | if activation_fn == "gelu": 653 | act_fn = GELU(dim, inner_dim, bias=bias) 654 | if activation_fn == "gelu-approximate": 655 | act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) 656 | elif activation_fn == "geglu": 657 | act_fn = GEGLU(dim, inner_dim, bias=bias) 658 | elif activation_fn == "geglu-approximate": 659 | act_fn = ApproximateGELU(dim, inner_dim, bias=bias) 660 | 661 | self.net = nn.ModuleList([]) 662 | # project in 663 | self.net.append(act_fn) 664 | # project dropout 665 | self.net.append(nn.Dropout(dropout)) 666 | # project out 667 | self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) 668 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout 669 | if final_dropout: 670 | self.net.append(nn.Dropout(dropout)) 671 | 672 | def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: 673 | compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) 674 | for module in self.net: 675 | if isinstance(module, compatible_cls): 676 | hidden_states = module(hidden_states, scale) 677 | else: 678 | hidden_states = module(hidden_states) 679 | return hidden_states 680 | -------------------------------------------------------------------------------- /src/idm_vton/transformerhacked_garmnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from typing import Any, Dict, Optional 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | from torch import nn 20 | 21 | from diffusers.configuration_utils import ConfigMixin, register_to_config 22 | from diffusers.models.embeddings import ImagePositionalEmbeddings 23 | from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version 24 | from .attentionhacked_garmnet import BasicTransformerBlock 25 | from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection 26 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear 27 | from diffusers.models.modeling_utils import ModelMixin 28 | from diffusers.models.normalization import AdaLayerNormSingle 29 | 30 | 31 | @dataclass 32 | class Transformer2DModelOutput(BaseOutput): 33 | """ 34 | The output of [`Transformer2DModel`]. 35 | 36 | Args: 37 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): 38 | The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability 39 | distributions for the unnoised latent pixels. 40 | """ 41 | 42 | sample: torch.FloatTensor 43 | 44 | 45 | class Transformer2DModel(ModelMixin, ConfigMixin): 46 | """ 47 | A 2D Transformer model for image-like data. 48 | 49 | Parameters: 50 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 51 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 52 | in_channels (`int`, *optional*): 53 | The number of channels in the input and output (specify if the input is **continuous**). 54 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 55 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 56 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 57 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). 58 | This is fixed during training since it is used to learn a number of position embeddings. 59 | num_vector_embeds (`int`, *optional*): 60 | The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). 61 | Includes the class for the masked latent pixel. 62 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. 63 | num_embeds_ada_norm ( `int`, *optional*): 64 | The number of diffusion steps used during training. Pass if at least one of the norm_layers is 65 | `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are 66 | added to the hidden states. 67 | 68 | During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. 69 | attention_bias (`bool`, *optional*): 70 | Configure if the `TransformerBlocks` attention should contain a bias parameter. 71 | """ 72 | 73 | _supports_gradient_checkpointing = True 74 | 75 | @register_to_config 76 | def __init__( 77 | self, 78 | num_attention_heads: int = 16, 79 | attention_head_dim: int = 88, 80 | in_channels: Optional[int] = None, 81 | out_channels: Optional[int] = None, 82 | num_layers: int = 1, 83 | dropout: float = 0.0, 84 | norm_num_groups: int = 32, 85 | cross_attention_dim: Optional[int] = None, 86 | attention_bias: bool = False, 87 | sample_size: Optional[int] = None, 88 | num_vector_embeds: Optional[int] = None, 89 | patch_size: Optional[int] = None, 90 | activation_fn: str = "geglu", 91 | num_embeds_ada_norm: Optional[int] = None, 92 | use_linear_projection: bool = False, 93 | only_cross_attention: bool = False, 94 | double_self_attention: bool = False, 95 | upcast_attention: bool = False, 96 | norm_type: str = "layer_norm", 97 | norm_elementwise_affine: bool = True, 98 | norm_eps: float = 1e-5, 99 | attention_type: str = "default", 100 | caption_channels: int = None, 101 | ): 102 | super().__init__() 103 | self.use_linear_projection = use_linear_projection 104 | self.num_attention_heads = num_attention_heads 105 | self.attention_head_dim = attention_head_dim 106 | inner_dim = num_attention_heads * attention_head_dim 107 | 108 | conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv 109 | linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear 110 | 111 | # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` 112 | # Define whether input is continuous or discrete depending on configuration 113 | self.is_input_continuous = (in_channels is not None) and (patch_size is None) 114 | self.is_input_vectorized = num_vector_embeds is not None 115 | self.is_input_patches = in_channels is not None and patch_size is not None 116 | 117 | if norm_type == "layer_norm" and num_embeds_ada_norm is not None: 118 | deprecation_message = ( 119 | f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" 120 | " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." 121 | " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" 122 | " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" 123 | " would be very nice if you could open a Pull request for the `transformer/config.json` file" 124 | ) 125 | deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) 126 | norm_type = "ada_norm" 127 | 128 | if self.is_input_continuous and self.is_input_vectorized: 129 | raise ValueError( 130 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" 131 | " sure that either `in_channels` or `num_vector_embeds` is None." 132 | ) 133 | elif self.is_input_vectorized and self.is_input_patches: 134 | raise ValueError( 135 | f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" 136 | " sure that either `num_vector_embeds` or `num_patches` is None." 137 | ) 138 | elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: 139 | raise ValueError( 140 | f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" 141 | f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." 142 | ) 143 | 144 | # 2. Define input layers 145 | if self.is_input_continuous: 146 | self.in_channels = in_channels 147 | 148 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 149 | if use_linear_projection: 150 | self.proj_in = linear_cls(in_channels, inner_dim) 151 | else: 152 | self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 153 | elif self.is_input_vectorized: 154 | assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" 155 | assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" 156 | 157 | self.height = sample_size 158 | self.width = sample_size 159 | self.num_vector_embeds = num_vector_embeds 160 | self.num_latent_pixels = self.height * self.width 161 | 162 | self.latent_image_embedding = ImagePositionalEmbeddings( 163 | num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width 164 | ) 165 | elif self.is_input_patches: 166 | assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" 167 | 168 | self.height = sample_size 169 | self.width = sample_size 170 | 171 | self.patch_size = patch_size 172 | interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1 173 | interpolation_scale = max(interpolation_scale, 1) 174 | self.pos_embed = PatchEmbed( 175 | height=sample_size, 176 | width=sample_size, 177 | patch_size=patch_size, 178 | in_channels=in_channels, 179 | embed_dim=inner_dim, 180 | interpolation_scale=interpolation_scale, 181 | ) 182 | 183 | # 3. Define transformers blocks 184 | self.transformer_blocks = nn.ModuleList( 185 | [ 186 | BasicTransformerBlock( 187 | inner_dim, 188 | num_attention_heads, 189 | attention_head_dim, 190 | dropout=dropout, 191 | cross_attention_dim=cross_attention_dim, 192 | activation_fn=activation_fn, 193 | num_embeds_ada_norm=num_embeds_ada_norm, 194 | attention_bias=attention_bias, 195 | only_cross_attention=only_cross_attention, 196 | double_self_attention=double_self_attention, 197 | upcast_attention=upcast_attention, 198 | norm_type=norm_type, 199 | norm_elementwise_affine=norm_elementwise_affine, 200 | norm_eps=norm_eps, 201 | attention_type=attention_type, 202 | ) 203 | for d in range(num_layers) 204 | ] 205 | ) 206 | 207 | # 4. Define output layers 208 | self.out_channels = in_channels if out_channels is None else out_channels 209 | if self.is_input_continuous: 210 | # TODO: should use out_channels for continuous projections 211 | if use_linear_projection: 212 | self.proj_out = linear_cls(inner_dim, in_channels) 213 | else: 214 | self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 215 | elif self.is_input_vectorized: 216 | self.norm_out = nn.LayerNorm(inner_dim) 217 | self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) 218 | elif self.is_input_patches and norm_type != "ada_norm_single": 219 | self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) 220 | self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) 221 | self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) 222 | elif self.is_input_patches and norm_type == "ada_norm_single": 223 | self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) 224 | self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) 225 | self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) 226 | 227 | # 5. PixArt-Alpha blocks. 228 | self.adaln_single = None 229 | self.use_additional_conditions = False 230 | if norm_type == "ada_norm_single": 231 | self.use_additional_conditions = self.config.sample_size == 128 232 | # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use 233 | # additional conditions until we find better name 234 | self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) 235 | 236 | self.caption_projection = None 237 | if caption_channels is not None: 238 | self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) 239 | 240 | self.gradient_checkpointing = False 241 | 242 | def _set_gradient_checkpointing(self, module, value=False): 243 | if hasattr(module, "gradient_checkpointing"): 244 | module.gradient_checkpointing = value 245 | 246 | def forward( 247 | self, 248 | hidden_states: torch.Tensor, 249 | encoder_hidden_states: Optional[torch.Tensor] = None, 250 | timestep: Optional[torch.LongTensor] = None, 251 | added_cond_kwargs: Dict[str, torch.Tensor] = None, 252 | class_labels: Optional[torch.LongTensor] = None, 253 | cross_attention_kwargs: Dict[str, Any] = None, 254 | attention_mask: Optional[torch.Tensor] = None, 255 | encoder_attention_mask: Optional[torch.Tensor] = None, 256 | return_dict: bool = True, 257 | ): 258 | """ 259 | The [`Transformer2DModel`] forward method. 260 | 261 | Args: 262 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): 263 | Input `hidden_states`. 264 | encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): 265 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 266 | self-attention. 267 | timestep ( `torch.LongTensor`, *optional*): 268 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. 269 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): 270 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in 271 | `AdaLayerZeroNorm`. 272 | cross_attention_kwargs ( `Dict[str, Any]`, *optional*): 273 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 274 | `self.processor` in 275 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 276 | attention_mask ( `torch.Tensor`, *optional*): 277 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask 278 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large 279 | negative values to the attention scores corresponding to "discard" tokens. 280 | encoder_attention_mask ( `torch.Tensor`, *optional*): 281 | Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: 282 | 283 | * Mask `(batch, sequence_length)` True = keep, False = discard. 284 | * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. 285 | 286 | If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format 287 | above. This bias will be added to the cross-attention scores. 288 | return_dict (`bool`, *optional*, defaults to `True`): 289 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain 290 | tuple. 291 | 292 | Returns: 293 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 294 | `tuple` where the first element is the sample tensor. 295 | """ 296 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. 297 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. 298 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. 299 | # expects mask of shape: 300 | # [batch, key_tokens] 301 | # adds singleton query_tokens dimension: 302 | # [batch, 1, key_tokens] 303 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 304 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 305 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 306 | if attention_mask is not None and attention_mask.ndim == 2: 307 | # assume that mask is expressed as: 308 | # (1 = keep, 0 = discard) 309 | # convert mask into a bias that can be added to attention scores: 310 | # (keep = +0, discard = -10000.0) 311 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 312 | attention_mask = attention_mask.unsqueeze(1) 313 | 314 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 315 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: 316 | encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 317 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 318 | 319 | # Retrieve lora scale. 320 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 321 | 322 | # 1. Input 323 | if self.is_input_continuous: 324 | batch, _, height, width = hidden_states.shape 325 | residual = hidden_states 326 | 327 | hidden_states = self.norm(hidden_states) 328 | if not self.use_linear_projection: 329 | hidden_states = ( 330 | self.proj_in(hidden_states, scale=lora_scale) 331 | if not USE_PEFT_BACKEND 332 | else self.proj_in(hidden_states) 333 | ) 334 | inner_dim = hidden_states.shape[1] 335 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 336 | else: 337 | inner_dim = hidden_states.shape[1] 338 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 339 | hidden_states = ( 340 | self.proj_in(hidden_states, scale=lora_scale) 341 | if not USE_PEFT_BACKEND 342 | else self.proj_in(hidden_states) 343 | ) 344 | 345 | elif self.is_input_vectorized: 346 | hidden_states = self.latent_image_embedding(hidden_states) 347 | elif self.is_input_patches: 348 | height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size 349 | hidden_states = self.pos_embed(hidden_states) 350 | 351 | if self.adaln_single is not None: 352 | if self.use_additional_conditions and added_cond_kwargs is None: 353 | raise ValueError( 354 | "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." 355 | ) 356 | batch_size = hidden_states.shape[0] 357 | timestep, embedded_timestep = self.adaln_single( 358 | timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype 359 | ) 360 | 361 | # 2. Blocks 362 | if self.caption_projection is not None: 363 | batch_size = hidden_states.shape[0] 364 | encoder_hidden_states = self.caption_projection(encoder_hidden_states) 365 | encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) 366 | 367 | garment_features = [] 368 | for block in self.transformer_blocks: 369 | if self.training and self.gradient_checkpointing: 370 | 371 | def create_custom_forward(module, return_dict=None): 372 | def custom_forward(*inputs): 373 | if return_dict is not None: 374 | return module(*inputs, return_dict=return_dict) 375 | else: 376 | return module(*inputs) 377 | 378 | return custom_forward 379 | 380 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 381 | hidden_states,out_garment_feat = torch.utils.checkpoint.checkpoint( 382 | create_custom_forward(block), 383 | hidden_states, 384 | attention_mask, 385 | encoder_hidden_states, 386 | encoder_attention_mask, 387 | timestep, 388 | cross_attention_kwargs, 389 | class_labels, 390 | **ckpt_kwargs, 391 | ) 392 | else: 393 | hidden_states,out_garment_feat = block( 394 | hidden_states, 395 | attention_mask=attention_mask, 396 | encoder_hidden_states=encoder_hidden_states, 397 | encoder_attention_mask=encoder_attention_mask, 398 | timestep=timestep, 399 | cross_attention_kwargs=cross_attention_kwargs, 400 | class_labels=class_labels, 401 | ) 402 | garment_features += out_garment_feat 403 | # 3. Output 404 | if self.is_input_continuous: 405 | if not self.use_linear_projection: 406 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 407 | hidden_states = ( 408 | self.proj_out(hidden_states, scale=lora_scale) 409 | if not USE_PEFT_BACKEND 410 | else self.proj_out(hidden_states) 411 | ) 412 | else: 413 | hidden_states = ( 414 | self.proj_out(hidden_states, scale=lora_scale) 415 | if not USE_PEFT_BACKEND 416 | else self.proj_out(hidden_states) 417 | ) 418 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 419 | 420 | output = hidden_states + residual 421 | elif self.is_input_vectorized: 422 | hidden_states = self.norm_out(hidden_states) 423 | logits = self.out(hidden_states) 424 | # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) 425 | logits = logits.permute(0, 2, 1) 426 | 427 | # log(p(x_0)) 428 | output = F.log_softmax(logits.double(), dim=1).float() 429 | 430 | if self.is_input_patches: 431 | if self.config.norm_type != "ada_norm_single": 432 | conditioning = self.transformer_blocks[0].norm1.emb( 433 | timestep, class_labels, hidden_dtype=hidden_states.dtype 434 | ) 435 | shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) 436 | hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] 437 | hidden_states = self.proj_out_2(hidden_states) 438 | elif self.config.norm_type == "ada_norm_single": 439 | shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) 440 | hidden_states = self.norm_out(hidden_states) 441 | # Modulation 442 | hidden_states = hidden_states * (1 + scale) + shift 443 | hidden_states = self.proj_out(hidden_states) 444 | hidden_states = hidden_states.squeeze(1) 445 | 446 | # unpatchify 447 | if self.adaln_single is None: 448 | height = width = int(hidden_states.shape[1] ** 0.5) 449 | hidden_states = hidden_states.reshape( 450 | shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) 451 | ) 452 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) 453 | output = hidden_states.reshape( 454 | shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) 455 | ) 456 | 457 | if not return_dict: 458 | return (output,) ,garment_features 459 | 460 | return Transformer2DModelOutput(sample=output),garment_features 461 | -------------------------------------------------------------------------------- /src/idm_vton/transformerhacked_tryon.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from typing import Any, Dict, Optional 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | from torch import nn 20 | 21 | from diffusers.configuration_utils import ConfigMixin, register_to_config 22 | from diffusers.models.embeddings import ImagePositionalEmbeddings 23 | from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version 24 | from .attentionhacked_tryon import BasicTransformerBlock 25 | from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection 26 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear 27 | from diffusers.models.modeling_utils import ModelMixin 28 | from diffusers.models.normalization import AdaLayerNormSingle 29 | 30 | 31 | @dataclass 32 | class Transformer2DModelOutput(BaseOutput): 33 | """ 34 | The output of [`Transformer2DModel`]. 35 | 36 | Args: 37 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): 38 | The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability 39 | distributions for the unnoised latent pixels. 40 | """ 41 | 42 | sample: torch.FloatTensor 43 | 44 | 45 | class Transformer2DModel(ModelMixin, ConfigMixin): 46 | """ 47 | A 2D Transformer model for image-like data. 48 | 49 | Parameters: 50 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 51 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 52 | in_channels (`int`, *optional*): 53 | The number of channels in the input and output (specify if the input is **continuous**). 54 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 55 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 56 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 57 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). 58 | This is fixed during training since it is used to learn a number of position embeddings. 59 | num_vector_embeds (`int`, *optional*): 60 | The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). 61 | Includes the class for the masked latent pixel. 62 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. 63 | num_embeds_ada_norm ( `int`, *optional*): 64 | The number of diffusion steps used during training. Pass if at least one of the norm_layers is 65 | `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are 66 | added to the hidden states. 67 | 68 | During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. 69 | attention_bias (`bool`, *optional*): 70 | Configure if the `TransformerBlocks` attention should contain a bias parameter. 71 | """ 72 | 73 | _supports_gradient_checkpointing = True 74 | 75 | @register_to_config 76 | def __init__( 77 | self, 78 | num_attention_heads: int = 16, 79 | attention_head_dim: int = 88, 80 | in_channels: Optional[int] = None, 81 | out_channels: Optional[int] = None, 82 | num_layers: int = 1, 83 | dropout: float = 0.0, 84 | norm_num_groups: int = 32, 85 | cross_attention_dim: Optional[int] = None, 86 | attention_bias: bool = False, 87 | sample_size: Optional[int] = None, 88 | num_vector_embeds: Optional[int] = None, 89 | patch_size: Optional[int] = None, 90 | activation_fn: str = "geglu", 91 | num_embeds_ada_norm: Optional[int] = None, 92 | use_linear_projection: bool = False, 93 | only_cross_attention: bool = False, 94 | double_self_attention: bool = False, 95 | upcast_attention: bool = False, 96 | norm_type: str = "layer_norm", 97 | norm_elementwise_affine: bool = True, 98 | norm_eps: float = 1e-5, 99 | attention_type: str = "default", 100 | caption_channels: int = None, 101 | ): 102 | super().__init__() 103 | self.use_linear_projection = use_linear_projection 104 | self.num_attention_heads = num_attention_heads 105 | self.attention_head_dim = attention_head_dim 106 | inner_dim = num_attention_heads * attention_head_dim 107 | 108 | conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv 109 | linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear 110 | 111 | # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` 112 | # Define whether input is continuous or discrete depending on configuration 113 | self.is_input_continuous = (in_channels is not None) and (patch_size is None) 114 | self.is_input_vectorized = num_vector_embeds is not None 115 | self.is_input_patches = in_channels is not None and patch_size is not None 116 | 117 | if norm_type == "layer_norm" and num_embeds_ada_norm is not None: 118 | deprecation_message = ( 119 | f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" 120 | " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." 121 | " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" 122 | " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" 123 | " would be very nice if you could open a Pull request for the `transformer/config.json` file" 124 | ) 125 | deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) 126 | norm_type = "ada_norm" 127 | 128 | if self.is_input_continuous and self.is_input_vectorized: 129 | raise ValueError( 130 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" 131 | " sure that either `in_channels` or `num_vector_embeds` is None." 132 | ) 133 | elif self.is_input_vectorized and self.is_input_patches: 134 | raise ValueError( 135 | f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" 136 | " sure that either `num_vector_embeds` or `num_patches` is None." 137 | ) 138 | elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: 139 | raise ValueError( 140 | f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" 141 | f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." 142 | ) 143 | 144 | # 2. Define input layers 145 | if self.is_input_continuous: 146 | self.in_channels = in_channels 147 | 148 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 149 | if use_linear_projection: 150 | self.proj_in = linear_cls(in_channels, inner_dim) 151 | else: 152 | self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 153 | elif self.is_input_vectorized: 154 | assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" 155 | assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" 156 | 157 | self.height = sample_size 158 | self.width = sample_size 159 | self.num_vector_embeds = num_vector_embeds 160 | self.num_latent_pixels = self.height * self.width 161 | 162 | self.latent_image_embedding = ImagePositionalEmbeddings( 163 | num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width 164 | ) 165 | elif self.is_input_patches: 166 | assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" 167 | 168 | self.height = sample_size 169 | self.width = sample_size 170 | 171 | self.patch_size = patch_size 172 | interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1 173 | interpolation_scale = max(interpolation_scale, 1) 174 | self.pos_embed = PatchEmbed( 175 | height=sample_size, 176 | width=sample_size, 177 | patch_size=patch_size, 178 | in_channels=in_channels, 179 | embed_dim=inner_dim, 180 | interpolation_scale=interpolation_scale, 181 | ) 182 | 183 | # 3. Define transformers blocks 184 | self.transformer_blocks = nn.ModuleList( 185 | [ 186 | BasicTransformerBlock( 187 | inner_dim, 188 | num_attention_heads, 189 | attention_head_dim, 190 | dropout=dropout, 191 | cross_attention_dim=cross_attention_dim, 192 | activation_fn=activation_fn, 193 | num_embeds_ada_norm=num_embeds_ada_norm, 194 | attention_bias=attention_bias, 195 | only_cross_attention=only_cross_attention, 196 | double_self_attention=double_self_attention, 197 | upcast_attention=upcast_attention, 198 | norm_type=norm_type, 199 | norm_elementwise_affine=norm_elementwise_affine, 200 | norm_eps=norm_eps, 201 | attention_type=attention_type, 202 | ) 203 | for d in range(num_layers) 204 | ] 205 | ) 206 | 207 | # 4. Define output layers 208 | self.out_channels = in_channels if out_channels is None else out_channels 209 | if self.is_input_continuous: 210 | # TODO: should use out_channels for continuous projections 211 | if use_linear_projection: 212 | self.proj_out = linear_cls(inner_dim, in_channels) 213 | else: 214 | self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 215 | elif self.is_input_vectorized: 216 | self.norm_out = nn.LayerNorm(inner_dim) 217 | self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) 218 | elif self.is_input_patches and norm_type != "ada_norm_single": 219 | self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) 220 | self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) 221 | self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) 222 | elif self.is_input_patches and norm_type == "ada_norm_single": 223 | self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) 224 | self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) 225 | self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) 226 | 227 | # 5. PixArt-Alpha blocks. 228 | self.adaln_single = None 229 | self.use_additional_conditions = False 230 | if norm_type == "ada_norm_single": 231 | self.use_additional_conditions = self.config.sample_size == 128 232 | # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use 233 | # additional conditions until we find better name 234 | self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) 235 | 236 | self.caption_projection = None 237 | if caption_channels is not None: 238 | self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) 239 | 240 | self.gradient_checkpointing = False 241 | 242 | def _set_gradient_checkpointing(self, module, value=False): 243 | if hasattr(module, "gradient_checkpointing"): 244 | module.gradient_checkpointing = value 245 | 246 | def forward( 247 | self, 248 | hidden_states: torch.Tensor, 249 | encoder_hidden_states: Optional[torch.Tensor] = None, 250 | timestep: Optional[torch.LongTensor] = None, 251 | added_cond_kwargs: Dict[str, torch.Tensor] = None, 252 | class_labels: Optional[torch.LongTensor] = None, 253 | cross_attention_kwargs: Dict[str, Any] = None, 254 | attention_mask: Optional[torch.Tensor] = None, 255 | encoder_attention_mask: Optional[torch.Tensor] = None, 256 | garment_features=None, 257 | curr_garment_feat_idx=0, 258 | return_dict: bool = True, 259 | ): 260 | """ 261 | The [`Transformer2DModel`] forward method. 262 | 263 | Args: 264 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): 265 | Input `hidden_states`. 266 | encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): 267 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 268 | self-attention. 269 | timestep ( `torch.LongTensor`, *optional*): 270 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. 271 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): 272 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in 273 | `AdaLayerZeroNorm`. 274 | cross_attention_kwargs ( `Dict[str, Any]`, *optional*): 275 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 276 | `self.processor` in 277 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 278 | attention_mask ( `torch.Tensor`, *optional*): 279 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask 280 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large 281 | negative values to the attention scores corresponding to "discard" tokens. 282 | encoder_attention_mask ( `torch.Tensor`, *optional*): 283 | Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: 284 | 285 | * Mask `(batch, sequence_length)` True = keep, False = discard. 286 | * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. 287 | 288 | If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format 289 | above. This bias will be added to the cross-attention scores. 290 | return_dict (`bool`, *optional*, defaults to `True`): 291 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain 292 | tuple. 293 | 294 | Returns: 295 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 296 | `tuple` where the first element is the sample tensor. 297 | """ 298 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. 299 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. 300 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. 301 | # expects mask of shape: 302 | # [batch, key_tokens] 303 | # adds singleton query_tokens dimension: 304 | # [batch, 1, key_tokens] 305 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 306 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 307 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 308 | if attention_mask is not None and attention_mask.ndim == 2: 309 | # assume that mask is expressed as: 310 | # (1 = keep, 0 = discard) 311 | # convert mask into a bias that can be added to attention scores: 312 | # (keep = +0, discard = -10000.0) 313 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 314 | attention_mask = attention_mask.unsqueeze(1) 315 | 316 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 317 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: 318 | encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 319 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 320 | 321 | # Retrieve lora scale. 322 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 323 | 324 | # 1. Input 325 | if self.is_input_continuous: 326 | batch, _, height, width = hidden_states.shape 327 | residual = hidden_states 328 | 329 | hidden_states = self.norm(hidden_states) 330 | if not self.use_linear_projection: 331 | hidden_states = ( 332 | self.proj_in(hidden_states, scale=lora_scale) 333 | if not USE_PEFT_BACKEND 334 | else self.proj_in(hidden_states) 335 | ) 336 | inner_dim = hidden_states.shape[1] 337 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 338 | else: 339 | inner_dim = hidden_states.shape[1] 340 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 341 | hidden_states = ( 342 | self.proj_in(hidden_states, scale=lora_scale) 343 | if not USE_PEFT_BACKEND 344 | else self.proj_in(hidden_states) 345 | ) 346 | 347 | elif self.is_input_vectorized: 348 | hidden_states = self.latent_image_embedding(hidden_states) 349 | elif self.is_input_patches: 350 | height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size 351 | hidden_states = self.pos_embed(hidden_states) 352 | 353 | if self.adaln_single is not None: 354 | if self.use_additional_conditions and added_cond_kwargs is None: 355 | raise ValueError( 356 | "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." 357 | ) 358 | batch_size = hidden_states.shape[0] 359 | timestep, embedded_timestep = self.adaln_single( 360 | timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype 361 | ) 362 | 363 | # 2. Blocks 364 | if self.caption_projection is not None: 365 | batch_size = hidden_states.shape[0] 366 | encoder_hidden_states = self.caption_projection(encoder_hidden_states) 367 | encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) 368 | 369 | 370 | for block in self.transformer_blocks: 371 | if self.training and self.gradient_checkpointing: 372 | 373 | def create_custom_forward(module, return_dict=None): 374 | def custom_forward(*inputs): 375 | if return_dict is not None: 376 | return module(*inputs, return_dict=return_dict) 377 | else: 378 | return module(*inputs) 379 | 380 | return custom_forward 381 | 382 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 383 | hidden_states,curr_garment_feat_idx = torch.utils.checkpoint.checkpoint( 384 | create_custom_forward(block), 385 | hidden_states, 386 | attention_mask, 387 | encoder_hidden_states, 388 | encoder_attention_mask, 389 | timestep, 390 | cross_attention_kwargs, 391 | class_labels, 392 | garment_features, 393 | curr_garment_feat_idx, 394 | **ckpt_kwargs, 395 | ) 396 | else: 397 | hidden_states,curr_garment_feat_idx = block( 398 | hidden_states, 399 | attention_mask=attention_mask, 400 | encoder_hidden_states=encoder_hidden_states, 401 | encoder_attention_mask=encoder_attention_mask, 402 | timestep=timestep, 403 | cross_attention_kwargs=cross_attention_kwargs, 404 | class_labels=class_labels, 405 | garment_features=garment_features, 406 | curr_garment_feat_idx=curr_garment_feat_idx, 407 | ) 408 | 409 | 410 | # 3. Output 411 | if self.is_input_continuous: 412 | if not self.use_linear_projection: 413 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 414 | hidden_states = ( 415 | self.proj_out(hidden_states, scale=lora_scale) 416 | if not USE_PEFT_BACKEND 417 | else self.proj_out(hidden_states) 418 | ) 419 | else: 420 | hidden_states = ( 421 | self.proj_out(hidden_states, scale=lora_scale) 422 | if not USE_PEFT_BACKEND 423 | else self.proj_out(hidden_states) 424 | ) 425 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 426 | 427 | output = hidden_states + residual 428 | elif self.is_input_vectorized: 429 | hidden_states = self.norm_out(hidden_states) 430 | logits = self.out(hidden_states) 431 | # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) 432 | logits = logits.permute(0, 2, 1) 433 | 434 | # log(p(x_0)) 435 | output = F.log_softmax(logits.double(), dim=1).float() 436 | 437 | if self.is_input_patches: 438 | if self.config.norm_type != "ada_norm_single": 439 | conditioning = self.transformer_blocks[0].norm1.emb( 440 | timestep, class_labels, hidden_dtype=hidden_states.dtype 441 | ) 442 | shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) 443 | hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] 444 | hidden_states = self.proj_out_2(hidden_states) 445 | elif self.config.norm_type == "ada_norm_single": 446 | shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) 447 | hidden_states = self.norm_out(hidden_states) 448 | # Modulation 449 | hidden_states = hidden_states * (1 + scale) + shift 450 | hidden_states = self.proj_out(hidden_states) 451 | hidden_states = hidden_states.squeeze(1) 452 | 453 | # unpatchify 454 | if self.adaln_single is None: 455 | height = width = int(hidden_states.shape[1] ** 0.5) 456 | hidden_states = hidden_states.reshape( 457 | shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) 458 | ) 459 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) 460 | output = hidden_states.reshape( 461 | shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) 462 | ) 463 | 464 | if not return_dict: 465 | return (output,),curr_garment_feat_idx 466 | 467 | return Transformer2DModelOutput(sample=output),curr_garment_feat_idx 468 | -------------------------------------------------------------------------------- /src/ip_adapter/__init__.py: -------------------------------------------------------------------------------- 1 | from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull, IPAdapterPlus_Lora, IPAdapterPlus_Lora_up 2 | 3 | __all__ = [ 4 | "IPAdapter", 5 | "IPAdapterPlus", 6 | "IPAdapterPlusXL", 7 | "IPAdapterXL", 8 | "IPAdapterFull", 9 | "IPAdapterPlus_Lora", 10 | 'IPAdapterPlus_Lora_up', 11 | ] 12 | -------------------------------------------------------------------------------- /src/ip_adapter/ip_adapter.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import torch 5 | from diffusers import StableDiffusionPipeline 6 | from diffusers.pipelines.controlnet import MultiControlNetModel 7 | from PIL import Image 8 | from safetensors import safe_open 9 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection 10 | 11 | from .utils import is_torch2_available 12 | 13 | if is_torch2_available(): 14 | from .attention_processor import ( 15 | AttnProcessor2_0 as AttnProcessor, 16 | ) 17 | from .attention_processor import ( 18 | CNAttnProcessor2_0 as CNAttnProcessor, 19 | ) 20 | from .attention_processor import ( 21 | IPAttnProcessor2_0 as IPAttnProcessor, 22 | ) 23 | from .attention_processor import IPAttnProcessor2_0_Lora 24 | # else: 25 | # from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor 26 | from .resampler import Resampler 27 | from diffusers.models.lora import LoRALinearLayer 28 | 29 | 30 | class ImageProjModel(torch.nn.Module): 31 | """Projection Model""" 32 | 33 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 34 | super().__init__() 35 | 36 | self.cross_attention_dim = cross_attention_dim 37 | self.clip_extra_context_tokens = clip_extra_context_tokens 38 | self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 39 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 40 | 41 | def forward(self, image_embeds): 42 | embeds = image_embeds 43 | clip_extra_context_tokens = self.proj(embeds).reshape( 44 | -1, self.clip_extra_context_tokens, self.cross_attention_dim 45 | ) 46 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 47 | return clip_extra_context_tokens 48 | 49 | 50 | class MLPProjModel(torch.nn.Module): 51 | """SD model with image prompt""" 52 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): 53 | super().__init__() 54 | 55 | self.proj = torch.nn.Sequential( 56 | torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), 57 | torch.nn.GELU(), 58 | torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), 59 | torch.nn.LayerNorm(cross_attention_dim) 60 | ) 61 | 62 | def forward(self, image_embeds): 63 | clip_extra_context_tokens = self.proj(image_embeds) 64 | return clip_extra_context_tokens 65 | 66 | 67 | class IPAdapter: 68 | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4): 69 | self.device = device 70 | self.image_encoder_path = image_encoder_path 71 | self.ip_ckpt = ip_ckpt 72 | self.num_tokens = num_tokens 73 | 74 | self.pipe = sd_pipe.to(self.device) 75 | self.set_ip_adapter() 76 | 77 | # load image encoder 78 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( 79 | self.device, dtype=torch.float16 80 | ) 81 | self.clip_image_processor = CLIPImageProcessor() 82 | # image proj model 83 | self.image_proj_model = self.init_proj() 84 | 85 | self.load_ip_adapter() 86 | 87 | def init_proj(self): 88 | image_proj_model = ImageProjModel( 89 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim, 90 | clip_embeddings_dim=self.image_encoder.config.projection_dim, 91 | clip_extra_context_tokens=self.num_tokens, 92 | ).to(self.device, dtype=torch.float16) 93 | return image_proj_model 94 | 95 | def set_ip_adapter(self): 96 | unet = self.pipe.unet 97 | attn_procs = {} 98 | for name in unet.attn_processors.keys(): 99 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 100 | if name.startswith("mid_block"): 101 | hidden_size = unet.config.block_out_channels[-1] 102 | elif name.startswith("up_blocks"): 103 | block_id = int(name[len("up_blocks.")]) 104 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 105 | elif name.startswith("down_blocks"): 106 | block_id = int(name[len("down_blocks.")]) 107 | hidden_size = unet.config.block_out_channels[block_id] 108 | if cross_attention_dim is None: 109 | attn_procs[name] = AttnProcessor() 110 | else: 111 | attn_procs[name] = IPAttnProcessor( 112 | hidden_size=hidden_size, 113 | cross_attention_dim=cross_attention_dim, 114 | scale=1.0, 115 | num_tokens=self.num_tokens, 116 | ).to(self.device, dtype=torch.float16) 117 | unet.set_attn_processor(attn_procs) 118 | if hasattr(self.pipe, "controlnet"): 119 | if isinstance(self.pipe.controlnet, MultiControlNetModel): 120 | for controlnet in self.pipe.controlnet.nets: 121 | controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) 122 | else: 123 | self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) 124 | 125 | def load_ip_adapter(self): 126 | if self.ip_ckpt is not None: 127 | if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": 128 | state_dict = {"image_proj": {}, "ip_adapter": {}} 129 | with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: 130 | for key in f.keys(): 131 | if key.startswith("image_proj."): 132 | state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) 133 | elif key.startswith("ip_adapter."): 134 | state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) 135 | else: 136 | state_dict = torch.load(self.ip_ckpt, map_location="cpu") 137 | self.image_proj_model.load_state_dict(state_dict["image_proj"]) 138 | ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) 139 | ip_layers.load_state_dict(state_dict["ip_adapter"]) 140 | 141 | 142 | # def load_ip_adapter(self): 143 | # if self.ip_ckpt is not None: 144 | # if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": 145 | # state_dict = {"image_proj_model": {}, "ip_adapter": {}} 146 | # with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: 147 | # for key in f.keys(): 148 | # if key.startswith("image_proj_model."): 149 | # state_dict["image_proj_model"][key.replace("image_proj_model.", "")] = f.get_tensor(key) 150 | # elif key.startswith("ip_adapter."): 151 | # state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) 152 | # else: 153 | # state_dict = torch.load(self.ip_ckpt, map_location="cpu") 154 | 155 | # tmp1 = {} 156 | # for k,v in state_dict.items(): 157 | # if 'image_proj_model' in k: 158 | # tmp1[k.replace('image_proj_model.','')] = v 159 | # self.image_proj_model.load_state_dict(tmp1, strict=True) 160 | # # ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) 161 | # tmp2 = {} 162 | # for k,v in state_dict.ites(): 163 | # if 'adapter_mode' in k: 164 | # tmp1[k] = v 165 | 166 | # print(ip_layers.state_dict()) 167 | # ip_layers.load_state_dict(state_dict,strict=False) 168 | 169 | 170 | @torch.inference_mode() 171 | def get_image_embeds(self, pil_image=None, clip_image_embeds=None): 172 | if pil_image is not None: 173 | if isinstance(pil_image, Image.Image): 174 | pil_image = [pil_image] 175 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 176 | clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds 177 | else: 178 | clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16) 179 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 180 | uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) 181 | return image_prompt_embeds, uncond_image_prompt_embeds 182 | 183 | def get_image_embeds_train(self, pil_image=None, clip_image_embeds=None): 184 | if pil_image is not None: 185 | if isinstance(pil_image, Image.Image): 186 | pil_image = [pil_image] 187 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 188 | clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float32)).image_embeds 189 | else: 190 | clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float32) 191 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 192 | uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) 193 | return image_prompt_embeds, uncond_image_prompt_embeds 194 | 195 | 196 | def set_scale(self, scale): 197 | for attn_processor in self.pipe.unet.attn_processors.values(): 198 | if isinstance(attn_processor, IPAttnProcessor): 199 | attn_processor.scale = scale 200 | 201 | def generate( 202 | self, 203 | pil_image=None, 204 | clip_image_embeds=None, 205 | prompt=None, 206 | negative_prompt=None, 207 | scale=1.0, 208 | num_samples=4, 209 | seed=None, 210 | guidance_scale=7.5, 211 | num_inference_steps=50, 212 | **kwargs, 213 | ): 214 | self.set_scale(scale) 215 | 216 | if pil_image is not None: 217 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 218 | else: 219 | num_prompts = clip_image_embeds.size(0) 220 | 221 | if prompt is None: 222 | prompt = "best quality, high quality" 223 | if negative_prompt is None: 224 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 225 | 226 | if not isinstance(prompt, List): 227 | prompt = [prompt] * num_prompts 228 | if not isinstance(negative_prompt, List): 229 | negative_prompt = [negative_prompt] * num_prompts 230 | 231 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( 232 | pil_image=pil_image, clip_image_embeds=clip_image_embeds 233 | ) 234 | bs_embed, seq_len, _ = image_prompt_embeds.shape 235 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 236 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 237 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 238 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 239 | 240 | with torch.inference_mode(): 241 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( 242 | prompt, 243 | device=self.device, 244 | num_images_per_prompt=num_samples, 245 | do_classifier_free_guidance=True, 246 | negative_prompt=negative_prompt, 247 | ) 248 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) 249 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) 250 | 251 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 252 | images = self.pipe( 253 | prompt_embeds=prompt_embeds, 254 | negative_prompt_embeds=negative_prompt_embeds, 255 | guidance_scale=guidance_scale, 256 | num_inference_steps=num_inference_steps, 257 | generator=generator, 258 | **kwargs, 259 | ).images 260 | 261 | return images 262 | 263 | 264 | class IPAdapterXL(IPAdapter): 265 | """SDXL""" 266 | 267 | def generate_test( 268 | self, 269 | pil_image, 270 | prompt=None, 271 | negative_prompt=None, 272 | scale=1.0, 273 | num_samples=4, 274 | seed=None, 275 | num_inference_steps=30, 276 | **kwargs, 277 | ): 278 | self.set_scale(scale) 279 | 280 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 281 | 282 | if prompt is None: 283 | prompt = "best quality, high quality" 284 | if negative_prompt is None: 285 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 286 | 287 | if not isinstance(prompt, List): 288 | prompt = [prompt] * num_prompts 289 | if not isinstance(negative_prompt, List): 290 | negative_prompt = [negative_prompt] * num_prompts 291 | 292 | 293 | with torch.inference_mode(): 294 | ( 295 | prompt_embeds, 296 | negative_prompt_embeds, 297 | pooled_prompt_embeds, 298 | negative_pooled_prompt_embeds, 299 | ) = self.pipe.encode_prompt( 300 | prompt, 301 | num_images_per_prompt=num_samples, 302 | do_classifier_free_guidance=True, 303 | negative_prompt=negative_prompt, 304 | ) 305 | 306 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 307 | images = self.pipe( 308 | prompt_embeds=prompt_embeds, 309 | negative_prompt_embeds=negative_prompt_embeds, 310 | pooled_prompt_embeds=pooled_prompt_embeds, 311 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 312 | num_inference_steps=num_inference_steps, 313 | generator=generator, 314 | **kwargs, 315 | ).images 316 | 317 | 318 | # with torch.autocast("cuda"): 319 | # images = self.pipe( 320 | # prompt_embeds=prompt_embeds, 321 | # negative_prompt_embeds=negative_prompt_embeds, 322 | # pooled_prompt_embeds=pooled_prompt_embeds, 323 | # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 324 | # num_inference_steps=num_inference_steps, 325 | # generator=generator, 326 | # **kwargs, 327 | # ).images 328 | 329 | return images 330 | 331 | 332 | def generate( 333 | self, 334 | pil_image, 335 | prompt=None, 336 | negative_prompt=None, 337 | scale=1.0, 338 | num_samples=4, 339 | seed=None, 340 | num_inference_steps=30, 341 | **kwargs, 342 | ): 343 | self.set_scale(scale) 344 | 345 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 346 | 347 | if prompt is None: 348 | prompt = "best quality, high quality" 349 | if negative_prompt is None: 350 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 351 | 352 | if not isinstance(prompt, List): 353 | prompt = [prompt] * num_prompts 354 | if not isinstance(negative_prompt, List): 355 | negative_prompt = [negative_prompt] * num_prompts 356 | 357 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) 358 | bs_embed, seq_len, _ = image_prompt_embeds.shape 359 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 360 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 361 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 362 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 363 | 364 | with torch.inference_mode(): 365 | ( 366 | prompt_embeds, 367 | negative_prompt_embeds, 368 | pooled_prompt_embeds, 369 | negative_pooled_prompt_embeds, 370 | ) = self.pipe.encode_prompt( 371 | prompt, 372 | num_images_per_prompt=num_samples, 373 | do_classifier_free_guidance=True, 374 | negative_prompt=negative_prompt, 375 | ) 376 | prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) 377 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) 378 | 379 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 380 | images = self.pipe( 381 | prompt_embeds=prompt_embeds, 382 | negative_prompt_embeds=negative_prompt_embeds, 383 | pooled_prompt_embeds=pooled_prompt_embeds, 384 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 385 | num_inference_steps=num_inference_steps, 386 | generator=generator, 387 | **kwargs, 388 | ).images 389 | 390 | 391 | # with torch.autocast("cuda"): 392 | # images = self.pipe( 393 | # prompt_embeds=prompt_embeds, 394 | # negative_prompt_embeds=negative_prompt_embeds, 395 | # pooled_prompt_embeds=pooled_prompt_embeds, 396 | # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 397 | # num_inference_steps=num_inference_steps, 398 | # generator=generator, 399 | # **kwargs, 400 | # ).images 401 | 402 | return images 403 | 404 | 405 | class IPAdapterPlus(IPAdapter): 406 | """IP-Adapter with fine-grained features""" 407 | 408 | def generate( 409 | self, 410 | pil_image=None, 411 | clip_image_embeds=None, 412 | prompt=None, 413 | negative_prompt=None, 414 | scale=1.0, 415 | num_samples=4, 416 | seed=None, 417 | guidance_scale=7.5, 418 | num_inference_steps=50, 419 | **kwargs, 420 | ): 421 | self.set_scale(scale) 422 | 423 | if pil_image is not None: 424 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 425 | else: 426 | num_prompts = clip_image_embeds.size(0) 427 | 428 | if prompt is None: 429 | prompt = "best quality, high quality" 430 | if negative_prompt is None: 431 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 432 | 433 | if not isinstance(prompt, List): 434 | prompt = [prompt] * num_prompts 435 | if not isinstance(negative_prompt, List): 436 | negative_prompt = [negative_prompt] * num_prompts 437 | 438 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( 439 | pil_image=pil_image, clip_image=clip_image_embeds 440 | ) 441 | bs_embed, seq_len, _ = image_prompt_embeds.shape 442 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 443 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 444 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 445 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 446 | 447 | with torch.inference_mode(): 448 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( 449 | prompt, 450 | device=self.device, 451 | num_images_per_prompt=num_samples, 452 | do_classifier_free_guidance=True, 453 | negative_prompt=negative_prompt, 454 | ) 455 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) 456 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) 457 | 458 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 459 | images = self.pipe( 460 | prompt_embeds=prompt_embeds, 461 | negative_prompt_embeds=negative_prompt_embeds, 462 | guidance_scale=guidance_scale, 463 | num_inference_steps=num_inference_steps, 464 | generator=generator, 465 | **kwargs, 466 | ).images 467 | 468 | return images 469 | 470 | 471 | def init_proj(self): 472 | image_proj_model = Resampler( 473 | dim=self.pipe.unet.config.cross_attention_dim, 474 | depth=4, 475 | dim_head=64, 476 | heads=12, 477 | num_queries=self.num_tokens, 478 | embedding_dim=self.image_encoder.config.hidden_size, 479 | output_dim=self.pipe.unet.config.cross_attention_dim, 480 | ff_mult=4, 481 | ).to(self.device, dtype=torch.float16) 482 | return image_proj_model 483 | 484 | @torch.inference_mode() 485 | def get_image_embeds(self, pil_image=None, clip_image=None, uncond= None): 486 | if pil_image is not None: 487 | if isinstance(pil_image, Image.Image): 488 | pil_image = [pil_image] 489 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 490 | clip_image = clip_image.to(self.device, dtype=torch.float16) 491 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 492 | else: 493 | clip_image = clip_image.to(self.device, dtype=torch.float16) 494 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 495 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 496 | uncond_clip_image_embeds = self.image_encoder( 497 | torch.zeros_like(clip_image), output_hidden_states=True 498 | ).hidden_states[-2] 499 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) 500 | return image_prompt_embeds, uncond_image_prompt_embeds 501 | 502 | 503 | 504 | 505 | class IPAdapterPlus_Lora(IPAdapter): 506 | """IP-Adapter with fine-grained features""" 507 | 508 | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, rank=32): 509 | self.rank = rank 510 | super().__init__(sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens) 511 | 512 | 513 | def generate( 514 | self, 515 | pil_image=None, 516 | clip_image_embeds=None, 517 | prompt=None, 518 | negative_prompt=None, 519 | scale=1.0, 520 | num_samples=4, 521 | seed=None, 522 | guidance_scale=7.5, 523 | num_inference_steps=50, 524 | **kwargs, 525 | ): 526 | self.set_scale(scale) 527 | 528 | if pil_image is not None: 529 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 530 | else: 531 | num_prompts = clip_image_embeds.size(0) 532 | 533 | if prompt is None: 534 | prompt = "best quality, high quality" 535 | if negative_prompt is None: 536 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 537 | 538 | if not isinstance(prompt, List): 539 | prompt = [prompt] * num_prompts 540 | if not isinstance(negative_prompt, List): 541 | negative_prompt = [negative_prompt] * num_prompts 542 | 543 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( 544 | pil_image=pil_image, clip_image=clip_image_embeds 545 | ) 546 | bs_embed, seq_len, _ = image_prompt_embeds.shape 547 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 548 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 549 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 550 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 551 | 552 | with torch.inference_mode(): 553 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( 554 | prompt, 555 | device=self.device, 556 | num_images_per_prompt=num_samples, 557 | do_classifier_free_guidance=True, 558 | negative_prompt=negative_prompt, 559 | ) 560 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) 561 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) 562 | 563 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 564 | images = self.pipe( 565 | prompt_embeds=prompt_embeds, 566 | negative_prompt_embeds=negative_prompt_embeds, 567 | guidance_scale=guidance_scale, 568 | num_inference_steps=num_inference_steps, 569 | generator=generator, 570 | **kwargs, 571 | ).images 572 | 573 | return images 574 | 575 | 576 | def init_proj(self): 577 | image_proj_model = Resampler( 578 | dim=self.pipe.unet.config.cross_attention_dim, 579 | depth=4, 580 | dim_head=64, 581 | heads=12, 582 | num_queries=self.num_tokens, 583 | embedding_dim=self.image_encoder.config.hidden_size, 584 | output_dim=self.pipe.unet.config.cross_attention_dim, 585 | ff_mult=4, 586 | ).to(self.device, dtype=torch.float16) 587 | return image_proj_model 588 | 589 | @torch.inference_mode() 590 | def get_image_embeds(self, pil_image=None, clip_image=None, uncond= None): 591 | if pil_image is not None: 592 | if isinstance(pil_image, Image.Image): 593 | pil_image = [pil_image] 594 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 595 | clip_image = clip_image.to(self.device, dtype=torch.float16) 596 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 597 | else: 598 | clip_image = clip_image.to(self.device, dtype=torch.float16) 599 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 600 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 601 | uncond_clip_image_embeds = self.image_encoder( 602 | torch.zeros_like(clip_image), output_hidden_states=True 603 | ).hidden_states[-2] 604 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) 605 | return image_prompt_embeds, uncond_image_prompt_embeds 606 | 607 | def set_ip_adapter(self): 608 | unet = self.pipe.unet 609 | attn_procs = {} 610 | unet_sd = unet.state_dict() 611 | 612 | for attn_processor_name, attn_processor in unet.attn_processors.items(): 613 | # Parse the attention module. 614 | cross_attention_dim = None if attn_processor_name.endswith("attn1.processor") else unet.config.cross_attention_dim 615 | if attn_processor_name.startswith("mid_block"): 616 | hidden_size = unet.config.block_out_channels[-1] 617 | elif attn_processor_name.startswith("up_blocks"): 618 | block_id = int(attn_processor_name[len("up_blocks.")]) 619 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 620 | elif attn_processor_name.startswith("down_blocks"): 621 | block_id = int(attn_processor_name[len("down_blocks.")]) 622 | hidden_size = unet.config.block_out_channels[block_id] 623 | if cross_attention_dim is None: 624 | attn_procs[attn_processor_name] = AttnProcessor() 625 | else: 626 | layer_name = attn_processor_name.split(".processor")[0] 627 | weights = { 628 | "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], 629 | "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], 630 | } 631 | attn_procs[attn_processor_name] = IPAttnProcessor2_0_Lora(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=self.num_tokens) 632 | attn_procs[attn_processor_name].load_state_dict(weights,strict=False) 633 | 634 | attn_module = unet 635 | for n in attn_processor_name.split(".")[:-1]: 636 | attn_module = getattr(attn_module, n) 637 | 638 | attn_module.q_lora = LoRALinearLayer(in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=self.rank) 639 | attn_module.k_lora = LoRALinearLayer(in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=self.rank) 640 | attn_module.v_lora = LoRALinearLayer(in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=self.rank) 641 | attn_module.out_lora = LoRALinearLayer(in_features=attn_module.to_out[0].in_features, out_features=attn_module.to_out[0].out_features, rank=self.rank) 642 | 643 | unet.set_attn_processor(attn_procs) 644 | if hasattr(self.pipe, "controlnet"): 645 | if isinstance(self.pipe.controlnet, MultiControlNetModel): 646 | for controlnet in self.pipe.controlnet.nets: 647 | controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) 648 | else: 649 | self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) 650 | 651 | 652 | 653 | class IPAdapterPlus_Lora_up(IPAdapter): 654 | """IP-Adapter with fine-grained features""" 655 | 656 | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, rank=32): 657 | self.rank = rank 658 | super().__init__(sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens) 659 | 660 | 661 | def generate( 662 | self, 663 | pil_image=None, 664 | clip_image_embeds=None, 665 | prompt=None, 666 | negative_prompt=None, 667 | scale=1.0, 668 | num_samples=4, 669 | seed=None, 670 | guidance_scale=7.5, 671 | num_inference_steps=50, 672 | **kwargs, 673 | ): 674 | self.set_scale(scale) 675 | 676 | if pil_image is not None: 677 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 678 | else: 679 | num_prompts = clip_image_embeds.size(0) 680 | 681 | if prompt is None: 682 | prompt = "best quality, high quality" 683 | if negative_prompt is None: 684 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 685 | 686 | if not isinstance(prompt, List): 687 | prompt = [prompt] * num_prompts 688 | if not isinstance(negative_prompt, List): 689 | negative_prompt = [negative_prompt] * num_prompts 690 | 691 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( 692 | pil_image=pil_image, clip_image=clip_image_embeds 693 | ) 694 | bs_embed, seq_len, _ = image_prompt_embeds.shape 695 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 696 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 697 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 698 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 699 | 700 | with torch.inference_mode(): 701 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( 702 | prompt, 703 | device=self.device, 704 | num_images_per_prompt=num_samples, 705 | do_classifier_free_guidance=True, 706 | negative_prompt=negative_prompt, 707 | ) 708 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) 709 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) 710 | 711 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 712 | images = self.pipe( 713 | prompt_embeds=prompt_embeds, 714 | negative_prompt_embeds=negative_prompt_embeds, 715 | guidance_scale=guidance_scale, 716 | num_inference_steps=num_inference_steps, 717 | generator=generator, 718 | **kwargs, 719 | ).images 720 | 721 | return images 722 | 723 | 724 | def init_proj(self): 725 | image_proj_model = Resampler( 726 | dim=self.pipe.unet.config.cross_attention_dim, 727 | depth=4, 728 | dim_head=64, 729 | heads=12, 730 | num_queries=self.num_tokens, 731 | embedding_dim=self.image_encoder.config.hidden_size, 732 | output_dim=self.pipe.unet.config.cross_attention_dim, 733 | ff_mult=4, 734 | ).to(self.device, dtype=torch.float16) 735 | return image_proj_model 736 | 737 | @torch.inference_mode() 738 | def get_image_embeds(self, pil_image=None, clip_image=None, uncond= None): 739 | if pil_image is not None: 740 | if isinstance(pil_image, Image.Image): 741 | pil_image = [pil_image] 742 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 743 | clip_image = clip_image.to(self.device, dtype=torch.float16) 744 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 745 | else: 746 | clip_image = clip_image.to(self.device, dtype=torch.float16) 747 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 748 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 749 | uncond_clip_image_embeds = self.image_encoder( 750 | torch.zeros_like(clip_image), output_hidden_states=True 751 | ).hidden_states[-2] 752 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) 753 | return image_prompt_embeds, uncond_image_prompt_embeds 754 | 755 | def set_ip_adapter(self): 756 | unet = self.pipe.unet 757 | attn_procs = {} 758 | unet_sd = unet.state_dict() 759 | 760 | for attn_processor_name, attn_processor in unet.attn_processors.items(): 761 | # Parse the attention module. 762 | cross_attention_dim = None if attn_processor_name.endswith("attn1.processor") else unet.config.cross_attention_dim 763 | if attn_processor_name.startswith("mid_block"): 764 | hidden_size = unet.config.block_out_channels[-1] 765 | elif attn_processor_name.startswith("up_blocks"): 766 | block_id = int(attn_processor_name[len("up_blocks.")]) 767 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 768 | elif attn_processor_name.startswith("down_blocks"): 769 | block_id = int(attn_processor_name[len("down_blocks.")]) 770 | hidden_size = unet.config.block_out_channels[block_id] 771 | if cross_attention_dim is None: 772 | attn_procs[attn_processor_name] = AttnProcessor() 773 | else: 774 | layer_name = attn_processor_name.split(".processor")[0] 775 | weights = { 776 | "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], 777 | "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], 778 | } 779 | attn_procs[attn_processor_name] = IPAttnProcessor2_0_Lora(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=self.num_tokens) 780 | attn_procs[attn_processor_name].load_state_dict(weights,strict=False) 781 | 782 | attn_module = unet 783 | for n in attn_processor_name.split(".")[:-1]: 784 | attn_module = getattr(attn_module, n) 785 | 786 | 787 | if "up_blocks" in attn_processor_name: 788 | attn_module.q_lora = LoRALinearLayer(in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=self.rank) 789 | attn_module.k_lora = LoRALinearLayer(in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=self.rank) 790 | attn_module.v_lora = LoRALinearLayer(in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=self.rank) 791 | attn_module.out_lora = LoRALinearLayer(in_features=attn_module.to_out[0].in_features, out_features=attn_module.to_out[0].out_features, rank=self.rank) 792 | 793 | 794 | 795 | unet.set_attn_processor(attn_procs) 796 | if hasattr(self.pipe, "controlnet"): 797 | if isinstance(self.pipe.controlnet, MultiControlNetModel): 798 | for controlnet in self.pipe.controlnet.nets: 799 | controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) 800 | else: 801 | self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) 802 | 803 | 804 | 805 | class IPAdapterFull(IPAdapterPlus): 806 | """IP-Adapter with full features""" 807 | 808 | def init_proj(self): 809 | image_proj_model = MLPProjModel( 810 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim, 811 | clip_embeddings_dim=self.image_encoder.config.hidden_size, 812 | ).to(self.device, dtype=torch.float16) 813 | return image_proj_model 814 | 815 | 816 | class IPAdapterPlusXL(IPAdapter): 817 | """SDXL""" 818 | 819 | def init_proj(self): 820 | image_proj_model = Resampler( 821 | dim=1280, 822 | depth=4, 823 | dim_head=64, 824 | heads=20, 825 | num_queries=self.num_tokens, 826 | embedding_dim=self.image_encoder.config.hidden_size, 827 | output_dim=self.pipe.unet.config.cross_attention_dim, 828 | ff_mult=4, 829 | ).to(self.device, dtype=torch.float16) 830 | return image_proj_model 831 | 832 | @torch.inference_mode() 833 | def get_image_embeds(self, pil_image=None, clip_image_embeds=None): 834 | if pil_image is not None: 835 | if isinstance(pil_image, Image.Image): 836 | pil_image = [pil_image] 837 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 838 | clip_image = clip_image.to(self.device, dtype=torch.float16) 839 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 840 | else: 841 | clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16) 842 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 843 | uncond_clip_image_embeds = self.image_encoder( 844 | torch.zeros_like(clip_image), output_hidden_states=True 845 | ).hidden_states[-2] 846 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) 847 | return image_prompt_embeds, uncond_image_prompt_embeds 848 | 849 | def generate( 850 | self, 851 | pil_image, 852 | prompt=None, 853 | negative_prompt=None, 854 | scale=1.0, 855 | num_samples=4, 856 | seed=None, 857 | num_inference_steps=30, 858 | **kwargs, 859 | ): 860 | self.set_scale(scale) 861 | 862 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 863 | 864 | if prompt is None: 865 | prompt = "best quality, high quality" 866 | if negative_prompt is None: 867 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 868 | 869 | if not isinstance(prompt, List): 870 | prompt = [prompt] * num_prompts 871 | if not isinstance(negative_prompt, List): 872 | negative_prompt = [negative_prompt] * num_prompts 873 | 874 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) 875 | bs_embed, seq_len, _ = image_prompt_embeds.shape 876 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 877 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 878 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 879 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 880 | 881 | with torch.inference_mode(): 882 | ( 883 | prompt_embeds, 884 | negative_prompt_embeds, 885 | pooled_prompt_embeds, 886 | negative_pooled_prompt_embeds, 887 | ) = self.pipe.encode_prompt( 888 | prompt, 889 | num_images_per_prompt=num_samples, 890 | do_classifier_free_guidance=True, 891 | negative_prompt=negative_prompt, 892 | ) 893 | prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) 894 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) 895 | 896 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 897 | images = self.pipe( 898 | prompt_embeds=prompt_embeds, 899 | negative_prompt_embeds=negative_prompt_embeds, 900 | pooled_prompt_embeds=pooled_prompt_embeds, 901 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 902 | num_inference_steps=num_inference_steps, 903 | generator=generator, 904 | **kwargs, 905 | ).images 906 | 907 | return images 908 | -------------------------------------------------------------------------------- /src/ip_adapter/resampler.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | from einops import rearrange 9 | from einops.layers.torch import Rearrange 10 | 11 | 12 | # FFN 13 | def FeedForward(dim, mult=4): 14 | inner_dim = int(dim * mult) 15 | return nn.Sequential( 16 | nn.LayerNorm(dim), 17 | nn.Linear(dim, inner_dim, bias=False), 18 | nn.GELU(), 19 | nn.Linear(inner_dim, dim, bias=False), 20 | ) 21 | 22 | 23 | def reshape_tensor(x, heads): 24 | bs, length, width = x.shape 25 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head) 26 | x = x.view(bs, length, heads, -1) 27 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 28 | x = x.transpose(1, 2) 29 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 30 | x = x.reshape(bs, heads, length, -1) 31 | return x 32 | 33 | 34 | class PerceiverAttention(nn.Module): 35 | def __init__(self, *, dim, dim_head=64, heads=8): 36 | super().__init__() 37 | self.scale = dim_head**-0.5 38 | self.dim_head = dim_head 39 | self.heads = heads 40 | inner_dim = dim_head * heads 41 | 42 | self.norm1 = nn.LayerNorm(dim) 43 | self.norm2 = nn.LayerNorm(dim) 44 | 45 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 46 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 47 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 48 | 49 | def forward(self, x, latents): 50 | """ 51 | Args: 52 | x (torch.Tensor): image features 53 | shape (b, n1, D) 54 | latent (torch.Tensor): latent features 55 | shape (b, n2, D) 56 | """ 57 | x = self.norm1(x) 58 | latents = self.norm2(latents) 59 | 60 | b, l, _ = latents.shape 61 | 62 | q = self.to_q(latents) 63 | kv_input = torch.cat((x, latents), dim=-2) 64 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 65 | 66 | q = reshape_tensor(q, self.heads) 67 | k = reshape_tensor(k, self.heads) 68 | v = reshape_tensor(v, self.heads) 69 | 70 | # attention 71 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 72 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 73 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 74 | out = weight @ v 75 | 76 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 77 | 78 | return self.to_out(out) 79 | 80 | 81 | class CrossAttention(nn.Module): 82 | def __init__(self, *, dim, dim_head=64, heads=8): 83 | super().__init__() 84 | self.scale = dim_head**-0.5 85 | self.dim_head = dim_head 86 | self.heads = heads 87 | inner_dim = dim_head * heads 88 | 89 | self.norm1 = nn.LayerNorm(dim) 90 | self.norm2 = nn.LayerNorm(dim) 91 | 92 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 93 | self.to_k = nn.Linear(dim, inner_dim, bias=False) 94 | self.to_v = nn.Linear(dim, inner_dim, bias=False) 95 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 96 | 97 | 98 | def forward(self, x, x2): 99 | """ 100 | Args: 101 | x (torch.Tensor): image features 102 | shape (b, n1, D) 103 | latent (torch.Tensor): latent features 104 | shape (b, n2, D) 105 | """ 106 | x = self.norm1(x) 107 | x2 = self.norm2(x2) 108 | 109 | b, l, _ = x2.shape 110 | 111 | q = self.to_q(x) 112 | k = self.to_k(x2) 113 | v = self.to_v(x2) 114 | 115 | q = reshape_tensor(q, self.heads) 116 | k = reshape_tensor(k, self.heads) 117 | v = reshape_tensor(v, self.heads) 118 | 119 | # attention 120 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 121 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 122 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 123 | out = weight @ v 124 | 125 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 126 | return self.to_out(out) 127 | 128 | 129 | class Resampler(nn.Module): 130 | def __init__( 131 | self, 132 | dim=1024, 133 | depth=8, 134 | dim_head=64, 135 | heads=16, 136 | num_queries=8, 137 | embedding_dim=768, 138 | output_dim=1024, 139 | ff_mult=4, 140 | max_seq_len: int = 257, # CLIP tokens + CLS token 141 | apply_pos_emb: bool = False, 142 | num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence 143 | ): 144 | super().__init__() 145 | 146 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 147 | 148 | self.proj_in = nn.Linear(embedding_dim, dim) 149 | 150 | self.proj_out = nn.Linear(dim, output_dim) 151 | self.norm_out = nn.LayerNorm(output_dim) 152 | 153 | self.layers = nn.ModuleList([]) 154 | for _ in range(depth): 155 | self.layers.append( 156 | nn.ModuleList( 157 | [ 158 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 159 | FeedForward(dim=dim, mult=ff_mult), 160 | ] 161 | ) 162 | ) 163 | 164 | def forward(self, x): 165 | 166 | latents = self.latents.repeat(x.size(0), 1, 1) 167 | 168 | x = self.proj_in(x) 169 | 170 | 171 | for attn, ff in self.layers: 172 | latents = attn(x, latents) + latents 173 | latents = ff(latents) + latents 174 | 175 | latents = self.proj_out(latents) 176 | return self.norm_out(latents) 177 | 178 | 179 | 180 | def masked_mean(t, *, dim, mask=None): 181 | if mask is None: 182 | return t.mean(dim=dim) 183 | 184 | denom = mask.sum(dim=dim, keepdim=True) 185 | mask = rearrange(mask, "b n -> b n 1") 186 | masked_t = t.masked_fill(~mask, 0.0) 187 | 188 | return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) 189 | -------------------------------------------------------------------------------- /src/ip_adapter/test_resampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from resampler import Resampler 3 | from transformers import CLIPVisionModel 4 | 5 | BATCH_SIZE = 2 6 | OUTPUT_DIM = 1280 7 | NUM_QUERIES = 8 8 | NUM_LATENTS_MEAN_POOLED = 4 # 0 for no mean pooling (previous behavior) 9 | APPLY_POS_EMB = True # False for no positional embeddings (previous behavior) 10 | IMAGE_ENCODER_NAME_OR_PATH = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" 11 | 12 | 13 | def main(): 14 | image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH) 15 | embedding_dim = image_encoder.config.hidden_size 16 | print(f"image_encoder hidden size: ", embedding_dim) 17 | 18 | image_proj_model = Resampler( 19 | dim=1024, 20 | depth=2, 21 | dim_head=64, 22 | heads=16, 23 | num_queries=NUM_QUERIES, 24 | embedding_dim=embedding_dim, 25 | output_dim=OUTPUT_DIM, 26 | ff_mult=2, 27 | max_seq_len=257, 28 | apply_pos_emb=APPLY_POS_EMB, 29 | num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED, 30 | ) 31 | 32 | dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224) 33 | with torch.no_grad(): 34 | image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2] 35 | print("image_embds shape: ", image_embeds.shape) 36 | 37 | with torch.no_grad(): 38 | ip_tokens = image_proj_model(image_embeds) 39 | print("ip_tokens shape:", ip_tokens.shape) 40 | assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM) 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /src/ip_adapter/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | 4 | def is_torch2_available(): 5 | return hasattr(F, "scaled_dot_product_attention") 6 | -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import logging 4 | 5 | 6 | class ColoredFormatter(logging.Formatter): 7 | COLORS = { 8 | "DEBUG": "\033[0;36m", # CYAN 9 | "INFO": "\033[0;32m", # GREEN 10 | "WARNING": "\033[0;33m", # YELLOW 11 | "ERROR": "\033[0;31m", # RED 12 | "CRITICAL": "\033[0;37;41m", # WHITE ON RED 13 | "RESET": "\033[0m", # RESET COLOR 14 | } 15 | 16 | def format(self, record): 17 | colored_record = copy.copy(record) 18 | levelname = colored_record.levelname 19 | seq = self.COLORS.get(levelname, self.COLORS["RESET"]) 20 | colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" 21 | return super().format(colored_record) 22 | 23 | 24 | # Create a new logger 25 | logger = logging.getLogger("ComfyUI-IDM-VTON") 26 | logger.propagate = False 27 | 28 | # Add handler if we don't have one. 29 | if not logger.handlers: 30 | handler = logging.StreamHandler(sys.stdout) 31 | handler.setFormatter(ColoredFormatter("[%(name)s] - %(levelname)s - %(message)s")) 32 | logger.addHandler(handler) 33 | 34 | # Configure logger 35 | loglevel = logging.INFO 36 | logger.setLevel(loglevel) -------------------------------------------------------------------------------- /src/nodes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TemryL/ComfyUI-IDM-VTON/5a8334c58c390381e31a8023cb7ba398ade40b39/src/nodes/__init__.py -------------------------------------------------------------------------------- /src/nodes/idm_vton.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | sys.path.append('..') 4 | 5 | import torch 6 | from torchvision import transforms 7 | from comfy.model_management import get_torch_device 8 | 9 | 10 | DEVICE = get_torch_device() 11 | MAX_RESOLUTION = 16384 12 | 13 | 14 | class IDM_VTON: 15 | @classmethod 16 | def INPUT_TYPES(s): 17 | return { 18 | "required": { 19 | "pipeline": ("PIPELINE",), 20 | "human_img": ("IMAGE",), 21 | "pose_img": ("IMAGE",), 22 | "mask_img": ("IMAGE",), 23 | "garment_img": ("IMAGE",), 24 | "garment_description": ("STRING", {"multiline": True, "dynamicPrompts": True}), 25 | "negative_prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), 26 | "width": ("INT", {"default": 768, "min": 0, "max": MAX_RESOLUTION}), 27 | "height": ("INT", {"default": 1024, "min": 0, "max": MAX_RESOLUTION}), 28 | "num_inference_steps": ("INT", {"default": 30}), 29 | "guidance_scale": ("FLOAT", {"default": 2.0}), 30 | "strength": ("FLOAT", {"default": 1.0}), 31 | "seed": ("INT", {"default": 42, "min": 0, "max": 0xffffffffffffffff}), 32 | } 33 | } 34 | 35 | RETURN_TYPES = ("IMAGE", "MASK") 36 | FUNCTION = "make_inference" 37 | CATEGORY = "ComfyUI-IDM-VTON" 38 | 39 | def preprocess_images(self, human_img, garment_img, pose_img, mask_img, height, width): 40 | human_img = human_img.squeeze().permute(2,0,1) 41 | garment_img = garment_img.squeeze().permute(2,0,1) 42 | pose_img = pose_img.squeeze().permute(2,0,1) 43 | mask_img = mask_img.squeeze().permute(2,0,1) 44 | 45 | human_img = transforms.functional.to_pil_image(human_img) 46 | garment_img = transforms.functional.to_pil_image(garment_img) 47 | pose_img = transforms.functional.to_pil_image(pose_img) 48 | mask_img = transforms.functional.to_pil_image(mask_img) 49 | 50 | human_img = human_img.convert("RGB").resize((width, height)) 51 | garment_img = garment_img.convert("RGB").resize((width, height)) 52 | mask_img = mask_img.convert("RGB").resize((width, height)) 53 | pose_img = pose_img.convert("RGB").resize((width, height)) 54 | 55 | return human_img, garment_img, pose_img, mask_img 56 | 57 | def make_inference(self, pipeline, human_img, garment_img, pose_img, mask_img, height, width, garment_description, negative_prompt, num_inference_steps, strength, guidance_scale, seed): 58 | human_img, garment_img, pose_img, mask_img = self.preprocess_images(human_img, garment_img, pose_img, mask_img, height, width) 59 | tensor_transfrom = transforms.Compose( 60 | [ 61 | transforms.ToTensor(), 62 | transforms.Normalize([0.5], [0.5]), 63 | ] 64 | ) 65 | 66 | with torch.no_grad(): 67 | # Extract the images 68 | with torch.cuda.amp.autocast(): 69 | with torch.inference_mode(): 70 | prompt = "model is wearing " + garment_description 71 | ( 72 | prompt_embeds, 73 | negative_prompt_embeds, 74 | pooled_prompt_embeds, 75 | negative_pooled_prompt_embeds, 76 | ) = pipeline.encode_prompt( 77 | prompt, 78 | num_images_per_prompt=1, 79 | do_classifier_free_guidance=True, 80 | negative_prompt=negative_prompt, 81 | ) 82 | 83 | prompt = ["a photo of " + garment_description] 84 | negative_prompt = [negative_prompt] 85 | ( 86 | prompt_embeds_c, 87 | _, 88 | _, 89 | _, 90 | ) = pipeline.encode_prompt( 91 | prompt, 92 | num_images_per_prompt=1, 93 | do_classifier_free_guidance=False, 94 | negative_prompt=negative_prompt, 95 | ) 96 | 97 | pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(DEVICE, pipeline.dtype) 98 | garment_tensor = tensor_transfrom(garment_img).unsqueeze(0).to(DEVICE, pipeline.dtype) 99 | 100 | images = pipeline( 101 | prompt_embeds=prompt_embeds, 102 | negative_prompt_embeds=negative_prompt_embeds, 103 | pooled_prompt_embeds=pooled_prompt_embeds, 104 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 105 | num_inference_steps=num_inference_steps, 106 | generator=torch.Generator(DEVICE).manual_seed(seed), 107 | strength=strength, 108 | pose_img=pose_img, 109 | text_embeds_cloth=prompt_embeds_c, 110 | cloth=garment_tensor, 111 | mask_image=mask_img, 112 | image=human_img, 113 | height=height, 114 | width=width, 115 | ip_adapter_image=garment_img, 116 | guidance_scale=guidance_scale, 117 | )[0] 118 | 119 | images = [transforms.ToTensor()(image) for image in images] 120 | images = [image.permute(1,2,0) for image in images] 121 | images = torch.stack(images) 122 | return (images, ) -------------------------------------------------------------------------------- /src/nodes/pipeline_loader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | sys.path.append('..') 4 | 5 | import torch 6 | from diffusers import AutoencoderKL, DDPMScheduler 7 | from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection,CLIPTextModelWithProjection, CLIPTextModel 8 | 9 | from ..idm_vton.unet_hacked_tryon import UNet2DConditionModel 10 | from ..idm_vton.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref 11 | from ..idm_vton.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline 12 | from comfy.model_management import get_torch_device 13 | from ...install import WEIGHTS_PATH 14 | 15 | 16 | DEVICE = get_torch_device() 17 | 18 | 19 | class PipelineLoader: 20 | @classmethod 21 | def INPUT_TYPES(s): 22 | return { 23 | "required": { 24 | "weight_dtype": (("float32", "float16", "bfloat16"), ), 25 | } 26 | } 27 | 28 | CATEGORY = "ComfyUI-IDM-VTON" 29 | INPUT_NODE = True 30 | RETURN_TYPES = ("PIPELINE",) 31 | FUNCTION = "load_pipeline" 32 | 33 | def load_pipeline(self, weight_dtype): 34 | if weight_dtype == "float32": 35 | weight_dtype = torch.float32 36 | elif weight_dtype == "float16": 37 | weight_dtype = torch.float16 38 | elif weight_dtype == "bfloat16": 39 | weight_dtype = torch.bfloat16 40 | 41 | noise_scheduler = DDPMScheduler.from_pretrained( 42 | WEIGHTS_PATH, 43 | subfolder="scheduler" 44 | ) 45 | 46 | vae = AutoencoderKL.from_pretrained( 47 | WEIGHTS_PATH, 48 | subfolder="vae", 49 | torch_dtype=weight_dtype 50 | ).requires_grad_(False).eval().to(DEVICE) 51 | 52 | unet = UNet2DConditionModel.from_pretrained( 53 | WEIGHTS_PATH, 54 | subfolder="unet", 55 | torch_dtype=weight_dtype 56 | ).requires_grad_(False).eval().to(DEVICE) 57 | 58 | image_encoder = CLIPVisionModelWithProjection.from_pretrained( 59 | WEIGHTS_PATH, 60 | subfolder="image_encoder", 61 | torch_dtype=weight_dtype 62 | ).requires_grad_(False).eval().to(DEVICE) 63 | 64 | unet_encoder = UNet2DConditionModel_ref.from_pretrained( 65 | WEIGHTS_PATH, 66 | subfolder="unet_encoder", 67 | torch_dtype=weight_dtype 68 | ).requires_grad_(False).eval().to(DEVICE) 69 | 70 | text_encoder_one = CLIPTextModel.from_pretrained( 71 | WEIGHTS_PATH, 72 | subfolder="text_encoder", 73 | torch_dtype=weight_dtype 74 | ).requires_grad_(False).eval().to(DEVICE) 75 | 76 | text_encoder_two = CLIPTextModelWithProjection.from_pretrained( 77 | WEIGHTS_PATH, 78 | subfolder="text_encoder_2", 79 | torch_dtype=weight_dtype 80 | ).requires_grad_(False).eval().to(DEVICE) 81 | 82 | tokenizer_one = AutoTokenizer.from_pretrained( 83 | WEIGHTS_PATH, 84 | subfolder="tokenizer", 85 | revision=None, 86 | use_fast=False, 87 | ) 88 | 89 | tokenizer_two = AutoTokenizer.from_pretrained( 90 | WEIGHTS_PATH, 91 | subfolder="tokenizer_2", 92 | revision=None, 93 | use_fast=False, 94 | ) 95 | 96 | pipe = TryonPipeline.from_pretrained( 97 | WEIGHTS_PATH, 98 | unet=unet, 99 | vae=vae, 100 | feature_extractor=CLIPImageProcessor(), 101 | text_encoder=text_encoder_one, 102 | text_encoder_2=text_encoder_two, 103 | tokenizer=tokenizer_one, 104 | tokenizer_2=tokenizer_two, 105 | scheduler=noise_scheduler, 106 | image_encoder=image_encoder, 107 | torch_dtype=weight_dtype, 108 | ) 109 | pipe.unet_encoder = unet_encoder 110 | pipe = pipe.to(DEVICE) 111 | pipe.weight_dtype = weight_dtype 112 | 113 | return (pipe, ) -------------------------------------------------------------------------------- /src/nodes_mappings.py: -------------------------------------------------------------------------------- 1 | from .nodes.pipeline_loader import PipelineLoader 2 | from .nodes.idm_vton import IDM_VTON 3 | 4 | 5 | NODE_CLASS_MAPPINGS = { 6 | "PipelineLoader": PipelineLoader, 7 | "IDM-VTON": IDM_VTON, 8 | } 9 | 10 | NODE_DISPLAY_NAME_MAPPINGS = { 11 | "PipelineLoader": "Load IDM-VTON Pipeline", 12 | "IDM-VTON": "Run IDM-VTON Inference", 13 | } -------------------------------------------------------------------------------- /workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TemryL/ComfyUI-IDM-VTON/5a8334c58c390381e31a8023cb7ba398ade40b39/workflow.png --------------------------------------------------------------------------------