├── .env.example ├── .github └── workflows │ └── release.yml ├── .gitignore ├── LICENSE.md ├── README.md ├── consts.py ├── esrgan_upscaler.py ├── exceptions.py ├── face_restoration.py ├── gobig.py ├── logging_settings.py ├── main.py ├── poetry.lock ├── pyinstaller_hooks ├── __init__.py ├── hook-diffusers.py └── hook-transformers.py ├── pyproject.toml ├── request_models.py ├── settings.py ├── universal_pipeline.py └── utils.py /.env.example: -------------------------------------------------------------------------------- 1 | IMAGE_AI_UTILS_USERNAME=user 2 | IMAGE_AI_UTILS_PASSWORD=password 3 | HOST=0.0.0.0 4 | PORT=7331 5 | PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:1024 6 | DIFFUSERS_CACHE_PATH=./diffusers_cache 7 | HUGGING_FACE_HUB_TOKEN= 8 | USE_OPTIMIZED_MODE=true 9 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | # Triggers the workflow on push or pull request events but only for the "master" branch 5 | push: 6 | tags: 7 | - "v*" 8 | 9 | workflow_dispatch: 10 | 11 | jobs: 12 | build-windows: 13 | name: "Build Windows" 14 | runs-on: windows-latest 15 | steps: 16 | - uses: "actions/checkout@v3" 17 | - uses: "actions/setup-python@v4" 18 | with: 19 | python-version: "3.9" 20 | - name: build 21 | run: | 22 | python -m venv .venv 23 | .venv\Scripts\activate 24 | pip install poetry 25 | poetry install 26 | pip install pyinstaller 27 | pyinstaller main.py --hidden-import colorlog --collect-all huggingface_hub --additional-hooks-dir pyinstaller_hooks 28 | copy .env.example dist\main\.env 29 | cd dist\main\ 30 | 31 | mkdir basicsr\archs, basicsr\data, basicsr\losses, basicsr\models 32 | mkdir realesrgan\archs, realesrgan\data, realesrgan\losses, realesrgan\models 33 | mkdir gfpgan\archs, gfpgan\data, gfpgan\losses, gfpgan\models 34 | 35 | & "C:/Program Files/7-Zip/7z.exe" a -r ..\image_ai_utils_windows.7z * 36 | - uses: actions/upload-artifact@v3 37 | with: 38 | name: windows-build 39 | path: dist/image_ai_utils_windows.7z 40 | 41 | release: 42 | name: "Release" 43 | needs: [build-windows] 44 | runs-on: ubuntu-latest 45 | steps: 46 | - uses: actions/download-artifact@v3 47 | with: 48 | name: windows-build 49 | - uses: "marvinpinto/action-automatic-releases@latest" 50 | with: 51 | repo_token: "${{ secrets.GITHUB_TOKEN }}" 52 | prerelease: false 53 | draft: true 54 | files: | 55 | *.7z 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | .idea 3 | __pycache__ 4 | test.py 5 | *.prof 6 | models 7 | .env 8 | build 9 | dist 10 | diffusers_cache 11 | main.spec 12 | messages.log 13 | gfpgan -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image AI Utils Server 2 | ## Installation 3 | ### Requirements 4 | - NVIDIA GPU with at least 10gb of VRAM is highly recomended 5 | - You need to create your Hugging Face token [here](https://huggingface.co/docs/hub/security-tokens) 6 | and accept terms of service [here](https://huggingface.co/CompVis/stable-diffusion-v1-4) 7 | 8 | ### Windows 9 | - Download and extract `image_ai_utils_windows.7z` from the [releases](https://github.com/qweryty/image-ai-utils-server/releases) page 10 | - Replace `` in `.env` file with your Hugging Face token(see [.env section](#env-file-fields-description) for details) 11 | - Run main.exe 12 | - During the first run it will download stable diffusion models to directory, specified in .env file 13 | 14 | ### Linux 15 | Python version 3.9 should be installed. Newer versions are not supported yet and older versions 16 | may have some unexpected problems. 17 | 18 | ```shell 19 | python3.9 -m venv .venv 20 | source .venv/bin/activate 21 | pip install poetry 22 | poetry install 23 | mkdir models 24 | curl -L https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -o models 25 | curl -L https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -o models 26 | curl -L https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth -o models 27 | curl -L https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth -o models 28 | curl -L https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth -o models 29 | curl -L https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -o models 30 | curl -L https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -o models 31 | cp .env.example .env 32 | ``` 33 | 34 | Replace `` in `.env` file with your Hugging Face token(see [.env section](#env-file-fields-description) for details) 35 | 36 | Run with `python main.py` 37 | 38 | During the first run it will download stable diffusion models to directory, specified in .env file 39 | 40 | ### Docker 41 | TODO 42 | 43 | ## `.env` File Fields Description 44 | - `IMAGE_AI_UTILS_USERNAME` - username which the plugin uses to access the server (you don't need to change this field for local installation) 45 | - `IMAGE_AI_UTILS_PASSWORD` - password which the plugin uses to access the server (you don't need to change this field for local installation) 46 | - `HOST` - URL or IP addres of the server; one server can serve multiple URLs or IPs, `0.0.0.0` will (you don't need to change this field for local installation) 47 | - `PORT` - server port (you don't need to change this field for local installation, unless it conflicts with some other service) 48 | - `PYTORCH_CUDA_ALLOC_CONF` - see https://pytorch.org/docs/stable/notes/cuda.html#memory-management 49 | - `DIFFUSERS_CACHE_PATH` - the path where downloaded stable diffusion models will be stored 50 | - `HUGGING_FACE_HUB_TOKEN` - token required to download stable diffusion models 51 | - `USE_OPTIMIZED_MODE` - when enabled, stable diffusion will consume less VRAM at the expense of 10% speed 52 | 53 | ## Common Problems 54 | ### main.exe closes shortly after startup 55 | You can look into `messages.log` file, it will contain all the errors encountered during the run of the program. If you still can't solve your problem, please [report an issue](https://github.com/qweryty/image-ai-utils-server/issues/new) and attach this file in the comment 56 | 57 | ### `RuntimeError: CUDA out of memory` 58 | Try enabling `USE_OPTIMIZED_MODE` in .env file. 59 | 60 | If that didn't help and you have less than 4GB of VRAM, you are probably out of luck and need better hardware. 61 | 62 | Another option would be to rent a VPS with GPU and running your server there. 63 | 64 | ### `OSError: Windows requires Developer Mode to be activated` 65 | During the first run `diffusers` library needs to create some symlinks which requires developer mode or admin rights. 66 | If you don't want to activate developer mode, right mouse click on main.exe and choose "Run as administrator", 67 | you only need to do it once, next time it will work without extra privileges. 68 | -------------------------------------------------------------------------------- /consts.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class ESRGANModel(str, Enum): 5 | # General 6 | GENERAL_X4_V3 = 'general_x4_v3' 7 | X4_PLUS = 'x4_plus' 8 | X2_PLUS = 'x2_plus' 9 | ESRNET_X4_PLUS = 'x4_plus' 10 | OFFICIAL_X4 = 'official_x4' 11 | 12 | # Anime/Illustrations 13 | X4_PLUS_ANIME_6B = 'x4_plus_anime_6b' 14 | 15 | # Anime video 16 | ANIME_VIDEO_V3 = 'anime_video_v3' 17 | 18 | 19 | class GFPGANModel(str, Enum): 20 | V1_3 = 'V1.3' 21 | V1_2 = 'V1.2' 22 | V1 = 'V1' 23 | 24 | 25 | class ScalingMode(str, Enum): 26 | SHRINK = 'shrink' 27 | GROW = 'grow' 28 | 29 | 30 | class ImageFormat(str, Enum): 31 | PNG = 'PNG' 32 | JPEG = 'JPEG' 33 | BMP = 'BMP' 34 | 35 | 36 | class WebSocketResponseStatus(str, Enum): 37 | FINISHED = 'finished' 38 | PROGRESS = 'progress' 39 | 40 | 41 | MIN_SEED = -0x8000_0000_0000_0000 42 | MAX_SEED = 0xffff_ffff_ffff_ffff 43 | -------------------------------------------------------------------------------- /esrgan_upscaler.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from enum import Enum 3 | 4 | import numpy as np 5 | from PIL import Image 6 | from basicsr.archs.rrdbnet_arch import RRDBNet 7 | from realesrgan import RealESRGANer 8 | from realesrgan.archs.srvgg_arch import SRVGGNetCompact 9 | 10 | from request_models import ESRGANModel 11 | from utils import resolve_path 12 | 13 | ESRGAN_URLS = [ 14 | 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth', 15 | 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', 16 | 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', 17 | 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth', 18 | 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth', 19 | 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth', 20 | 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth', 21 | ] 22 | 23 | MODEL_PATHS = { 24 | ESRGANModel.GENERAL_X4_V3: 'models/realesr-general-x4v3.pth', 25 | ESRGANModel.X4_PLUS: 'models/RealESRGAN_x4plus.pth', 26 | ESRGANModel.X2_PLUS: 'models/RealESRGAN_x2plus.pth', 27 | ESRGANModel.ESRNET_X4_PLUS: 'models/RealESRNet_x4plus.pth', 28 | ESRGANModel.OFFICIAL_X4: 'models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth', 29 | ESRGANModel.ANIME_VIDEO_V3: 'models/realesr-animevideov3.pth', 30 | ESRGANModel.X4_PLUS_ANIME_6B: 'models/RealESRGAN_x4plus_anime_6B.pth', 31 | } 32 | 33 | for key, value in MODEL_PATHS.items(): 34 | MODEL_PATHS[key] = resolve_path(value) 35 | 36 | 37 | def get_upsampler( 38 | model_type: ESRGANModel, 39 | tile: int = 0, 40 | tile_pad: int = 10, 41 | pre_pad: int = 0, 42 | half: bool = True, 43 | ) -> RealESRGANer: 44 | # x4 RRDBNet model 45 | if model_type in [ESRGANModel.X4_PLUS, ESRGANModel.ESRNET_X4_PLUS, ESRGANModel.OFFICIAL_X4]: 46 | model = RRDBNet( 47 | num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4 48 | ) 49 | netscale = 4 50 | elif model_type in [ESRGANModel.X4_PLUS_ANIME_6B]: # x4 RRDBNet model with 6 blocks 51 | model = RRDBNet( 52 | num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4 53 | ) 54 | netscale = 4 55 | elif model_type in [ESRGANModel.X2_PLUS]: # x2 RRDBNet model 56 | model = RRDBNet( 57 | num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2 58 | ) 59 | netscale = 2 60 | elif model_type in [ESRGANModel.ANIME_VIDEO_V3]: # x4 VGG-style model (XS size) 61 | model = SRVGGNetCompact( 62 | num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu' 63 | ) 64 | netscale = 4 65 | elif model_type in [ESRGANModel.GENERAL_X4_V3]: 66 | model = SRVGGNetCompact( 67 | num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu' 68 | ) 69 | netscale = 4 70 | else: 71 | raise ValueError('Incorrect model') # TODO custom exception 72 | 73 | # determine model paths 74 | model_path = MODEL_PATHS[model_type] 75 | # restorer 76 | # TODO use gpu(available in newer version) 77 | # TODO cache upsampler 78 | return RealESRGANer( 79 | scale=netscale, 80 | model_path=model_path, 81 | model=model, 82 | tile=tile, 83 | tile_pad=tile_pad, 84 | pre_pad=pre_pad, 85 | half=not half, 86 | ) 87 | 88 | 89 | def upscale( 90 | image: Image.Image, 91 | model_type: ESRGANModel = ESRGANModel.GENERAL_X4_V3, 92 | tile: int = 0, 93 | tile_pad: int = 10, 94 | pre_pad: int = 0, 95 | half: bool = True, 96 | outscale: float = 4 97 | ): 98 | upsampler = get_upsampler( 99 | model_type=model_type, tile=tile, tile_pad=tile_pad, pre_pad=pre_pad, half=half 100 | ) 101 | 102 | numpy_image = np.asarray(image) 103 | output, _ = upsampler.enhance(numpy_image, outscale=outscale) 104 | return Image.fromarray(np.uint8(output)) 105 | -------------------------------------------------------------------------------- /exceptions.py: -------------------------------------------------------------------------------- 1 | from fastapi import HTTPException, status 2 | 3 | 4 | class BaseWebSocketException(Exception): 5 | message = 'Unexpected exception, check server logs for details' 6 | 7 | 8 | class BatchSizeIsTooLargeException(BaseWebSocketException): 9 | def __init__(self, batch_size): 10 | self.message = f'Couldn\'t fit {batch_size} images with such aspect ratio into ' \ 11 | f'memory. Try using smaller batch size or enabling ' \ 12 | f'try_smaller_batch_on_fail option' 13 | 14 | 15 | class AspectRatioTooWideException(BaseWebSocketException): 16 | message = 'Couldn\'t fit image with such aspect ratio into memory, ' \ 17 | 'try using another scaling mode' 18 | 19 | 20 | class CouldntFixFaceException(HTTPException): 21 | def __init__(self): 22 | super().__init__( 23 | status_code=status.HTTP_400_BAD_REQUEST, 24 | detail='Couldn\'t fix faces. GFPGANer returned None' 25 | ) 26 | -------------------------------------------------------------------------------- /face_restoration.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | from PIL import Image 5 | 6 | from exceptions import CouldntFixFaceException 7 | from gfpgan import GFPGANer 8 | 9 | from esrgan_upscaler import get_upsampler 10 | from request_models import ESRGANModel, GFPGANModel 11 | from utils import resolve_path 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | GFPGAN_URLS = [ 17 | 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', 18 | 'https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth', 19 | 'https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth' 20 | ] 21 | 22 | MODEL_PATHS = { 23 | # path, channel_multiplier, arch 24 | GFPGANModel.V1_3: ['models/GFPGANv1.3.pth', 2, 'clean'], 25 | GFPGANModel.V1_2: ['models/GFPGANCleanv1-NoCE-C2.pth', 2, 'clean'], 26 | GFPGANModel.V1: ['models/GFPGANv1.pth', 1, 'original'], 27 | } 28 | 29 | for key, value in MODEL_PATHS.items(): 30 | MODEL_PATHS[key][0] = resolve_path(value[0]) 31 | 32 | 33 | def restore_face( 34 | image: Image.Image, 35 | model_type: GFPGANModel = GFPGANModel.V1_3, 36 | use_real_esrgan: bool = True, 37 | bg_tile: int = 400, 38 | upscale: int = 2, 39 | aligned: bool = False, 40 | only_center_face: bool = False 41 | ): 42 | # ------------------------ set up background upsampler ------------------------ 43 | if use_real_esrgan: 44 | bg_upsampler = get_upsampler(model_type=ESRGANModel.X2_PLUS, tile=bg_tile) 45 | else: 46 | bg_upsampler = None 47 | 48 | # ------------------------ set up GFPGAN restorer ------------------------ 49 | # determine model paths 50 | model_path, channel_multiplier, arch = MODEL_PATHS[model_type] 51 | 52 | restorer = GFPGANer( 53 | model_path=model_path, 54 | upscale=upscale, 55 | arch=arch, 56 | channel_multiplier=channel_multiplier, 57 | bg_upsampler=bg_upsampler 58 | ) 59 | 60 | # ------------------------ restore ------------------------ 61 | input_img = np.asarray(image.convert('RGB')) 62 | 63 | # restore faces and background if necessary 64 | cropped_faces, restored_faces, restored_img = restorer.enhance( 65 | input_img, 66 | has_aligned=aligned, 67 | only_center_face=only_center_face, 68 | paste_back=True 69 | ) 70 | 71 | if restored_img is None: 72 | raise CouldntFixFaceException 73 | 74 | return Image.fromarray(restored_img) 75 | -------------------------------------------------------------------------------- /gobig.py: -------------------------------------------------------------------------------- 1 | # https://github.com/lowfuel/progrock-stable 2 | from typing import Tuple, Optional, List, Callable, Awaitable 3 | 4 | import torch 5 | from PIL import Image, ImageDraw 6 | from PIL.Image import Resampling 7 | from torch import autocast 8 | 9 | from esrgan_upscaler import upscale 10 | from request_models import ESRGANModel 11 | from universal_pipeline import StableDiffusionUniversalPipeline, preprocess 12 | 13 | 14 | # Alternative method composites a grid of images at the positions provided 15 | def grid_merge(source: Image.Image, slices: List[Tuple[Image.Image, int, int]]): 16 | source = source.convert('RGBA') 17 | for image_slice, posx, posy in slices: # go in reverse to get proper stacking 18 | source.alpha_composite(image_slice, (posx, posy)) 19 | return source 20 | 21 | 22 | def grid_coords(target: Tuple[int, int], slice_size: Tuple[int, int], overlap: int): 23 | # generate a list of coordinate tuples for our sections, in order of how they'll be rendered 24 | # target should be the size for the gobig result, original is the size of each chunk being 25 | # rendered 26 | center = [] 27 | target_x, target_y = target 28 | center_x = int(target_x / 2) 29 | center_y = int(target_y / 2) 30 | slice_x, slice_y = slice_size 31 | x = center_x - int(slice_x / 2) 32 | y = center_y - int(slice_y / 2) 33 | center.append((x, y)) # center chunk 34 | uy = y # up 35 | uy_list = [] 36 | dy = y # down 37 | dy_list = [] 38 | lx = x # left 39 | lx_list = [] 40 | rx = x # right 41 | rx_list = [] 42 | while uy > 0: # center row vertical up 43 | uy = uy - slice_y + overlap 44 | uy_list.append((lx, uy)) 45 | while (dy + slice_y) <= target_y: # center row vertical down 46 | dy = dy + slice_y - overlap 47 | dy_list.append((rx, dy)) 48 | while lx > 0: 49 | lx = lx - slice_x + overlap 50 | lx_list.append((lx, y)) 51 | uy = y 52 | while uy > 0: 53 | uy = uy - slice_y + overlap 54 | uy_list.append((lx, uy)) 55 | dy = y 56 | while (dy + slice_y) <= target_y: 57 | dy = dy + slice_y - overlap 58 | dy_list.append((lx, dy)) 59 | while (rx + slice_x) <= target_x: 60 | rx = rx + slice_x - overlap 61 | rx_list.append((rx, y)) 62 | uy = y 63 | while uy > 0: 64 | uy = uy - slice_y + overlap 65 | uy_list.append((rx, uy)) 66 | dy = y 67 | while (dy + slice_y) <= target_y: 68 | dy = dy + slice_y - overlap 69 | dy_list.append((rx, dy)) 70 | # calculate a new size that will fill the canvas, which will be optionally used in grid_slice and go_big 71 | last_coordx, last_coordy = dy_list[-1:][0] 72 | render_edgey = last_coordy + slice_y # outer bottom edge of the render canvas 73 | render_edgex = last_coordx + slice_x # outer side edge of the render canvas 74 | scalarx = render_edgex / target_x 75 | scalary = render_edgey / target_y 76 | if scalarx <= scalary: 77 | new_edgex = int(target_x * scalarx) 78 | new_edgey = int(target_y * scalarx) 79 | else: 80 | new_edgex = int(target_x * scalary) 81 | new_edgey = int(target_y * scalary) 82 | # now put all the chunks into one master list of coordinates (essentially reverse of how we calculated them so that the central slices will be on top) 83 | result = [] 84 | for coords in dy_list[::-1]: 85 | result.append(coords) 86 | for coords in uy_list[::-1]: 87 | result.append(coords) 88 | for coords in rx_list[::-1]: 89 | result.append(coords) 90 | for coords in lx_list[::-1]: 91 | result.append(coords) 92 | result.append(center[0]) 93 | return result, (new_edgex, new_edgey) 94 | 95 | 96 | def grid_slice(source: Image.Image, overlap: int, slice_size: Tuple[int, int]): 97 | width, height = slice_size 98 | coordinates, new_size = grid_coords(source.size, slice_size, overlap) 99 | slices = [] 100 | for coordinate in coordinates: 101 | x, y = coordinate 102 | slices.append(((source.crop((x, y, x + width, y + height))), x, y)) 103 | return slices, new_size 104 | 105 | 106 | async def do_gobig( 107 | input_image: Image.Image, 108 | prompt: str, 109 | maximize: bool, 110 | target_width: int, 111 | target_height: int, 112 | overlap: int, 113 | use_real_esrgan: bool, 114 | esrgan_model: ESRGANModel, 115 | pipeline: StableDiffusionUniversalPipeline, 116 | resampling_mode: Resampling = Resampling.LANCZOS, 117 | strength: float = 0.8, 118 | num_inference_steps: Optional[int] = 50, 119 | guidance_scale: Optional[float] = 7.5, 120 | generator: Optional[torch.Generator] = None, 121 | progress_callback: Optional[Callable[[float], Awaitable]] = None 122 | ) -> Image.Image: 123 | # get our render size for each slice, and our target size 124 | slice_width = slice_height = 512 125 | if use_real_esrgan: 126 | input_image = upscale(input_image, esrgan_model) 127 | target_image = input_image.resize((target_width, target_height), resampling_mode) 128 | slices, new_canvas_size = grid_slice(target_image, overlap, (slice_width, slice_height)) 129 | if maximize: 130 | # increase our final image size to use up blank space 131 | target_image = input_image.resize(new_canvas_size, resampling_mode) 132 | slices, new_canvas_size = grid_slice(target_image, overlap, (slice_width, slice_height)) 133 | input_image.close() 134 | # now we trigger a do_run for each slice 135 | better_slices = [] 136 | 137 | count = 0 138 | if progress_callback is not None: 139 | async def chunk_progress_callback(batch_step: int, total_batch_steps: int): 140 | current_step = count * total_batch_steps + batch_step 141 | total_steps = len(slices) * total_batch_steps 142 | progress = current_step / total_steps 143 | await progress_callback(progress) 144 | else: 145 | chunk_progress_callback = None 146 | 147 | with autocast('cuda'): 148 | with torch.inference_mode(): 149 | # TODO run in batches 150 | for count, (chunk, coord_x, coord_y) in enumerate(slices): 151 | result_slice = (await pipeline.image_to_image( 152 | prompt=prompt, 153 | init_image=preprocess(chunk).to(pipeline.device), 154 | strength=strength, 155 | num_inference_steps=num_inference_steps, 156 | guidance_scale=guidance_scale, 157 | generator=generator, 158 | progress_callback=chunk_progress_callback 159 | ))[0] 160 | # result_slice.copy? 161 | better_slices.append((result_slice, coord_x, coord_y)) 162 | 163 | # create an alpha channel for compositing the slices 164 | alpha = Image.new('L', (slice_width, slice_height), color=0xFF) 165 | alpha_gradient = ImageDraw.Draw(alpha) 166 | # we want the alpha gradient to be half the size of the overlap, 167 | # otherwise we always see some of the original background underneath 168 | alpha_overlap = int(overlap / 2) 169 | for i in range(overlap): 170 | shape = ((slice_width - i, slice_height - i), (i, i)) 171 | fill = min(int(i * (255 / alpha_overlap)), 255) 172 | alpha_gradient.rectangle(shape, fill=fill) 173 | # now composite the slices together 174 | finished_slices = [] 175 | for better_slice, x, y in better_slices: 176 | better_slice.putalpha(alpha) 177 | finished_slices.append((better_slice, x, y)) 178 | final_output = grid_merge(target_image, finished_slices) 179 | 180 | return final_output 181 | -------------------------------------------------------------------------------- /logging_settings.py: -------------------------------------------------------------------------------- 1 | import logging.config 2 | import os.path 3 | 4 | from settings import settings 5 | from utils import resolve_path 6 | 7 | LOGGING = { 8 | 'version': 1, 9 | 'disable_existing_loggers': False, 10 | 'handlers': { 11 | 'default': { 12 | 'level': settings.LOG_LEVEL.upper(), 13 | 'class': 'logging.StreamHandler', 14 | 'stream': 'ext://sys.stdout', 15 | 'formatter': 'verbose', 16 | }, 17 | 'file_handler': { 18 | 'level': settings.FILE_LOG_LEVEL.upper(), 19 | 'formatter': 'verbose', 20 | 'class': 'logging.FileHandler', 21 | 'filename': resolve_path('messages.log') 22 | }, 23 | 'blackhole': {'level': 'DEBUG', 'class': 'logging.NullHandler'}, 24 | }, 25 | 'formatters': { 26 | 'verbose': { 27 | 'format': '%(log_color)s%(asctime)s [%(levelname)s] [%(name)s] %(message)s (%(filename)s:%(lineno)d)', 28 | '()': 'colorlog.ColoredFormatter', 29 | 'log_colors': { 30 | 'DEBUG': 'cyan', 31 | 'INFO': 'green', 32 | 'WARNING': 'yellow', 33 | 'ERROR': 'red', 34 | 'CRITICAL': 'bold_red', 35 | }, 36 | } 37 | }, 38 | 'loggers': { 39 | 'fastapi': {'level': 'INFO', 'handlers': ['default', 'file_handler']}, 40 | 'uvicorn.error': { 41 | 'level': 'INFO', 'handlers': ['default', 'file_handler'], 'propagate': False 42 | }, 43 | 'uvicorn.access': { 44 | 'level': 'INFO', 'handlers': ['default', 'file_handler'], 'propagate': False 45 | }, 46 | 'uvicorn': { 47 | 'level': 'INFO', 'handlers': ['default', 'file_handler'], 'propagate': False 48 | }, 49 | '': { 50 | 'level': settings.LOG_LEVEL.upper(), 51 | 'handlers': ['default', 'file_handler'], 52 | 'propagate': True, 53 | }, 54 | 55 | } 56 | } 57 | 58 | logging.config.dictConfig(LOGGING) 59 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageChops, Image, ImageDraw 2 | 3 | from settings import settings # noqa 4 | from logging_settings import LOGGING # noqa 5 | import asyncio 6 | import functools 7 | from json import JSONDecodeError 8 | 9 | from consts import WebSocketResponseStatus, GFPGANModel 10 | from exceptions import BatchSizeIsTooLargeException, AspectRatioTooWideException, \ 11 | BaseWebSocketException 12 | import face_restoration 13 | from gobig import do_gobig 14 | 15 | import json 16 | import logging 17 | from typing import Callable, List, Dict, Any, Optional, Union 18 | 19 | import torch 20 | import uvicorn 21 | from fastapi import FastAPI, HTTPException, Depends, WebSocket, status, WebSocketDisconnect 22 | from fastapi.middleware.gzip import GZipMiddleware 23 | from fastapi.security import HTTPBasic, HTTPBasicCredentials 24 | from torch import autocast 25 | 26 | import esrgan_upscaler 27 | from request_models import BaseImageGenerationRequest, ImageArrayResponse, ImageToImageRequest, \ 28 | TextToImageRequest, GoBigRequest, ImageResponse, UpscaleRequest, InpaintingRequest, \ 29 | FaceRestorationRequest, MakeTilableRequest, MakeTilableResponse 30 | from universal_pipeline import StableDiffusionUniversalPipeline, preprocess, preprocess_mask 31 | from utils import base64url_to_image, image_to_base64url, size_from_aspect_ratio, download_models 32 | 33 | logger = logging.getLogger(__name__) 34 | security = HTTPBasic() 35 | 36 | 37 | async def authorize(credentials: HTTPBasicCredentials = Depends(security)): 38 | if credentials.username != settings.USERNAME or credentials.password != settings.PASSWORD: 39 | raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) 40 | 41 | 42 | async def authorize_web_socket(websocket: WebSocket) -> bool: 43 | credentials = await websocket.receive_json() 44 | if credentials.get('username') != settings.USERNAME or \ 45 | credentials.get('password') != settings.PASSWORD: 46 | await websocket.close(status.WS_1008_POLICY_VIOLATION, 'Authorization error') 47 | return False 48 | 49 | return True 50 | 51 | 52 | def websocket_handler(path, app): 53 | def decorator(handler): 54 | @functools.wraps(handler) 55 | async def wrapper(websocket: WebSocket): 56 | await websocket.accept() 57 | try: 58 | if not await authorize_web_socket(websocket): 59 | return 60 | 61 | await handler(websocket) 62 | except BaseWebSocketException as e: 63 | await websocket.close(status.WS_1008_POLICY_VIOLATION, e.message) 64 | except JSONDecodeError: 65 | await websocket.close( 66 | status.WS_1008_POLICY_VIOLATION, 67 | 'Server received message that is not in json format' 68 | ) 69 | except WebSocketDisconnect: 70 | return 71 | 72 | return app.websocket(path)(wrapper) 73 | 74 | return decorator 75 | 76 | 77 | app = FastAPI(dependencies=[Depends(authorize)]) 78 | app.add_middleware(GZipMiddleware, minimum_size=1000) 79 | 80 | try: 81 | pipeline = StableDiffusionUniversalPipeline.from_pretrained( 82 | 'CompVis/stable-diffusion-v1-4', 83 | revision='fp16', 84 | torch_dtype=torch.bfloat16, 85 | use_auth_token=True, 86 | cache_dir=settings.DIFFUSERS_CACHE_PATH, 87 | ).to('cuda') 88 | if settings.USE_OPTIMIZED_MODE: 89 | pipeline.enable_attention_slicing() 90 | except Exception as e: 91 | logger.exception(e) 92 | raise e 93 | 94 | 95 | async def do_diffusion( 96 | request: BaseImageGenerationRequest, 97 | diffusion_method: Callable, 98 | websocket: WebSocket, 99 | batched_params: Optional[Dict[str, List[Any]]] = None, 100 | return_images: bool = False, 101 | progress_multiplier: float = 1., 102 | progress_offset: float = 0., 103 | **kwargs 104 | ) -> Union[ImageArrayResponse, List[Image.Image]]: 105 | if request.seed is not None: 106 | generator = torch.Generator('cuda').manual_seed(request.seed) 107 | else: 108 | generator = None 109 | 110 | if batched_params is None: 111 | batched_params = {} 112 | with autocast('cuda'): 113 | with torch.inference_mode(): 114 | for batch_size in range(min(request.batch_size, request.num_variants), 0, -1): 115 | try: 116 | num_batches = request.num_variants // batch_size 117 | prompts = [request.prompt] * batch_size 118 | last_batch_size = request.num_variants - batch_size * num_batches 119 | images = [] 120 | 121 | for i in range(num_batches): 122 | batch_kwargs = {} 123 | for key, value in batched_params.items(): 124 | batch_kwargs[key] = value[i * batch_size: (i + 1) * batch_size] 125 | 126 | async def progress_callback(batch_step: int, total_batch_steps: int): 127 | current_step = i * total_batch_steps + batch_step 128 | total_steps = num_batches * total_batch_steps 129 | if last_batch_size: 130 | total_steps += total_batch_steps 131 | progress = progress_multiplier * current_step / total_steps + \ 132 | progress_offset 133 | await websocket.send_json( 134 | {'status': WebSocketResponseStatus.PROGRESS, 'progress': progress} 135 | ) 136 | # https://github.com/aaugustin/websockets/issues/867 137 | # Probably will be fixed if/when diffusers will implement asynchronous 138 | # pipeline 139 | # https://github.com/huggingface/diffusers/issues/374 140 | await asyncio.sleep(0) 141 | 142 | new_images = (await diffusion_method( 143 | prompt=prompts, 144 | num_inference_steps=request.num_inference_steps, 145 | generator=generator, 146 | guidance_scale=request.guidance_scale, 147 | progress_callback=progress_callback, 148 | **batch_kwargs, 149 | **kwargs 150 | )) 151 | images.extend(new_images) 152 | 153 | if last_batch_size: 154 | batch_kwargs = {} 155 | for key, value in batched_params.items(): 156 | batch_kwargs[key] = value[-last_batch_size:] 157 | 158 | async def progress_callback(batch_step: int, total_batch_steps: int): 159 | current_step = num_batches * total_batch_steps + batch_step 160 | total_steps = (num_batches + 1) * total_batch_steps 161 | progress = progress_multiplier * current_step / total_steps + \ 162 | progress_offset 163 | await websocket.send_json( 164 | {'status': WebSocketResponseStatus.PROGRESS, 'progress': progress} 165 | ) 166 | 167 | new_images = (await diffusion_method( 168 | prompt=[request.prompt] * last_batch_size, 169 | num_inference_steps=request.num_inference_steps, 170 | generator=generator, 171 | guidance_scale=request.guidance_scale, 172 | progress_callback=progress_callback, 173 | **batch_kwargs, 174 | **kwargs 175 | )) 176 | images.extend(new_images) 177 | break 178 | except RuntimeError as e: 179 | if request.try_smaller_batch_on_fail: 180 | logger.warning(f'Batch size {batch_size} was too large, trying smaller') 181 | else: 182 | raise BatchSizeIsTooLargeException(batch_size) 183 | else: 184 | raise AspectRatioTooWideException 185 | 186 | if return_images: 187 | return images 188 | else: 189 | return ImageArrayResponse(images=images) 190 | 191 | 192 | # TODO task queue? 193 | # (or can set up an external scheduler and use this as internal endpoint) 194 | @websocket_handler('/text_to_image', app) 195 | async def text_to_image(websocket: WebSocket): 196 | request = TextToImageRequest(**(await websocket.receive_json())) 197 | width, height = size_from_aspect_ratio(request.aspect_ratio, request.scaling_mode) 198 | response = await do_diffusion( 199 | request, pipeline.text_to_image, websocket, height=height, width=width 200 | ) 201 | await websocket.send_json( 202 | {'status': WebSocketResponseStatus.FINISHED, 'result': json.loads(response.json())} 203 | ) 204 | 205 | 206 | @websocket_handler('/image_to_image', app) 207 | async def image_to_image(websocket: WebSocket): 208 | request = ImageToImageRequest(**(await websocket.receive_json())) 209 | source_image = base64url_to_image(request.source_image) 210 | aspect_ratio = source_image.width / source_image.height 211 | size = size_from_aspect_ratio(aspect_ratio, request.scaling_mode) 212 | 213 | source_image = source_image.resize(size) 214 | 215 | with autocast('cuda'): 216 | preprocessed_source_image = preprocess(source_image).to(pipeline.device) 217 | preprocessed_alpha = None 218 | if source_image.mode == 'RGBA': 219 | preprocessed_alpha = 1 - preprocess_mask( 220 | source_image.getchannel('A') 221 | ).to(pipeline.device) 222 | 223 | if preprocessed_alpha is not None and not preprocessed_alpha.any(): 224 | preprocessed_alpha = None 225 | 226 | response = await do_diffusion( 227 | request, 228 | pipeline.image_to_image, 229 | websocket, 230 | init_image=preprocessed_source_image, 231 | strength=request.strength, 232 | alpha=preprocessed_alpha, 233 | ) 234 | await websocket.send_json( 235 | {'status': WebSocketResponseStatus.FINISHED, 'result': json.loads(response.json())} 236 | ) 237 | 238 | 239 | @websocket_handler('/inpainting', app) 240 | async def inpainting(websocket: WebSocket): 241 | request = InpaintingRequest(**(await websocket.receive_json())) 242 | source_image = base64url_to_image(request.source_image) 243 | aspect_ratio = source_image.width / source_image.height 244 | size = size_from_aspect_ratio(aspect_ratio, request.scaling_mode) 245 | 246 | source_image = source_image.resize(size) 247 | mask = None 248 | if request.mask: 249 | mask = base64url_to_image(request.mask).resize(size) 250 | 251 | with autocast('cuda'): 252 | preprocessed_source_image = preprocess(source_image).to(pipeline.device) 253 | preprocessed_mask = None 254 | if mask is not None: 255 | preprocessed_mask = preprocess_mask(mask).to(pipeline.device) 256 | 257 | if preprocessed_mask is not None and not preprocessed_mask.any(): 258 | preprocessed_mask = None 259 | 260 | preprocessed_alpha = None 261 | if source_image.mode == 'RGBA': 262 | preprocessed_alpha = 1 - preprocess_mask( 263 | source_image.getchannel('A') 264 | ).to(pipeline.device) 265 | 266 | if preprocessed_alpha is not None and not preprocessed_alpha.any(): 267 | preprocessed_alpha = None 268 | 269 | if preprocessed_alpha is not None: 270 | if preprocessed_mask is not None: 271 | preprocessed_mask = torch.max(preprocessed_mask, preprocessed_alpha) 272 | else: 273 | preprocessed_mask = preprocessed_alpha 274 | 275 | # TODO return error if mask empty 276 | response = await do_diffusion( 277 | request, 278 | pipeline.image_to_image, 279 | websocket, 280 | init_image=preprocessed_source_image, 281 | strength=request.strength, 282 | mask=preprocessed_mask, 283 | alpha=preprocessed_alpha, 284 | ) 285 | await websocket.send_json( 286 | {'status': WebSocketResponseStatus.FINISHED, 'result': json.loads(response.json())} 287 | ) 288 | 289 | 290 | @websocket_handler('/gobig', app) 291 | async def gobig(websocket: WebSocket): 292 | request = GoBigRequest(**(await websocket.receive_json())) 293 | 294 | if request.seed is not None: 295 | generator = torch.Generator('cuda').manual_seed(request.seed) 296 | else: 297 | generator = None 298 | 299 | async def progress_callback(progress: float): 300 | await websocket.send_json( 301 | {'status': WebSocketResponseStatus.PROGRESS, 'progress': progress} 302 | ) 303 | 304 | upscaled = await do_gobig( 305 | input_image=base64url_to_image(request.image), 306 | prompt=request.prompt, 307 | maximize=request.maximize, 308 | target_width=request.target_width, 309 | target_height=request.target_height, 310 | overlap=request.overlap, 311 | use_real_esrgan=request.use_real_esrgan, 312 | esrgan_model=request.esrgan_model, 313 | pipeline=pipeline, 314 | strength=request.strength, 315 | num_inference_steps=request.num_inference_steps, 316 | guidance_scale=request.guidance_scale, 317 | generator=generator, 318 | progress_callback=progress_callback 319 | ) 320 | response = ImageResponse(image=image_to_base64url(upscaled)) 321 | await websocket.send_json( 322 | {'status': WebSocketResponseStatus.FINISHED, 'result': json.loads(response.json())} 323 | ) 324 | 325 | 326 | @app.post('/upscale') 327 | async def upscale(request: UpscaleRequest) -> ImageResponse: 328 | try: 329 | source_image = base64url_to_image(request.image) 330 | while source_image.width < request.target_width or \ 331 | source_image.height < request.target_height: 332 | source_image = esrgan_upscaler.upscale(image=source_image, model_type=request.model) 333 | 334 | if not request.maximize: 335 | source_image = source_image.resize((request.target_width, request.target_height)) 336 | 337 | return ImageResponse(image=image_to_base64url(source_image)) 338 | except RuntimeError: 339 | raise HTTPException( 340 | status_code=status.HTTP_400_BAD_REQUEST, detail='Scaling factor or image is too large' 341 | ) 342 | 343 | 344 | @app.post('/restore_face') 345 | async def restore_face(request: FaceRestorationRequest) -> ImageResponse: 346 | if request.model_type == GFPGANModel.V1: 347 | raise HTTPException( 348 | status_code=status.HTTP_400_BAD_REQUEST, 349 | detail='GFPGAN v1 model is not supported' 350 | ) 351 | return ImageResponse( 352 | image=face_restoration.restore_face( 353 | image=base64url_to_image(request.image), 354 | model_type=request.model_type, 355 | use_real_esrgan=request.use_real_esrgan, 356 | bg_tile=request.bg_tile, 357 | upscale=request.upscale, 358 | aligned=request.aligned, 359 | only_center_face=request.only_center_face 360 | ) 361 | ) 362 | 363 | 364 | @websocket_handler('/make_tilable', app) 365 | async def make_tilable(websocket: WebSocket): 366 | request = MakeTilableRequest(**(await websocket.receive_json())) 367 | source_image = base64url_to_image(request.source_image) 368 | aspect_ratio = source_image.width / source_image.height 369 | size = size_from_aspect_ratio(aspect_ratio, request.scaling_mode) 370 | horizontal_offset_image = ImageChops.offset(source_image.resize(size), int(size[0] / 2), 0) 371 | 372 | # Horizontal offset 373 | with autocast('cuda'): 374 | preprocessed_horizontal_offset_image = preprocess( 375 | horizontal_offset_image 376 | ).to(pipeline.device) 377 | preprocessed_alpha = None 378 | if horizontal_offset_image.mode == 'RGBA': 379 | preprocessed_alpha = 1 - preprocess_mask( 380 | horizontal_offset_image.getchannel('A') 381 | ).to(pipeline.device) 382 | 383 | if preprocessed_alpha is not None and not preprocessed_alpha.any(): 384 | preprocessed_alpha = None 385 | 386 | gradient_width = request.border_width * request.border_softness 387 | if int(gradient_width) != 0: 388 | gradient_step = 255 / gradient_width 389 | else: 390 | gradient_step = 255 391 | 392 | horizontal_mask = Image.new('L', size, color=0x00) 393 | start_gradient_x = size[0] / 2 - request.border_width 394 | horizontal_draw = ImageDraw.Draw(horizontal_mask) 395 | for i in range(int(gradient_width)): 396 | fill_color = min(int(i * gradient_step), 255) 397 | x = int(start_gradient_x + i) 398 | width = (request.border_width - i) * 2 399 | horizontal_draw.rectangle(((x, 0), (x + width, size[1])), fill=fill_color) 400 | x = int(start_gradient_x + gradient_width) 401 | width = (request.border_width - gradient_width) * 2 402 | horizontal_draw.rectangle(((x, 0), (x + width, size[1])), fill=255) 403 | horizontal_preprocessed_mask = preprocess_mask(horizontal_mask).to(pipeline.device) 404 | 405 | horizontal_offset_result = await do_diffusion( 406 | request, 407 | pipeline.image_to_image, 408 | websocket, 409 | return_images=True, 410 | progress_multiplier=1/3, 411 | init_image=preprocessed_horizontal_offset_image, 412 | mask=horizontal_preprocessed_mask, 413 | strength=request.strength, 414 | alpha=preprocessed_alpha, 415 | ) 416 | '''horizontal_offset_result = [ 417 | Image.composite(image, horizontal_offset_image, horizontal_mask) 418 | for image in horizontal_offset_result 419 | ]''' 420 | 421 | # Vertical offset 422 | with autocast('cuda'): 423 | preprocessed_vertical_offset_images = [] 424 | for image in horizontal_offset_result: 425 | vertical_offset_image = ImageChops.offset(image, 0, int(size[1] / 2)) 426 | preprocessed_vertical_offset_images.append( 427 | preprocess(vertical_offset_image).to(pipeline.device) 428 | ) 429 | 430 | vertical_mask = Image.new('L', size, color=0x00) 431 | start_gradient_y = size[1] / 2 - request.border_width 432 | vertical_draw = ImageDraw.Draw(vertical_mask) 433 | for i in range(int(gradient_width)): 434 | fill_color = min(int(i * gradient_step), 255) 435 | y = int(start_gradient_y + i) 436 | height = (request.border_width - i) * 2 437 | vertical_draw.rectangle(((0, y), (size[0], y + height)), fill=fill_color) 438 | 439 | y = int(start_gradient_y + gradient_width) 440 | height = (request.border_width - gradient_width) * 2 441 | vertical_draw.rectangle(((0, y), (size[0], y + height)), fill=255) 442 | vertical_preprocessed_mask = preprocess_mask(vertical_mask).to(pipeline.device) 443 | 444 | vertical_offset_result = await do_diffusion( 445 | request, 446 | pipeline.image_to_image, 447 | websocket, 448 | return_images=True, 449 | progress_multiplier=1/3, 450 | progress_offset=1/3, 451 | batched_params={'init_image': preprocessed_vertical_offset_images}, 452 | strength=request.strength, 453 | mask=vertical_preprocessed_mask, 454 | ) 455 | 456 | # Center 457 | with autocast('cuda'): 458 | preprocessed_center_offset_images = [] 459 | for image in vertical_offset_result: 460 | center_offset_image = ImageChops.offset(image, -int(size[0] / 2), 0) 461 | preprocessed_center_offset_images.append( 462 | preprocess(center_offset_image).to(pipeline.device) 463 | ) 464 | 465 | center_mask = Image.new('L', size, color=0x00) 466 | center_draw = ImageDraw.Draw(center_mask) 467 | for i in range(int(gradient_width)): 468 | fill_color = min(int(i * gradient_step), 255) 469 | y = int(start_gradient_y + i) 470 | x = int(start_gradient_x + i) 471 | offset = (request.border_width - i) * 2 472 | center_draw.rectangle(((x, y), (x + offset, y + offset)), fill=fill_color) 473 | 474 | y = int(start_gradient_y + gradient_width) 475 | x = int(start_gradient_x + gradient_width) 476 | offset = (request.border_width - gradient_width) * 2 477 | center_draw.rectangle(((x, y), (x + offset, y + offset)), fill=255) 478 | center_preprocessed_mask = preprocess_mask(center_mask).to(pipeline.device) 479 | 480 | images = await do_diffusion( 481 | request, 482 | pipeline.image_to_image, 483 | websocket, 484 | return_images=True, 485 | progress_multiplier=1 / 3, 486 | progress_offset=2 / 3, 487 | batched_params={'init_image': preprocessed_center_offset_images}, 488 | strength=request.strength, 489 | mask=center_preprocessed_mask, 490 | ) 491 | 492 | response = MakeTilableResponse( 493 | images=[ 494 | ImageChops.offset(image, 0, -int(size[1] / 2)) for image in images 495 | ], 496 | mask=ImageChops.lighter( 497 | ImageChops.offset(horizontal_mask, -int(size[0] / 2), 0), 498 | ImageChops.offset(vertical_mask, 0, -int(size[1] / 2)), 499 | ) 500 | ) 501 | await websocket.send_json( 502 | {'status': WebSocketResponseStatus.FINISHED, 'result': json.loads(response.json())} 503 | ) 504 | 505 | 506 | @app.get('/ping') 507 | async def ping(): 508 | return 509 | 510 | 511 | async def setup(): 512 | await download_models(face_restoration.GFPGAN_URLS + esrgan_upscaler.ESRGAN_URLS) 513 | 514 | 515 | if __name__ == '__main__': 516 | asyncio.run(setup()) 517 | uvicorn.run(app, host=settings.HOST, port=settings.PORT, log_config=LOGGING) 518 | -------------------------------------------------------------------------------- /pyinstaller_hooks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qweryty/image-ai-utils-server/c439e2df0c2e010af8b5ff35f8696fadc250c0f5/pyinstaller_hooks/__init__.py -------------------------------------------------------------------------------- /pyinstaller_hooks/hook-diffusers.py: -------------------------------------------------------------------------------- 1 | from PyInstaller.utils.hooks import collect_all 2 | from diffusers.dependency_versions_table import deps 3 | 4 | 5 | def hook(hook_api): 6 | packages = deps.keys() 7 | print('---------------------- IMPORTING DIFFUSERS DEPS ----------------------') 8 | print(list(packages)) 9 | 10 | for package in packages: 11 | datas, _, hidden_imports = collect_all(package) 12 | hook_api.add_datas(datas) 13 | hook_api.add_imports(*hidden_imports) 14 | -------------------------------------------------------------------------------- /pyinstaller_hooks/hook-transformers.py: -------------------------------------------------------------------------------- 1 | from PyInstaller.utils.hooks import collect_all 2 | from transformers.dependency_versions_table import deps 3 | 4 | 5 | def hook(hook_api): 6 | packages = deps.keys() 7 | print('---------------------- IMPORTING TRANSFORMERS DEPS ----------------------') 8 | print(list(packages)) 9 | 10 | for package in packages: 11 | datas, binaries, hidden_imports = collect_all(package) 12 | hook_api.add_datas(datas) 13 | #hook_api.add_binaries(binaries) 14 | hook_api.add_imports(*hidden_imports) 15 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "image-ai-utils-server" 3 | version = "0.0.3" 4 | description = "" 5 | authors = ["Sergey Morozov "] 6 | 7 | [tool.poetry.dependencies] 8 | python = ">=3.9,<3.10" # 3.9 required for gfpgan which requires older version of numpy 9 | diffusers = "^0.3.0" 10 | transformers = "^4.21.1" 11 | scipy = "^1.9.0" 12 | fastapi = "^0.79.1" 13 | gunicorn = "^20.1.0" 14 | uvicorn = "^0.18.2" 15 | torch = { version="^1.12.1", source="torch" } 16 | pydantic = "^1.9.2" 17 | Pillow = "^9.2.0" 18 | gfpgan = "^1.3.4" 19 | realesrgan = "^0.2.5.0" 20 | numpy = "1.20.3" 21 | python-dotenv = "^0.20.0" 22 | colorlog = "^6.7.0" 23 | websockets = "^10.3" 24 | httpx = "^0.23.0" 25 | aiofiles = "^22.1.0" 26 | 27 | [tool.poetry.dev-dependencies] 28 | snakeviz = "^2.1.1" 29 | 30 | [[tool.poetry.source]] 31 | name = "torch" 32 | url = "https://download.pytorch.org/whl/cu116" 33 | secondary = true 34 | 35 | [build-system] 36 | requires = ["poetry-core>=1.0.0"] 37 | build-backend = "poetry.core.masonry.api" 38 | -------------------------------------------------------------------------------- /request_models.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | from PIL import Image 4 | from pydantic import BaseModel, Field, validator 5 | 6 | from consts import ImageFormat, MIN_SEED, MAX_SEED, ScalingMode, ESRGANModel, GFPGANModel 7 | from utils import image_to_base64url 8 | 9 | 10 | class BaseDiffusionRequest(BaseModel): 11 | prompt: str = Field(...) 12 | output_format: ImageFormat = ImageFormat.PNG 13 | num_inference_steps: int = Field(50, gt=0) 14 | guidance_scale: float = Field(7.5) 15 | seed: Optional[int] = Field(None, ge=MIN_SEED, le=MAX_SEED) 16 | batch_size: int = Field(7, gt=0) 17 | try_smaller_batch_on_fail: bool = True 18 | 19 | 20 | class BaseImageGenerationRequest(BaseDiffusionRequest): 21 | num_variants: int = Field(4, gt=0) 22 | scaling_mode: ScalingMode = ScalingMode.GROW 23 | 24 | 25 | class TextToImageRequest(BaseImageGenerationRequest): 26 | aspect_ratio: float = Field(1., gt=0) # width/height 27 | 28 | 29 | class ImageToImageRequest(BaseImageGenerationRequest): 30 | source_image: bytes 31 | strength: float = Field(0.8, ge=0, le=1) 32 | 33 | 34 | class InpaintingRequest(ImageToImageRequest): 35 | mask: Optional[bytes] = None 36 | 37 | 38 | class GoBigRequest(BaseDiffusionRequest): 39 | image: bytes 40 | use_real_esrgan: bool = True 41 | esrgan_model: ESRGANModel = ESRGANModel.GENERAL_X4_V3 42 | maximize: bool = True 43 | strength: float = Field(.5, ge=0, le=1) 44 | target_width: int = Field(..., gt=0) 45 | target_height: int = Field(..., gt=0) 46 | overlap: int = Field(64, gt=0, lt=512) 47 | 48 | 49 | class MakeTilableRequest(BaseImageGenerationRequest): 50 | source_image: bytes 51 | border_width: int = Field(50, gt=0, lt=256) 52 | border_softness: float = Field(.5, ge=0, le=1) 53 | strength: float = Field(0.8, ge=0, le=1) 54 | 55 | 56 | class UpscaleRequest(BaseModel): 57 | image: bytes 58 | model: ESRGANModel = ESRGANModel.GENERAL_X4_V3 59 | target_width: int = Field(..., gt=0) 60 | target_height: int = Field(..., gt=0) 61 | maximize: bool = True 62 | 63 | 64 | class ImageArrayResponse(BaseModel): 65 | images: List[bytes] 66 | 67 | @validator('images', pre=True) 68 | def images_to_bytes(cls, v: List): 69 | return [ 70 | image_to_base64url(image) if isinstance(image, Image.Image) else image for image in v 71 | ] 72 | 73 | 74 | class MakeTilableResponse(ImageArrayResponse): 75 | mask: bytes 76 | 77 | @validator('mask', pre=True) 78 | def mask_to_bytes(cls, v): 79 | if isinstance(v, Image.Image): 80 | v = image_to_base64url(v) 81 | return v 82 | 83 | 84 | class ImageResponse(BaseModel): 85 | image: bytes 86 | 87 | @validator('image', pre=True) 88 | def image_to_bytes(cls, v): 89 | if isinstance(v, Image.Image): 90 | v = image_to_base64url(v) 91 | return v 92 | 93 | 94 | class FaceRestorationRequest(BaseModel): 95 | image: bytes 96 | model_type: GFPGANModel 97 | use_real_esrgan: bool = True 98 | bg_tile: int = 400 99 | upscale: int = 2 100 | aligned: bool = False 101 | only_center_face: bool = False 102 | -------------------------------------------------------------------------------- /settings.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from typing import Optional 3 | 4 | from diffusers.utils import DIFFUSERS_CACHE 5 | from dotenv import load_dotenv 6 | from pydantic import BaseSettings, Field, validator 7 | 8 | from utils import resolve_path 9 | 10 | 11 | class Settings(BaseSettings): 12 | USERNAME: str = Field(..., env='IMAGE_AI_UTILS_USERNAME') 13 | PASSWORD: str = Field(..., env='IMAGE_AI_UTILS_PASSWORD') 14 | HOST: str = Field('0.0.0.0', env='HOST') 15 | PORT: int = Field(7331, env='PORT') 16 | LOG_LEVEL: str = Field('DEBUG', env='LOG_LEVEL') 17 | FILE_LOG_LEVEL: str = Field('ERROR', env='FILE_LOG_LEVEL') 18 | DIFFUSERS_CACHE_PATH: str = Field(DIFFUSERS_CACHE, env='DIFFUSERS_CACHE_PATH') 19 | USE_OPTIMIZED_MODE: bool = Field(True, env='USE_OPTIMIZED_MODE') 20 | 21 | # TODO make abspath from current file 22 | @validator('DIFFUSERS_CACHE_PATH') 23 | def make_abspath(cls, path: Optional[str]) -> Optional[str]: 24 | if path is None or os.path.isabs(path): 25 | return path 26 | 27 | return resolve_path(path) 28 | 29 | 30 | load_dotenv() 31 | settings = Settings() 32 | -------------------------------------------------------------------------------- /universal_pipeline.py: -------------------------------------------------------------------------------- 1 | # This file was created by contributors of Diffusers library. 2 | # The original code can be found here: 3 | # https://github.com/huggingface/diffusers/blob/main/examples/inference/image_to_image.py 4 | import inspect 5 | from typing import List, Optional, Union, Callable, Awaitable 6 | 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | from diffusers import AutoencoderKL, DDIMScheduler, PNDMScheduler, \ 11 | UNet2DConditionModel, LMSDiscreteScheduler, DiffusionPipeline 12 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor 13 | 14 | 15 | def preprocess(image: Image.Image) -> torch.FloatTensor: 16 | image = image.convert('RGB') 17 | w, h = image.size 18 | w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 19 | image = image.resize((w, h), resample=Image.LANCZOS) 20 | image = np.array(image).astype(np.float32) / 255.0 21 | image = image[None].transpose(0, 3, 1, 2) 22 | image = torch.from_numpy(image) 23 | return 2.0 * image - 1.0 24 | 25 | 26 | def preprocess_mask(mask: Image.Image) -> torch.FloatTensor: 27 | mask = mask.convert('L') 28 | w, h = mask.size 29 | w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 30 | if w < h: 31 | h = int(h / (w / 64)) 32 | w = 64 33 | else: 34 | w = int(w / (h / 64)) 35 | h = 64 36 | 37 | mask = mask.resize((w, h), resample=Image.LANCZOS) 38 | mask = np.array(mask).astype(np.float32) / 255.0 39 | mask = torch.from_numpy(mask) 40 | return mask 41 | 42 | 43 | def mask_overlay( 44 | first: torch.FloatTensor, second: torch.FloatTensor, mask: torch.FloatTensor 45 | ) -> torch.FloatTensor: 46 | return first * (1 - mask) + second * mask 47 | 48 | 49 | class StableDiffusionUniversalPipeline(DiffusionPipeline): 50 | vae: AutoencoderKL 51 | text_encoder: CLIPTextModel 52 | tokenizer: CLIPTokenizer 53 | unet: UNet2DConditionModel 54 | scheduler: Union[DDIMScheduler, PNDMScheduler] 55 | 56 | def __init__( 57 | self, 58 | vae: AutoencoderKL, 59 | text_encoder: CLIPTextModel, 60 | tokenizer: CLIPTokenizer, 61 | unet: UNet2DConditionModel, 62 | scheduler: Union[DDIMScheduler, PNDMScheduler], 63 | feature_extractor: CLIPFeatureExtractor, 64 | ): 65 | super().__init__() 66 | scheduler = scheduler.set_format('pt') 67 | self.register_modules( 68 | vae=vae, 69 | text_encoder=text_encoder, 70 | tokenizer=tokenizer, 71 | unet=unet, 72 | scheduler=scheduler, 73 | feature_extractor=feature_extractor, 74 | ) 75 | 76 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 77 | r""" 78 | Enable sliced attention computation. 79 | 80 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention 81 | in several steps. This is useful to save some memory in exchange for a small speed decrease. 82 | 83 | Args: 84 | slice_size (`str` or `int`, *optional*, defaults to `"auto"`): 85 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If 86 | a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, 87 | `attention_head_dim` must be a multiple of `slice_size`. 88 | """ 89 | if slice_size == "auto": 90 | # half the attention head size is usually a good trade-off between 91 | # speed and memory 92 | slice_size = self.unet.config.attention_head_dim // 2 93 | self.unet.set_attention_slice(slice_size) 94 | 95 | def disable_attention_slicing(self): 96 | r""" 97 | Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go 98 | back to computing attention in one step. 99 | """ 100 | # set slice_size = `None` to disable `attention slicing` 101 | self.enable_attention_slicing(None) 102 | 103 | def _scale_and_encode(self, image: torch.FloatTensor, generator: Optional[torch.Generator]): 104 | latent_dist = self.vae.encode(image).latent_dist 105 | return 0.18215 * latent_dist.sample(generator=generator) 106 | 107 | def _scale_and_decode(self, latents): 108 | return self.vae.decode(1 / 0.18215 * latents).sample 109 | 110 | async def text_to_image( 111 | self, 112 | prompt: Union[str, List[str]], 113 | height: Optional[int] = 512, 114 | width: Optional[int] = 512, 115 | num_inference_steps: Optional[int] = 50, 116 | guidance_scale: Optional[float] = 7.5, 117 | eta: Optional[float] = 0.0, 118 | generator: Optional[torch.Generator] = None, 119 | latents: Optional[torch.FloatTensor] = None, 120 | progress_callback: Optional[Callable[[int, int], Awaitable]] = None 121 | ) -> List[Image.Image]: 122 | if isinstance(prompt, str): 123 | batch_size = 1 124 | elif isinstance(prompt, list): 125 | batch_size = len(prompt) 126 | else: 127 | raise ValueError(f'`prompt` has to be of type `str` or `list` but is {type(prompt)}') 128 | 129 | if height % 8 != 0 or width % 8 != 0: 130 | raise ValueError( 131 | f'`height` and `width` have to be divisible by 8 but are {height} and {width}.') 132 | 133 | # get prompt text embeddings 134 | text_input = self.tokenizer( 135 | prompt, 136 | padding='max_length', 137 | max_length=self.tokenizer.model_max_length, 138 | truncation=True, 139 | return_tensors='pt', 140 | ) 141 | text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] 142 | 143 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 144 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 145 | # corresponds to doing no classifier free guidance. 146 | do_classifier_free_guidance = guidance_scale > 1.0 147 | # get unconditional embeddings for classifier free guidance 148 | if do_classifier_free_guidance: 149 | max_length = text_input.input_ids.shape[-1] 150 | uncond_input = self.tokenizer( 151 | [''] * batch_size, padding='max_length', max_length=max_length, return_tensors='pt' 152 | ) 153 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] 154 | 155 | # For classifier free guidance, we need to do two forward passes. 156 | # Here we concatenate the unconditional and text embeddings into a single batch 157 | # to avoid doing two forward passes 158 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 159 | 160 | # get the initial random noise unless the user supplied it 161 | latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) 162 | if latents is None: 163 | latents = torch.randn( 164 | latents_shape, 165 | generator=generator, 166 | device=self.device, 167 | ) 168 | else: 169 | if latents.shape != latents_shape: 170 | raise ValueError( 171 | f'Unexpected latents shape, got {latents.shape}, expected {latents_shape}') 172 | latents = latents.to(self.device) 173 | 174 | # set timesteps 175 | accepts_offset = 'offset' in set( 176 | inspect.signature(self.scheduler.set_timesteps).parameters.keys()) 177 | extra_set_kwargs = {} 178 | if accepts_offset: 179 | extra_set_kwargs['offset'] = 1 180 | 181 | self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) 182 | 183 | # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas 184 | if isinstance(self.scheduler, LMSDiscreteScheduler): 185 | latents = latents * self.scheduler.sigmas[0] 186 | 187 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 188 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 189 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 190 | # and should be between [0, 1] 191 | accepts_eta = 'eta' in set(inspect.signature(self.scheduler.step).parameters.keys()) 192 | extra_step_kwargs = {} 193 | if accepts_eta: 194 | extra_step_kwargs['eta'] = eta 195 | 196 | for i, t in enumerate(self.scheduler.timesteps): 197 | if progress_callback is not None: 198 | await progress_callback(i, len(self.scheduler.timesteps)) 199 | # expand the latents if we are doing classifier free guidance 200 | latent_model_input = torch.cat( 201 | [latents] * 2) if do_classifier_free_guidance else latents 202 | if isinstance(self.scheduler, LMSDiscreteScheduler): 203 | sigma = self.scheduler.sigmas[i] 204 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS 205 | latent_model_input = latent_model_input / ((sigma ** 2 + 1) ** 0.5) 206 | 207 | # predict the noise residual 208 | noise_pred = self.unet( 209 | latent_model_input, t, encoder_hidden_states=text_embeddings 210 | )['sample'] 211 | 212 | # perform guidance 213 | if do_classifier_free_guidance: 214 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 215 | noise_pred = noise_pred_uncond + guidance_scale * ( 216 | noise_pred_text - noise_pred_uncond) 217 | 218 | # compute the previous noisy sample x_t -> x_t-1 219 | if isinstance(self.scheduler, LMSDiscreteScheduler): 220 | latents = self.scheduler.step( 221 | noise_pred, i, latents, **extra_step_kwargs 222 | )['prev_sample'] 223 | else: 224 | latents = self.scheduler.step( 225 | noise_pred, t, latents, **extra_step_kwargs 226 | )['prev_sample'] 227 | 228 | # scale and decode the image latents with vae 229 | image = self._scale_and_decode(latents) 230 | 231 | image = (image / 2 + 0.5).clamp(0, 1) 232 | image = image.cpu().permute(0, 2, 3, 1).numpy() 233 | 234 | image = self.numpy_to_pil(image) 235 | 236 | return image 237 | 238 | async def image_to_image( 239 | self, 240 | prompt: Union[str, List[str]], 241 | init_image: Union[torch.FloatTensor, List[torch.FloatTensor]], 242 | mask: Optional[torch.FloatTensor] = None, 243 | alpha: Optional[torch.FloatTensor] = None, 244 | strength: float = 0.8, 245 | num_inference_steps: Optional[int] = 50, 246 | guidance_scale: Optional[float] = 7.5, 247 | eta: Optional[float] = 0.0, 248 | generator: Optional[torch.Generator] = None, 249 | progress_callback: Optional[Callable[[int, int], Awaitable]] = None 250 | ) -> List[Image.Image]: 251 | if isinstance(prompt, str): 252 | batch_size = 1 253 | elif isinstance(prompt, list): 254 | batch_size = len(prompt) 255 | else: 256 | raise ValueError(f'`prompt` has to be of type `str` or `list` but is {type(prompt)}') 257 | 258 | if strength < 0 or strength > 1: 259 | raise ValueError(f'The value of strength should in [0.0, 1.0] but is {strength}') 260 | 261 | if isinstance(init_image, list) and len(init_image) != batch_size: 262 | raise ValueError( 263 | f'Length of list of init images({len(init_image)}) ' 264 | f'should be equal to batch_size({batch_size})' 265 | ) 266 | 267 | # set timesteps 268 | accepts_offset = 'offset' in set( 269 | inspect.signature(self.scheduler.set_timesteps).parameters.keys() 270 | ) 271 | extra_set_kwargs = {} 272 | offset = 0 273 | if accepts_offset: 274 | offset = 1 275 | extra_set_kwargs['offset'] = 1 276 | 277 | self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) 278 | 279 | # encode the init image into latents and scale the latents 280 | if isinstance(init_image, list): 281 | init_latents = [] 282 | for image in init_image: 283 | latents = self._scale_and_encode(image, generator) 284 | if alpha is not None: 285 | # Replacing transparent area with noise 286 | latents = mask_overlay( 287 | latents, 288 | torch.randn(latents.shape, generator=generator, device=self.device), 289 | alpha 290 | ) 291 | init_latents.append(latents) 292 | 293 | init_latents = torch.cat(init_latents) 294 | else: 295 | init_latents = self._scale_and_encode(init_image, generator) 296 | if alpha is not None: 297 | # Replacing transparent area with noise 298 | init_latents = mask_overlay( 299 | init_latents, 300 | torch.randn(init_latents.shape, generator=generator, device=self.device), 301 | alpha 302 | ) 303 | 304 | # Expand init_latents for batch_size 305 | init_latents = torch.cat([init_latents] * batch_size) 306 | 307 | init_latents_orig = init_latents 308 | 309 | # get the original timestep using init_timestep 310 | init_timestep = int(num_inference_steps * strength) + offset 311 | init_timestep = min(init_timestep, num_inference_steps) 312 | timesteps = self.scheduler.timesteps[-init_timestep] 313 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) 314 | 315 | # add noise to latents using the timesteps 316 | noise = torch.randn(init_latents.shape, generator=generator, device=self.device) 317 | init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) 318 | 319 | # get prompt text embeddings 320 | text_input = self.tokenizer( 321 | prompt, 322 | padding='max_length', 323 | max_length=self.tokenizer.model_max_length, 324 | truncation=True, 325 | return_tensors='pt', 326 | ) 327 | text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] 328 | 329 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 330 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 331 | # corresponds to doing no classifier free guidance. 332 | do_classifier_free_guidance = guidance_scale > 1.0 333 | # get unconditional embeddings for classifier free guidance 334 | if do_classifier_free_guidance: 335 | max_length = text_input.input_ids.shape[-1] 336 | uncond_input = self.tokenizer( 337 | [''] * batch_size, padding='max_length', max_length=max_length, return_tensors='pt' 338 | ) 339 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] 340 | 341 | # For classifier free guidance, we need to do two forward passes. 342 | # Here we concatenate the unconditional and text embeddings into a single batch 343 | # to avoid doing two forward passes 344 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 345 | 346 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 347 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 348 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 349 | # and should be between [0, 1] 350 | accepts_eta = 'eta' in set(inspect.signature(self.scheduler.step).parameters.keys()) 351 | extra_step_kwargs = {} 352 | if accepts_eta: 353 | extra_step_kwargs['eta'] = eta 354 | 355 | latents = init_latents 356 | t_start = max(num_inference_steps - init_timestep + offset, 0) 357 | time_steps = self.scheduler.timesteps[t_start:] 358 | for i, t in enumerate(time_steps): 359 | if progress_callback is not None: 360 | await progress_callback(i, len(time_steps)) 361 | # expand the latents if we are doing classifier free guidance 362 | if do_classifier_free_guidance: 363 | latent_model_input = torch.cat([latents] * 2) 364 | else: 365 | latent_model_input = latents 366 | 367 | # predict the noise residual 368 | noise_pred = self.unet( 369 | latent_model_input, t, encoder_hidden_states=text_embeddings 370 | ).sample 371 | 372 | # perform guidance 373 | if do_classifier_free_guidance: 374 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 375 | noise_pred = noise_pred_uncond + guidance_scale * ( 376 | noise_pred_text - noise_pred_uncond 377 | ) 378 | 379 | # compute the previous noisy sample x_t -> x_t-1 380 | latents = self.scheduler.step( 381 | noise_pred, t, latents, **extra_step_kwargs 382 | ).prev_sample 383 | 384 | if mask is not None: 385 | init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) 386 | latents = mask_overlay(init_latents_proper, latents, mask) 387 | 388 | # scale and decode the image latents with vae 389 | image = self._scale_and_decode(latents) 390 | 391 | image = (image / 2 + 0.5).clamp(0, 1) 392 | image = image.cpu().permute(0, 2, 3, 1).numpy() 393 | image = self.numpy_to_pil(image) 394 | 395 | return image 396 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import mimetypes 4 | import os 5 | from base64 import b64decode, b64encode 6 | from io import BytesIO 7 | from typing import Tuple, List 8 | 9 | import aiofiles 10 | import httpx 11 | from PIL import Image 12 | 13 | # TODO use common utils 14 | from tqdm import tqdm 15 | 16 | from consts import ScalingMode 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def image_to_base64url(image: Image.Image, output_format: str = 'PNG') -> bytes: 22 | data_string = f'data:{mimetypes.types_map[f".{output_format.lower()}"]};base64,'.encode() 23 | buffer = BytesIO() 24 | image.save(buffer, format=output_format) 25 | return data_string + b64encode(buffer.getvalue()) 26 | 27 | 28 | def base64url_to_image(source: bytes) -> Image.Image: 29 | _, data = source.split(b',') 30 | return Image.open(BytesIO(b64decode(data))) 31 | 32 | 33 | def size_from_aspect_ratio(aspect_ratio: float, scaling_mode: ScalingMode) -> Tuple[int, int]: 34 | if (scaling_mode == ScalingMode.GROW) == (aspect_ratio > 1): 35 | height = 512 36 | width = int(height * aspect_ratio) 37 | width -= width % 64 38 | else: 39 | width = 512 40 | height = int(width / aspect_ratio) 41 | height -= height % 64 42 | return width, height 43 | 44 | 45 | def resolve_path(path: str) -> str: 46 | return os.path.abspath(os.path.join(os.path.dirname(__file__), path)) 47 | 48 | 49 | async def download_models(models: List[str]): 50 | models_dir = resolve_path('models') 51 | os.makedirs(models_dir, exist_ok=True) 52 | 53 | async def download(url: str, position: int): 54 | file_name = url.split('/')[-1] 55 | file_path = os.path.join(models_dir, file_name) 56 | if os.path.exists(file_path): 57 | return 58 | 59 | logger.info(f'Downloading {file_path} {position}') 60 | async with aiofiles.open(file_path, mode='wb') as f: 61 | async with httpx.AsyncClient() as client: 62 | async with client.stream('GET', url, follow_redirects=True) as response: 63 | response.raise_for_status() 64 | total = int(response.headers['Content-Length']) 65 | with tqdm( 66 | desc=f'Downloading {file_name}', 67 | total=total, 68 | unit_scale=True, 69 | unit_divisor=1024, 70 | unit='B', 71 | position=position 72 | ) as progress: 73 | num_bytes_downloaded = response.num_bytes_downloaded 74 | async for chunk in response.aiter_bytes(): 75 | await f.write(chunk) 76 | progress.update(response.num_bytes_downloaded - num_bytes_downloaded) 77 | num_bytes_downloaded = response.num_bytes_downloaded 78 | 79 | await asyncio.gather(*[download(model, i) for i, model in enumerate(models)]) 80 | --------------------------------------------------------------------------------