├── .gitignore ├── Dockerfile ├── LICENSE ├── MANIFEST.in ├── Makefile ├── OTHER_LICENSES ├── README.md ├── diffuzers.ipynb ├── diffuzers ├── Home.py ├── __init__.py ├── api │ ├── __init__.py │ ├── main.py │ ├── schemas.py │ └── utils.py ├── blip.py ├── cli │ ├── __init__.py │ ├── main.py │ ├── run_api.py │ └── run_app.py ├── clip_interrogator.py ├── data │ ├── artists.txt │ ├── flavors.txt │ ├── mediums.txt │ └── movements.txt ├── gfp_gan.py ├── gradio_app.py ├── image_info.py ├── img2img.py ├── inpainting.py ├── interrogator.py ├── pages │ ├── 1_Inpainting.py │ ├── 2_Utilities.py │ ├── 3_FAQs.py │ └── 4_Code of Conduct.py ├── text2img.py ├── textual_inversion.py ├── upscaler.py ├── utils.py └── x2image.py ├── docs ├── Makefile ├── conf.py ├── index.rst └── make.bat ├── requirements.txt ├── setup.cfg ├── setup.py └── static ├── .keep ├── screenshot.jpeg ├── screenshot.png └── screenshot_st.png /.gitignore: -------------------------------------------------------------------------------- 1 | # local stuff 2 | .vscode/ 3 | examples/.logs/ 4 | *.bin 5 | *.csv 6 | input/ 7 | models/ 8 | logs/ 9 | gfpgan/ 10 | *.pkl 11 | *.pt 12 | *.pth 13 | abhishek/ 14 | diffout/ 15 | 16 | # Byte-compiled / optimized / DLL files 17 | __pycache__/ 18 | *.py[cod] 19 | *$py.class 20 | 21 | # C extensions 22 | *.so 23 | 24 | # Distribution / packaging 25 | .Python 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | pip-wheel-metadata/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | src/ 44 | 45 | # PyInstaller 46 | # Usually these files are written by a python script from a template 47 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 48 | *.manifest 49 | *.spec 50 | 51 | # Installer logs 52 | pip-log.txt 53 | pip-delete-this-directory.txt 54 | 55 | # Unit test / coverage reports 56 | htmlcov/ 57 | .tox/ 58 | .nox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *.cover 65 | *.py,cover 66 | .hypothesis/ 67 | .pytest_cache/ 68 | 69 | # Translations 70 | *.mo 71 | *.pot 72 | 73 | # Django stuff: 74 | *.log 75 | local_settings.py 76 | db.sqlite3 77 | db.sqlite3-journal 78 | 79 | # Flask stuff: 80 | instance/ 81 | .webassets-cache 82 | 83 | # Scrapy stuff: 84 | .scrapy 85 | 86 | # Sphinx documentation 87 | docs/_build/ 88 | 89 | # PyBuilder 90 | target/ 91 | 92 | # Jupyter Notebook 93 | .ipynb_checkpoints 94 | 95 | # IPython 96 | profile_default/ 97 | ipython_config.py 98 | 99 | # pyenv 100 | .python-version 101 | 102 | # pipenv 103 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 104 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 105 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 106 | # install all needed dependencies. 107 | #Pipfile.lock 108 | 109 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 110 | __pypackages__/ 111 | 112 | # Celery stuff 113 | celerybeat-schedule 114 | celerybeat.pid 115 | 116 | # SageMath parsed files 117 | *.sage.py 118 | 119 | # Environments 120 | .env 121 | .venv 122 | env/ 123 | venv/ 124 | ENV/ 125 | env.bak/ 126 | venv.bak/ 127 | 128 | # Spyder project settings 129 | .spyderproject 130 | .spyproject 131 | 132 | # Rope project settings 133 | .ropeproject 134 | 135 | # mkdocs documentation 136 | /site 137 | 138 | # mypy 139 | .mypy_cache/ 140 | .dmypy.json 141 | dmypy.json 142 | 143 | # Pyre type checker 144 | .pyre/ 145 | 146 | # vscode 147 | .vscode/ 148 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04 2 | 3 | ARG DEBIAN_FRONTEND=noninteractive 4 | ENV PATH="${HOME}/miniconda3/bin:${PATH}" 5 | ARG PATH="${HOME}/miniconda3/bin:${PATH}" 6 | 7 | RUN apt-get update && \ 8 | apt-get upgrade -y && \ 9 | apt-get install -y \ 10 | build-essential \ 11 | cmake \ 12 | curl \ 13 | ca-certificates \ 14 | gcc \ 15 | locales \ 16 | wget \ 17 | git \ 18 | git-lfs \ 19 | ffmpeg \ 20 | libsm6 \ 21 | libxext6 \ 22 | && rm -rf /var/lib/apt/lists/* 23 | 24 | RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash && \ 25 | git lfs install 26 | 27 | WORKDIR /app 28 | ENV HOME=/app 29 | 30 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ 31 | && bash Miniconda3-latest-Linux-x86_64.sh -b -p /app/miniconda \ 32 | && rm -f Miniconda3-latest-Linux-x86_64.sh 33 | ENV PATH /app/miniconda/bin:$PATH 34 | 35 | RUN conda create -p /app/env -y python=3.8 36 | 37 | SHELL ["conda", "run","--no-capture-output", "-p","/app/env", "/bin/bash", "-c"] 38 | 39 | RUN conda install pytorch torchvision pytorch-cuda=11.7 -c pytorch -c nvidia 40 | RUN pip install https://github.com/abhishekkrthakur/xformers/raw/main/xformers-0.0.15%2B7e05e2c.d20221223-cp38-cp38-linux_x86_64.whl 41 | 42 | COPY requirements.txt /app/requirements.txt 43 | RUN pip install -U --no-cache-dir -r /app/requirements.txt 44 | 45 | COPY . /app 46 | RUN cd /app && pip install -e . 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include diffuzers/pages/*.py 2 | include diffuzers/data/*.txt 3 | include OTHER_LICENSES -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: quality style 2 | 3 | quality: 4 | python -m black --check --line-length 119 --target-version py38 . 5 | python -m isort --check-only . 6 | python -m flake8 --max-line-length 119 . 7 | 8 | style: 9 | python -m black --line-length 119 --target-version py38 . 10 | python -m isort . -------------------------------------------------------------------------------- /OTHER_LICENSES: -------------------------------------------------------------------------------- 1 | clip_interrogator licence 2 | for diffuzes/clip_interrogator.py 3 | for diffuzes/data/*.txt 4 | ------------------------- 5 | MIT License 6 | 7 | Copyright (c) 2022 pharmapsychotic 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in all 17 | copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | SOFTWARE. 26 | 27 | blip licence 28 | for diffuzes/blip.py 29 | ------------ 30 | Copyright (c) 2022, Salesforce.com, Inc. 31 | All rights reserved. 32 | 33 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 34 | 35 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 36 | 37 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 38 | 39 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 40 | 41 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # diffuzers 2 | 3 | A web ui and deployable API for [🤗 diffusers](https://github.com/huggingface/diffusers). 4 | 5 | < under development, request features using issues, prs not accepted atm > 6 | 7 | 8 | Open In Colab 9 | 10 | 11 | 12 | Documentation Status 13 | 14 | 15 | ![image](https://github.com/abhishekkrthakur/diffuzers/raw/main/static/screenshot.jpeg) 16 | 17 | 18 | If something doesnt work as expected, or if you need some features which are not available, then create request using [github issues](https://github.com/abhishekkrthakur/diffuzers/issues) 19 | 20 | 21 | ## Features available in the app: 22 | 23 | - text to image 24 | - image to image 25 | - instruct pix2pix 26 | - textual inversion 27 | - inpainting 28 | - outpainting (coming soon) 29 | - image info 30 | - stable diffusion upscaler 31 | - gfpgan 32 | - clip interrogator 33 | - more coming soon! 34 | 35 | ## Features available in the api: 36 | 37 | - text to image 38 | - image to image 39 | - instruct pix2pix 40 | - textual inversion 41 | - inpainting 42 | - outpainting (via inpainting) 43 | - more coming soon! 44 | 45 | 46 | ## Installation 47 | 48 | To install bleeding edge version of diffuzers, clone the repo and install it using pip. 49 | 50 | ```bash 51 | git clone https://github.com/abhishekkrthakur/diffuzers 52 | cd diffuzers 53 | pip install -e . 54 | ``` 55 | 56 | Installation using pip: 57 | 58 | ```bash 59 | pip install diffuzers 60 | ``` 61 | 62 | ## Usage 63 | 64 | ### Web App 65 | To run the web app, run the following command: 66 | 67 | ```bash 68 | diffuzers app 69 | ``` 70 | 71 | ### API 72 | 73 | To run the api, run the following command: 74 | 75 | 76 | ```bash 77 | diffuzers api 78 | ``` 79 | 80 | Starting the API requires the following environment variables: 81 | 82 | ``` 83 | export X2IMG_MODEL=stabilityai/stable-diffusion-2-1 84 | export DEVICE=cuda 85 | ``` 86 | 87 | If you want to use inpainting: 88 | 89 | ``` 90 | export INPAINTING_MODEL=stabilityai/stable-diffusion-2-inpainting 91 | ``` 92 | 93 | To use long prompt weighting, use: 94 | 95 | ``` 96 | export PIPELINE=lpw_stable_diffusion 97 | ``` 98 | 99 | If you have `OUTPUT_PATH` in environment variables, all generations will be saved in `OUTPUT_PATH`. You can also use other (or private) huggingface models. To use private models, you must login using `huggingface-cli login`. 100 | 101 | API docs are available at `host:port/docs`. For example, with default settings, you can access docs at: `127.0.0.1:10000/docs`. 102 | 103 | 104 | ## All CLI Options for running the app: 105 | 106 | ```bash 107 | ❯ diffuzers app --help 108 | usage: diffuzers [] app [-h] [--output OUTPUT] [--share] [--port PORT] [--host HOST] 109 | [--device DEVICE] [--ngrok_key NGROK_KEY] 110 | 111 | ✨ Run diffuzers app 112 | 113 | optional arguments: 114 | -h, --help show this help message and exit 115 | --output OUTPUT Output path is optional, but if provided, all generations will automatically be saved to this 116 | path. 117 | --share Share the app 118 | --port PORT Port to run the app on 119 | --host HOST Host to run the app on 120 | --device DEVICE Device to use, e.g. cpu, cuda, cuda:0, mps (for m1 mac) etc. 121 | --ngrok_key NGROK_KEY 122 | Ngrok key to use for sharing the app. Only required if you want to share the app 123 | ``` 124 | 125 | ## All CLI Options for running the api: 126 | 127 | ```bash 128 | ❯ diffuzers api --help 129 | usage: diffuzers [] api [-h] [--output OUTPUT] [--port PORT] [--host HOST] [--device DEVICE] 130 | [--workers WORKERS] 131 | 132 | ✨ Run diffuzers api 133 | 134 | optional arguments: 135 | -h, --help show this help message and exit 136 | --output OUTPUT Output path is optional, but if provided, all generations will automatically be saved to this 137 | path. 138 | --port PORT Port to run the app on 139 | --host HOST Host to run the app on 140 | --device DEVICE Device to use, e.g. cpu, cuda, cuda:0, mps (for m1 mac) etc. 141 | --workers WORKERS Number of workers to use 142 | ``` 143 | 144 | ## Using private models from huggingface hub 145 | 146 | If you want to use private models from huggingface hub, then you need to login using `huggingface-cli login` command. 147 | 148 | Note: You can also save your generations directly to huggingface hub if your output path points to a huggingface hub dataset repo and you have access to push to that repository. Thus, you will end up saving a lot of disk space. 149 | -------------------------------------------------------------------------------- /diffuzers.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "!pip install diffuzers -qqqq" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "!diffuzers app --port 10000 --ngrok_key YOUR_NGROK_AUTHTOKEN --share" 19 | ] 20 | } 21 | ], 22 | "metadata": { 23 | "kernelspec": { 24 | "display_name": "ml", 25 | "language": "python", 26 | "name": "python3" 27 | }, 28 | "language_info": { 29 | "name": "python", 30 | "version": "3.8.13 (default, Mar 28 2022, 11:38:47) \n[GCC 7.5.0]" 31 | }, 32 | "orig_nbformat": 4, 33 | "vscode": { 34 | "interpreter": { 35 | "hash": "a6cdb91f3f78c48bad72cb110b9a247c2b670f553f33c17c34bce0072b5c1357" 36 | } 37 | } 38 | }, 39 | "nbformat": 4, 40 | "nbformat_minor": 2 41 | } 42 | -------------------------------------------------------------------------------- /diffuzers/Home.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import streamlit as st 4 | from loguru import logger 5 | 6 | from diffuzers import utils 7 | from diffuzers.x2image import X2Image 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument( 13 | "--output", 14 | type=str, 15 | required=False, 16 | default=None, 17 | help="Output path", 18 | ) 19 | parser.add_argument( 20 | "--device", 21 | type=str, 22 | required=True, 23 | help="Device to use, e.g. cpu, cuda, cuda:0, mps etc.", 24 | ) 25 | return parser.parse_args() 26 | 27 | 28 | def x2img_app(): 29 | with st.form("x2img_model_form"): 30 | col1, col2 = st.columns(2) 31 | with col1: 32 | model = st.text_input( 33 | "Which model do you want to use?", 34 | value="stabilityai/stable-diffusion-2-base" 35 | if st.session_state.get("x2img_model") is None 36 | else st.session_state.x2img_model, 37 | ) 38 | with col2: 39 | custom_pipeline = st.selectbox( 40 | "Custom pipeline", 41 | options=[ 42 | "Vanilla", 43 | "Long Prompt Weighting", 44 | ], 45 | index=0 if st.session_state.get("x2img_custom_pipeline") in (None, "Vanilla") else 1, 46 | ) 47 | 48 | with st.expander("Textual Inversion (Optional)"): 49 | token_identifier = st.text_input( 50 | "Token identifier", 51 | placeholder="" 52 | if st.session_state.get("textual_inversion_token_identifier") is None 53 | else st.session_state.textual_inversion_token_identifier, 54 | ) 55 | embeddings = st.text_input( 56 | "Embeddings", 57 | placeholder="https://huggingface.co/sd-concepts-library/axe-tattoo/resolve/main/learned_embeds.bin" 58 | if st.session_state.get("textual_inversion_embeddings") is None 59 | else st.session_state.textual_inversion_embeddings, 60 | ) 61 | submit = st.form_submit_button("Load model") 62 | 63 | if submit: 64 | st.session_state.x2img_model = model 65 | st.session_state.x2img_custom_pipeline = custom_pipeline 66 | st.session_state.textual_inversion_token_identifier = token_identifier 67 | st.session_state.textual_inversion_embeddings = embeddings 68 | cpipe = "lpw_stable_diffusion" if custom_pipeline == "Long Prompt Weighting" else None 69 | with st.spinner("Loading model..."): 70 | x2img = X2Image( 71 | model=model, 72 | device=st.session_state.device, 73 | output_path=st.session_state.output_path, 74 | custom_pipeline=cpipe, 75 | token_identifier=token_identifier, 76 | embeddings_url=embeddings, 77 | ) 78 | st.session_state.x2img = x2img 79 | if "x2img" in st.session_state: 80 | st.write(f"Current model: {st.session_state.x2img}") 81 | st.session_state.x2img.app() 82 | 83 | 84 | def run_app(): 85 | utils.create_base_page() 86 | x2img_app() 87 | 88 | 89 | if __name__ == "__main__": 90 | args = parse_args() 91 | logger.info(f"Args: {args}") 92 | logger.info(st.session_state) 93 | st.session_state.device = args.device 94 | st.session_state.output_path = args.output 95 | run_app() 96 | -------------------------------------------------------------------------------- /diffuzers/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from loguru import logger 4 | 5 | 6 | logger.configure(handlers=[dict(sink=sys.stderr, format="> {level:<7} {message}")]) 7 | 8 | __version__ = "0.3.5" 9 | -------------------------------------------------------------------------------- /diffuzers/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekkrthakur/diffuzers/71f710a1940c22637561809387b1c8622be1a48b/diffuzers/api/__init__.py -------------------------------------------------------------------------------- /diffuzers/api/main.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | 4 | from fastapi import Depends, FastAPI, File, UploadFile 5 | from loguru import logger 6 | from PIL import Image 7 | from starlette.middleware.cors import CORSMiddleware 8 | 9 | from diffuzers.api.schemas import Img2ImgParams, ImgResponse, InpaintingParams, InstructPix2PixParams, Text2ImgParams 10 | from diffuzers.api.utils import convert_to_b64_list 11 | from diffuzers.inpainting import Inpainting 12 | from diffuzers.x2image import X2Image 13 | 14 | 15 | app = FastAPI( 16 | title="diffuzers api", 17 | license_info={ 18 | "name": "Apache 2.0", 19 | "url": "https://www.apache.org/licenses/LICENSE-2.0.html", 20 | }, 21 | ) 22 | app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) 23 | 24 | 25 | @app.on_event("startup") 26 | async def startup_event(): 27 | 28 | x2img_model = os.environ.get("X2IMG_MODEL") 29 | x2img_pipeline = os.environ.get("X2IMG_PIPELINE") 30 | inpainting_model = os.environ.get("INPAINTING_MODEL") 31 | device = os.environ.get("DEVICE") 32 | output_path = os.environ.get("OUTPUT_PATH") 33 | ti_identifier = os.environ.get("TOKEN_IDENTIFIER", "") 34 | ti_embeddings_url = os.environ.get("TOKEN_EMBEDDINGS_URL", "") 35 | logger.info("@@@@@ Starting Diffuzes API @@@@@ ") 36 | logger.info(f"Text2Image Model: {x2img_model}") 37 | logger.info(f"Text2Image Pipeline: {x2img_pipeline if x2img_pipeline is not None else 'Vanilla'}") 38 | logger.info(f"Inpainting Model: {inpainting_model}") 39 | logger.info(f"Device: {device}") 40 | logger.info(f"Output Path: {output_path}") 41 | logger.info(f"Token Identifier: {ti_identifier}") 42 | logger.info(f"Token Embeddings URL: {ti_embeddings_url}") 43 | 44 | logger.info("Loading x2img model...") 45 | if x2img_model is not None: 46 | app.state.x2img_model = X2Image( 47 | model=x2img_model, 48 | device=device, 49 | output_path=output_path, 50 | custom_pipeline=x2img_pipeline, 51 | token_identifier=ti_identifier, 52 | embeddings_url=ti_embeddings_url, 53 | ) 54 | else: 55 | app.state.x2img_model = None 56 | logger.info("Loading inpainting model...") 57 | if inpainting_model is not None: 58 | app.state.inpainting_model = Inpainting( 59 | model=inpainting_model, 60 | device=device, 61 | output_path=output_path, 62 | ) 63 | logger.info("API is ready to use!") 64 | 65 | 66 | @app.post("/text2img") 67 | async def text2img(params: Text2ImgParams) -> ImgResponse: 68 | logger.info(f"Params: {params}") 69 | if app.state.x2img_model is None: 70 | return {"error": "x2img model is not loaded"} 71 | 72 | images, _ = app.state.x2img_model.text2img_generate( 73 | params.prompt, 74 | num_images=params.num_images, 75 | steps=params.steps, 76 | seed=params.seed, 77 | negative_prompt=params.negative_prompt, 78 | scheduler=params.scheduler, 79 | image_size=(params.image_height, params.image_width), 80 | guidance_scale=params.guidance_scale, 81 | ) 82 | base64images = convert_to_b64_list(images) 83 | return ImgResponse(images=base64images, metadata=params.dict()) 84 | 85 | 86 | @app.post("/img2img") 87 | async def img2img(params: Img2ImgParams = Depends(), image: UploadFile = File(...)) -> ImgResponse: 88 | if app.state.x2img_model is None: 89 | return {"error": "x2img model is not loaded"} 90 | image = Image.open(io.BytesIO(image.file.read())) 91 | images, _ = app.state.x2img_model.img2img_generate( 92 | image=image, 93 | prompt=params.prompt, 94 | negative_prompt=params.negative_prompt, 95 | num_images=params.num_images, 96 | steps=params.steps, 97 | seed=params.seed, 98 | scheduler=params.scheduler, 99 | guidance_scale=params.guidance_scale, 100 | strength=params.strength, 101 | ) 102 | base64images = convert_to_b64_list(images) 103 | return ImgResponse(images=base64images, metadata=params.dict()) 104 | 105 | 106 | @app.post("/instruct-pix2pix") 107 | async def instruct_pix2pix(params: InstructPix2PixParams = Depends(), image: UploadFile = File(...)) -> ImgResponse: 108 | if app.state.x2img_model is None: 109 | return {"error": "x2img model is not loaded"} 110 | image = Image.open(io.BytesIO(image.file.read())) 111 | images, _ = app.state.x2img_model.pix2pix_generate( 112 | image=image, 113 | prompt=params.prompt, 114 | negative_prompt=params.negative_prompt, 115 | num_images=params.num_images, 116 | steps=params.steps, 117 | seed=params.seed, 118 | scheduler=params.scheduler, 119 | guidance_scale=params.guidance_scale, 120 | image_guidance_scale=params.image_guidance_scale, 121 | ) 122 | base64images = convert_to_b64_list(images) 123 | return ImgResponse(images=base64images, metadata=params.dict()) 124 | 125 | 126 | @app.post("/inpainting") 127 | async def inpainting( 128 | params: InpaintingParams = Depends(), image: UploadFile = File(...), mask: UploadFile = File(...) 129 | ) -> ImgResponse: 130 | if app.state.inpainting_model is None: 131 | return {"error": "inpainting model is not loaded"} 132 | image = Image.open(io.BytesIO(image.file.read())) 133 | mask = Image.open(io.BytesIO(mask.file.read())) 134 | images, _ = app.state.inpainting_model.generate_image( 135 | image=image, 136 | mask=mask, 137 | prompt=params.prompt, 138 | negative_prompt=params.negative_prompt, 139 | scheduler=params.scheduler, 140 | height=params.image_height, 141 | width=params.image_width, 142 | num_images=params.num_images, 143 | guidance_scale=params.guidance_scale, 144 | steps=params.steps, 145 | seed=params.seed, 146 | ) 147 | base64images = convert_to_b64_list(images) 148 | return ImgResponse(images=base64images, metadata=params.dict()) 149 | 150 | 151 | @app.get("/") 152 | def read_root(): 153 | return {"Hello": "World"} 154 | -------------------------------------------------------------------------------- /diffuzers/api/schemas.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | 6 | class Text2ImgParams(BaseModel): 7 | prompt: str = Field(..., description="Text prompt for the model") 8 | negative_prompt: str = Field(None, description="Negative text prompt for the model") 9 | scheduler: str = Field("EulerAncestralDiscreteScheduler", description="Scheduler to use for the model") 10 | image_height: int = Field(512, description="Image height") 11 | image_width: int = Field(512, description="Image width") 12 | num_images: int = Field(1, description="Number of images to generate") 13 | guidance_scale: float = Field(7, description="Guidance scale") 14 | steps: int = Field(50, description="Number of steps to run the model for") 15 | seed: int = Field(42, description="Seed for the model") 16 | 17 | 18 | class Img2ImgParams(BaseModel): 19 | prompt: str = Field(..., description="Text prompt for the model") 20 | negative_prompt: str = Field(None, description="Negative text prompt for the model") 21 | scheduler: str = Field("EulerAncestralDiscreteScheduler", description="Scheduler to use for the model") 22 | strength: float = Field(0.7, description="Strength") 23 | num_images: int = Field(1, description="Number of images to generate") 24 | guidance_scale: float = Field(7, description="Guidance scale") 25 | steps: int = Field(50, description="Number of steps to run the model for") 26 | seed: int = Field(42, description="Seed for the model") 27 | 28 | 29 | class InstructPix2PixParams(BaseModel): 30 | prompt: str = Field(..., description="Text prompt for the model") 31 | negative_prompt: str = Field(None, description="Negative text prompt for the model") 32 | scheduler: str = Field("EulerAncestralDiscreteScheduler", description="Scheduler to use for the model") 33 | num_images: int = Field(1, description="Number of images to generate") 34 | guidance_scale: float = Field(7, description="Guidance scale") 35 | image_guidance_scale: float = Field(1.5, description="Image guidance scale") 36 | steps: int = Field(50, description="Number of steps to run the model for") 37 | seed: int = Field(42, description="Seed for the model") 38 | 39 | 40 | class ImgResponse(BaseModel): 41 | images: List[str] = Field(..., description="List of images in base64 format") 42 | metadata: Dict = Field(..., description="Metadata") 43 | 44 | 45 | class InpaintingParams(BaseModel): 46 | prompt: str = Field(..., description="Text prompt for the model") 47 | negative_prompt: str = Field(None, description="Negative text prompt for the model") 48 | scheduler: str = Field("EulerAncestralDiscreteScheduler", description="Scheduler to use for the model") 49 | image_height: int = Field(512, description="Image height") 50 | image_width: int = Field(512, description="Image width") 51 | num_images: int = Field(1, description="Number of images to generate") 52 | guidance_scale: float = Field(7, description="Guidance scale") 53 | steps: int = Field(50, description="Number of steps to run the model for") 54 | seed: int = Field(42, description="Seed for the model") 55 | -------------------------------------------------------------------------------- /diffuzers/api/utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | 4 | 5 | def convert_to_b64_list(images): 6 | base64images = [] 7 | for image in images: 8 | buf = io.BytesIO() 9 | image.save(buf, format="PNG") 10 | byte_im = base64.b64encode(buf.getvalue()) 11 | base64images.append(byte_im) 12 | return base64images 13 | -------------------------------------------------------------------------------- /diffuzers/cli/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from argparse import ArgumentParser 3 | 4 | 5 | class BaseDiffuzersCommand(ABC): 6 | @staticmethod 7 | @abstractmethod 8 | def register_subcommand(parser: ArgumentParser): 9 | raise NotImplementedError() 10 | 11 | @abstractmethod 12 | def run(self): 13 | raise NotImplementedError() 14 | -------------------------------------------------------------------------------- /diffuzers/cli/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from .. import __version__ 4 | from .run_api import RunDiffuzersAPICommand 5 | from .run_app import RunDiffuzersAppCommand 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser( 10 | "Diffuzers CLI", 11 | usage="diffuzers []", 12 | epilog="For more information about a command, run: `diffuzers --help`", 13 | ) 14 | parser.add_argument("--version", "-v", help="Display diffuzers version", action="store_true") 15 | commands_parser = parser.add_subparsers(help="commands") 16 | 17 | # Register commands 18 | RunDiffuzersAppCommand.register_subcommand(commands_parser) 19 | RunDiffuzersAPICommand.register_subcommand(commands_parser) 20 | 21 | args = parser.parse_args() 22 | 23 | if args.version: 24 | print(__version__) 25 | exit(0) 26 | 27 | if not hasattr(args, "func"): 28 | parser.print_help() 29 | exit(1) 30 | 31 | command = args.func(args) 32 | command.run() 33 | 34 | 35 | if __name__ == "__main__": 36 | main() 37 | -------------------------------------------------------------------------------- /diffuzers/cli/run_api.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from argparse import ArgumentParser 3 | 4 | import torch 5 | 6 | from . import BaseDiffuzersCommand 7 | 8 | 9 | def run_api_command_factory(args): 10 | return RunDiffuzersAPICommand( 11 | args.output, 12 | args.port, 13 | args.host, 14 | args.device, 15 | args.workers, 16 | args.ssl_certfile, 17 | args.ssl_keyfile, 18 | ) 19 | 20 | 21 | class RunDiffuzersAPICommand(BaseDiffuzersCommand): 22 | @staticmethod 23 | def register_subcommand(parser: ArgumentParser): 24 | run_api_parser = parser.add_parser( 25 | "api", 26 | description="✨ Run diffuzers api", 27 | ) 28 | run_api_parser.add_argument( 29 | "--output", 30 | type=str, 31 | required=False, 32 | help="Output path is optional, but if provided, all generations will automatically be saved to this path.", 33 | ) 34 | run_api_parser.add_argument( 35 | "--port", 36 | type=int, 37 | default=10000, 38 | help="Port to run the app on", 39 | required=False, 40 | ) 41 | run_api_parser.add_argument( 42 | "--host", 43 | type=str, 44 | default="127.0.0.1", 45 | help="Host to run the app on", 46 | required=False, 47 | ) 48 | run_api_parser.add_argument( 49 | "--device", 50 | type=str, 51 | required=False, 52 | help="Device to use, e.g. cpu, cuda, cuda:0, mps (for m1 mac) etc.", 53 | ) 54 | run_api_parser.add_argument( 55 | "--workers", 56 | type=int, 57 | required=False, 58 | default=1, 59 | help="Number of workers to use", 60 | ) 61 | run_api_parser.add_argument( 62 | "--ssl_certfile", 63 | type=str, 64 | required=False, 65 | help="the path to your ssl cert", 66 | ) 67 | run_api_parser.add_argument( 68 | "--ssl_keyfile", 69 | type=str, 70 | required=False, 71 | help="the path to your ssl key", 72 | ) 73 | run_api_parser.set_defaults(func=run_api_command_factory) 74 | 75 | def __init__(self, output, port, host, device, workers, ssl_certfile, ssl_keyfile): 76 | self.output = output 77 | self.port = port 78 | self.host = host 79 | self.device = device 80 | self.workers = workers 81 | self.ssl_certfile = ssl_certfile 82 | self.ssl_keyfile = ssl_keyfile 83 | 84 | if self.device is None: 85 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 86 | 87 | self.port = str(self.port) 88 | self.workers = str(self.workers) 89 | 90 | def run(self): 91 | cmd = [ 92 | "uvicorn", 93 | "diffuzers.api.main:app", 94 | "--host", 95 | self.host, 96 | "--port", 97 | self.port, 98 | "--workers", 99 | self.workers, 100 | ] 101 | if self.ssl_certfile is not None: 102 | cmd.extend(["--ssl-certfile", self.ssl_certfile]) 103 | if self.ssl_keyfile is not None: 104 | cmd.extend(["--ssl-keyfile", self.ssl_keyfile]) 105 | 106 | proc = subprocess.Popen( 107 | cmd, 108 | stdout=subprocess.PIPE, 109 | stderr=subprocess.STDOUT, 110 | shell=False, 111 | universal_newlines=True, 112 | bufsize=1, 113 | ) 114 | with proc as p: 115 | try: 116 | for line in p.stdout: 117 | print(line, end="") 118 | except KeyboardInterrupt: 119 | print("Killing api") 120 | p.kill() 121 | p.wait() 122 | raise 123 | -------------------------------------------------------------------------------- /diffuzers/cli/run_app.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from argparse import ArgumentParser 3 | 4 | import torch 5 | from pyngrok import ngrok 6 | 7 | from . import BaseDiffuzersCommand 8 | 9 | 10 | def run_app_command_factory(args): 11 | return RunDiffuzersAppCommand( 12 | args.output, 13 | args.share, 14 | args.port, 15 | args.host, 16 | args.device, 17 | args.ngrok_key, 18 | ) 19 | 20 | 21 | class RunDiffuzersAppCommand(BaseDiffuzersCommand): 22 | @staticmethod 23 | def register_subcommand(parser: ArgumentParser): 24 | run_app_parser = parser.add_parser( 25 | "app", 26 | description="✨ Run diffuzers app", 27 | ) 28 | run_app_parser.add_argument( 29 | "--output", 30 | type=str, 31 | required=False, 32 | help="Output path is optional, but if provided, all generations will automatically be saved to this path.", 33 | ) 34 | run_app_parser.add_argument( 35 | "--share", 36 | action="store_true", 37 | help="Share the app", 38 | ) 39 | run_app_parser.add_argument( 40 | "--port", 41 | type=int, 42 | default=10000, 43 | help="Port to run the app on", 44 | required=False, 45 | ) 46 | run_app_parser.add_argument( 47 | "--host", 48 | type=str, 49 | default="127.0.0.1", 50 | help="Host to run the app on", 51 | required=False, 52 | ) 53 | run_app_parser.add_argument( 54 | "--device", 55 | type=str, 56 | required=False, 57 | help="Device to use, e.g. cpu, cuda, cuda:0, mps (for m1 mac) etc.", 58 | ) 59 | run_app_parser.add_argument( 60 | "--ngrok_key", 61 | type=str, 62 | required=False, 63 | help="Ngrok key to use for sharing the app. Only required if you want to share the app", 64 | ) 65 | 66 | run_app_parser.set_defaults(func=run_app_command_factory) 67 | 68 | def __init__(self, output, share, port, host, device, ngrok_key): 69 | self.output = output 70 | self.share = share 71 | self.port = port 72 | self.host = host 73 | self.device = device 74 | self.ngrok_key = ngrok_key 75 | 76 | if self.device is None: 77 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 78 | 79 | if self.share: 80 | if self.ngrok_key is None: 81 | raise ValueError( 82 | "ngrok key is required if you want to share the app. Get it for free from https://dashboard.ngrok.com/get-started/your-authtoken" 83 | ) 84 | 85 | def run(self): 86 | # from ..app import Diffuzers 87 | 88 | # print(self.share) 89 | # app = Diffuzers(self.model, self.output).app() 90 | # app.launch(show_api=False, share=self.share, server_port=self.port, server_name=self.host) 91 | import os 92 | 93 | dirname = os.path.dirname(__file__) 94 | filename = os.path.join(dirname, "..", "Home.py") 95 | cmd = [ 96 | "streamlit", 97 | "run", 98 | filename, 99 | "--browser.gatherUsageStats", 100 | "false", 101 | "--browser.serverAddress", 102 | self.host, 103 | "--server.port", 104 | str(self.port), 105 | "--theme.base", 106 | "light", 107 | "--", 108 | "--device", 109 | self.device, 110 | ] 111 | if self.output is not None: 112 | cmd.extend(["--output", self.output]) 113 | 114 | if self.share: 115 | ngrok.set_auth_token(self.ngrok_key) 116 | public_url = ngrok.connect(self.port).public_url 117 | print(f"Sharing app at {public_url}") 118 | 119 | proc = subprocess.Popen( 120 | cmd, 121 | stdout=subprocess.PIPE, 122 | stderr=subprocess.STDOUT, 123 | shell=False, 124 | universal_newlines=True, 125 | bufsize=1, 126 | ) 127 | with proc as p: 128 | try: 129 | for line in p.stdout: 130 | print(line, end="") 131 | except KeyboardInterrupt: 132 | print("Killing streamlit app") 133 | p.kill() 134 | if self.share: 135 | print("Killing ngrok tunnel") 136 | ngrok.kill() 137 | p.wait() 138 | raise 139 | -------------------------------------------------------------------------------- /diffuzers/clip_interrogator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code taken from: https://github.com/pharmapsychotic/clip-interrogator 3 | """ 4 | 5 | import hashlib 6 | import inspect 7 | import math 8 | import os 9 | import pickle 10 | import time 11 | from dataclasses import dataclass 12 | from typing import List 13 | 14 | import numpy as np 15 | import open_clip 16 | import torch 17 | from PIL import Image 18 | from torchvision import transforms 19 | from torchvision.transforms.functional import InterpolationMode 20 | from tqdm import tqdm 21 | 22 | from diffuzers.blip import BLIP_Decoder, blip_decoder 23 | 24 | 25 | @dataclass 26 | class Config: 27 | # models can optionally be passed in directly 28 | blip_model: BLIP_Decoder = None 29 | clip_model = None 30 | clip_preprocess = None 31 | 32 | # blip settings 33 | blip_image_eval_size: int = 384 34 | blip_max_length: int = 32 35 | blip_model_url: str = ( 36 | "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth" 37 | ) 38 | blip_num_beams: int = 8 39 | blip_offload: bool = False 40 | 41 | # clip settings 42 | clip_model_name: str = "ViT-L-14/openai" 43 | clip_model_path: str = None 44 | 45 | # interrogator settings 46 | cache_path: str = "cache" 47 | chunk_size: int = 2048 48 | data_path: str = os.path.join(os.path.dirname(__file__), "data") 49 | device: str = "cuda" if torch.cuda.is_available() else "cpu" 50 | flavor_intermediate_count: int = 2048 51 | quiet: bool = False # when quiet progress bars are not shown 52 | 53 | 54 | class Interrogator: 55 | def __init__(self, config: Config): 56 | self.config = config 57 | self.device = config.device 58 | 59 | if config.blip_model is None: 60 | if not config.quiet: 61 | print("Loading BLIP model...") 62 | blip_path = os.path.dirname(inspect.getfile(blip_decoder)) 63 | configs_path = os.path.join(os.path.dirname(blip_path), "configs") 64 | med_config = os.path.join(configs_path, "med_config.json") 65 | blip_model = blip_decoder( 66 | pretrained=config.blip_model_url, 67 | image_size=config.blip_image_eval_size, 68 | vit="large", 69 | med_config=med_config, 70 | ) 71 | blip_model.eval() 72 | blip_model = blip_model.to(config.device) 73 | self.blip_model = blip_model 74 | else: 75 | self.blip_model = config.blip_model 76 | 77 | self.load_clip_model() 78 | 79 | def load_clip_model(self): 80 | start_time = time.time() 81 | config = self.config 82 | 83 | if config.clip_model is None: 84 | if not config.quiet: 85 | print("Loading CLIP model...") 86 | 87 | clip_model_name, clip_model_pretrained_name = config.clip_model_name.split("/", 2) 88 | self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms( 89 | clip_model_name, 90 | pretrained=clip_model_pretrained_name, 91 | precision="fp16" if config.device == "cuda" else "fp32", 92 | device=config.device, 93 | jit=False, 94 | cache_dir=config.clip_model_path, 95 | ) 96 | self.clip_model.to(config.device).eval() 97 | else: 98 | self.clip_model = config.clip_model 99 | self.clip_preprocess = config.clip_preprocess 100 | self.tokenize = open_clip.get_tokenizer(clip_model_name) 101 | 102 | sites = [ 103 | "Artstation", 104 | "behance", 105 | "cg society", 106 | "cgsociety", 107 | "deviantart", 108 | "dribble", 109 | "flickr", 110 | "instagram", 111 | "pexels", 112 | "pinterest", 113 | "pixabay", 114 | "pixiv", 115 | "polycount", 116 | "reddit", 117 | "shutterstock", 118 | "tumblr", 119 | "unsplash", 120 | "zbrush central", 121 | ] 122 | trending_list = [site for site in sites] 123 | trending_list.extend(["trending on " + site for site in sites]) 124 | trending_list.extend(["featured on " + site for site in sites]) 125 | trending_list.extend([site + " contest winner" for site in sites]) 126 | 127 | raw_artists = _load_list(config.data_path, "artists.txt") 128 | artists = [f"by {a}" for a in raw_artists] 129 | artists.extend([f"inspired by {a}" for a in raw_artists]) 130 | 131 | self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config) 132 | self.flavors = LabelTable( 133 | _load_list(config.data_path, "flavors.txt"), "flavors", self.clip_model, self.tokenize, config 134 | ) 135 | self.mediums = LabelTable( 136 | _load_list(config.data_path, "mediums.txt"), "mediums", self.clip_model, self.tokenize, config 137 | ) 138 | self.movements = LabelTable( 139 | _load_list(config.data_path, "movements.txt"), "movements", self.clip_model, self.tokenize, config 140 | ) 141 | self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config) 142 | 143 | end_time = time.time() 144 | if not config.quiet: 145 | print(f"Loaded CLIP model and data in {end_time-start_time:.2f} seconds.") 146 | 147 | def generate_caption(self, pil_image: Image) -> str: 148 | if self.config.blip_offload: 149 | self.blip_model = self.blip_model.to(self.device) 150 | size = self.config.blip_image_eval_size 151 | gpu_image = ( 152 | transforms.Compose( 153 | [ 154 | transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC), 155 | transforms.ToTensor(), 156 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 157 | ] 158 | )(pil_image) 159 | .unsqueeze(0) 160 | .to(self.device) 161 | ) 162 | 163 | with torch.no_grad(): 164 | caption = self.blip_model.generate( 165 | gpu_image, 166 | sample=False, 167 | num_beams=self.config.blip_num_beams, 168 | max_length=self.config.blip_max_length, 169 | min_length=5, 170 | ) 171 | if self.config.blip_offload: 172 | self.blip_model = self.blip_model.to("cpu") 173 | return caption[0] 174 | 175 | def image_to_features(self, image: Image) -> torch.Tensor: 176 | images = self.clip_preprocess(image).unsqueeze(0).to(self.device) 177 | with torch.no_grad(), torch.cuda.amp.autocast(): 178 | image_features = self.clip_model.encode_image(images) 179 | image_features /= image_features.norm(dim=-1, keepdim=True) 180 | return image_features 181 | 182 | def interrogate_classic(self, image: Image, max_flavors: int = 3) -> str: 183 | caption = self.generate_caption(image) 184 | image_features = self.image_to_features(image) 185 | 186 | medium = self.mediums.rank(image_features, 1)[0] 187 | artist = self.artists.rank(image_features, 1)[0] 188 | trending = self.trendings.rank(image_features, 1)[0] 189 | movement = self.movements.rank(image_features, 1)[0] 190 | flaves = ", ".join(self.flavors.rank(image_features, max_flavors)) 191 | 192 | if caption.startswith(medium): 193 | prompt = f"{caption} {artist}, {trending}, {movement}, {flaves}" 194 | else: 195 | prompt = f"{caption}, {medium} {artist}, {trending}, {movement}, {flaves}" 196 | 197 | return _truncate_to_fit(prompt, self.tokenize) 198 | 199 | def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str: 200 | caption = self.generate_caption(image) 201 | image_features = self.image_to_features(image) 202 | merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) 203 | tops = merged.rank(image_features, max_flavors) 204 | return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize) 205 | 206 | def interrogate(self, image: Image, max_flavors: int = 32) -> str: 207 | caption = self.generate_caption(image) 208 | image_features = self.image_to_features(image) 209 | 210 | flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count) 211 | best_medium = self.mediums.rank(image_features, 1)[0] 212 | best_artist = self.artists.rank(image_features, 1)[0] 213 | best_trending = self.trendings.rank(image_features, 1)[0] 214 | best_movement = self.movements.rank(image_features, 1)[0] 215 | 216 | best_prompt = caption 217 | best_sim = self.similarity(image_features, best_prompt) 218 | 219 | def check(addition: str) -> bool: 220 | nonlocal best_prompt, best_sim 221 | prompt = best_prompt + ", " + addition 222 | sim = self.similarity(image_features, prompt) 223 | if sim > best_sim: 224 | best_sim = sim 225 | best_prompt = prompt 226 | return True 227 | return False 228 | 229 | def check_multi_batch(opts: List[str]): 230 | nonlocal best_prompt, best_sim 231 | prompts = [] 232 | for i in range(2 ** len(opts)): 233 | prompt = best_prompt 234 | for bit in range(len(opts)): 235 | if i & (1 << bit): 236 | prompt += ", " + opts[bit] 237 | prompts.append(prompt) 238 | 239 | t = LabelTable(prompts, None, self.clip_model, self.tokenize, self.config) 240 | best_prompt = t.rank(image_features, 1)[0] 241 | best_sim = self.similarity(image_features, best_prompt) 242 | 243 | check_multi_batch([best_medium, best_artist, best_trending, best_movement]) 244 | 245 | extended_flavors = set(flaves) 246 | for _ in tqdm(range(max_flavors), desc="Flavor chain", disable=self.config.quiet): 247 | best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors]) 248 | flave = best[len(best_prompt) + 2 :] 249 | if not check(flave): 250 | break 251 | if _prompt_at_max_len(best_prompt, self.tokenize): 252 | break 253 | extended_flavors.remove(flave) 254 | 255 | return best_prompt 256 | 257 | def rank_top(self, image_features: torch.Tensor, text_array: List[str]) -> str: 258 | text_tokens = self.tokenize([text for text in text_array]).to(self.device) 259 | with torch.no_grad(), torch.cuda.amp.autocast(): 260 | text_features = self.clip_model.encode_text(text_tokens) 261 | text_features /= text_features.norm(dim=-1, keepdim=True) 262 | similarity = text_features @ image_features.T 263 | return text_array[similarity.argmax().item()] 264 | 265 | def similarity(self, image_features: torch.Tensor, text: str) -> float: 266 | text_tokens = self.tokenize([text]).to(self.device) 267 | with torch.no_grad(), torch.cuda.amp.autocast(): 268 | text_features = self.clip_model.encode_text(text_tokens) 269 | text_features /= text_features.norm(dim=-1, keepdim=True) 270 | similarity = text_features @ image_features.T 271 | return similarity[0][0].item() 272 | 273 | 274 | class LabelTable: 275 | def __init__(self, labels: List[str], desc: str, clip_model, tokenize, config: Config): 276 | self.chunk_size = config.chunk_size 277 | self.config = config 278 | self.device = config.device 279 | self.embeds = [] 280 | self.labels = labels 281 | self.tokenize = tokenize 282 | 283 | hash = hashlib.sha256(",".join(labels).encode()).hexdigest() 284 | 285 | cache_filepath = None 286 | if config.cache_path is not None and desc is not None: 287 | os.makedirs(config.cache_path, exist_ok=True) 288 | sanitized_name = config.clip_model_name.replace("/", "_").replace("@", "_") 289 | cache_filepath = os.path.join(config.cache_path, f"{sanitized_name}_{desc}.pkl") 290 | if desc is not None and os.path.exists(cache_filepath): 291 | with open(cache_filepath, "rb") as f: 292 | try: 293 | data = pickle.load(f) 294 | if data.get("hash") == hash: 295 | self.labels = data["labels"] 296 | self.embeds = data["embeds"] 297 | except Exception as e: 298 | print(f"Error loading cached table {desc}: {e}") 299 | 300 | if len(self.labels) != len(self.embeds): 301 | self.embeds = [] 302 | chunks = np.array_split(self.labels, max(1, len(self.labels) / config.chunk_size)) 303 | for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet): 304 | text_tokens = self.tokenize(chunk).to(self.device) 305 | with torch.no_grad(), torch.cuda.amp.autocast(): 306 | text_features = clip_model.encode_text(text_tokens) 307 | text_features /= text_features.norm(dim=-1, keepdim=True) 308 | text_features = text_features.half().cpu().numpy() 309 | for i in range(text_features.shape[0]): 310 | self.embeds.append(text_features[i]) 311 | 312 | if cache_filepath is not None: 313 | with open(cache_filepath, "wb") as f: 314 | pickle.dump( 315 | {"labels": self.labels, "embeds": self.embeds, "hash": hash, "model": config.clip_model_name}, 316 | f, 317 | ) 318 | 319 | if self.device == "cpu" or self.device == torch.device("cpu"): 320 | self.embeds = [e.astype(np.float32) for e in self.embeds] 321 | 322 | def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int = 1) -> str: 323 | top_count = min(top_count, len(text_embeds)) 324 | text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device) 325 | with torch.cuda.amp.autocast(): 326 | similarity = image_features @ text_embeds.T 327 | _, top_labels = similarity.float().cpu().topk(top_count, dim=-1) 328 | return [top_labels[0][i].numpy() for i in range(top_count)] 329 | 330 | def rank(self, image_features: torch.Tensor, top_count: int = 1) -> List[str]: 331 | if len(self.labels) <= self.chunk_size: 332 | tops = self._rank(image_features, self.embeds, top_count=top_count) 333 | return [self.labels[i] for i in tops] 334 | 335 | num_chunks = int(math.ceil(len(self.labels) / self.chunk_size)) 336 | keep_per_chunk = int(self.chunk_size / num_chunks) 337 | 338 | top_labels, top_embeds = [], [] 339 | for chunk_idx in tqdm(range(num_chunks), disable=self.config.quiet): 340 | start = chunk_idx * self.chunk_size 341 | stop = min(start + self.chunk_size, len(self.embeds)) 342 | tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk) 343 | top_labels.extend([self.labels[start + i] for i in tops]) 344 | top_embeds.extend([self.embeds[start + i] for i in tops]) 345 | 346 | tops = self._rank(image_features, top_embeds, top_count=top_count) 347 | return [top_labels[i] for i in tops] 348 | 349 | 350 | def _load_list(data_path: str, filename: str) -> List[str]: 351 | with open(os.path.join(data_path, filename), "r", encoding="utf-8", errors="replace") as f: 352 | items = [line.strip() for line in f.readlines()] 353 | return items 354 | 355 | 356 | def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable: 357 | m = LabelTable([], None, None, None, config) 358 | for table in tables: 359 | m.labels.extend(table.labels) 360 | m.embeds.extend(table.embeds) 361 | return m 362 | 363 | 364 | def _prompt_at_max_len(text: str, tokenize) -> bool: 365 | tokens = tokenize([text]) 366 | return tokens[0][-1] != 0 367 | 368 | 369 | def _truncate_to_fit(text: str, tokenize) -> str: 370 | parts = text.split(", ") 371 | new_text = parts[0] 372 | for part in parts[1:]: 373 | if _prompt_at_max_len(new_text + part, tokenize): 374 | break 375 | new_text += ", " + part 376 | return new_text 377 | -------------------------------------------------------------------------------- /diffuzers/data/mediums.txt: -------------------------------------------------------------------------------- 1 | a 3D render 2 | a black and white photo 3 | a bronze sculpture 4 | a cartoon 5 | a cave painting 6 | a character portrait 7 | a charcoal drawing 8 | a child's drawing 9 | a color pencil sketch 10 | a colorized photo 11 | a comic book panel 12 | a computer rendering 13 | a cross stitch 14 | a cubist painting 15 | a detailed drawing 16 | a detailed matte painting 17 | a detailed painting 18 | a diagram 19 | a digital painting 20 | a digital rendering 21 | a drawing 22 | a fine art painting 23 | a flemish Baroque 24 | a gouache 25 | a hologram 26 | a hyperrealistic painting 27 | a jigsaw puzzle 28 | a low poly render 29 | a macro photograph 30 | a manga drawing 31 | a marble sculpture 32 | a matte painting 33 | a microscopic photo 34 | a mid-nineteenth century engraving 35 | a minimalist painting 36 | a mosaic 37 | a painting 38 | a pastel 39 | a pencil sketch 40 | a photo 41 | a photocopy 42 | a photorealistic painting 43 | a picture 44 | a pointillism painting 45 | a polaroid photo 46 | a pop art painting 47 | a portrait 48 | a poster 49 | a raytraced image 50 | a renaissance painting 51 | a screenprint 52 | a screenshot 53 | a silk screen 54 | a sketch 55 | a statue 56 | a still life 57 | a stipple 58 | a stock photo 59 | a storybook illustration 60 | a surrealist painting 61 | a surrealist sculpture 62 | a tattoo 63 | a tilt shift photo 64 | a watercolor painting 65 | a wireframe diagram 66 | a woodcut 67 | an abstract drawing 68 | an abstract painting 69 | an abstract sculpture 70 | an acrylic painting 71 | an airbrush painting 72 | an album cover 73 | an ambient occlusion render 74 | an anime drawing 75 | an art deco painting 76 | an art deco sculpture 77 | an engraving 78 | an etching 79 | an illustration of 80 | an impressionist painting 81 | an ink drawing 82 | an oil on canvas painting 83 | an oil painting 84 | an ultrafine detailed painting 85 | chalk art 86 | computer graphics 87 | concept art 88 | cyberpunk art 89 | digital art 90 | egyptian art 91 | graffiti art 92 | lineart 93 | pixel art 94 | poster art 95 | vector art 96 | -------------------------------------------------------------------------------- /diffuzers/data/movements.txt: -------------------------------------------------------------------------------- 1 | abstract art 2 | abstract expressionism 3 | abstract illusionism 4 | academic art 5 | action painting 6 | aestheticism 7 | afrofuturism 8 | altermodern 9 | american barbizon school 10 | american impressionism 11 | american realism 12 | american romanticism 13 | american scene painting 14 | analytical art 15 | antipodeans 16 | arabesque 17 | arbeitsrat für kunst 18 | art & language 19 | art brut 20 | art deco 21 | art informel 22 | art nouveau 23 | art photography 24 | arte povera 25 | arts and crafts movement 26 | ascii art 27 | ashcan school 28 | assemblage 29 | australian tonalism 30 | auto-destructive art 31 | barbizon school 32 | baroque 33 | bauhaus 34 | bengal school of art 35 | berlin secession 36 | black arts movement 37 | brutalism 38 | classical realism 39 | cloisonnism 40 | cobra 41 | color field 42 | computer art 43 | conceptual art 44 | concrete art 45 | constructivism 46 | context art 47 | crayon art 48 | crystal cubism 49 | cubism 50 | cubo-futurism 51 | cynical realism 52 | dada 53 | danube school 54 | dau-al-set 55 | de stijl 56 | deconstructivism 57 | digital art 58 | ecological art 59 | environmental art 60 | excessivism 61 | expressionism 62 | fantastic realism 63 | fantasy art 64 | fauvism 65 | feminist art 66 | figuration libre 67 | figurative art 68 | figurativism 69 | fine art 70 | fluxus 71 | folk art 72 | funk art 73 | furry art 74 | futurism 75 | generative art 76 | geometric abstract art 77 | german romanticism 78 | gothic art 79 | graffiti 80 | gutai group 81 | happening 82 | harlem renaissance 83 | heidelberg school 84 | holography 85 | hudson river school 86 | hurufiyya 87 | hypermodernism 88 | hyperrealism 89 | impressionism 90 | incoherents 91 | institutional critique 92 | interactive art 93 | international gothic 94 | international typographic style 95 | kinetic art 96 | kinetic pointillism 97 | kitsch movement 98 | land art 99 | les automatistes 100 | les nabis 101 | letterism 102 | light and space 103 | lowbrow 104 | lyco art 105 | lyrical abstraction 106 | magic realism 107 | magical realism 108 | mail art 109 | mannerism 110 | massurrealism 111 | maximalism 112 | metaphysical painting 113 | mingei 114 | minimalism 115 | modern european ink painting 116 | modernism 117 | modular constructivism 118 | naive art 119 | naturalism 120 | neo-dada 121 | neo-expressionism 122 | neo-fauvism 123 | neo-figurative 124 | neo-primitivism 125 | neo-romanticism 126 | neoclassicism 127 | neogeo 128 | neoism 129 | neoplasticism 130 | net art 131 | new objectivity 132 | new sculpture 133 | northwest school 134 | nuclear art 135 | objective abstraction 136 | op art 137 | optical illusion 138 | orphism 139 | panfuturism 140 | paris school 141 | photorealism 142 | pixel art 143 | plasticien 144 | plein air 145 | pointillism 146 | pop art 147 | pop surrealism 148 | post-impressionism 149 | postminimalism 150 | pre-raphaelitism 151 | precisionism 152 | primitivism 153 | private press 154 | process art 155 | psychedelic art 156 | purism 157 | qajar art 158 | quito school 159 | rasquache 160 | rayonism 161 | realism 162 | regionalism 163 | remodernism 164 | renaissance 165 | retrofuturism 166 | rococo 167 | romanesque 168 | romanticism 169 | samikshavad 170 | serial art 171 | shin hanga 172 | shock art 173 | socialist realism 174 | sots art 175 | space art 176 | street art 177 | stuckism 178 | sumatraism 179 | superflat 180 | suprematism 181 | surrealism 182 | symbolism 183 | synchromism 184 | synthetism 185 | sōsaku hanga 186 | tachisme 187 | temporary art 188 | tonalism 189 | toyism 190 | transgressive art 191 | ukiyo-e 192 | underground comix 193 | unilalianism 194 | vancouver school 195 | vanitas 196 | verdadism 197 | video art 198 | viennese actionism 199 | visual art 200 | vorticism 201 | -------------------------------------------------------------------------------- /diffuzers/gfp_gan.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | import shutil 4 | from dataclasses import dataclass 5 | from io import BytesIO 6 | from typing import Optional 7 | 8 | import cv2 9 | import numpy as np 10 | import streamlit as st 11 | import torch 12 | from basicsr.archs.srvgg_arch import SRVGGNetCompact 13 | from gfpgan.utils import GFPGANer 14 | from loguru import logger 15 | from PIL import Image 16 | from realesrgan.utils import RealESRGANer 17 | 18 | from diffuzers import utils 19 | 20 | 21 | @dataclass 22 | class GFPGAN: 23 | device: Optional[str] = None 24 | output_path: Optional[str] = None 25 | 26 | def __str__(self) -> str: 27 | return f"GFPGAN(device={self.device}, output_path={self.output_path})" 28 | 29 | def __post_init__(self): 30 | files = { 31 | "realesr-general-x4v3.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", 32 | "v1.2": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth", 33 | "v1.3": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth", 34 | "v1.4": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth", 35 | "RestoreFormer": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth", 36 | "CodeFormer": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth", 37 | } 38 | _ = utils.cache_folder() 39 | self.model_paths = {} 40 | for file_key, file in files.items(): 41 | logger.info(f"Downloading {file_key} from {file}") 42 | basename = os.path.basename(file) 43 | output_path = os.path.join(utils.cache_folder(), basename) 44 | if os.path.exists(output_path): 45 | self.model_paths[file_key] = output_path 46 | continue 47 | temp_file = utils.download_file(file) 48 | shutil.move(temp_file, output_path) 49 | self.model_paths[file_key] = output_path 50 | 51 | self.model = SRVGGNetCompact( 52 | num_in_ch=3, 53 | num_out_ch=3, 54 | num_feat=64, 55 | num_conv=32, 56 | upscale=4, 57 | act_type="prelu", 58 | ) 59 | model_path = os.path.join(utils.cache_folder(), self.model_paths["realesr-general-x4v3.pth"]) 60 | half = True if torch.cuda.is_available() else False 61 | self.upsampler = RealESRGANer( 62 | scale=4, 63 | model_path=model_path, 64 | model=self.model, 65 | tile=0, 66 | tile_pad=10, 67 | pre_pad=0, 68 | half=half, 69 | ) 70 | 71 | def inference(self, img, version, scale): 72 | # taken from: https://huggingface.co/spaces/Xintao/GFPGAN/blob/main/app.py 73 | # weight /= 100 74 | if scale > 4: 75 | scale = 4 # avoid too large scale value 76 | 77 | file_bytes = np.asarray(bytearray(img.read()), dtype=np.uint8) 78 | img = cv2.imdecode(file_bytes, 1) 79 | # img = cv2.imread(img, cv2.IMREAD_UNCHANGED) 80 | if len(img.shape) == 2: # for gray inputs 81 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 82 | 83 | h, w = img.shape[0:2] 84 | if h < 300: 85 | img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4) 86 | 87 | if version == "v1.2": 88 | face_enhancer = GFPGANer( 89 | model_path=self.model_paths["v1.2"], 90 | upscale=2, 91 | arch="clean", 92 | channel_multiplier=2, 93 | bg_upsampler=self.upsampler, 94 | ) 95 | elif version == "v1.3": 96 | face_enhancer = GFPGANer( 97 | model_path=self.model_paths["v1.3"], 98 | upscale=2, 99 | arch="clean", 100 | channel_multiplier=2, 101 | bg_upsampler=self.upsampler, 102 | ) 103 | elif version == "v1.4": 104 | face_enhancer = GFPGANer( 105 | model_path=self.model_paths["v1.4"], 106 | upscale=2, 107 | arch="clean", 108 | channel_multiplier=2, 109 | bg_upsampler=self.upsampler, 110 | ) 111 | elif version == "RestoreFormer": 112 | face_enhancer = GFPGANer( 113 | model_path=self.model_paths["RestoreFormer"], 114 | upscale=2, 115 | arch="RestoreFormer", 116 | channel_multiplier=2, 117 | bg_upsampler=self.upsampler, 118 | ) 119 | # elif version == 'CodeFormer': 120 | # face_enhancer = GFPGANer( 121 | # model_path='CodeFormer.pth', upscale=2, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler) 122 | try: 123 | # _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True, weight=weight) 124 | _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) 125 | except RuntimeError as error: 126 | logger.error("Error", error) 127 | 128 | try: 129 | if scale != 2: 130 | interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4 131 | h, w = img.shape[0:2] 132 | output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation) 133 | except Exception as error: 134 | logger.error("wrong scale input.", error) 135 | 136 | output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB) 137 | return output 138 | 139 | def app(self): 140 | input_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) 141 | if input_image is not None: 142 | st.image(input_image, width=512) 143 | with st.form(key="gfpgan"): 144 | version = st.selectbox("GFPGAN version", ["v1.2", "v1.3", "v1.4", "RestoreFormer"]) 145 | scale = st.slider("Scale", 2, 4, 4, 1) 146 | submit = st.form_submit_button("Upscale") 147 | if submit: 148 | if input_image is not None: 149 | with st.spinner("Upscaling image..."): 150 | output_img = self.inference(input_image, version, scale) 151 | st.image(output_img, width=512) 152 | # add image download button 153 | output_img = Image.fromarray(output_img) 154 | buffered = BytesIO() 155 | output_img.save(buffered, format="PNG") 156 | img_str = base64.b64encode(buffered.getvalue()).decode() 157 | href = f'Download Image' 158 | st.markdown(href, unsafe_allow_html=True) 159 | -------------------------------------------------------------------------------- /diffuzers/gradio_app.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | import gradio as gr 6 | import torch 7 | from PIL.PngImagePlugin import PngInfo 8 | 9 | from .img2img import Img2Img 10 | from .text2img import Text2Image 11 | 12 | 13 | @dataclass 14 | class Diffuzers: 15 | model: str 16 | output_path: str 17 | img2img_model: Optional[str] = None 18 | inpainting_model: Optional[str] = None 19 | 20 | def __post_init__(self): 21 | device = "cuda" if torch.cuda.is_available() else "cpu" 22 | self.text2img = Text2Image( 23 | model=self.model, 24 | device=device, 25 | ) 26 | self.img2img = Img2Img( 27 | model=self.img2img_model, 28 | device=device, 29 | text2img_model=self.text2img.pipeline, 30 | ) 31 | os.makedirs(self.output_path, exist_ok=True) 32 | os.makedirs(os.path.join(self.output_path, "text2img"), exist_ok=True) 33 | os.makedirs(os.path.join(self.output_path, "img2img"), exist_ok=True) 34 | 35 | def _text2img_input(self): 36 | with gr.Column(): 37 | project_name_text = "Project name (optional; used to save the images, if provided)" 38 | project_name = gr.Textbox(label=project_name_text, lines=1, max_lines=1, placeholder="my_project") 39 | # TODO: add batch support 40 | # with gr.Tabs(): 41 | # with gr.TabItem("single"): 42 | prompt = gr.Textbox(label="Prompt", lines=3, max_lines=3) 43 | # with gr.TabItem("batch"): 44 | # prompt = gr.File(file_types=["text"]) 45 | # with gr.Tabs(): 46 | # with gr.TabItem("single"): 47 | negative_prompt = gr.Textbox(label="Negative prompt (optional)", lines=3, max_lines=3) 48 | # with gr.TabItem("batch"): 49 | # negative_prompt = gr.File(file_types=["text"]) 50 | with gr.Column(): 51 | available_schedulers = list(self.text2img.compatible_schedulers.keys()) 52 | scheduler = gr.Dropdown(choices=available_schedulers, label="Scheduler", value=available_schedulers[0]) 53 | image_size = gr.Number(512, label="Image size", precision=0) 54 | guidance_scale = gr.Slider(1, maximum=20, value=7.5, step=0.5, label="Guidance scale") 55 | num_images = gr.Slider(1, 128, 1, label="Number of images per prompt", step=4) 56 | steps = gr.Slider(1, 150, 50, label="Steps") 57 | seed = gr.Slider(minimum=1, step=1, maximum=999999, randomize=True, label="Seed") 58 | generate_button = gr.Button("Generate") 59 | params_dict = { 60 | "prompt": prompt, 61 | "negative_prompt": negative_prompt, 62 | "scheduler": scheduler, 63 | "image_size": image_size, 64 | "guidance_scale": guidance_scale, 65 | "num_images": num_images, 66 | "steps": steps, 67 | "seed": seed, 68 | "project_name": project_name, 69 | "generate_button": generate_button, 70 | } 71 | return params_dict 72 | 73 | def _text2img_output(self): 74 | with gr.Column(): 75 | text2img_output = gr.Gallery() 76 | text2img_output.style(grid=[4], container=False) 77 | with gr.Column(): 78 | text2img_output_params = gr.Markdown() 79 | params_dict = { 80 | "output": text2img_output, 81 | "markdown": text2img_output_params, 82 | } 83 | return params_dict 84 | 85 | def _text2img_generate( 86 | self, prompt, negative_prompt, scheduler, image_size, guidance_scale, num_images, steps, seed, project_name 87 | ): 88 | output_images = self.text2img.generate_image( 89 | prompt=prompt, 90 | negative_prompt=negative_prompt, 91 | scheduler=scheduler, 92 | image_size=image_size, 93 | guidance_scale=guidance_scale, 94 | steps=steps, 95 | seed=seed, 96 | num_images=num_images, 97 | ) 98 | params_used = f" - **Prompt:** {prompt}" 99 | params_used += f"\n - **Negative prompt:** {negative_prompt}" 100 | params_used += f"\n - **Scheduler:** {scheduler}" 101 | params_used += f"\n - **Image size:** {image_size}" 102 | params_used += f"\n - **Guidance scale:** {guidance_scale}" 103 | params_used += f"\n - **Steps:** {steps}" 104 | params_used += f"\n - **Seed:** {seed}" 105 | 106 | if len(project_name.strip()) > 0: 107 | self._save_images( 108 | images=output_images, 109 | metadata=params_used, 110 | folder_name=project_name, 111 | prefix="text2img", 112 | ) 113 | 114 | return [output_images, params_used] 115 | 116 | def _img2img_input(self): 117 | with gr.Column(): 118 | input_image = gr.Image(source="upload", label="input image | size must match model", type="pil") 119 | strength = gr.Slider(0, 1, 0.8, label="Strength") 120 | available_schedulers = list(self.img2img.compatible_schedulers.keys()) 121 | scheduler = gr.Dropdown(choices=available_schedulers, label="Scheduler", value=available_schedulers[0]) 122 | image_size = gr.Number(512, label="Image size (image will be resized to this)", precision=0) 123 | guidance_scale = gr.Slider(1, maximum=20, value=7.5, step=0.5, label="Guidance scale") 124 | num_images = gr.Slider(4, 128, 4, label="Number of images", step=4) 125 | steps = gr.Slider(1, 150, 50, label="Steps") 126 | with gr.Column(): 127 | project_name_text = "Project name (optional; used to save the images, if provided)" 128 | project_name = gr.Textbox(label=project_name_text, lines=1, max_lines=1, placeholder="my_project") 129 | prompt = gr.Textbox(label="Prompt", lines=3, max_lines=3) 130 | negative_prompt = gr.Textbox(label="Negative prompt (optional)", lines=3, max_lines=3) 131 | seed = gr.Slider(minimum=1, step=1, maximum=999999, randomize=True, label="Seed") 132 | generate_button = gr.Button("Generate") 133 | params_dict = { 134 | "input_image": input_image, 135 | "prompt": prompt, 136 | "negative_prompt": negative_prompt, 137 | "strength": strength, 138 | "scheduler": scheduler, 139 | "image_size": image_size, 140 | "guidance_scale": guidance_scale, 141 | "num_images": num_images, 142 | "steps": steps, 143 | "seed": seed, 144 | "project_name": project_name, 145 | "generate_button": generate_button, 146 | } 147 | return params_dict 148 | 149 | def _img2img_output(self): 150 | with gr.Column(): 151 | img2img_output = gr.Gallery() 152 | img2img_output.style(grid=[4], container=False) 153 | with gr.Column(): 154 | img2img_output_params = gr.Markdown() 155 | params_dict = { 156 | "output": img2img_output, 157 | "markdown": img2img_output_params, 158 | } 159 | return params_dict 160 | 161 | def _img2img_generate( 162 | self, 163 | input_image, 164 | prompt, 165 | negative_prompt, 166 | strength, 167 | scheduler, 168 | image_size, 169 | guidance_scale, 170 | num_images, 171 | steps, 172 | seed, 173 | project_name, 174 | ): 175 | output_images = self.img2img.generate_image( 176 | image=input_image, 177 | prompt=prompt, 178 | negative_prompt=negative_prompt, 179 | strength=strength, 180 | scheduler=scheduler, 181 | image_size=image_size, 182 | guidance_scale=guidance_scale, 183 | steps=steps, 184 | seed=seed, 185 | num_images=num_images, 186 | ) 187 | params_used = f" - **Prompt:** {prompt}" 188 | params_used += f"\n - **Negative prompt:** {negative_prompt}" 189 | params_used += f"\n - **Scheduler:** {scheduler}" 190 | params_used += f"\n - **Strength:** {strength}" 191 | params_used += f"\n - **Image size:** {image_size}" 192 | params_used += f"\n - **Guidance scale:** {guidance_scale}" 193 | params_used += f"\n - **Steps:** {steps}" 194 | params_used += f"\n - **Seed:** {seed}" 195 | 196 | if len(project_name.strip()) > 0: 197 | self._save_images( 198 | images=output_images, 199 | metadata=params_used, 200 | folder_name=project_name, 201 | prefix="img2img", 202 | ) 203 | 204 | return [output_images, params_used] 205 | 206 | def _save_images(self, images, metadata, folder_name, prefix): 207 | folder_path = os.path.join(self.output_path, prefix, folder_name) 208 | os.makedirs(folder_path, exist_ok=True) 209 | for idx, image in enumerate(images): 210 | text2img_metadata = PngInfo() 211 | text2img_metadata.add_text(prefix, metadata) 212 | image.save(os.path.join(folder_path, f"{idx}.png"), format="PNG", pnginfo=text2img_metadata) 213 | with open(os.path.join(folder_path, "metadata.txt"), "w") as f: 214 | f.write(metadata) 215 | 216 | def _png_info(self, img): 217 | text2img_md = img.info.get("text2img", "") 218 | img2img_md = img.info.get("img2img", "") 219 | return_text = "" 220 | if len(text2img_md) > 0: 221 | return_text += f"## Text2Img\n{text2img_md}\n" 222 | if len(img2img_md) > 0: 223 | return_text += f"## Img2Img\n{img2img_md}\n" 224 | return return_text 225 | 226 | def app(self): 227 | with gr.Blocks() as demo: 228 | with gr.Blocks(): 229 | gr.Markdown("# Diffuzers") 230 | gr.Markdown(f"Text2Img Model: {self.model}") 231 | if self.img2img_model: 232 | gr.Markdown(f"Img2Img Model: {self.img2img_model}") 233 | else: 234 | gr.Markdown(f"Img2Img Model: {self.model}") 235 | 236 | with gr.Tabs(): 237 | with gr.TabItem("Text2Image", id="text2image"): 238 | with gr.Row(): 239 | text2img_input = self._text2img_input() 240 | with gr.Row(): 241 | text2img_output = self._text2img_output() 242 | text2img_input["generate_button"].click( 243 | fn=self._text2img_generate, 244 | inputs=[ 245 | text2img_input["prompt"], 246 | text2img_input["negative_prompt"], 247 | text2img_input["scheduler"], 248 | text2img_input["image_size"], 249 | text2img_input["guidance_scale"], 250 | text2img_input["num_images"], 251 | text2img_input["steps"], 252 | text2img_input["seed"], 253 | text2img_input["project_name"], 254 | ], 255 | outputs=[text2img_output["output"], text2img_output["markdown"]], 256 | ) 257 | with gr.TabItem("Image2Image", id="img2img"): 258 | with gr.Row(): 259 | img2img_input = self._img2img_input() 260 | with gr.Row(): 261 | img2img_output = self._img2img_output() 262 | img2img_input["generate_button"].click( 263 | fn=self._img2img_generate, 264 | inputs=[ 265 | img2img_input["input_image"], 266 | img2img_input["prompt"], 267 | img2img_input["negative_prompt"], 268 | img2img_input["strength"], 269 | img2img_input["scheduler"], 270 | img2img_input["image_size"], 271 | img2img_input["guidance_scale"], 272 | img2img_input["num_images"], 273 | img2img_input["steps"], 274 | img2img_input["seed"], 275 | img2img_input["project_name"], 276 | ], 277 | outputs=[img2img_output["output"], img2img_output["markdown"]], 278 | ) 279 | with gr.TabItem("Inpainting", id="inpainting"): 280 | gr.Markdown("# coming soon!") 281 | with gr.TabItem("ImageInfo", id="imginfo"): 282 | with gr.Column(): 283 | img_info_input_file = gr.Image(label="Input image", source="upload", type="pil") 284 | with gr.Column(): 285 | img_info_output_md = gr.Markdown() 286 | img_info_input_file.change( 287 | fn=self._png_info, 288 | inputs=[img_info_input_file], 289 | outputs=[img_info_output_md], 290 | ) 291 | 292 | return demo 293 | -------------------------------------------------------------------------------- /diffuzers/image_info.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import streamlit as st 4 | from PIL import Image 5 | 6 | 7 | @dataclass 8 | class ImageInfo: 9 | def app(self): 10 | # upload image 11 | uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"]) 12 | if uploaded_file is not None: 13 | # read image using pil 14 | pil_image = Image.open(uploaded_file) 15 | st.image(uploaded_file, use_column_width=True) 16 | image_info = pil_image.info 17 | # display image info 18 | st.write(image_info) 19 | -------------------------------------------------------------------------------- /diffuzers/img2img.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | from dataclasses import dataclass 4 | from io import BytesIO 5 | from typing import Optional, Union 6 | 7 | import requests 8 | import streamlit as st 9 | import torch 10 | from diffusers import ( 11 | AltDiffusionImg2ImgPipeline, 12 | AltDiffusionPipeline, 13 | DiffusionPipeline, 14 | StableDiffusionImg2ImgPipeline, 15 | StableDiffusionPipeline, 16 | ) 17 | from loguru import logger 18 | from PIL import Image 19 | from PIL.PngImagePlugin import PngInfo 20 | 21 | from diffuzers import utils 22 | 23 | 24 | @dataclass 25 | class Img2Img: 26 | model: Optional[str] = None 27 | device: Optional[str] = None 28 | output_path: Optional[str] = None 29 | text2img_model: Optional[Union[StableDiffusionPipeline, AltDiffusionPipeline]] = None 30 | 31 | def __str__(self) -> str: 32 | return f"Img2Img(model={self.model}, device={self.device}, output_path={self.output_path})" 33 | 34 | def __post_init__(self): 35 | if self.model is not None: 36 | self.text2img_model = DiffusionPipeline.from_pretrained( 37 | self.model, 38 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, 39 | ) 40 | components = self.text2img_model.components 41 | print(components) 42 | if isinstance(self.text2img_model, StableDiffusionPipeline): 43 | self.pipeline = StableDiffusionImg2ImgPipeline(**components) 44 | elif isinstance(self.text2img_model, AltDiffusionPipeline): 45 | self.pipeline = AltDiffusionImg2ImgPipeline(**components) 46 | else: 47 | raise ValueError("Model type not supported") 48 | 49 | self.pipeline.to(self.device) 50 | self.pipeline.safety_checker = utils.no_safety_checker 51 | self._compatible_schedulers = self.pipeline.scheduler.compatibles 52 | self.scheduler_config = self.pipeline.scheduler.config 53 | self.compatible_schedulers = {scheduler.__name__: scheduler for scheduler in self._compatible_schedulers} 54 | 55 | if self.device == "mps": 56 | self.pipeline.enable_attention_slicing() 57 | # warmup 58 | url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" 59 | response = requests.get(url) 60 | init_image = Image.open(BytesIO(response.content)).convert("RGB") 61 | init_image.thumbnail((768, 768)) 62 | prompt = "A fantasy landscape, trending on artstation" 63 | _ = self.pipeline( 64 | prompt=prompt, 65 | image=init_image, 66 | strength=0.75, 67 | guidance_scale=7.5, 68 | num_inference_steps=2, 69 | ) 70 | 71 | def _set_scheduler(self, scheduler_name): 72 | scheduler = self.compatible_schedulers[scheduler_name].from_config(self.scheduler_config) 73 | self.pipeline.scheduler = scheduler 74 | 75 | def generate_image( 76 | self, prompt, image, strength, negative_prompt, scheduler, num_images, guidance_scale, steps, seed 77 | ): 78 | self._set_scheduler(scheduler) 79 | logger.info(self.pipeline.scheduler) 80 | if self.device == "mps": 81 | generator = torch.manual_seed(seed) 82 | num_images = 1 83 | else: 84 | generator = torch.Generator(device=self.device).manual_seed(seed) 85 | num_images = int(num_images) 86 | output_images = self.pipeline( 87 | prompt=prompt, 88 | image=image, 89 | strength=strength, 90 | negative_prompt=negative_prompt, 91 | num_inference_steps=steps, 92 | guidance_scale=guidance_scale, 93 | num_images_per_prompt=num_images, 94 | generator=generator, 95 | ).images 96 | torch.cuda.empty_cache() 97 | gc.collect() 98 | metadata = { 99 | "prompt": prompt, 100 | "negative_prompt": negative_prompt, 101 | "scheduler": scheduler, 102 | "num_images": num_images, 103 | "guidance_scale": guidance_scale, 104 | "steps": steps, 105 | "seed": seed, 106 | } 107 | metadata = json.dumps(metadata) 108 | _metadata = PngInfo() 109 | _metadata.add_text("img2img", metadata) 110 | 111 | utils.save_images( 112 | images=output_images, 113 | module="img2img", 114 | metadata=metadata, 115 | output_path=self.output_path, 116 | ) 117 | return output_images, _metadata 118 | 119 | def app(self): 120 | available_schedulers = list(self.compatible_schedulers.keys()) 121 | # if EulerAncestralDiscreteScheduler is available in available_schedulers, move it to the first position 122 | if "EulerAncestralDiscreteScheduler" in available_schedulers: 123 | available_schedulers.insert( 124 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler")) 125 | ) 126 | 127 | input_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) 128 | if input_image is not None: 129 | input_image = Image.open(input_image) 130 | st.image(input_image, use_column_width=True) 131 | 132 | # with st.form(key="img2img"): 133 | col1, col2 = st.columns(2) 134 | with col1: 135 | prompt = st.text_area("Prompt", "") 136 | with col2: 137 | negative_prompt = st.text_area("Negative Prompt", "") 138 | 139 | scheduler = st.sidebar.selectbox("Scheduler", available_schedulers, index=0) 140 | guidance_scale = st.sidebar.slider("Guidance scale", 1.0, 40.0, 7.5, 0.5) 141 | strength = st.sidebar.slider("Strength", 0.0, 1.0, 0.8, 0.05) 142 | num_images = st.sidebar.slider("Number of images per prompt", 1, 30, 1, 1) 143 | steps = st.sidebar.slider("Steps", 1, 150, 50, 1) 144 | seed_placeholder = st.sidebar.empty() 145 | seed = seed_placeholder.number_input("Seed", value=42, min_value=1, max_value=999999, step=1) 146 | random_seed = st.sidebar.button("Random seed") 147 | _seed = torch.randint(1, 999999, (1,)).item() 148 | if random_seed: 149 | seed = seed_placeholder.number_input("Seed", value=_seed, min_value=1, max_value=999999, step=1) 150 | # seed = st.sidebar.number_input("Seed", 1, 999999, 1, 1) 151 | sub_col, download_col = st.columns(2) 152 | with sub_col: 153 | submit = st.button("Generate") 154 | 155 | if submit: 156 | with st.spinner("Generating images..."): 157 | output_images, metadata = self.generate_image( 158 | prompt=prompt, 159 | image=input_image, 160 | negative_prompt=negative_prompt, 161 | scheduler=scheduler, 162 | num_images=num_images, 163 | guidance_scale=guidance_scale, 164 | steps=steps, 165 | seed=seed, 166 | strength=strength, 167 | ) 168 | 169 | utils.display_and_download_images(output_images, metadata, download_col) 170 | -------------------------------------------------------------------------------- /diffuzers/inpainting.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | import random 4 | from dataclasses import dataclass 5 | from io import BytesIO 6 | from typing import Optional 7 | 8 | import requests 9 | import streamlit as st 10 | import torch 11 | from diffusers import StableDiffusionInpaintPipeline 12 | from loguru import logger 13 | from PIL import Image 14 | from PIL.PngImagePlugin import PngInfo 15 | from streamlit_drawable_canvas import st_canvas 16 | 17 | from diffuzers import utils 18 | 19 | 20 | @dataclass 21 | class Inpainting: 22 | model: Optional[str] = None 23 | device: Optional[str] = None 24 | output_path: Optional[str] = None 25 | 26 | def __str__(self) -> str: 27 | return f"Inpainting(model={self.model}, device={self.device}, output_path={self.output_path})" 28 | 29 | def __post_init__(self): 30 | self.pipeline = StableDiffusionInpaintPipeline.from_pretrained( 31 | self.model, 32 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, 33 | use_auth_token=utils.use_auth_token(), 34 | ) 35 | 36 | self.pipeline.to(self.device) 37 | self.pipeline.safety_checker = utils.no_safety_checker 38 | self._compatible_schedulers = self.pipeline.scheduler.compatibles 39 | self.scheduler_config = self.pipeline.scheduler.config 40 | self.compatible_schedulers = {scheduler.__name__: scheduler for scheduler in self._compatible_schedulers} 41 | 42 | if self.device == "mps": 43 | self.pipeline.enable_attention_slicing() 44 | # warmup 45 | 46 | def download_image(url): 47 | response = requests.get(url) 48 | return Image.open(BytesIO(response.content)).convert("RGB") 49 | 50 | img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" 51 | mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" 52 | 53 | init_image = download_image(img_url).resize((512, 512)) 54 | mask_image = download_image(mask_url).resize((512, 512)) 55 | 56 | prompt = "Face of a yellow cat, high resolution, sitting on a park bench" 57 | _ = self.pipeline( 58 | prompt=prompt, 59 | image=init_image, 60 | mask_image=mask_image, 61 | num_inference_steps=2, 62 | ) 63 | 64 | def _set_scheduler(self, scheduler_name): 65 | scheduler = self.compatible_schedulers[scheduler_name].from_config(self.scheduler_config) 66 | self.pipeline.scheduler = scheduler 67 | 68 | def generate_image( 69 | self, prompt, negative_prompt, image, mask, guidance_scale, scheduler, steps, seed, height, width, num_images 70 | ): 71 | 72 | if seed == -1: 73 | # generate random seed 74 | seed = random.randint(0, 999999) 75 | 76 | self._set_scheduler(scheduler) 77 | logger.info(self.pipeline.scheduler) 78 | 79 | if self.device == "mps": 80 | generator = torch.manual_seed(seed) 81 | num_images = 1 82 | else: 83 | generator = torch.Generator(device=self.device).manual_seed(seed) 84 | 85 | output_images = self.pipeline( 86 | prompt=prompt, 87 | negative_prompt=negative_prompt, 88 | image=image, 89 | mask_image=mask, 90 | num_inference_steps=steps, 91 | guidance_scale=guidance_scale, 92 | num_images_per_prompt=num_images, 93 | generator=generator, 94 | height=height, 95 | width=width, 96 | ).images 97 | metadata = { 98 | "prompt": prompt, 99 | "negative_prompt": negative_prompt, 100 | "guidance_scale": guidance_scale, 101 | "scheduler": scheduler, 102 | "steps": steps, 103 | "seed": seed, 104 | } 105 | metadata = json.dumps(metadata) 106 | _metadata = PngInfo() 107 | _metadata.add_text("inpainting", metadata) 108 | 109 | utils.save_images( 110 | images=output_images, 111 | module="inpainting", 112 | metadata=metadata, 113 | output_path=self.output_path, 114 | ) 115 | 116 | torch.cuda.empty_cache() 117 | gc.collect() 118 | return output_images, _metadata 119 | 120 | def app(self): 121 | stroke_color = "#FFF" 122 | bg_color = "#000" 123 | col1, col2 = st.columns(2) 124 | # with col1: 125 | with col1: 126 | prompt = st.text_area("Prompt", "", key="inpainting_prompt", help="Prompt for the image generation") 127 | # with col2: 128 | negative_prompt = st.text_area( 129 | "Negative Prompt", 130 | "", 131 | key="inpainting_negative_prompt", 132 | help="The prompt not to guide image generation. Write things that you dont want to see in the image.", 133 | ) 134 | with col2: 135 | uploaded_file = st.file_uploader( 136 | "Image:", 137 | type=["png", "jpg", "jpeg"], 138 | help="Image size must match model's image size. Usually: 512 or 768", 139 | key="inpainting_uploaded_file", 140 | ) 141 | 142 | # sidebar options 143 | available_schedulers = list(self.compatible_schedulers.keys()) 144 | if "EulerAncestralDiscreteScheduler" in available_schedulers: 145 | available_schedulers.insert( 146 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler")) 147 | ) 148 | scheduler = st.sidebar.selectbox( 149 | "Scheduler", 150 | available_schedulers, 151 | index=0, 152 | key="inpainting_scheduler", 153 | help="Scheduler to use for generation", 154 | ) 155 | guidance_scale = st.sidebar.slider( 156 | "Guidance scale", 157 | 1.0, 158 | 40.0, 159 | 7.5, 160 | 0.5, 161 | key="inpainting_guidance_scale", 162 | help="Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.", 163 | ) 164 | num_images = st.sidebar.slider( 165 | "Number of images per prompt", 166 | 1, 167 | 30, 168 | 1, 169 | 1, 170 | key="inpainting_num_images", 171 | help="Number of images you want to generate. More images requires more time and uses more GPU memory.", 172 | ) 173 | steps = st.sidebar.slider( 174 | "Steps", 175 | 1, 176 | 150, 177 | 50, 178 | 1, 179 | key="inpainting_steps", 180 | help="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.", 181 | ) 182 | seed = st.sidebar.number_input( 183 | "Seed", 184 | value=42, 185 | min_value=-1, 186 | max_value=999999, 187 | step=1, 188 | help="Random seed. Change for different results using same parameters.", 189 | ) 190 | 191 | if uploaded_file is not None: 192 | with col2: 193 | drawing_mode = st.selectbox( 194 | "Drawing tool:", ("freedraw", "rect", "circle"), key="inpainting_drawing_mode" 195 | ) 196 | stroke_width = st.slider("Stroke width: ", 1, 25, 8, key="inpainting_stroke_width") 197 | 198 | pil_image = Image.open(uploaded_file).convert("RGB") 199 | img_height, img_width = pil_image.size 200 | canvas_result = st_canvas( 201 | fill_color="rgb(255, 255, 255)", 202 | stroke_width=stroke_width, 203 | stroke_color=stroke_color, 204 | background_color=bg_color, 205 | background_image=pil_image, 206 | update_streamlit=True, 207 | drawing_mode=drawing_mode, 208 | height=768, 209 | width=768, 210 | key="inpainting_canvas", 211 | ) 212 | 213 | with col1: 214 | submit = st.button("Generate", key="inpainting_submit") 215 | if ( 216 | canvas_result.image_data is not None 217 | and pil_image 218 | and len(canvas_result.json_data["objects"]) > 0 219 | and submit 220 | ): 221 | mask_npy = canvas_result.image_data[:, :, 3] 222 | # convert mask npy to PIL image 223 | mask_pil = Image.fromarray(mask_npy).convert("RGB") 224 | # resize mask to match image size 225 | mask_pil = mask_pil.resize((img_width, img_height), resample=Image.LANCZOS) 226 | with st.spinner("Generating..."): 227 | output_images, metadata = self.generate_image( 228 | prompt=prompt, 229 | negative_prompt=negative_prompt, 230 | image=pil_image, 231 | mask=mask_pil, 232 | guidance_scale=guidance_scale, 233 | scheduler=scheduler, 234 | steps=steps, 235 | seed=seed, 236 | height=img_height, 237 | width=img_width, 238 | num_images=num_images, 239 | ) 240 | 241 | utils.display_and_download_images(output_images, metadata) 242 | -------------------------------------------------------------------------------- /diffuzers/interrogator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | import streamlit as st 6 | from clip_interrogator import Config, Interrogator 7 | from huggingface_hub import hf_hub_download 8 | from loguru import logger 9 | from PIL import Image 10 | 11 | from diffuzers import utils 12 | 13 | 14 | @dataclass 15 | class ImageInterrogator: 16 | device: Optional[str] = None 17 | output_path: Optional[str] = None 18 | 19 | def inference(self, model, image, mode): 20 | preprocess_files = [ 21 | "ViT-H-14_laion2b_s32b_b79k_artists.pkl", 22 | "ViT-H-14_laion2b_s32b_b79k_flavors.pkl", 23 | "ViT-H-14_laion2b_s32b_b79k_mediums.pkl", 24 | "ViT-H-14_laion2b_s32b_b79k_movements.pkl", 25 | "ViT-H-14_laion2b_s32b_b79k_trendings.pkl", 26 | "ViT-L-14_openai_artists.pkl", 27 | "ViT-L-14_openai_flavors.pkl", 28 | "ViT-L-14_openai_mediums.pkl", 29 | "ViT-L-14_openai_movements.pkl", 30 | "ViT-L-14_openai_trendings.pkl", 31 | ] 32 | 33 | logger.info("Downloading preprocessed cache files...") 34 | for file in preprocess_files: 35 | path = hf_hub_download(repo_id="pharma/ci-preprocess", filename=file, cache_dir=utils.cache_folder()) 36 | cache_path = os.path.dirname(path) 37 | 38 | config = Config(cache_path=cache_path, clip_model_path=utils.cache_folder(), clip_model_name=model) 39 | pipeline = Interrogator(config) 40 | 41 | pipeline.config.blip_num_beams = 64 42 | pipeline.config.chunk_size = 2048 43 | pipeline.config.flavor_intermediate_count = 2048 if model == "ViT-L-14/openai" else 1024 44 | 45 | if mode == "best": 46 | prompt = pipeline.interrogate(image) 47 | elif mode == "classic": 48 | prompt = pipeline.interrogate_classic(image) 49 | else: 50 | prompt = pipeline.interrogate_fast(image) 51 | return prompt 52 | 53 | def app(self): 54 | # upload image 55 | input_image = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"]) 56 | with st.form(key="image_interrogator"): 57 | clip_model = st.selectbox("CLIP Model", ["ViT-L (Best for SD1.X)", "ViT-H (Best for SD2.X)"]) 58 | mode = st.selectbox("Mode", ["Best", "Classic"]) 59 | submit = st.form_submit_button("Interrogate") 60 | if input_image is not None: 61 | # read image using pil 62 | pil_image = Image.open(input_image).convert("RGB") 63 | if submit: 64 | with st.spinner("Interrogating image..."): 65 | if clip_model == "ViT-L (Best for SD1.X)": 66 | model = "ViT-L-14/openai" 67 | else: 68 | model = "ViT-H-14/laion2b_s32b_b79k" 69 | prompt = self.inference(model, pil_image, mode.lower()) 70 | col1, col2 = st.columns(2) 71 | with col1: 72 | st.image(input_image, use_column_width=True) 73 | with col2: 74 | st.write(prompt) 75 | -------------------------------------------------------------------------------- /diffuzers/pages/1_Inpainting.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | from diffuzers import utils 4 | from diffuzers.inpainting import Inpainting 5 | 6 | 7 | def app(): 8 | utils.create_base_page() 9 | with st.form("inpainting_model_form"): 10 | model = st.text_input( 11 | "Which model do you want to use for inpainting?", 12 | value="runwayml/stable-diffusion-inpainting" 13 | if st.session_state.get("inpainting_model") is None 14 | else st.session_state.inpainting_model, 15 | ) 16 | submit = st.form_submit_button("Load model") 17 | if submit: 18 | st.session_state.inpainting_model = model 19 | with st.spinner("Loading model..."): 20 | inpainting = Inpainting( 21 | model=model, 22 | device=st.session_state.device, 23 | output_path=st.session_state.output_path, 24 | ) 25 | st.session_state.inpainting = inpainting 26 | if "inpainting" in st.session_state: 27 | st.write(f"Current model: {st.session_state.inpainting}") 28 | st.session_state.inpainting.app() 29 | 30 | 31 | if __name__ == "__main__": 32 | app() 33 | -------------------------------------------------------------------------------- /diffuzers/pages/2_Utilities.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | from diffuzers import utils 4 | from diffuzers.gfp_gan import GFPGAN 5 | from diffuzers.image_info import ImageInfo 6 | from diffuzers.interrogator import ImageInterrogator 7 | from diffuzers.upscaler import Upscaler 8 | 9 | 10 | def app(): 11 | utils.create_base_page() 12 | task = st.selectbox( 13 | "Choose a utility", 14 | [ 15 | "ImageInfo", 16 | "SD Upscaler", 17 | "GFPGAN", 18 | "CLIP Interrogator", 19 | ], 20 | ) 21 | if task == "ImageInfo": 22 | ImageInfo().app() 23 | elif task == "SD Upscaler": 24 | with st.form("upscaler_model"): 25 | upscaler_model = st.text_input("Model", "stabilityai/stable-diffusion-x4-upscaler") 26 | submit = st.form_submit_button("Load model") 27 | if submit: 28 | with st.spinner("Loading model..."): 29 | ups = Upscaler( 30 | model=upscaler_model, 31 | device=st.session_state.device, 32 | output_path=st.session_state.output_path, 33 | ) 34 | st.session_state.ups = ups 35 | if "ups" in st.session_state: 36 | st.write(f"Current model: {st.session_state.ups}") 37 | st.session_state.ups.app() 38 | 39 | elif task == "GFPGAN": 40 | with st.spinner("Loading model..."): 41 | gfpgan = GFPGAN( 42 | device=st.session_state.device, 43 | output_path=st.session_state.output_path, 44 | ) 45 | gfpgan.app() 46 | elif task == "CLIP Interrogator": 47 | interrogator = ImageInterrogator( 48 | device=st.session_state.device, 49 | output_path=st.session_state.output_path, 50 | ) 51 | interrogator.app() 52 | 53 | 54 | if __name__ == "__main__": 55 | app() 56 | -------------------------------------------------------------------------------- /diffuzers/pages/3_FAQs.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | from diffuzers import utils 4 | 5 | 6 | def app(): 7 | utils.create_base_page() 8 | st.markdown("## FAQs") 9 | 10 | 11 | if __name__ == "__main__": 12 | app() 13 | -------------------------------------------------------------------------------- /diffuzers/pages/4_Code of Conduct.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | from diffuzers import utils 4 | 5 | 6 | def app(): 7 | utils.create_base_page() 8 | st.markdown(utils.CODE_OF_CONDUCT) 9 | 10 | 11 | if __name__ == "__main__": 12 | app() 13 | -------------------------------------------------------------------------------- /diffuzers/text2img.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import streamlit as st 7 | import torch 8 | from diffusers import DiffusionPipeline 9 | from loguru import logger 10 | from PIL.PngImagePlugin import PngInfo 11 | 12 | from diffuzers import utils 13 | 14 | 15 | @dataclass 16 | class Text2Image: 17 | device: Optional[str] = None 18 | model: Optional[str] = None 19 | output_path: Optional[str] = None 20 | 21 | def __str__(self) -> str: 22 | return f"Text2Image(model={self.model}, device={self.device}, output_path={self.output_path})" 23 | 24 | def __post_init__(self): 25 | self.pipeline = DiffusionPipeline.from_pretrained( 26 | self.model, 27 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, 28 | ) 29 | self.pipeline.to(self.device) 30 | self.pipeline.safety_checker = utils.no_safety_checker 31 | self._compatible_schedulers = self.pipeline.scheduler.compatibles 32 | self.scheduler_config = self.pipeline.scheduler.config 33 | self.compatible_schedulers = {scheduler.__name__: scheduler for scheduler in self._compatible_schedulers} 34 | 35 | if self.device == "mps": 36 | self.pipeline.enable_attention_slicing() 37 | # warmup 38 | prompt = "a photo of an astronaut riding a horse on mars" 39 | _ = self.pipeline(prompt, num_inference_steps=2) 40 | 41 | def _set_scheduler(self, scheduler_name): 42 | scheduler = self.compatible_schedulers[scheduler_name].from_config(self.scheduler_config) 43 | self.pipeline.scheduler = scheduler 44 | 45 | def generate_image(self, prompt, negative_prompt, scheduler, image_size, num_images, guidance_scale, steps, seed): 46 | self._set_scheduler(scheduler) 47 | logger.info(self.pipeline.scheduler) 48 | if self.device == "mps": 49 | generator = torch.manual_seed(seed) 50 | num_images = 1 51 | else: 52 | generator = torch.Generator(device=self.device).manual_seed(seed) 53 | num_images = int(num_images) 54 | output_images = self.pipeline( 55 | prompt, 56 | negative_prompt=negative_prompt, 57 | width=image_size[1], 58 | height=image_size[0], 59 | num_inference_steps=steps, 60 | guidance_scale=guidance_scale, 61 | num_images_per_prompt=num_images, 62 | generator=generator, 63 | ).images 64 | torch.cuda.empty_cache() 65 | gc.collect() 66 | metadata = { 67 | "prompt": prompt, 68 | "negative_prompt": negative_prompt, 69 | "scheduler": scheduler, 70 | "image_size": image_size, 71 | "num_images": num_images, 72 | "guidance_scale": guidance_scale, 73 | "steps": steps, 74 | "seed": seed, 75 | } 76 | metadata = json.dumps(metadata) 77 | _metadata = PngInfo() 78 | _metadata.add_text("text2img", metadata) 79 | 80 | utils.save_images( 81 | images=output_images, 82 | module="text2img", 83 | metadata=metadata, 84 | output_path=self.output_path, 85 | ) 86 | 87 | return output_images, _metadata 88 | 89 | def app(self): 90 | available_schedulers = list(self.compatible_schedulers.keys()) 91 | if "EulerAncestralDiscreteScheduler" in available_schedulers: 92 | available_schedulers.insert( 93 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler")) 94 | ) 95 | # with st.form(key="text2img"): 96 | col1, col2 = st.columns(2) 97 | with col1: 98 | prompt = st.text_area("Prompt", "Blue elephant") 99 | with col2: 100 | negative_prompt = st.text_area("Negative Prompt", "") 101 | # sidebar options 102 | scheduler = st.sidebar.selectbox("Scheduler", available_schedulers, index=0) 103 | image_height = st.sidebar.slider("Image height", 128, 1024, 512, 128) 104 | image_width = st.sidebar.slider("Image width", 128, 1024, 512, 128) 105 | guidance_scale = st.sidebar.slider("Guidance scale", 1.0, 40.0, 7.5, 0.5) 106 | num_images = st.sidebar.slider("Number of images per prompt", 1, 30, 1, 1) 107 | steps = st.sidebar.slider("Steps", 1, 150, 50, 1) 108 | 109 | seed_placeholder = st.sidebar.empty() 110 | seed = seed_placeholder.number_input("Seed", value=42, min_value=1, max_value=999999, step=1) 111 | random_seed = st.sidebar.button("Random seed") 112 | _seed = torch.randint(1, 999999, (1,)).item() 113 | if random_seed: 114 | seed = seed_placeholder.number_input("Seed", value=_seed, min_value=1, max_value=999999, step=1) 115 | 116 | sub_col, download_col = st.columns(2) 117 | with sub_col: 118 | submit = st.button("Generate") 119 | 120 | if submit: 121 | with st.spinner("Generating images..."): 122 | output_images, metadata = self.generate_image( 123 | prompt=prompt, 124 | negative_prompt=negative_prompt, 125 | scheduler=scheduler, 126 | image_size=(image_height, image_width), 127 | num_images=num_images, 128 | guidance_scale=guidance_scale, 129 | steps=steps, 130 | seed=seed, 131 | ) 132 | 133 | utils.display_and_download_images(output_images, metadata, download_col) 134 | -------------------------------------------------------------------------------- /diffuzers/textual_inversion.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import streamlit as st 7 | import torch 8 | from diffusers import DiffusionPipeline 9 | from loguru import logger 10 | from PIL.PngImagePlugin import PngInfo 11 | 12 | from diffuzers import utils 13 | 14 | 15 | def load_embed(learned_embeds_path, text_encoder, tokenizer, token=None): 16 | loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") 17 | if len(loaded_learned_embeds) > 2: 18 | embeds = loaded_learned_embeds["string_to_param"]["*"][-1, :] 19 | else: 20 | # separate token and the embeds 21 | trained_token = list(loaded_learned_embeds.keys())[0] 22 | embeds = loaded_learned_embeds[trained_token] 23 | 24 | # add the token in tokenizer 25 | token = token if token is not None else trained_token 26 | num_added_tokens = tokenizer.add_tokens(token) 27 | i = 1 28 | while num_added_tokens == 0: 29 | logger.warning(f"The tokenizer already contains the token {token}.") 30 | token = f"{token[:-1]}-{i}>" 31 | logger.info(f"Attempting to add the token {token}.") 32 | num_added_tokens = tokenizer.add_tokens(token) 33 | i += 1 34 | 35 | # resize the token embeddings 36 | text_encoder.resize_token_embeddings(len(tokenizer)) 37 | 38 | # get the id for the token and assign the embeds 39 | token_id = tokenizer.convert_tokens_to_ids(token) 40 | text_encoder.get_input_embeddings().weight.data[token_id] = embeds 41 | return token 42 | 43 | 44 | @dataclass 45 | class TextualInversion: 46 | model: str 47 | embeddings_url: str 48 | token_identifier: str 49 | device: Optional[str] = None 50 | output_path: Optional[str] = None 51 | 52 | def __str__(self) -> str: 53 | return f"TextualInversion(model={self.model}, embeddings={self.embeddings_url}, token_identifier={self.token_identifier}, device={self.device}, output_path={self.output_path})" 54 | 55 | def __post_init__(self): 56 | self.pipeline = DiffusionPipeline.from_pretrained( 57 | self.model, 58 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, 59 | ) 60 | self.pipeline.to(self.device) 61 | self.pipeline.safety_checker = utils.no_safety_checker 62 | self._compatible_schedulers = self.pipeline.scheduler.compatibles 63 | self.scheduler_config = self.pipeline.scheduler.config 64 | self.compatible_schedulers = {scheduler.__name__: scheduler for scheduler in self._compatible_schedulers} 65 | 66 | # download the embeddings 67 | self.embeddings_path = utils.download_file(self.embeddings_url) 68 | load_embed( 69 | learned_embeds_path=self.embeddings_path, 70 | text_encoder=self.pipeline.text_encoder, 71 | tokenizer=self.pipeline.tokenizer, 72 | token=self.token_identifier, 73 | ) 74 | 75 | if self.device == "mps": 76 | self.pipeline.enable_attention_slicing() 77 | # warmup 78 | prompt = "a photo of an astronaut riding a horse on mars" 79 | _ = self.pipeline(prompt, num_inference_steps=1) 80 | 81 | def _set_scheduler(self, scheduler_name): 82 | scheduler = self.compatible_schedulers[scheduler_name].from_config(self.scheduler_config) 83 | self.pipeline.scheduler = scheduler 84 | 85 | def generate_image(self, prompt, negative_prompt, scheduler, image_size, num_images, guidance_scale, steps, seed): 86 | self._set_scheduler(scheduler) 87 | logger.info(self.pipeline.scheduler) 88 | if self.device == "mps": 89 | generator = torch.manual_seed(seed) 90 | num_images = 1 91 | else: 92 | generator = torch.Generator(device=self.device).manual_seed(seed) 93 | num_images = int(num_images) 94 | output_images = self.pipeline( 95 | prompt, 96 | negative_prompt=negative_prompt, 97 | width=image_size[1], 98 | height=image_size[0], 99 | num_inference_steps=steps, 100 | guidance_scale=guidance_scale, 101 | num_images_per_prompt=num_images, 102 | generator=generator, 103 | ).images 104 | torch.cuda.empty_cache() 105 | gc.collect() 106 | metadata = { 107 | "prompt": prompt, 108 | "negative_prompt": negative_prompt, 109 | "scheduler": scheduler, 110 | "image_size": image_size, 111 | "num_images": num_images, 112 | "guidance_scale": guidance_scale, 113 | "steps": steps, 114 | "seed": seed, 115 | } 116 | metadata = json.dumps(metadata) 117 | _metadata = PngInfo() 118 | _metadata.add_text("textual_inversion", metadata) 119 | 120 | utils.save_images( 121 | images=output_images, 122 | module="textual_inversion", 123 | metadata=metadata, 124 | output_path=self.output_path, 125 | ) 126 | 127 | return output_images, _metadata 128 | 129 | def app(self): 130 | available_schedulers = list(self.compatible_schedulers.keys()) 131 | if "EulerAncestralDiscreteScheduler" in available_schedulers: 132 | available_schedulers.insert( 133 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler")) 134 | ) 135 | col1, col2 = st.columns(2) 136 | with col1: 137 | prompt = st.text_area("Prompt", "Blue elephant") 138 | with col2: 139 | negative_prompt = st.text_area("Negative Prompt", "") 140 | # sidebar options 141 | scheduler = st.sidebar.selectbox("Scheduler", available_schedulers, index=0) 142 | image_height = st.sidebar.slider("Image height", 128, 1024, 512, 128) 143 | image_width = st.sidebar.slider("Image width", 128, 1024, 512, 128) 144 | guidance_scale = st.sidebar.slider("Guidance scale", 1.0, 40.0, 7.5, 0.5) 145 | num_images = st.sidebar.slider("Number of images per prompt", 1, 30, 1, 1) 146 | steps = st.sidebar.slider("Steps", 1, 150, 50, 1) 147 | seed_placeholder = st.sidebar.empty() 148 | seed = seed_placeholder.number_input("Seed", value=42, min_value=1, max_value=999999, step=1) 149 | random_seed = st.sidebar.button("Random seed") 150 | _seed = torch.randint(1, 999999, (1,)).item() 151 | if random_seed: 152 | seed = seed_placeholder.number_input("Seed", value=_seed, min_value=1, max_value=999999, step=1) 153 | sub_col, download_col = st.columns(2) 154 | with sub_col: 155 | submit = st.button("Generate") 156 | 157 | if submit: 158 | with st.spinner("Generating images..."): 159 | output_images, metadata = self.generate_image( 160 | prompt=prompt, 161 | negative_prompt=negative_prompt, 162 | scheduler=scheduler, 163 | image_size=(image_height, image_width), 164 | num_images=num_images, 165 | guidance_scale=guidance_scale, 166 | steps=steps, 167 | seed=seed, 168 | ) 169 | 170 | utils.display_and_download_images(output_images, metadata, download_col) 171 | -------------------------------------------------------------------------------- /diffuzers/upscaler.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import streamlit as st 7 | import torch 8 | from diffusers import StableDiffusionUpscalePipeline 9 | from loguru import logger 10 | from PIL import Image 11 | from PIL.PngImagePlugin import PngInfo 12 | 13 | from diffuzers import utils 14 | 15 | 16 | @dataclass 17 | class Upscaler: 18 | model: str = "stabilityai/stable-diffusion-x4-upscaler" 19 | device: Optional[str] = None 20 | output_path: Optional[str] = None 21 | 22 | def __str__(self) -> str: 23 | return f"Upscaler(model={self.model}, device={self.device}, output_path={self.output_path})" 24 | 25 | def __post_init__(self): 26 | self.pipeline = StableDiffusionUpscalePipeline.from_pretrained( 27 | self.model, 28 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, 29 | ) 30 | self.pipeline.to(self.device) 31 | self.pipeline.safety_checker = utils.no_safety_checker 32 | self._compatible_schedulers = self.pipeline.scheduler.compatibles 33 | self.scheduler_config = self.pipeline.scheduler.config 34 | self.compatible_schedulers = {scheduler.__name__: scheduler for scheduler in self._compatible_schedulers} 35 | 36 | def _set_scheduler(self, scheduler_name): 37 | scheduler = self.compatible_schedulers[scheduler_name].from_config(self.scheduler_config) 38 | self.pipeline.scheduler = scheduler 39 | 40 | def generate_image( 41 | self, image, prompt, negative_prompt, guidance_scale, noise_level, num_images, eta, scheduler, steps, seed 42 | ): 43 | self._set_scheduler(scheduler) 44 | logger.info(self.pipeline.scheduler) 45 | if self.device == "mps": 46 | generator = torch.manual_seed(seed) 47 | num_images = 1 48 | else: 49 | generator = torch.Generator(device=self.device).manual_seed(seed) 50 | num_images = int(num_images) 51 | output_images = self.pipeline( 52 | image=image, 53 | prompt=prompt, 54 | negative_prompt=negative_prompt, 55 | noise_level=noise_level, 56 | num_inference_steps=steps, 57 | eta=eta, 58 | num_images_per_prompt=num_images, 59 | generator=generator, 60 | guidance_scale=guidance_scale, 61 | ).images 62 | 63 | torch.cuda.empty_cache() 64 | gc.collect() 65 | metadata = { 66 | "prompt": prompt, 67 | "negative_prompt": negative_prompt, 68 | "noise_level": noise_level, 69 | "num_images": num_images, 70 | "eta": eta, 71 | "scheduler": scheduler, 72 | "steps": steps, 73 | "seed": seed, 74 | } 75 | 76 | metadata = json.dumps(metadata) 77 | _metadata = PngInfo() 78 | _metadata.add_text("upscaler", metadata) 79 | 80 | utils.save_images( 81 | images=output_images, 82 | module="upscaler", 83 | metadata=metadata, 84 | output_path=self.output_path, 85 | ) 86 | 87 | return output_images, _metadata 88 | 89 | def app(self): 90 | available_schedulers = list(self.compatible_schedulers.keys()) 91 | # if EulerAncestralDiscreteScheduler is available in available_schedulers, move it to the first position 92 | if "EulerAncestralDiscreteScheduler" in available_schedulers: 93 | available_schedulers.insert( 94 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler")) 95 | ) 96 | 97 | input_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) 98 | if input_image is not None: 99 | input_image = Image.open(input_image) 100 | input_image = input_image.convert("RGB").resize((128, 128), resample=Image.LANCZOS) 101 | st.image(input_image, use_column_width=True) 102 | 103 | col1, col2 = st.columns(2) 104 | with col1: 105 | prompt = st.text_area("Prompt (Optional)", "") 106 | with col2: 107 | negative_prompt = st.text_area("Negative Prompt (Optional)", "") 108 | 109 | scheduler = st.sidebar.selectbox("Scheduler", available_schedulers, index=0) 110 | guidance_scale = st.sidebar.slider("Guidance scale", 1.0, 40.0, 7.5, 0.5) 111 | noise_level = st.sidebar.slider("Noise level", 0, 100, 20, 1) 112 | eta = st.sidebar.slider("Eta", 0.0, 1.0, 0.0, 0.1) 113 | num_images = st.sidebar.slider("Number of images per prompt", 1, 30, 1, 1) 114 | steps = st.sidebar.slider("Steps", 1, 150, 50, 1) 115 | seed_placeholder = st.sidebar.empty() 116 | seed = seed_placeholder.number_input("Seed", value=42, min_value=1, max_value=999999, step=1) 117 | random_seed = st.sidebar.button("Random seed") 118 | _seed = torch.randint(1, 999999, (1,)).item() 119 | if random_seed: 120 | seed = seed_placeholder.number_input("Seed", value=_seed, min_value=1, max_value=999999, step=1) 121 | sub_col, download_col = st.columns(2) 122 | with sub_col: 123 | submit = st.button("Generate") 124 | 125 | if submit: 126 | with st.spinner("Generating images..."): 127 | output_images, metadata = self.generate_image( 128 | prompt=prompt, 129 | image=input_image, 130 | negative_prompt=negative_prompt, 131 | scheduler=scheduler, 132 | num_images=num_images, 133 | guidance_scale=guidance_scale, 134 | steps=steps, 135 | seed=seed, 136 | noise_level=noise_level, 137 | eta=eta, 138 | ) 139 | 140 | utils.display_and_download_images(output_images, metadata, download_col) 141 | -------------------------------------------------------------------------------- /diffuzers/utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import gc 3 | import io 4 | import os 5 | import tempfile 6 | import zipfile 7 | from datetime import datetime 8 | from threading import Thread 9 | 10 | import requests 11 | import streamlit as st 12 | import torch 13 | from huggingface_hub import HfApi 14 | from huggingface_hub.utils._errors import RepositoryNotFoundError 15 | from huggingface_hub.utils._validators import HFValidationError 16 | from loguru import logger 17 | from PIL.PngImagePlugin import PngInfo 18 | from st_clickable_images import clickable_images 19 | 20 | 21 | no_safety_checker = None 22 | 23 | 24 | CODE_OF_CONDUCT = """ 25 | ## Code of conduct 26 | The app should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes. 27 | 28 | Using the app to generate content that is cruel to individuals is a misuse of this app. One shall not use this app to generate content that is intended to be cruel to individuals, or to generate content that is intended to be cruel to individuals in a way that is not obvious to the viewer. 29 | This includes, but is not limited to: 30 | - Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc. 31 | - Intentionally promoting or propagating discriminatory content or harmful stereotypes. 32 | - Impersonating individuals without their consent. 33 | - Sexual content without consent of the people who might see it. 34 | - Mis- and disinformation 35 | - Representations of egregious violence and gore 36 | - Sharing of copyrighted or licensed material in violation of its terms of use. 37 | - Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use. 38 | 39 | By using this app, you agree to the above code of conduct. 40 | 41 | """ 42 | 43 | 44 | def use_auth_token(): 45 | token_path = os.path.join(os.path.expanduser("~"), ".huggingface", "token") 46 | if os.path.exists(token_path): 47 | return True 48 | if "HF_TOKEN" in os.environ: 49 | return os.environ["HF_TOKEN"] 50 | return False 51 | 52 | 53 | def create_base_page(): 54 | st.set_page_config(layout="wide") 55 | st.title("Diffuzers") 56 | st.markdown("Welcome to Diffuzers! A web app for [🤗 Diffusers](https://github.com/huggingface/diffusers)") 57 | 58 | 59 | def download_file(file_url): 60 | r = requests.get(file_url, stream=True) 61 | with tempfile.NamedTemporaryFile(delete=False) as tmp: 62 | for chunk in r.iter_content(chunk_size=1024): 63 | if chunk: # filter out keep-alive new chunks 64 | tmp.write(chunk) 65 | return tmp.name 66 | 67 | 68 | def cache_folder(): 69 | _cache_folder = os.path.join(os.path.expanduser("~"), ".diffuzers") 70 | os.makedirs(_cache_folder, exist_ok=True) 71 | return _cache_folder 72 | 73 | 74 | def clear_memory(preserve): 75 | torch.cuda.empty_cache() 76 | gc.collect() 77 | to_clear = ["inpainting", "text2img", "img2text"] 78 | for key in to_clear: 79 | if key not in preserve and key in st.session_state: 80 | del st.session_state[key] 81 | 82 | 83 | def save_to_hub(api, images, module, current_datetime, metadata, output_path): 84 | logger.info(f"Saving images to hub: {output_path}") 85 | _metadata = PngInfo() 86 | _metadata.add_text("text2img", metadata) 87 | for i, img in enumerate(images): 88 | img_byte_arr = io.BytesIO() 89 | img.save(img_byte_arr, format="PNG", pnginfo=_metadata) 90 | img_byte_arr = img_byte_arr.getvalue() 91 | api.upload_file( 92 | path_or_fileobj=img_byte_arr, 93 | path_in_repo=f"{module}/{current_datetime}/{i}.png", 94 | repo_id=output_path, 95 | repo_type="dataset", 96 | ) 97 | 98 | api.upload_file( 99 | path_or_fileobj=str.encode(metadata), 100 | path_in_repo=f"{module}/{current_datetime}/metadata.json", 101 | repo_id=output_path, 102 | repo_type="dataset", 103 | ) 104 | logger.info(f"Saved images to hub: {output_path}") 105 | 106 | 107 | def save_to_local(images, module, current_datetime, metadata, output_path): 108 | _metadata = PngInfo() 109 | _metadata.add_text("text2img", metadata) 110 | os.makedirs(output_path, exist_ok=True) 111 | os.makedirs(f"{output_path}/{module}", exist_ok=True) 112 | os.makedirs(f"{output_path}/{module}/{current_datetime}", exist_ok=True) 113 | 114 | for i, img in enumerate(images): 115 | img.save( 116 | f"{output_path}/{module}/{current_datetime}/{i}.png", 117 | pnginfo=_metadata, 118 | ) 119 | 120 | # save metadata as text file 121 | with open(f"{output_path}/{module}/{current_datetime}/metadata.txt", "w") as f: 122 | f.write(metadata) 123 | logger.info(f"Saved images to {output_path}/{module}/{current_datetime}") 124 | 125 | 126 | def save_images(images, module, metadata, output_path): 127 | if output_path is None: 128 | logger.warning("No output path specified, skipping saving images") 129 | return 130 | 131 | api = HfApi() 132 | dset_info = None 133 | try: 134 | dset_info = api.dataset_info(output_path) 135 | except (HFValidationError, RepositoryNotFoundError): 136 | logger.warning("No valid hugging face repo. Saving locally...") 137 | 138 | current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 139 | 140 | if not dset_info: 141 | save_to_local(images, module, current_datetime, metadata, output_path) 142 | else: 143 | Thread(target=save_to_hub, args=(api, images, module, current_datetime, metadata, output_path)).start() 144 | 145 | 146 | def display_and_download_images(output_images, metadata, download_col=None): 147 | # st.image(output_images, width=128, output_format="PNG") 148 | 149 | with st.spinner("Preparing images for download..."): 150 | # save images to a temporary directory 151 | with tempfile.TemporaryDirectory() as tmpdir: 152 | gallery_images = [] 153 | for i, image in enumerate(output_images): 154 | image.save(os.path.join(tmpdir, f"{i + 1}.png"), pnginfo=metadata) 155 | with open(os.path.join(tmpdir, f"{i + 1}.png"), "rb") as img: 156 | encoded = base64.b64encode(img.read()).decode() 157 | gallery_images.append(f"data:image/jpeg;base64,{encoded}") 158 | 159 | # zip the images 160 | zip_path = os.path.join(tmpdir, "images.zip") 161 | with zipfile.ZipFile(zip_path, "w") as zip: 162 | for filename in os.listdir(tmpdir): 163 | if filename.endswith(".png"): 164 | zip.write(os.path.join(tmpdir, filename), filename) 165 | 166 | # convert zip to base64 167 | with open(zip_path, "rb") as f: 168 | encoded = base64.b64encode(f.read()).decode() 169 | 170 | _ = clickable_images( 171 | gallery_images, 172 | titles=[f"Image #{str(i)}" for i in range(len(gallery_images))], 173 | div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, 174 | img_style={"margin": "5px", "height": "200px"}, 175 | ) 176 | 177 | # add download link 178 | st.markdown( 179 | f""" 180 | 181 | Download Images 182 | 183 | """, 184 | unsafe_allow_html=True, 185 | ) 186 | -------------------------------------------------------------------------------- /diffuzers/x2image.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import gc 3 | import json 4 | import os 5 | import random 6 | import tempfile 7 | from dataclasses import dataclass 8 | from io import BytesIO 9 | from typing import Optional 10 | 11 | import requests 12 | import streamlit as st 13 | import torch 14 | from diffusers import ( 15 | AltDiffusionImg2ImgPipeline, 16 | AltDiffusionPipeline, 17 | DiffusionPipeline, 18 | StableDiffusionImg2ImgPipeline, 19 | StableDiffusionInstructPix2PixPipeline, 20 | StableDiffusionPipeline, 21 | ) 22 | from loguru import logger 23 | from PIL import Image 24 | from PIL.PngImagePlugin import PngInfo 25 | from st_clickable_images import clickable_images 26 | 27 | from diffuzers import utils 28 | 29 | 30 | def load_embed(learned_embeds_path, text_encoder, tokenizer, token=None): 31 | loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") 32 | if len(loaded_learned_embeds) > 2: 33 | embeds = loaded_learned_embeds["string_to_param"]["*"][-1, :] 34 | else: 35 | # separate token and the embeds 36 | trained_token = list(loaded_learned_embeds.keys())[0] 37 | embeds = loaded_learned_embeds[trained_token] 38 | 39 | # add the token in tokenizer 40 | token = token if token is not None else trained_token 41 | num_added_tokens = tokenizer.add_tokens(token) 42 | i = 1 43 | while num_added_tokens == 0: 44 | logger.warning(f"The tokenizer already contains the token {token}.") 45 | token = f"{token[:-1]}-{i}>" 46 | logger.info(f"Attempting to add the token {token}.") 47 | num_added_tokens = tokenizer.add_tokens(token) 48 | i += 1 49 | 50 | # resize the token embeddings 51 | text_encoder.resize_token_embeddings(len(tokenizer)) 52 | 53 | # get the id for the token and assign the embeds 54 | token_id = tokenizer.convert_tokens_to_ids(token) 55 | text_encoder.get_input_embeddings().weight.data[token_id] = embeds 56 | return token 57 | 58 | 59 | @dataclass 60 | class X2Image: 61 | device: Optional[str] = None 62 | model: Optional[str] = None 63 | output_path: Optional[str] = None 64 | custom_pipeline: Optional[str] = None 65 | embeddings_url: Optional[str] = None 66 | token_identifier: Optional[str] = None 67 | 68 | def __str__(self) -> str: 69 | return f"X2Image(model={self.model}, pipeline={self.custom_pipeline})" 70 | 71 | def __post_init__(self): 72 | self.text2img_pipeline = DiffusionPipeline.from_pretrained( 73 | self.model, 74 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, 75 | custom_pipeline=self.custom_pipeline, 76 | use_auth_token=utils.use_auth_token(), 77 | ) 78 | components = self.text2img_pipeline.components 79 | self.pix2pix_pipeline = None 80 | if isinstance(self.text2img_pipeline, StableDiffusionPipeline): 81 | self.img2img_pipeline = StableDiffusionImg2ImgPipeline(**components) 82 | self.pix2pix_pipeline = StableDiffusionInstructPix2PixPipeline(**components) 83 | elif isinstance(self.text2img_pipeline, AltDiffusionPipeline): 84 | self.img2img_pipeline = AltDiffusionImg2ImgPipeline(**components) 85 | else: 86 | self.img2img_pipeline = None 87 | logger.error("Model type not supported, img2img pipeline not created") 88 | 89 | self.text2img_pipeline.to(self.device) 90 | self.text2img_pipeline.safety_checker = utils.no_safety_checker 91 | self.img2img_pipeline.to(self.device) 92 | self.img2img_pipeline.safety_checker = utils.no_safety_checker 93 | if self.pix2pix_pipeline is not None: 94 | self.pix2pix_pipeline.to(self.device) 95 | self.pix2pix_pipeline.safety_checker = utils.no_safety_checker 96 | 97 | self.compatible_schedulers = { 98 | scheduler.__name__: scheduler for scheduler in self.text2img_pipeline.scheduler.compatibles 99 | } 100 | 101 | if len(self.embeddings_url) > 0 and len(self.token_identifier) > 0: 102 | # download the embeddings 103 | self.embeddings_path = utils.download_file(self.embeddings_url) 104 | load_embed( 105 | learned_embeds_path=self.embeddings_path, 106 | text_encoder=self.pipeline.text_encoder, 107 | tokenizer=self.pipeline.tokenizer, 108 | token=self.token_identifier, 109 | ) 110 | 111 | if self.device == "mps": 112 | self.text2img_pipeline.enable_attention_slicing() 113 | prompt = "a photo of an astronaut riding a horse on mars" 114 | _ = self.text2img_pipeline(prompt, num_inference_steps=2) 115 | 116 | self.img2img_pipeline.enable_attention_slicing() 117 | url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" 118 | response = requests.get(url) 119 | init_image = Image.open(BytesIO(response.content)).convert("RGB") 120 | init_image.thumbnail((768, 768)) 121 | prompt = "A fantasy landscape, trending on artstation" 122 | _ = self.img2img_pipeline( 123 | prompt=prompt, 124 | image=init_image, 125 | strength=0.75, 126 | guidance_scale=7.5, 127 | num_inference_steps=2, 128 | ) 129 | if self.pix2pix_pipeline is not None: 130 | self.pix2pix_pipeline.enable_attention_slicing() 131 | prompt = "turn him into a cyborg" 132 | _ = self.pix2pix_pipeline(prompt, image=init_image, num_inference_steps=2) 133 | 134 | def _set_scheduler(self, pipeline_name, scheduler_name): 135 | if pipeline_name == "text2img": 136 | scheduler_config = self.text2img_pipeline.scheduler.config 137 | elif pipeline_name == "img2img": 138 | scheduler_config = self.img2img_pipeline.scheduler.config 139 | elif pipeline_name == "pix2pix": 140 | scheduler_config = self.pix2pix_pipeline.scheduler.config 141 | else: 142 | raise ValueError(f"Pipeline {pipeline_name} not supported") 143 | 144 | scheduler = self.compatible_schedulers[scheduler_name].from_config(scheduler_config) 145 | 146 | if pipeline_name == "text2img": 147 | self.text2img_pipeline.scheduler = scheduler 148 | elif pipeline_name == "img2img": 149 | self.img2img_pipeline.scheduler = scheduler 150 | 151 | def _pregen(self, pipeline_name, scheduler, num_images, seed): 152 | self._set_scheduler(scheduler_name=scheduler, pipeline_name=pipeline_name) 153 | if self.device == "mps": 154 | generator = torch.manual_seed(seed) 155 | num_images = 1 156 | else: 157 | generator = torch.Generator(device=self.device).manual_seed(seed) 158 | num_images = int(num_images) 159 | return generator, num_images 160 | 161 | def _postgen(self, metadata, output_images, pipeline_name): 162 | torch.cuda.empty_cache() 163 | gc.collect() 164 | metadata = json.dumps(metadata) 165 | _metadata = PngInfo() 166 | _metadata.add_text(pipeline_name, metadata) 167 | utils.save_images( 168 | images=output_images, 169 | module=pipeline_name, 170 | metadata=metadata, 171 | output_path=self.output_path, 172 | ) 173 | return output_images, _metadata 174 | 175 | def text2img_generate( 176 | self, prompt, negative_prompt, scheduler, image_size, num_images, guidance_scale, steps, seed 177 | ): 178 | 179 | if seed == -1: 180 | # generate random seed 181 | seed = random.randint(0, 999999) 182 | 183 | generator, num_images = self._pregen( 184 | pipeline_name="text2img", 185 | scheduler=scheduler, 186 | num_images=num_images, 187 | seed=seed, 188 | ) 189 | output_images = self.text2img_pipeline( 190 | prompt, 191 | negative_prompt=negative_prompt, 192 | width=image_size[1], 193 | height=image_size[0], 194 | num_inference_steps=steps, 195 | guidance_scale=guidance_scale, 196 | num_images_per_prompt=num_images, 197 | generator=generator, 198 | ).images 199 | metadata = { 200 | "prompt": prompt, 201 | "negative_prompt": negative_prompt, 202 | "scheduler": scheduler, 203 | "image_size": image_size, 204 | "num_images": num_images, 205 | "guidance_scale": guidance_scale, 206 | "steps": steps, 207 | "seed": seed, 208 | } 209 | 210 | output_images, _metadata = self._postgen( 211 | metadata=metadata, 212 | output_images=output_images, 213 | pipeline_name="text2img", 214 | ) 215 | return output_images, _metadata 216 | 217 | def img2img_generate( 218 | self, prompt, image, strength, negative_prompt, scheduler, num_images, guidance_scale, steps, seed 219 | ): 220 | 221 | if seed == -1: 222 | # generate random seed 223 | seed = random.randint(0, 999999) 224 | 225 | generator, num_images = self._pregen( 226 | pipeline_name="img2img", 227 | scheduler=scheduler, 228 | num_images=num_images, 229 | seed=seed, 230 | ) 231 | output_images = self.img2img_pipeline( 232 | prompt=prompt, 233 | image=image, 234 | strength=strength, 235 | negative_prompt=negative_prompt, 236 | num_inference_steps=steps, 237 | guidance_scale=guidance_scale, 238 | num_images_per_prompt=num_images, 239 | generator=generator, 240 | ).images 241 | metadata = { 242 | "prompt": prompt, 243 | "negative_prompt": negative_prompt, 244 | "scheduler": scheduler, 245 | "num_images": num_images, 246 | "guidance_scale": guidance_scale, 247 | "steps": steps, 248 | "seed": seed, 249 | } 250 | output_images, _metadata = self._postgen( 251 | metadata=metadata, 252 | output_images=output_images, 253 | pipeline_name="img2img", 254 | ) 255 | return output_images, _metadata 256 | 257 | def pix2pix_generate( 258 | self, prompt, image, negative_prompt, scheduler, num_images, guidance_scale, image_guidance_scale, steps, seed 259 | ): 260 | if seed == -1: 261 | # generate random seed 262 | seed = random.randint(0, 999999) 263 | 264 | generator, num_images = self._pregen( 265 | pipeline_name="pix2pix", 266 | scheduler=scheduler, 267 | num_images=num_images, 268 | seed=seed, 269 | ) 270 | output_images = self.pix2pix_pipeline( 271 | prompt=prompt, 272 | image=image, 273 | negative_prompt=negative_prompt, 274 | num_inference_steps=steps, 275 | guidance_scale=guidance_scale, 276 | image_guidance_scale=image_guidance_scale, 277 | num_images_per_prompt=num_images, 278 | generator=generator, 279 | ).images 280 | metadata = { 281 | "prompt": prompt, 282 | "negative_prompt": negative_prompt, 283 | "scheduler": scheduler, 284 | "num_images": num_images, 285 | "guidance_scale": guidance_scale, 286 | "image_guidance_scale": image_guidance_scale, 287 | "steps": steps, 288 | "seed": seed, 289 | } 290 | output_images, _metadata = self._postgen( 291 | metadata=metadata, 292 | output_images=output_images, 293 | pipeline_name="pix2pix", 294 | ) 295 | return output_images, _metadata 296 | 297 | def app(self): 298 | available_schedulers = list(self.compatible_schedulers.keys()) 299 | if "EulerAncestralDiscreteScheduler" in available_schedulers: 300 | available_schedulers.insert( 301 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler")) 302 | ) 303 | # col3, col4 = st.columns(2) 304 | # with col3: 305 | input_image = st.file_uploader( 306 | "Upload an image to use image2image or pix2pix", 307 | type=["png", "jpg", "jpeg"], 308 | help="Upload an image to use image2image. If left blank, text2image will be used instead.", 309 | ) 310 | use_pix2pix = st.checkbox("Use pix2pix", value=False) 311 | if input_image is not None: 312 | input_image = Image.open(input_image) 313 | if use_pix2pix: 314 | pipeline_name = "pix2pix" 315 | else: 316 | pipeline_name = "img2img" 317 | # display image using html 318 | # convert image to base64 319 | # st.markdown(f"", unsafe_allow_html=True) 320 | # st.image(input_image, use_column_width=True) 321 | with tempfile.TemporaryDirectory() as tmpdir: 322 | gallery_images = [] 323 | input_image.save(os.path.join(tmpdir, "img.png")) 324 | with open(os.path.join(tmpdir, "img.png"), "rb") as img: 325 | encoded = base64.b64encode(img.read()).decode() 326 | gallery_images.append(f"data:image/jpeg;base64,{encoded}") 327 | 328 | _ = clickable_images( 329 | gallery_images, 330 | titles=[f"Image #{str(i)}" for i in range(len(gallery_images))], 331 | div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, 332 | img_style={"margin": "5px", "height": "200px"}, 333 | ) 334 | else: 335 | pipeline_name = "text2img" 336 | # prompt = st.text_area("Prompt", "Blue elephant") 337 | # negative_prompt = st.text_area("Negative Prompt", "") 338 | # with col4: 339 | col1, col2 = st.columns(2) 340 | with col1: 341 | prompt = st.text_area("Prompt", "Blue elephant", help="Prompt to guide image generation") 342 | with col2: 343 | negative_prompt = st.text_area( 344 | "Negative Prompt", 345 | "", 346 | help="The prompt not to guide image generation. Write things that you dont want to see in the image.", 347 | ) 348 | # sidebar options 349 | if input_image is None: 350 | image_height = st.sidebar.slider( 351 | "Image height", 128, 1024, 512, 128, help="The height in pixels of the generated image." 352 | ) 353 | image_width = st.sidebar.slider( 354 | "Image width", 128, 1024, 512, 128, help="The width in pixels of the generated image." 355 | ) 356 | 357 | num_images = st.sidebar.slider( 358 | "Number of images per prompt", 359 | 1, 360 | 30, 361 | 1, 362 | 1, 363 | help="Number of images you want to generate. More images requires more time and uses more GPU memory.", 364 | ) 365 | 366 | # add section advanced options 367 | st.sidebar.markdown("### Advanced options") 368 | scheduler = st.sidebar.selectbox( 369 | "Scheduler", available_schedulers, index=0, help="Scheduler to use for generation" 370 | ) 371 | guidance_scale = st.sidebar.slider( 372 | "Guidance scale", 373 | 1.0, 374 | 40.0, 375 | 7.5, 376 | 0.5, 377 | help="Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.", 378 | ) 379 | if use_pix2pix and input_image is not None: 380 | image_guidance_scale = st.sidebar.slider( 381 | "Image guidance scale", 382 | 1.0, 383 | 40.0, 384 | 1.5, 385 | 0.5, 386 | help="Image guidance scale is to push the generated image towards the inital image `image`. Image guidance scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to generate images that are closely linked to the source image `image`, usually at the expense of lower image quality.", 387 | ) 388 | if input_image is not None and not use_pix2pix: 389 | strength = st.sidebar.slider( 390 | "Denoising strength", 391 | 0.0, 392 | 1.0, 393 | 0.8, 394 | 0.05, 395 | help="Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.", 396 | ) 397 | steps = st.sidebar.slider( 398 | "Steps", 399 | 1, 400 | 150, 401 | 50, 402 | 1, 403 | help="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.", 404 | ) 405 | seed = st.sidebar.number_input( 406 | "Seed", 407 | value=42, 408 | min_value=-1, 409 | max_value=999999, 410 | step=1, 411 | help="Random seed. Change for different results using same parameters.", 412 | ) 413 | 414 | sub_col, download_col = st.columns(2) 415 | with sub_col: 416 | submit = st.button("Generate") 417 | 418 | if submit: 419 | with st.spinner("Generating images..."): 420 | if pipeline_name == "text2img": 421 | output_images, metadata = self.text2img_generate( 422 | prompt=prompt, 423 | negative_prompt=negative_prompt, 424 | scheduler=scheduler, 425 | image_size=(image_height, image_width), 426 | num_images=num_images, 427 | guidance_scale=guidance_scale, 428 | steps=steps, 429 | seed=seed, 430 | ) 431 | elif pipeline_name == "img2img": 432 | output_images, metadata = self.img2img_generate( 433 | prompt=prompt, 434 | image=input_image, 435 | strength=strength, 436 | negative_prompt=negative_prompt, 437 | scheduler=scheduler, 438 | num_images=num_images, 439 | guidance_scale=guidance_scale, 440 | steps=steps, 441 | seed=seed, 442 | ) 443 | elif pipeline_name == "pix2pix": 444 | output_images, metadata = self.pix2pix_generate( 445 | prompt=prompt, 446 | image=input_image, 447 | negative_prompt=negative_prompt, 448 | scheduler=scheduler, 449 | num_images=num_images, 450 | guidance_scale=guidance_scale, 451 | image_guidance_scale=image_guidance_scale, 452 | steps=steps, 453 | seed=seed, 454 | ) 455 | utils.display_and_download_images(output_images, metadata, download_col) 456 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = "diffuzers: webapp and api for 🤗 diffusers" 21 | copyright = "2023, abhishek thakur" 22 | author = "abhishek thakur" 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = [] 31 | 32 | # Add any paths that contain templates here, relative to this directory. 33 | templates_path = ["_templates"] 34 | 35 | # List of patterns, relative to source directory, that match files and 36 | # directories to ignore when looking for source files. 37 | # This pattern also affects html_static_path and html_extra_path. 38 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 39 | 40 | 41 | # -- Options for HTML output ------------------------------------------------- 42 | 43 | # The theme to use for HTML and HTML Help pages. See the documentation for 44 | # a list of builtin themes. 45 | # 46 | html_theme = "alabaster" 47 | 48 | # Add any paths that contain custom static files (such as style sheets) here, 49 | # relative to this directory. They are copied after the builtin static files, 50 | # so a file named "default.css" will overwrite the builtin "default.css". 51 | html_static_path = ["_static"] 52 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. diffuzers: webapp and api for 🤗 diffusers documentation master file, created by 2 | sphinx-quickstart on Thu Jan 12 14:10:58 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to diffuzers' documentation! 7 | ====================================================================== 8 | 9 | Diffuzers offers web app and also api for 🤗 diffusers. Installation is very simple. 10 | You can install via pip: 11 | 12 | .. code-block:: bash 13 | 14 | pip install diffuzers 15 | 16 | .. toctree:: 17 | :maxdepth: 2 18 | :caption: Contents: 19 | 20 | 21 | 22 | Indices and tables 23 | ================== 24 | 25 | * :ref:`genindex` 26 | * :ref:`modindex` 27 | * :ref:`search` 28 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.15.0 2 | basicsr>=1.4.2 3 | diffusers>=0.12.1 4 | facexlib>=0.2.5 5 | fairscale==0.4.4 6 | fastapi==0.88.0 7 | gfpgan>=1.3.7 8 | huggingface_hub>=0.11.1 9 | loguru==0.6.0 10 | open_clip_torch==2.9.1 11 | opencv-python 12 | protobuf==3.20 13 | pyngrok==5.2.1 14 | python-multipart==0.0.5 15 | realesrgan>=0.2.5 16 | streamlit==1.16.0 17 | streamlit-drawable-canvas==0.9.0 18 | st-clickable-images==0.0.3 19 | timm==0.4.12 20 | torch>=1.12.0 21 | torchvision>=0.13.0 22 | transformers==4.25.1 23 | uvicorn==0.15.0 -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | version = attr: diffuzers.__version__ 3 | 4 | [isort] 5 | ensure_newline_before_comments = True 6 | force_grid_wrap = 0 7 | include_trailing_comma = True 8 | line_length = 119 9 | lines_after_imports = 2 10 | multi_line_output = 3 11 | use_parentheses = True 12 | 13 | [flake8] 14 | ignore = E203, E501, W503 15 | max-line-length = 119 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # pylint: enable=line-too-long 3 | """Hugging Face Competitions 4 | """ 5 | from setuptools import find_packages, setup 6 | 7 | 8 | with open("README.md") as f: 9 | long_description = f.read() 10 | 11 | QUALITY_REQUIRE = [ 12 | "black~=22.0", 13 | "isort==5.8.0", 14 | "flake8==3.9.2", 15 | "mypy==0.901", 16 | ] 17 | 18 | TEST_REQUIRE = ["pytest", "pytest-cov"] 19 | 20 | EXTRAS_REQUIRE = { 21 | "dev": QUALITY_REQUIRE, 22 | "quality": QUALITY_REQUIRE, 23 | "test": TEST_REQUIRE, 24 | "docs": [ 25 | "recommonmark", 26 | "sphinx==3.1.2", 27 | "sphinx-markdown-tables", 28 | "sphinx-rtd-theme==0.4.3", 29 | "sphinx-copybutton", 30 | ], 31 | } 32 | 33 | with open("requirements.txt") as f: 34 | INSTALL_REQUIRES = f.read().splitlines() 35 | 36 | setup( 37 | name="diffuzers", 38 | description="diffuzers", 39 | long_description=long_description, 40 | long_description_content_type="text/markdown", 41 | author="Abhishek Thakur", 42 | url="https://github.com/abhishekkrthakur/diffuzers", 43 | packages=find_packages("."), 44 | entry_points={"console_scripts": ["diffuzers=diffuzers.cli.main:main"]}, 45 | install_requires=INSTALL_REQUIRES, 46 | extras_require=EXTRAS_REQUIRE, 47 | python_requires=">=3.7", 48 | classifiers=[ 49 | "Intended Audience :: Developers", 50 | "Intended Audience :: Education", 51 | "Intended Audience :: Science/Research", 52 | "License :: OSI Approved :: Apache Software License", 53 | "Operating System :: OS Independent", 54 | "Programming Language :: Python :: 3.8", 55 | "Programming Language :: Python :: 3.9", 56 | "Programming Language :: Python :: 3.10", 57 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 58 | ], 59 | keywords="diffuzers diffusers", 60 | include_package_data=True, 61 | ) 62 | -------------------------------------------------------------------------------- /static/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekkrthakur/diffuzers/71f710a1940c22637561809387b1c8622be1a48b/static/.keep -------------------------------------------------------------------------------- /static/screenshot.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekkrthakur/diffuzers/71f710a1940c22637561809387b1c8622be1a48b/static/screenshot.jpeg -------------------------------------------------------------------------------- /static/screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekkrthakur/diffuzers/71f710a1940c22637561809387b1c8622be1a48b/static/screenshot.png -------------------------------------------------------------------------------- /static/screenshot_st.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekkrthakur/diffuzers/71f710a1940c22637561809387b1c8622be1a48b/static/screenshot_st.png --------------------------------------------------------------------------------