├── .gitignore
├── Dockerfile
├── LICENSE
├── MANIFEST.in
├── Makefile
├── OTHER_LICENSES
├── README.md
├── diffuzers.ipynb
├── diffuzers
├── Home.py
├── __init__.py
├── api
│ ├── __init__.py
│ ├── main.py
│ ├── schemas.py
│ └── utils.py
├── blip.py
├── cli
│ ├── __init__.py
│ ├── main.py
│ ├── run_api.py
│ └── run_app.py
├── clip_interrogator.py
├── data
│ ├── artists.txt
│ ├── flavors.txt
│ ├── mediums.txt
│ └── movements.txt
├── gfp_gan.py
├── gradio_app.py
├── image_info.py
├── img2img.py
├── inpainting.py
├── interrogator.py
├── pages
│ ├── 1_Inpainting.py
│ ├── 2_Utilities.py
│ ├── 3_FAQs.py
│ └── 4_Code of Conduct.py
├── text2img.py
├── textual_inversion.py
├── upscaler.py
├── utils.py
└── x2image.py
├── docs
├── Makefile
├── conf.py
├── index.rst
└── make.bat
├── requirements.txt
├── setup.cfg
├── setup.py
└── static
├── .keep
├── screenshot.jpeg
├── screenshot.png
└── screenshot_st.png
/.gitignore:
--------------------------------------------------------------------------------
1 | # local stuff
2 | .vscode/
3 | examples/.logs/
4 | *.bin
5 | *.csv
6 | input/
7 | models/
8 | logs/
9 | gfpgan/
10 | *.pkl
11 | *.pt
12 | *.pth
13 | abhishek/
14 | diffout/
15 |
16 | # Byte-compiled / optimized / DLL files
17 | __pycache__/
18 | *.py[cod]
19 | *$py.class
20 |
21 | # C extensions
22 | *.so
23 |
24 | # Distribution / packaging
25 | .Python
26 | build/
27 | develop-eggs/
28 | dist/
29 | downloads/
30 | eggs/
31 | .eggs/
32 | lib/
33 | lib64/
34 | parts/
35 | sdist/
36 | var/
37 | pip-wheel-metadata/
38 | share/python-wheels/
39 | *.egg-info/
40 | .installed.cfg
41 | *.egg
42 | MANIFEST
43 | src/
44 |
45 | # PyInstaller
46 | # Usually these files are written by a python script from a template
47 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
48 | *.manifest
49 | *.spec
50 |
51 | # Installer logs
52 | pip-log.txt
53 | pip-delete-this-directory.txt
54 |
55 | # Unit test / coverage reports
56 | htmlcov/
57 | .tox/
58 | .nox/
59 | .coverage
60 | .coverage.*
61 | .cache
62 | nosetests.xml
63 | coverage.xml
64 | *.cover
65 | *.py,cover
66 | .hypothesis/
67 | .pytest_cache/
68 |
69 | # Translations
70 | *.mo
71 | *.pot
72 |
73 | # Django stuff:
74 | *.log
75 | local_settings.py
76 | db.sqlite3
77 | db.sqlite3-journal
78 |
79 | # Flask stuff:
80 | instance/
81 | .webassets-cache
82 |
83 | # Scrapy stuff:
84 | .scrapy
85 |
86 | # Sphinx documentation
87 | docs/_build/
88 |
89 | # PyBuilder
90 | target/
91 |
92 | # Jupyter Notebook
93 | .ipynb_checkpoints
94 |
95 | # IPython
96 | profile_default/
97 | ipython_config.py
98 |
99 | # pyenv
100 | .python-version
101 |
102 | # pipenv
103 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
104 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
105 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
106 | # install all needed dependencies.
107 | #Pipfile.lock
108 |
109 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
110 | __pypackages__/
111 |
112 | # Celery stuff
113 | celerybeat-schedule
114 | celerybeat.pid
115 |
116 | # SageMath parsed files
117 | *.sage.py
118 |
119 | # Environments
120 | .env
121 | .venv
122 | env/
123 | venv/
124 | ENV/
125 | env.bak/
126 | venv.bak/
127 |
128 | # Spyder project settings
129 | .spyderproject
130 | .spyproject
131 |
132 | # Rope project settings
133 | .ropeproject
134 |
135 | # mkdocs documentation
136 | /site
137 |
138 | # mypy
139 | .mypy_cache/
140 | .dmypy.json
141 | dmypy.json
142 |
143 | # Pyre type checker
144 | .pyre/
145 |
146 | # vscode
147 | .vscode/
148 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
2 |
3 | ARG DEBIAN_FRONTEND=noninteractive
4 | ENV PATH="${HOME}/miniconda3/bin:${PATH}"
5 | ARG PATH="${HOME}/miniconda3/bin:${PATH}"
6 |
7 | RUN apt-get update && \
8 | apt-get upgrade -y && \
9 | apt-get install -y \
10 | build-essential \
11 | cmake \
12 | curl \
13 | ca-certificates \
14 | gcc \
15 | locales \
16 | wget \
17 | git \
18 | git-lfs \
19 | ffmpeg \
20 | libsm6 \
21 | libxext6 \
22 | && rm -rf /var/lib/apt/lists/*
23 |
24 | RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash && \
25 | git lfs install
26 |
27 | WORKDIR /app
28 | ENV HOME=/app
29 |
30 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
31 | && bash Miniconda3-latest-Linux-x86_64.sh -b -p /app/miniconda \
32 | && rm -f Miniconda3-latest-Linux-x86_64.sh
33 | ENV PATH /app/miniconda/bin:$PATH
34 |
35 | RUN conda create -p /app/env -y python=3.8
36 |
37 | SHELL ["conda", "run","--no-capture-output", "-p","/app/env", "/bin/bash", "-c"]
38 |
39 | RUN conda install pytorch torchvision pytorch-cuda=11.7 -c pytorch -c nvidia
40 | RUN pip install https://github.com/abhishekkrthakur/xformers/raw/main/xformers-0.0.15%2B7e05e2c.d20221223-cp38-cp38-linux_x86_64.whl
41 |
42 | COPY requirements.txt /app/requirements.txt
43 | RUN pip install -U --no-cache-dir -r /app/requirements.txt
44 |
45 | COPY . /app
46 | RUN cd /app && pip install -e .
47 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include diffuzers/pages/*.py
2 | include diffuzers/data/*.txt
3 | include OTHER_LICENSES
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: quality style
2 |
3 | quality:
4 | python -m black --check --line-length 119 --target-version py38 .
5 | python -m isort --check-only .
6 | python -m flake8 --max-line-length 119 .
7 |
8 | style:
9 | python -m black --line-length 119 --target-version py38 .
10 | python -m isort .
--------------------------------------------------------------------------------
/OTHER_LICENSES:
--------------------------------------------------------------------------------
1 | clip_interrogator licence
2 | for diffuzes/clip_interrogator.py
3 | for diffuzes/data/*.txt
4 | -------------------------
5 | MIT License
6 |
7 | Copyright (c) 2022 pharmapsychotic
8 |
9 | Permission is hereby granted, free of charge, to any person obtaining a copy
10 | of this software and associated documentation files (the "Software"), to deal
11 | in the Software without restriction, including without limitation the rights
12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 | copies of the Software, and to permit persons to whom the Software is
14 | furnished to do so, subject to the following conditions:
15 |
16 | The above copyright notice and this permission notice shall be included in all
17 | copies or substantial portions of the Software.
18 |
19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 | SOFTWARE.
26 |
27 | blip licence
28 | for diffuzes/blip.py
29 | ------------
30 | Copyright (c) 2022, Salesforce.com, Inc.
31 | All rights reserved.
32 |
33 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
34 |
35 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
36 |
37 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
38 |
39 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
40 |
41 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # diffuzers
2 |
3 | A web ui and deployable API for [🤗 diffusers](https://github.com/huggingface/diffusers).
4 |
5 | < under development, request features using issues, prs not accepted atm >
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | 
16 |
17 |
18 | If something doesnt work as expected, or if you need some features which are not available, then create request using [github issues](https://github.com/abhishekkrthakur/diffuzers/issues)
19 |
20 |
21 | ## Features available in the app:
22 |
23 | - text to image
24 | - image to image
25 | - instruct pix2pix
26 | - textual inversion
27 | - inpainting
28 | - outpainting (coming soon)
29 | - image info
30 | - stable diffusion upscaler
31 | - gfpgan
32 | - clip interrogator
33 | - more coming soon!
34 |
35 | ## Features available in the api:
36 |
37 | - text to image
38 | - image to image
39 | - instruct pix2pix
40 | - textual inversion
41 | - inpainting
42 | - outpainting (via inpainting)
43 | - more coming soon!
44 |
45 |
46 | ## Installation
47 |
48 | To install bleeding edge version of diffuzers, clone the repo and install it using pip.
49 |
50 | ```bash
51 | git clone https://github.com/abhishekkrthakur/diffuzers
52 | cd diffuzers
53 | pip install -e .
54 | ```
55 |
56 | Installation using pip:
57 |
58 | ```bash
59 | pip install diffuzers
60 | ```
61 |
62 | ## Usage
63 |
64 | ### Web App
65 | To run the web app, run the following command:
66 |
67 | ```bash
68 | diffuzers app
69 | ```
70 |
71 | ### API
72 |
73 | To run the api, run the following command:
74 |
75 |
76 | ```bash
77 | diffuzers api
78 | ```
79 |
80 | Starting the API requires the following environment variables:
81 |
82 | ```
83 | export X2IMG_MODEL=stabilityai/stable-diffusion-2-1
84 | export DEVICE=cuda
85 | ```
86 |
87 | If you want to use inpainting:
88 |
89 | ```
90 | export INPAINTING_MODEL=stabilityai/stable-diffusion-2-inpainting
91 | ```
92 |
93 | To use long prompt weighting, use:
94 |
95 | ```
96 | export PIPELINE=lpw_stable_diffusion
97 | ```
98 |
99 | If you have `OUTPUT_PATH` in environment variables, all generations will be saved in `OUTPUT_PATH`. You can also use other (or private) huggingface models. To use private models, you must login using `huggingface-cli login`.
100 |
101 | API docs are available at `host:port/docs`. For example, with default settings, you can access docs at: `127.0.0.1:10000/docs`.
102 |
103 |
104 | ## All CLI Options for running the app:
105 |
106 | ```bash
107 | ❯ diffuzers app --help
108 | usage: diffuzers [] app [-h] [--output OUTPUT] [--share] [--port PORT] [--host HOST]
109 | [--device DEVICE] [--ngrok_key NGROK_KEY]
110 |
111 | ✨ Run diffuzers app
112 |
113 | optional arguments:
114 | -h, --help show this help message and exit
115 | --output OUTPUT Output path is optional, but if provided, all generations will automatically be saved to this
116 | path.
117 | --share Share the app
118 | --port PORT Port to run the app on
119 | --host HOST Host to run the app on
120 | --device DEVICE Device to use, e.g. cpu, cuda, cuda:0, mps (for m1 mac) etc.
121 | --ngrok_key NGROK_KEY
122 | Ngrok key to use for sharing the app. Only required if you want to share the app
123 | ```
124 |
125 | ## All CLI Options for running the api:
126 |
127 | ```bash
128 | ❯ diffuzers api --help
129 | usage: diffuzers [] api [-h] [--output OUTPUT] [--port PORT] [--host HOST] [--device DEVICE]
130 | [--workers WORKERS]
131 |
132 | ✨ Run diffuzers api
133 |
134 | optional arguments:
135 | -h, --help show this help message and exit
136 | --output OUTPUT Output path is optional, but if provided, all generations will automatically be saved to this
137 | path.
138 | --port PORT Port to run the app on
139 | --host HOST Host to run the app on
140 | --device DEVICE Device to use, e.g. cpu, cuda, cuda:0, mps (for m1 mac) etc.
141 | --workers WORKERS Number of workers to use
142 | ```
143 |
144 | ## Using private models from huggingface hub
145 |
146 | If you want to use private models from huggingface hub, then you need to login using `huggingface-cli login` command.
147 |
148 | Note: You can also save your generations directly to huggingface hub if your output path points to a huggingface hub dataset repo and you have access to push to that repository. Thus, you will end up saving a lot of disk space.
149 |
--------------------------------------------------------------------------------
/diffuzers.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "!pip install diffuzers -qqqq"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "!diffuzers app --port 10000 --ngrok_key YOUR_NGROK_AUTHTOKEN --share"
19 | ]
20 | }
21 | ],
22 | "metadata": {
23 | "kernelspec": {
24 | "display_name": "ml",
25 | "language": "python",
26 | "name": "python3"
27 | },
28 | "language_info": {
29 | "name": "python",
30 | "version": "3.8.13 (default, Mar 28 2022, 11:38:47) \n[GCC 7.5.0]"
31 | },
32 | "orig_nbformat": 4,
33 | "vscode": {
34 | "interpreter": {
35 | "hash": "a6cdb91f3f78c48bad72cb110b9a247c2b670f553f33c17c34bce0072b5c1357"
36 | }
37 | }
38 | },
39 | "nbformat": 4,
40 | "nbformat_minor": 2
41 | }
42 |
--------------------------------------------------------------------------------
/diffuzers/Home.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import streamlit as st
4 | from loguru import logger
5 |
6 | from diffuzers import utils
7 | from diffuzers.x2image import X2Image
8 |
9 |
10 | def parse_args():
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument(
13 | "--output",
14 | type=str,
15 | required=False,
16 | default=None,
17 | help="Output path",
18 | )
19 | parser.add_argument(
20 | "--device",
21 | type=str,
22 | required=True,
23 | help="Device to use, e.g. cpu, cuda, cuda:0, mps etc.",
24 | )
25 | return parser.parse_args()
26 |
27 |
28 | def x2img_app():
29 | with st.form("x2img_model_form"):
30 | col1, col2 = st.columns(2)
31 | with col1:
32 | model = st.text_input(
33 | "Which model do you want to use?",
34 | value="stabilityai/stable-diffusion-2-base"
35 | if st.session_state.get("x2img_model") is None
36 | else st.session_state.x2img_model,
37 | )
38 | with col2:
39 | custom_pipeline = st.selectbox(
40 | "Custom pipeline",
41 | options=[
42 | "Vanilla",
43 | "Long Prompt Weighting",
44 | ],
45 | index=0 if st.session_state.get("x2img_custom_pipeline") in (None, "Vanilla") else 1,
46 | )
47 |
48 | with st.expander("Textual Inversion (Optional)"):
49 | token_identifier = st.text_input(
50 | "Token identifier",
51 | placeholder=""
52 | if st.session_state.get("textual_inversion_token_identifier") is None
53 | else st.session_state.textual_inversion_token_identifier,
54 | )
55 | embeddings = st.text_input(
56 | "Embeddings",
57 | placeholder="https://huggingface.co/sd-concepts-library/axe-tattoo/resolve/main/learned_embeds.bin"
58 | if st.session_state.get("textual_inversion_embeddings") is None
59 | else st.session_state.textual_inversion_embeddings,
60 | )
61 | submit = st.form_submit_button("Load model")
62 |
63 | if submit:
64 | st.session_state.x2img_model = model
65 | st.session_state.x2img_custom_pipeline = custom_pipeline
66 | st.session_state.textual_inversion_token_identifier = token_identifier
67 | st.session_state.textual_inversion_embeddings = embeddings
68 | cpipe = "lpw_stable_diffusion" if custom_pipeline == "Long Prompt Weighting" else None
69 | with st.spinner("Loading model..."):
70 | x2img = X2Image(
71 | model=model,
72 | device=st.session_state.device,
73 | output_path=st.session_state.output_path,
74 | custom_pipeline=cpipe,
75 | token_identifier=token_identifier,
76 | embeddings_url=embeddings,
77 | )
78 | st.session_state.x2img = x2img
79 | if "x2img" in st.session_state:
80 | st.write(f"Current model: {st.session_state.x2img}")
81 | st.session_state.x2img.app()
82 |
83 |
84 | def run_app():
85 | utils.create_base_page()
86 | x2img_app()
87 |
88 |
89 | if __name__ == "__main__":
90 | args = parse_args()
91 | logger.info(f"Args: {args}")
92 | logger.info(st.session_state)
93 | st.session_state.device = args.device
94 | st.session_state.output_path = args.output
95 | run_app()
96 |
--------------------------------------------------------------------------------
/diffuzers/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from loguru import logger
4 |
5 |
6 | logger.configure(handlers=[dict(sink=sys.stderr, format="> {level:<7} {message}")])
7 |
8 | __version__ = "0.3.5"
9 |
--------------------------------------------------------------------------------
/diffuzers/api/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abhishekkrthakur/diffuzers/71f710a1940c22637561809387b1c8622be1a48b/diffuzers/api/__init__.py
--------------------------------------------------------------------------------
/diffuzers/api/main.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 |
4 | from fastapi import Depends, FastAPI, File, UploadFile
5 | from loguru import logger
6 | from PIL import Image
7 | from starlette.middleware.cors import CORSMiddleware
8 |
9 | from diffuzers.api.schemas import Img2ImgParams, ImgResponse, InpaintingParams, InstructPix2PixParams, Text2ImgParams
10 | from diffuzers.api.utils import convert_to_b64_list
11 | from diffuzers.inpainting import Inpainting
12 | from diffuzers.x2image import X2Image
13 |
14 |
15 | app = FastAPI(
16 | title="diffuzers api",
17 | license_info={
18 | "name": "Apache 2.0",
19 | "url": "https://www.apache.org/licenses/LICENSE-2.0.html",
20 | },
21 | )
22 | app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
23 |
24 |
25 | @app.on_event("startup")
26 | async def startup_event():
27 |
28 | x2img_model = os.environ.get("X2IMG_MODEL")
29 | x2img_pipeline = os.environ.get("X2IMG_PIPELINE")
30 | inpainting_model = os.environ.get("INPAINTING_MODEL")
31 | device = os.environ.get("DEVICE")
32 | output_path = os.environ.get("OUTPUT_PATH")
33 | ti_identifier = os.environ.get("TOKEN_IDENTIFIER", "")
34 | ti_embeddings_url = os.environ.get("TOKEN_EMBEDDINGS_URL", "")
35 | logger.info("@@@@@ Starting Diffuzes API @@@@@ ")
36 | logger.info(f"Text2Image Model: {x2img_model}")
37 | logger.info(f"Text2Image Pipeline: {x2img_pipeline if x2img_pipeline is not None else 'Vanilla'}")
38 | logger.info(f"Inpainting Model: {inpainting_model}")
39 | logger.info(f"Device: {device}")
40 | logger.info(f"Output Path: {output_path}")
41 | logger.info(f"Token Identifier: {ti_identifier}")
42 | logger.info(f"Token Embeddings URL: {ti_embeddings_url}")
43 |
44 | logger.info("Loading x2img model...")
45 | if x2img_model is not None:
46 | app.state.x2img_model = X2Image(
47 | model=x2img_model,
48 | device=device,
49 | output_path=output_path,
50 | custom_pipeline=x2img_pipeline,
51 | token_identifier=ti_identifier,
52 | embeddings_url=ti_embeddings_url,
53 | )
54 | else:
55 | app.state.x2img_model = None
56 | logger.info("Loading inpainting model...")
57 | if inpainting_model is not None:
58 | app.state.inpainting_model = Inpainting(
59 | model=inpainting_model,
60 | device=device,
61 | output_path=output_path,
62 | )
63 | logger.info("API is ready to use!")
64 |
65 |
66 | @app.post("/text2img")
67 | async def text2img(params: Text2ImgParams) -> ImgResponse:
68 | logger.info(f"Params: {params}")
69 | if app.state.x2img_model is None:
70 | return {"error": "x2img model is not loaded"}
71 |
72 | images, _ = app.state.x2img_model.text2img_generate(
73 | params.prompt,
74 | num_images=params.num_images,
75 | steps=params.steps,
76 | seed=params.seed,
77 | negative_prompt=params.negative_prompt,
78 | scheduler=params.scheduler,
79 | image_size=(params.image_height, params.image_width),
80 | guidance_scale=params.guidance_scale,
81 | )
82 | base64images = convert_to_b64_list(images)
83 | return ImgResponse(images=base64images, metadata=params.dict())
84 |
85 |
86 | @app.post("/img2img")
87 | async def img2img(params: Img2ImgParams = Depends(), image: UploadFile = File(...)) -> ImgResponse:
88 | if app.state.x2img_model is None:
89 | return {"error": "x2img model is not loaded"}
90 | image = Image.open(io.BytesIO(image.file.read()))
91 | images, _ = app.state.x2img_model.img2img_generate(
92 | image=image,
93 | prompt=params.prompt,
94 | negative_prompt=params.negative_prompt,
95 | num_images=params.num_images,
96 | steps=params.steps,
97 | seed=params.seed,
98 | scheduler=params.scheduler,
99 | guidance_scale=params.guidance_scale,
100 | strength=params.strength,
101 | )
102 | base64images = convert_to_b64_list(images)
103 | return ImgResponse(images=base64images, metadata=params.dict())
104 |
105 |
106 | @app.post("/instruct-pix2pix")
107 | async def instruct_pix2pix(params: InstructPix2PixParams = Depends(), image: UploadFile = File(...)) -> ImgResponse:
108 | if app.state.x2img_model is None:
109 | return {"error": "x2img model is not loaded"}
110 | image = Image.open(io.BytesIO(image.file.read()))
111 | images, _ = app.state.x2img_model.pix2pix_generate(
112 | image=image,
113 | prompt=params.prompt,
114 | negative_prompt=params.negative_prompt,
115 | num_images=params.num_images,
116 | steps=params.steps,
117 | seed=params.seed,
118 | scheduler=params.scheduler,
119 | guidance_scale=params.guidance_scale,
120 | image_guidance_scale=params.image_guidance_scale,
121 | )
122 | base64images = convert_to_b64_list(images)
123 | return ImgResponse(images=base64images, metadata=params.dict())
124 |
125 |
126 | @app.post("/inpainting")
127 | async def inpainting(
128 | params: InpaintingParams = Depends(), image: UploadFile = File(...), mask: UploadFile = File(...)
129 | ) -> ImgResponse:
130 | if app.state.inpainting_model is None:
131 | return {"error": "inpainting model is not loaded"}
132 | image = Image.open(io.BytesIO(image.file.read()))
133 | mask = Image.open(io.BytesIO(mask.file.read()))
134 | images, _ = app.state.inpainting_model.generate_image(
135 | image=image,
136 | mask=mask,
137 | prompt=params.prompt,
138 | negative_prompt=params.negative_prompt,
139 | scheduler=params.scheduler,
140 | height=params.image_height,
141 | width=params.image_width,
142 | num_images=params.num_images,
143 | guidance_scale=params.guidance_scale,
144 | steps=params.steps,
145 | seed=params.seed,
146 | )
147 | base64images = convert_to_b64_list(images)
148 | return ImgResponse(images=base64images, metadata=params.dict())
149 |
150 |
151 | @app.get("/")
152 | def read_root():
153 | return {"Hello": "World"}
154 |
--------------------------------------------------------------------------------
/diffuzers/api/schemas.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 |
3 | from pydantic import BaseModel, Field
4 |
5 |
6 | class Text2ImgParams(BaseModel):
7 | prompt: str = Field(..., description="Text prompt for the model")
8 | negative_prompt: str = Field(None, description="Negative text prompt for the model")
9 | scheduler: str = Field("EulerAncestralDiscreteScheduler", description="Scheduler to use for the model")
10 | image_height: int = Field(512, description="Image height")
11 | image_width: int = Field(512, description="Image width")
12 | num_images: int = Field(1, description="Number of images to generate")
13 | guidance_scale: float = Field(7, description="Guidance scale")
14 | steps: int = Field(50, description="Number of steps to run the model for")
15 | seed: int = Field(42, description="Seed for the model")
16 |
17 |
18 | class Img2ImgParams(BaseModel):
19 | prompt: str = Field(..., description="Text prompt for the model")
20 | negative_prompt: str = Field(None, description="Negative text prompt for the model")
21 | scheduler: str = Field("EulerAncestralDiscreteScheduler", description="Scheduler to use for the model")
22 | strength: float = Field(0.7, description="Strength")
23 | num_images: int = Field(1, description="Number of images to generate")
24 | guidance_scale: float = Field(7, description="Guidance scale")
25 | steps: int = Field(50, description="Number of steps to run the model for")
26 | seed: int = Field(42, description="Seed for the model")
27 |
28 |
29 | class InstructPix2PixParams(BaseModel):
30 | prompt: str = Field(..., description="Text prompt for the model")
31 | negative_prompt: str = Field(None, description="Negative text prompt for the model")
32 | scheduler: str = Field("EulerAncestralDiscreteScheduler", description="Scheduler to use for the model")
33 | num_images: int = Field(1, description="Number of images to generate")
34 | guidance_scale: float = Field(7, description="Guidance scale")
35 | image_guidance_scale: float = Field(1.5, description="Image guidance scale")
36 | steps: int = Field(50, description="Number of steps to run the model for")
37 | seed: int = Field(42, description="Seed for the model")
38 |
39 |
40 | class ImgResponse(BaseModel):
41 | images: List[str] = Field(..., description="List of images in base64 format")
42 | metadata: Dict = Field(..., description="Metadata")
43 |
44 |
45 | class InpaintingParams(BaseModel):
46 | prompt: str = Field(..., description="Text prompt for the model")
47 | negative_prompt: str = Field(None, description="Negative text prompt for the model")
48 | scheduler: str = Field("EulerAncestralDiscreteScheduler", description="Scheduler to use for the model")
49 | image_height: int = Field(512, description="Image height")
50 | image_width: int = Field(512, description="Image width")
51 | num_images: int = Field(1, description="Number of images to generate")
52 | guidance_scale: float = Field(7, description="Guidance scale")
53 | steps: int = Field(50, description="Number of steps to run the model for")
54 | seed: int = Field(42, description="Seed for the model")
55 |
--------------------------------------------------------------------------------
/diffuzers/api/utils.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import io
3 |
4 |
5 | def convert_to_b64_list(images):
6 | base64images = []
7 | for image in images:
8 | buf = io.BytesIO()
9 | image.save(buf, format="PNG")
10 | byte_im = base64.b64encode(buf.getvalue())
11 | base64images.append(byte_im)
12 | return base64images
13 |
--------------------------------------------------------------------------------
/diffuzers/cli/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from argparse import ArgumentParser
3 |
4 |
5 | class BaseDiffuzersCommand(ABC):
6 | @staticmethod
7 | @abstractmethod
8 | def register_subcommand(parser: ArgumentParser):
9 | raise NotImplementedError()
10 |
11 | @abstractmethod
12 | def run(self):
13 | raise NotImplementedError()
14 |
--------------------------------------------------------------------------------
/diffuzers/cli/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from .. import __version__
4 | from .run_api import RunDiffuzersAPICommand
5 | from .run_app import RunDiffuzersAppCommand
6 |
7 |
8 | def main():
9 | parser = argparse.ArgumentParser(
10 | "Diffuzers CLI",
11 | usage="diffuzers []",
12 | epilog="For more information about a command, run: `diffuzers --help`",
13 | )
14 | parser.add_argument("--version", "-v", help="Display diffuzers version", action="store_true")
15 | commands_parser = parser.add_subparsers(help="commands")
16 |
17 | # Register commands
18 | RunDiffuzersAppCommand.register_subcommand(commands_parser)
19 | RunDiffuzersAPICommand.register_subcommand(commands_parser)
20 |
21 | args = parser.parse_args()
22 |
23 | if args.version:
24 | print(__version__)
25 | exit(0)
26 |
27 | if not hasattr(args, "func"):
28 | parser.print_help()
29 | exit(1)
30 |
31 | command = args.func(args)
32 | command.run()
33 |
34 |
35 | if __name__ == "__main__":
36 | main()
37 |
--------------------------------------------------------------------------------
/diffuzers/cli/run_api.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from argparse import ArgumentParser
3 |
4 | import torch
5 |
6 | from . import BaseDiffuzersCommand
7 |
8 |
9 | def run_api_command_factory(args):
10 | return RunDiffuzersAPICommand(
11 | args.output,
12 | args.port,
13 | args.host,
14 | args.device,
15 | args.workers,
16 | args.ssl_certfile,
17 | args.ssl_keyfile,
18 | )
19 |
20 |
21 | class RunDiffuzersAPICommand(BaseDiffuzersCommand):
22 | @staticmethod
23 | def register_subcommand(parser: ArgumentParser):
24 | run_api_parser = parser.add_parser(
25 | "api",
26 | description="✨ Run diffuzers api",
27 | )
28 | run_api_parser.add_argument(
29 | "--output",
30 | type=str,
31 | required=False,
32 | help="Output path is optional, but if provided, all generations will automatically be saved to this path.",
33 | )
34 | run_api_parser.add_argument(
35 | "--port",
36 | type=int,
37 | default=10000,
38 | help="Port to run the app on",
39 | required=False,
40 | )
41 | run_api_parser.add_argument(
42 | "--host",
43 | type=str,
44 | default="127.0.0.1",
45 | help="Host to run the app on",
46 | required=False,
47 | )
48 | run_api_parser.add_argument(
49 | "--device",
50 | type=str,
51 | required=False,
52 | help="Device to use, e.g. cpu, cuda, cuda:0, mps (for m1 mac) etc.",
53 | )
54 | run_api_parser.add_argument(
55 | "--workers",
56 | type=int,
57 | required=False,
58 | default=1,
59 | help="Number of workers to use",
60 | )
61 | run_api_parser.add_argument(
62 | "--ssl_certfile",
63 | type=str,
64 | required=False,
65 | help="the path to your ssl cert",
66 | )
67 | run_api_parser.add_argument(
68 | "--ssl_keyfile",
69 | type=str,
70 | required=False,
71 | help="the path to your ssl key",
72 | )
73 | run_api_parser.set_defaults(func=run_api_command_factory)
74 |
75 | def __init__(self, output, port, host, device, workers, ssl_certfile, ssl_keyfile):
76 | self.output = output
77 | self.port = port
78 | self.host = host
79 | self.device = device
80 | self.workers = workers
81 | self.ssl_certfile = ssl_certfile
82 | self.ssl_keyfile = ssl_keyfile
83 |
84 | if self.device is None:
85 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
86 |
87 | self.port = str(self.port)
88 | self.workers = str(self.workers)
89 |
90 | def run(self):
91 | cmd = [
92 | "uvicorn",
93 | "diffuzers.api.main:app",
94 | "--host",
95 | self.host,
96 | "--port",
97 | self.port,
98 | "--workers",
99 | self.workers,
100 | ]
101 | if self.ssl_certfile is not None:
102 | cmd.extend(["--ssl-certfile", self.ssl_certfile])
103 | if self.ssl_keyfile is not None:
104 | cmd.extend(["--ssl-keyfile", self.ssl_keyfile])
105 |
106 | proc = subprocess.Popen(
107 | cmd,
108 | stdout=subprocess.PIPE,
109 | stderr=subprocess.STDOUT,
110 | shell=False,
111 | universal_newlines=True,
112 | bufsize=1,
113 | )
114 | with proc as p:
115 | try:
116 | for line in p.stdout:
117 | print(line, end="")
118 | except KeyboardInterrupt:
119 | print("Killing api")
120 | p.kill()
121 | p.wait()
122 | raise
123 |
--------------------------------------------------------------------------------
/diffuzers/cli/run_app.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from argparse import ArgumentParser
3 |
4 | import torch
5 | from pyngrok import ngrok
6 |
7 | from . import BaseDiffuzersCommand
8 |
9 |
10 | def run_app_command_factory(args):
11 | return RunDiffuzersAppCommand(
12 | args.output,
13 | args.share,
14 | args.port,
15 | args.host,
16 | args.device,
17 | args.ngrok_key,
18 | )
19 |
20 |
21 | class RunDiffuzersAppCommand(BaseDiffuzersCommand):
22 | @staticmethod
23 | def register_subcommand(parser: ArgumentParser):
24 | run_app_parser = parser.add_parser(
25 | "app",
26 | description="✨ Run diffuzers app",
27 | )
28 | run_app_parser.add_argument(
29 | "--output",
30 | type=str,
31 | required=False,
32 | help="Output path is optional, but if provided, all generations will automatically be saved to this path.",
33 | )
34 | run_app_parser.add_argument(
35 | "--share",
36 | action="store_true",
37 | help="Share the app",
38 | )
39 | run_app_parser.add_argument(
40 | "--port",
41 | type=int,
42 | default=10000,
43 | help="Port to run the app on",
44 | required=False,
45 | )
46 | run_app_parser.add_argument(
47 | "--host",
48 | type=str,
49 | default="127.0.0.1",
50 | help="Host to run the app on",
51 | required=False,
52 | )
53 | run_app_parser.add_argument(
54 | "--device",
55 | type=str,
56 | required=False,
57 | help="Device to use, e.g. cpu, cuda, cuda:0, mps (for m1 mac) etc.",
58 | )
59 | run_app_parser.add_argument(
60 | "--ngrok_key",
61 | type=str,
62 | required=False,
63 | help="Ngrok key to use for sharing the app. Only required if you want to share the app",
64 | )
65 |
66 | run_app_parser.set_defaults(func=run_app_command_factory)
67 |
68 | def __init__(self, output, share, port, host, device, ngrok_key):
69 | self.output = output
70 | self.share = share
71 | self.port = port
72 | self.host = host
73 | self.device = device
74 | self.ngrok_key = ngrok_key
75 |
76 | if self.device is None:
77 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
78 |
79 | if self.share:
80 | if self.ngrok_key is None:
81 | raise ValueError(
82 | "ngrok key is required if you want to share the app. Get it for free from https://dashboard.ngrok.com/get-started/your-authtoken"
83 | )
84 |
85 | def run(self):
86 | # from ..app import Diffuzers
87 |
88 | # print(self.share)
89 | # app = Diffuzers(self.model, self.output).app()
90 | # app.launch(show_api=False, share=self.share, server_port=self.port, server_name=self.host)
91 | import os
92 |
93 | dirname = os.path.dirname(__file__)
94 | filename = os.path.join(dirname, "..", "Home.py")
95 | cmd = [
96 | "streamlit",
97 | "run",
98 | filename,
99 | "--browser.gatherUsageStats",
100 | "false",
101 | "--browser.serverAddress",
102 | self.host,
103 | "--server.port",
104 | str(self.port),
105 | "--theme.base",
106 | "light",
107 | "--",
108 | "--device",
109 | self.device,
110 | ]
111 | if self.output is not None:
112 | cmd.extend(["--output", self.output])
113 |
114 | if self.share:
115 | ngrok.set_auth_token(self.ngrok_key)
116 | public_url = ngrok.connect(self.port).public_url
117 | print(f"Sharing app at {public_url}")
118 |
119 | proc = subprocess.Popen(
120 | cmd,
121 | stdout=subprocess.PIPE,
122 | stderr=subprocess.STDOUT,
123 | shell=False,
124 | universal_newlines=True,
125 | bufsize=1,
126 | )
127 | with proc as p:
128 | try:
129 | for line in p.stdout:
130 | print(line, end="")
131 | except KeyboardInterrupt:
132 | print("Killing streamlit app")
133 | p.kill()
134 | if self.share:
135 | print("Killing ngrok tunnel")
136 | ngrok.kill()
137 | p.wait()
138 | raise
139 |
--------------------------------------------------------------------------------
/diffuzers/clip_interrogator.py:
--------------------------------------------------------------------------------
1 | """
2 | Code taken from: https://github.com/pharmapsychotic/clip-interrogator
3 | """
4 |
5 | import hashlib
6 | import inspect
7 | import math
8 | import os
9 | import pickle
10 | import time
11 | from dataclasses import dataclass
12 | from typing import List
13 |
14 | import numpy as np
15 | import open_clip
16 | import torch
17 | from PIL import Image
18 | from torchvision import transforms
19 | from torchvision.transforms.functional import InterpolationMode
20 | from tqdm import tqdm
21 |
22 | from diffuzers.blip import BLIP_Decoder, blip_decoder
23 |
24 |
25 | @dataclass
26 | class Config:
27 | # models can optionally be passed in directly
28 | blip_model: BLIP_Decoder = None
29 | clip_model = None
30 | clip_preprocess = None
31 |
32 | # blip settings
33 | blip_image_eval_size: int = 384
34 | blip_max_length: int = 32
35 | blip_model_url: str = (
36 | "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth"
37 | )
38 | blip_num_beams: int = 8
39 | blip_offload: bool = False
40 |
41 | # clip settings
42 | clip_model_name: str = "ViT-L-14/openai"
43 | clip_model_path: str = None
44 |
45 | # interrogator settings
46 | cache_path: str = "cache"
47 | chunk_size: int = 2048
48 | data_path: str = os.path.join(os.path.dirname(__file__), "data")
49 | device: str = "cuda" if torch.cuda.is_available() else "cpu"
50 | flavor_intermediate_count: int = 2048
51 | quiet: bool = False # when quiet progress bars are not shown
52 |
53 |
54 | class Interrogator:
55 | def __init__(self, config: Config):
56 | self.config = config
57 | self.device = config.device
58 |
59 | if config.blip_model is None:
60 | if not config.quiet:
61 | print("Loading BLIP model...")
62 | blip_path = os.path.dirname(inspect.getfile(blip_decoder))
63 | configs_path = os.path.join(os.path.dirname(blip_path), "configs")
64 | med_config = os.path.join(configs_path, "med_config.json")
65 | blip_model = blip_decoder(
66 | pretrained=config.blip_model_url,
67 | image_size=config.blip_image_eval_size,
68 | vit="large",
69 | med_config=med_config,
70 | )
71 | blip_model.eval()
72 | blip_model = blip_model.to(config.device)
73 | self.blip_model = blip_model
74 | else:
75 | self.blip_model = config.blip_model
76 |
77 | self.load_clip_model()
78 |
79 | def load_clip_model(self):
80 | start_time = time.time()
81 | config = self.config
82 |
83 | if config.clip_model is None:
84 | if not config.quiet:
85 | print("Loading CLIP model...")
86 |
87 | clip_model_name, clip_model_pretrained_name = config.clip_model_name.split("/", 2)
88 | self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
89 | clip_model_name,
90 | pretrained=clip_model_pretrained_name,
91 | precision="fp16" if config.device == "cuda" else "fp32",
92 | device=config.device,
93 | jit=False,
94 | cache_dir=config.clip_model_path,
95 | )
96 | self.clip_model.to(config.device).eval()
97 | else:
98 | self.clip_model = config.clip_model
99 | self.clip_preprocess = config.clip_preprocess
100 | self.tokenize = open_clip.get_tokenizer(clip_model_name)
101 |
102 | sites = [
103 | "Artstation",
104 | "behance",
105 | "cg society",
106 | "cgsociety",
107 | "deviantart",
108 | "dribble",
109 | "flickr",
110 | "instagram",
111 | "pexels",
112 | "pinterest",
113 | "pixabay",
114 | "pixiv",
115 | "polycount",
116 | "reddit",
117 | "shutterstock",
118 | "tumblr",
119 | "unsplash",
120 | "zbrush central",
121 | ]
122 | trending_list = [site for site in sites]
123 | trending_list.extend(["trending on " + site for site in sites])
124 | trending_list.extend(["featured on " + site for site in sites])
125 | trending_list.extend([site + " contest winner" for site in sites])
126 |
127 | raw_artists = _load_list(config.data_path, "artists.txt")
128 | artists = [f"by {a}" for a in raw_artists]
129 | artists.extend([f"inspired by {a}" for a in raw_artists])
130 |
131 | self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config)
132 | self.flavors = LabelTable(
133 | _load_list(config.data_path, "flavors.txt"), "flavors", self.clip_model, self.tokenize, config
134 | )
135 | self.mediums = LabelTable(
136 | _load_list(config.data_path, "mediums.txt"), "mediums", self.clip_model, self.tokenize, config
137 | )
138 | self.movements = LabelTable(
139 | _load_list(config.data_path, "movements.txt"), "movements", self.clip_model, self.tokenize, config
140 | )
141 | self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config)
142 |
143 | end_time = time.time()
144 | if not config.quiet:
145 | print(f"Loaded CLIP model and data in {end_time-start_time:.2f} seconds.")
146 |
147 | def generate_caption(self, pil_image: Image) -> str:
148 | if self.config.blip_offload:
149 | self.blip_model = self.blip_model.to(self.device)
150 | size = self.config.blip_image_eval_size
151 | gpu_image = (
152 | transforms.Compose(
153 | [
154 | transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC),
155 | transforms.ToTensor(),
156 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
157 | ]
158 | )(pil_image)
159 | .unsqueeze(0)
160 | .to(self.device)
161 | )
162 |
163 | with torch.no_grad():
164 | caption = self.blip_model.generate(
165 | gpu_image,
166 | sample=False,
167 | num_beams=self.config.blip_num_beams,
168 | max_length=self.config.blip_max_length,
169 | min_length=5,
170 | )
171 | if self.config.blip_offload:
172 | self.blip_model = self.blip_model.to("cpu")
173 | return caption[0]
174 |
175 | def image_to_features(self, image: Image) -> torch.Tensor:
176 | images = self.clip_preprocess(image).unsqueeze(0).to(self.device)
177 | with torch.no_grad(), torch.cuda.amp.autocast():
178 | image_features = self.clip_model.encode_image(images)
179 | image_features /= image_features.norm(dim=-1, keepdim=True)
180 | return image_features
181 |
182 | def interrogate_classic(self, image: Image, max_flavors: int = 3) -> str:
183 | caption = self.generate_caption(image)
184 | image_features = self.image_to_features(image)
185 |
186 | medium = self.mediums.rank(image_features, 1)[0]
187 | artist = self.artists.rank(image_features, 1)[0]
188 | trending = self.trendings.rank(image_features, 1)[0]
189 | movement = self.movements.rank(image_features, 1)[0]
190 | flaves = ", ".join(self.flavors.rank(image_features, max_flavors))
191 |
192 | if caption.startswith(medium):
193 | prompt = f"{caption} {artist}, {trending}, {movement}, {flaves}"
194 | else:
195 | prompt = f"{caption}, {medium} {artist}, {trending}, {movement}, {flaves}"
196 |
197 | return _truncate_to_fit(prompt, self.tokenize)
198 |
199 | def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str:
200 | caption = self.generate_caption(image)
201 | image_features = self.image_to_features(image)
202 | merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
203 | tops = merged.rank(image_features, max_flavors)
204 | return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize)
205 |
206 | def interrogate(self, image: Image, max_flavors: int = 32) -> str:
207 | caption = self.generate_caption(image)
208 | image_features = self.image_to_features(image)
209 |
210 | flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count)
211 | best_medium = self.mediums.rank(image_features, 1)[0]
212 | best_artist = self.artists.rank(image_features, 1)[0]
213 | best_trending = self.trendings.rank(image_features, 1)[0]
214 | best_movement = self.movements.rank(image_features, 1)[0]
215 |
216 | best_prompt = caption
217 | best_sim = self.similarity(image_features, best_prompt)
218 |
219 | def check(addition: str) -> bool:
220 | nonlocal best_prompt, best_sim
221 | prompt = best_prompt + ", " + addition
222 | sim = self.similarity(image_features, prompt)
223 | if sim > best_sim:
224 | best_sim = sim
225 | best_prompt = prompt
226 | return True
227 | return False
228 |
229 | def check_multi_batch(opts: List[str]):
230 | nonlocal best_prompt, best_sim
231 | prompts = []
232 | for i in range(2 ** len(opts)):
233 | prompt = best_prompt
234 | for bit in range(len(opts)):
235 | if i & (1 << bit):
236 | prompt += ", " + opts[bit]
237 | prompts.append(prompt)
238 |
239 | t = LabelTable(prompts, None, self.clip_model, self.tokenize, self.config)
240 | best_prompt = t.rank(image_features, 1)[0]
241 | best_sim = self.similarity(image_features, best_prompt)
242 |
243 | check_multi_batch([best_medium, best_artist, best_trending, best_movement])
244 |
245 | extended_flavors = set(flaves)
246 | for _ in tqdm(range(max_flavors), desc="Flavor chain", disable=self.config.quiet):
247 | best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors])
248 | flave = best[len(best_prompt) + 2 :]
249 | if not check(flave):
250 | break
251 | if _prompt_at_max_len(best_prompt, self.tokenize):
252 | break
253 | extended_flavors.remove(flave)
254 |
255 | return best_prompt
256 |
257 | def rank_top(self, image_features: torch.Tensor, text_array: List[str]) -> str:
258 | text_tokens = self.tokenize([text for text in text_array]).to(self.device)
259 | with torch.no_grad(), torch.cuda.amp.autocast():
260 | text_features = self.clip_model.encode_text(text_tokens)
261 | text_features /= text_features.norm(dim=-1, keepdim=True)
262 | similarity = text_features @ image_features.T
263 | return text_array[similarity.argmax().item()]
264 |
265 | def similarity(self, image_features: torch.Tensor, text: str) -> float:
266 | text_tokens = self.tokenize([text]).to(self.device)
267 | with torch.no_grad(), torch.cuda.amp.autocast():
268 | text_features = self.clip_model.encode_text(text_tokens)
269 | text_features /= text_features.norm(dim=-1, keepdim=True)
270 | similarity = text_features @ image_features.T
271 | return similarity[0][0].item()
272 |
273 |
274 | class LabelTable:
275 | def __init__(self, labels: List[str], desc: str, clip_model, tokenize, config: Config):
276 | self.chunk_size = config.chunk_size
277 | self.config = config
278 | self.device = config.device
279 | self.embeds = []
280 | self.labels = labels
281 | self.tokenize = tokenize
282 |
283 | hash = hashlib.sha256(",".join(labels).encode()).hexdigest()
284 |
285 | cache_filepath = None
286 | if config.cache_path is not None and desc is not None:
287 | os.makedirs(config.cache_path, exist_ok=True)
288 | sanitized_name = config.clip_model_name.replace("/", "_").replace("@", "_")
289 | cache_filepath = os.path.join(config.cache_path, f"{sanitized_name}_{desc}.pkl")
290 | if desc is not None and os.path.exists(cache_filepath):
291 | with open(cache_filepath, "rb") as f:
292 | try:
293 | data = pickle.load(f)
294 | if data.get("hash") == hash:
295 | self.labels = data["labels"]
296 | self.embeds = data["embeds"]
297 | except Exception as e:
298 | print(f"Error loading cached table {desc}: {e}")
299 |
300 | if len(self.labels) != len(self.embeds):
301 | self.embeds = []
302 | chunks = np.array_split(self.labels, max(1, len(self.labels) / config.chunk_size))
303 | for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet):
304 | text_tokens = self.tokenize(chunk).to(self.device)
305 | with torch.no_grad(), torch.cuda.amp.autocast():
306 | text_features = clip_model.encode_text(text_tokens)
307 | text_features /= text_features.norm(dim=-1, keepdim=True)
308 | text_features = text_features.half().cpu().numpy()
309 | for i in range(text_features.shape[0]):
310 | self.embeds.append(text_features[i])
311 |
312 | if cache_filepath is not None:
313 | with open(cache_filepath, "wb") as f:
314 | pickle.dump(
315 | {"labels": self.labels, "embeds": self.embeds, "hash": hash, "model": config.clip_model_name},
316 | f,
317 | )
318 |
319 | if self.device == "cpu" or self.device == torch.device("cpu"):
320 | self.embeds = [e.astype(np.float32) for e in self.embeds]
321 |
322 | def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int = 1) -> str:
323 | top_count = min(top_count, len(text_embeds))
324 | text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device)
325 | with torch.cuda.amp.autocast():
326 | similarity = image_features @ text_embeds.T
327 | _, top_labels = similarity.float().cpu().topk(top_count, dim=-1)
328 | return [top_labels[0][i].numpy() for i in range(top_count)]
329 |
330 | def rank(self, image_features: torch.Tensor, top_count: int = 1) -> List[str]:
331 | if len(self.labels) <= self.chunk_size:
332 | tops = self._rank(image_features, self.embeds, top_count=top_count)
333 | return [self.labels[i] for i in tops]
334 |
335 | num_chunks = int(math.ceil(len(self.labels) / self.chunk_size))
336 | keep_per_chunk = int(self.chunk_size / num_chunks)
337 |
338 | top_labels, top_embeds = [], []
339 | for chunk_idx in tqdm(range(num_chunks), disable=self.config.quiet):
340 | start = chunk_idx * self.chunk_size
341 | stop = min(start + self.chunk_size, len(self.embeds))
342 | tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk)
343 | top_labels.extend([self.labels[start + i] for i in tops])
344 | top_embeds.extend([self.embeds[start + i] for i in tops])
345 |
346 | tops = self._rank(image_features, top_embeds, top_count=top_count)
347 | return [top_labels[i] for i in tops]
348 |
349 |
350 | def _load_list(data_path: str, filename: str) -> List[str]:
351 | with open(os.path.join(data_path, filename), "r", encoding="utf-8", errors="replace") as f:
352 | items = [line.strip() for line in f.readlines()]
353 | return items
354 |
355 |
356 | def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable:
357 | m = LabelTable([], None, None, None, config)
358 | for table in tables:
359 | m.labels.extend(table.labels)
360 | m.embeds.extend(table.embeds)
361 | return m
362 |
363 |
364 | def _prompt_at_max_len(text: str, tokenize) -> bool:
365 | tokens = tokenize([text])
366 | return tokens[0][-1] != 0
367 |
368 |
369 | def _truncate_to_fit(text: str, tokenize) -> str:
370 | parts = text.split(", ")
371 | new_text = parts[0]
372 | for part in parts[1:]:
373 | if _prompt_at_max_len(new_text + part, tokenize):
374 | break
375 | new_text += ", " + part
376 | return new_text
377 |
--------------------------------------------------------------------------------
/diffuzers/data/mediums.txt:
--------------------------------------------------------------------------------
1 | a 3D render
2 | a black and white photo
3 | a bronze sculpture
4 | a cartoon
5 | a cave painting
6 | a character portrait
7 | a charcoal drawing
8 | a child's drawing
9 | a color pencil sketch
10 | a colorized photo
11 | a comic book panel
12 | a computer rendering
13 | a cross stitch
14 | a cubist painting
15 | a detailed drawing
16 | a detailed matte painting
17 | a detailed painting
18 | a diagram
19 | a digital painting
20 | a digital rendering
21 | a drawing
22 | a fine art painting
23 | a flemish Baroque
24 | a gouache
25 | a hologram
26 | a hyperrealistic painting
27 | a jigsaw puzzle
28 | a low poly render
29 | a macro photograph
30 | a manga drawing
31 | a marble sculpture
32 | a matte painting
33 | a microscopic photo
34 | a mid-nineteenth century engraving
35 | a minimalist painting
36 | a mosaic
37 | a painting
38 | a pastel
39 | a pencil sketch
40 | a photo
41 | a photocopy
42 | a photorealistic painting
43 | a picture
44 | a pointillism painting
45 | a polaroid photo
46 | a pop art painting
47 | a portrait
48 | a poster
49 | a raytraced image
50 | a renaissance painting
51 | a screenprint
52 | a screenshot
53 | a silk screen
54 | a sketch
55 | a statue
56 | a still life
57 | a stipple
58 | a stock photo
59 | a storybook illustration
60 | a surrealist painting
61 | a surrealist sculpture
62 | a tattoo
63 | a tilt shift photo
64 | a watercolor painting
65 | a wireframe diagram
66 | a woodcut
67 | an abstract drawing
68 | an abstract painting
69 | an abstract sculpture
70 | an acrylic painting
71 | an airbrush painting
72 | an album cover
73 | an ambient occlusion render
74 | an anime drawing
75 | an art deco painting
76 | an art deco sculpture
77 | an engraving
78 | an etching
79 | an illustration of
80 | an impressionist painting
81 | an ink drawing
82 | an oil on canvas painting
83 | an oil painting
84 | an ultrafine detailed painting
85 | chalk art
86 | computer graphics
87 | concept art
88 | cyberpunk art
89 | digital art
90 | egyptian art
91 | graffiti art
92 | lineart
93 | pixel art
94 | poster art
95 | vector art
96 |
--------------------------------------------------------------------------------
/diffuzers/data/movements.txt:
--------------------------------------------------------------------------------
1 | abstract art
2 | abstract expressionism
3 | abstract illusionism
4 | academic art
5 | action painting
6 | aestheticism
7 | afrofuturism
8 | altermodern
9 | american barbizon school
10 | american impressionism
11 | american realism
12 | american romanticism
13 | american scene painting
14 | analytical art
15 | antipodeans
16 | arabesque
17 | arbeitsrat für kunst
18 | art & language
19 | art brut
20 | art deco
21 | art informel
22 | art nouveau
23 | art photography
24 | arte povera
25 | arts and crafts movement
26 | ascii art
27 | ashcan school
28 | assemblage
29 | australian tonalism
30 | auto-destructive art
31 | barbizon school
32 | baroque
33 | bauhaus
34 | bengal school of art
35 | berlin secession
36 | black arts movement
37 | brutalism
38 | classical realism
39 | cloisonnism
40 | cobra
41 | color field
42 | computer art
43 | conceptual art
44 | concrete art
45 | constructivism
46 | context art
47 | crayon art
48 | crystal cubism
49 | cubism
50 | cubo-futurism
51 | cynical realism
52 | dada
53 | danube school
54 | dau-al-set
55 | de stijl
56 | deconstructivism
57 | digital art
58 | ecological art
59 | environmental art
60 | excessivism
61 | expressionism
62 | fantastic realism
63 | fantasy art
64 | fauvism
65 | feminist art
66 | figuration libre
67 | figurative art
68 | figurativism
69 | fine art
70 | fluxus
71 | folk art
72 | funk art
73 | furry art
74 | futurism
75 | generative art
76 | geometric abstract art
77 | german romanticism
78 | gothic art
79 | graffiti
80 | gutai group
81 | happening
82 | harlem renaissance
83 | heidelberg school
84 | holography
85 | hudson river school
86 | hurufiyya
87 | hypermodernism
88 | hyperrealism
89 | impressionism
90 | incoherents
91 | institutional critique
92 | interactive art
93 | international gothic
94 | international typographic style
95 | kinetic art
96 | kinetic pointillism
97 | kitsch movement
98 | land art
99 | les automatistes
100 | les nabis
101 | letterism
102 | light and space
103 | lowbrow
104 | lyco art
105 | lyrical abstraction
106 | magic realism
107 | magical realism
108 | mail art
109 | mannerism
110 | massurrealism
111 | maximalism
112 | metaphysical painting
113 | mingei
114 | minimalism
115 | modern european ink painting
116 | modernism
117 | modular constructivism
118 | naive art
119 | naturalism
120 | neo-dada
121 | neo-expressionism
122 | neo-fauvism
123 | neo-figurative
124 | neo-primitivism
125 | neo-romanticism
126 | neoclassicism
127 | neogeo
128 | neoism
129 | neoplasticism
130 | net art
131 | new objectivity
132 | new sculpture
133 | northwest school
134 | nuclear art
135 | objective abstraction
136 | op art
137 | optical illusion
138 | orphism
139 | panfuturism
140 | paris school
141 | photorealism
142 | pixel art
143 | plasticien
144 | plein air
145 | pointillism
146 | pop art
147 | pop surrealism
148 | post-impressionism
149 | postminimalism
150 | pre-raphaelitism
151 | precisionism
152 | primitivism
153 | private press
154 | process art
155 | psychedelic art
156 | purism
157 | qajar art
158 | quito school
159 | rasquache
160 | rayonism
161 | realism
162 | regionalism
163 | remodernism
164 | renaissance
165 | retrofuturism
166 | rococo
167 | romanesque
168 | romanticism
169 | samikshavad
170 | serial art
171 | shin hanga
172 | shock art
173 | socialist realism
174 | sots art
175 | space art
176 | street art
177 | stuckism
178 | sumatraism
179 | superflat
180 | suprematism
181 | surrealism
182 | symbolism
183 | synchromism
184 | synthetism
185 | sōsaku hanga
186 | tachisme
187 | temporary art
188 | tonalism
189 | toyism
190 | transgressive art
191 | ukiyo-e
192 | underground comix
193 | unilalianism
194 | vancouver school
195 | vanitas
196 | verdadism
197 | video art
198 | viennese actionism
199 | visual art
200 | vorticism
201 |
--------------------------------------------------------------------------------
/diffuzers/gfp_gan.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import os
3 | import shutil
4 | from dataclasses import dataclass
5 | from io import BytesIO
6 | from typing import Optional
7 |
8 | import cv2
9 | import numpy as np
10 | import streamlit as st
11 | import torch
12 | from basicsr.archs.srvgg_arch import SRVGGNetCompact
13 | from gfpgan.utils import GFPGANer
14 | from loguru import logger
15 | from PIL import Image
16 | from realesrgan.utils import RealESRGANer
17 |
18 | from diffuzers import utils
19 |
20 |
21 | @dataclass
22 | class GFPGAN:
23 | device: Optional[str] = None
24 | output_path: Optional[str] = None
25 |
26 | def __str__(self) -> str:
27 | return f"GFPGAN(device={self.device}, output_path={self.output_path})"
28 |
29 | def __post_init__(self):
30 | files = {
31 | "realesr-general-x4v3.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
32 | "v1.2": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth",
33 | "v1.3": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
34 | "v1.4": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
35 | "RestoreFormer": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth",
36 | "CodeFormer": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth",
37 | }
38 | _ = utils.cache_folder()
39 | self.model_paths = {}
40 | for file_key, file in files.items():
41 | logger.info(f"Downloading {file_key} from {file}")
42 | basename = os.path.basename(file)
43 | output_path = os.path.join(utils.cache_folder(), basename)
44 | if os.path.exists(output_path):
45 | self.model_paths[file_key] = output_path
46 | continue
47 | temp_file = utils.download_file(file)
48 | shutil.move(temp_file, output_path)
49 | self.model_paths[file_key] = output_path
50 |
51 | self.model = SRVGGNetCompact(
52 | num_in_ch=3,
53 | num_out_ch=3,
54 | num_feat=64,
55 | num_conv=32,
56 | upscale=4,
57 | act_type="prelu",
58 | )
59 | model_path = os.path.join(utils.cache_folder(), self.model_paths["realesr-general-x4v3.pth"])
60 | half = True if torch.cuda.is_available() else False
61 | self.upsampler = RealESRGANer(
62 | scale=4,
63 | model_path=model_path,
64 | model=self.model,
65 | tile=0,
66 | tile_pad=10,
67 | pre_pad=0,
68 | half=half,
69 | )
70 |
71 | def inference(self, img, version, scale):
72 | # taken from: https://huggingface.co/spaces/Xintao/GFPGAN/blob/main/app.py
73 | # weight /= 100
74 | if scale > 4:
75 | scale = 4 # avoid too large scale value
76 |
77 | file_bytes = np.asarray(bytearray(img.read()), dtype=np.uint8)
78 | img = cv2.imdecode(file_bytes, 1)
79 | # img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
80 | if len(img.shape) == 2: # for gray inputs
81 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
82 |
83 | h, w = img.shape[0:2]
84 | if h < 300:
85 | img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
86 |
87 | if version == "v1.2":
88 | face_enhancer = GFPGANer(
89 | model_path=self.model_paths["v1.2"],
90 | upscale=2,
91 | arch="clean",
92 | channel_multiplier=2,
93 | bg_upsampler=self.upsampler,
94 | )
95 | elif version == "v1.3":
96 | face_enhancer = GFPGANer(
97 | model_path=self.model_paths["v1.3"],
98 | upscale=2,
99 | arch="clean",
100 | channel_multiplier=2,
101 | bg_upsampler=self.upsampler,
102 | )
103 | elif version == "v1.4":
104 | face_enhancer = GFPGANer(
105 | model_path=self.model_paths["v1.4"],
106 | upscale=2,
107 | arch="clean",
108 | channel_multiplier=2,
109 | bg_upsampler=self.upsampler,
110 | )
111 | elif version == "RestoreFormer":
112 | face_enhancer = GFPGANer(
113 | model_path=self.model_paths["RestoreFormer"],
114 | upscale=2,
115 | arch="RestoreFormer",
116 | channel_multiplier=2,
117 | bg_upsampler=self.upsampler,
118 | )
119 | # elif version == 'CodeFormer':
120 | # face_enhancer = GFPGANer(
121 | # model_path='CodeFormer.pth', upscale=2, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler)
122 | try:
123 | # _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True, weight=weight)
124 | _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
125 | except RuntimeError as error:
126 | logger.error("Error", error)
127 |
128 | try:
129 | if scale != 2:
130 | interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
131 | h, w = img.shape[0:2]
132 | output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
133 | except Exception as error:
134 | logger.error("wrong scale input.", error)
135 |
136 | output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
137 | return output
138 |
139 | def app(self):
140 | input_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
141 | if input_image is not None:
142 | st.image(input_image, width=512)
143 | with st.form(key="gfpgan"):
144 | version = st.selectbox("GFPGAN version", ["v1.2", "v1.3", "v1.4", "RestoreFormer"])
145 | scale = st.slider("Scale", 2, 4, 4, 1)
146 | submit = st.form_submit_button("Upscale")
147 | if submit:
148 | if input_image is not None:
149 | with st.spinner("Upscaling image..."):
150 | output_img = self.inference(input_image, version, scale)
151 | st.image(output_img, width=512)
152 | # add image download button
153 | output_img = Image.fromarray(output_img)
154 | buffered = BytesIO()
155 | output_img.save(buffered, format="PNG")
156 | img_str = base64.b64encode(buffered.getvalue()).decode()
157 | href = f'Download Image'
158 | st.markdown(href, unsafe_allow_html=True)
159 |
--------------------------------------------------------------------------------
/diffuzers/gradio_app.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass
3 | from typing import Optional
4 |
5 | import gradio as gr
6 | import torch
7 | from PIL.PngImagePlugin import PngInfo
8 |
9 | from .img2img import Img2Img
10 | from .text2img import Text2Image
11 |
12 |
13 | @dataclass
14 | class Diffuzers:
15 | model: str
16 | output_path: str
17 | img2img_model: Optional[str] = None
18 | inpainting_model: Optional[str] = None
19 |
20 | def __post_init__(self):
21 | device = "cuda" if torch.cuda.is_available() else "cpu"
22 | self.text2img = Text2Image(
23 | model=self.model,
24 | device=device,
25 | )
26 | self.img2img = Img2Img(
27 | model=self.img2img_model,
28 | device=device,
29 | text2img_model=self.text2img.pipeline,
30 | )
31 | os.makedirs(self.output_path, exist_ok=True)
32 | os.makedirs(os.path.join(self.output_path, "text2img"), exist_ok=True)
33 | os.makedirs(os.path.join(self.output_path, "img2img"), exist_ok=True)
34 |
35 | def _text2img_input(self):
36 | with gr.Column():
37 | project_name_text = "Project name (optional; used to save the images, if provided)"
38 | project_name = gr.Textbox(label=project_name_text, lines=1, max_lines=1, placeholder="my_project")
39 | # TODO: add batch support
40 | # with gr.Tabs():
41 | # with gr.TabItem("single"):
42 | prompt = gr.Textbox(label="Prompt", lines=3, max_lines=3)
43 | # with gr.TabItem("batch"):
44 | # prompt = gr.File(file_types=["text"])
45 | # with gr.Tabs():
46 | # with gr.TabItem("single"):
47 | negative_prompt = gr.Textbox(label="Negative prompt (optional)", lines=3, max_lines=3)
48 | # with gr.TabItem("batch"):
49 | # negative_prompt = gr.File(file_types=["text"])
50 | with gr.Column():
51 | available_schedulers = list(self.text2img.compatible_schedulers.keys())
52 | scheduler = gr.Dropdown(choices=available_schedulers, label="Scheduler", value=available_schedulers[0])
53 | image_size = gr.Number(512, label="Image size", precision=0)
54 | guidance_scale = gr.Slider(1, maximum=20, value=7.5, step=0.5, label="Guidance scale")
55 | num_images = gr.Slider(1, 128, 1, label="Number of images per prompt", step=4)
56 | steps = gr.Slider(1, 150, 50, label="Steps")
57 | seed = gr.Slider(minimum=1, step=1, maximum=999999, randomize=True, label="Seed")
58 | generate_button = gr.Button("Generate")
59 | params_dict = {
60 | "prompt": prompt,
61 | "negative_prompt": negative_prompt,
62 | "scheduler": scheduler,
63 | "image_size": image_size,
64 | "guidance_scale": guidance_scale,
65 | "num_images": num_images,
66 | "steps": steps,
67 | "seed": seed,
68 | "project_name": project_name,
69 | "generate_button": generate_button,
70 | }
71 | return params_dict
72 |
73 | def _text2img_output(self):
74 | with gr.Column():
75 | text2img_output = gr.Gallery()
76 | text2img_output.style(grid=[4], container=False)
77 | with gr.Column():
78 | text2img_output_params = gr.Markdown()
79 | params_dict = {
80 | "output": text2img_output,
81 | "markdown": text2img_output_params,
82 | }
83 | return params_dict
84 |
85 | def _text2img_generate(
86 | self, prompt, negative_prompt, scheduler, image_size, guidance_scale, num_images, steps, seed, project_name
87 | ):
88 | output_images = self.text2img.generate_image(
89 | prompt=prompt,
90 | negative_prompt=negative_prompt,
91 | scheduler=scheduler,
92 | image_size=image_size,
93 | guidance_scale=guidance_scale,
94 | steps=steps,
95 | seed=seed,
96 | num_images=num_images,
97 | )
98 | params_used = f" - **Prompt:** {prompt}"
99 | params_used += f"\n - **Negative prompt:** {negative_prompt}"
100 | params_used += f"\n - **Scheduler:** {scheduler}"
101 | params_used += f"\n - **Image size:** {image_size}"
102 | params_used += f"\n - **Guidance scale:** {guidance_scale}"
103 | params_used += f"\n - **Steps:** {steps}"
104 | params_used += f"\n - **Seed:** {seed}"
105 |
106 | if len(project_name.strip()) > 0:
107 | self._save_images(
108 | images=output_images,
109 | metadata=params_used,
110 | folder_name=project_name,
111 | prefix="text2img",
112 | )
113 |
114 | return [output_images, params_used]
115 |
116 | def _img2img_input(self):
117 | with gr.Column():
118 | input_image = gr.Image(source="upload", label="input image | size must match model", type="pil")
119 | strength = gr.Slider(0, 1, 0.8, label="Strength")
120 | available_schedulers = list(self.img2img.compatible_schedulers.keys())
121 | scheduler = gr.Dropdown(choices=available_schedulers, label="Scheduler", value=available_schedulers[0])
122 | image_size = gr.Number(512, label="Image size (image will be resized to this)", precision=0)
123 | guidance_scale = gr.Slider(1, maximum=20, value=7.5, step=0.5, label="Guidance scale")
124 | num_images = gr.Slider(4, 128, 4, label="Number of images", step=4)
125 | steps = gr.Slider(1, 150, 50, label="Steps")
126 | with gr.Column():
127 | project_name_text = "Project name (optional; used to save the images, if provided)"
128 | project_name = gr.Textbox(label=project_name_text, lines=1, max_lines=1, placeholder="my_project")
129 | prompt = gr.Textbox(label="Prompt", lines=3, max_lines=3)
130 | negative_prompt = gr.Textbox(label="Negative prompt (optional)", lines=3, max_lines=3)
131 | seed = gr.Slider(minimum=1, step=1, maximum=999999, randomize=True, label="Seed")
132 | generate_button = gr.Button("Generate")
133 | params_dict = {
134 | "input_image": input_image,
135 | "prompt": prompt,
136 | "negative_prompt": negative_prompt,
137 | "strength": strength,
138 | "scheduler": scheduler,
139 | "image_size": image_size,
140 | "guidance_scale": guidance_scale,
141 | "num_images": num_images,
142 | "steps": steps,
143 | "seed": seed,
144 | "project_name": project_name,
145 | "generate_button": generate_button,
146 | }
147 | return params_dict
148 |
149 | def _img2img_output(self):
150 | with gr.Column():
151 | img2img_output = gr.Gallery()
152 | img2img_output.style(grid=[4], container=False)
153 | with gr.Column():
154 | img2img_output_params = gr.Markdown()
155 | params_dict = {
156 | "output": img2img_output,
157 | "markdown": img2img_output_params,
158 | }
159 | return params_dict
160 |
161 | def _img2img_generate(
162 | self,
163 | input_image,
164 | prompt,
165 | negative_prompt,
166 | strength,
167 | scheduler,
168 | image_size,
169 | guidance_scale,
170 | num_images,
171 | steps,
172 | seed,
173 | project_name,
174 | ):
175 | output_images = self.img2img.generate_image(
176 | image=input_image,
177 | prompt=prompt,
178 | negative_prompt=negative_prompt,
179 | strength=strength,
180 | scheduler=scheduler,
181 | image_size=image_size,
182 | guidance_scale=guidance_scale,
183 | steps=steps,
184 | seed=seed,
185 | num_images=num_images,
186 | )
187 | params_used = f" - **Prompt:** {prompt}"
188 | params_used += f"\n - **Negative prompt:** {negative_prompt}"
189 | params_used += f"\n - **Scheduler:** {scheduler}"
190 | params_used += f"\n - **Strength:** {strength}"
191 | params_used += f"\n - **Image size:** {image_size}"
192 | params_used += f"\n - **Guidance scale:** {guidance_scale}"
193 | params_used += f"\n - **Steps:** {steps}"
194 | params_used += f"\n - **Seed:** {seed}"
195 |
196 | if len(project_name.strip()) > 0:
197 | self._save_images(
198 | images=output_images,
199 | metadata=params_used,
200 | folder_name=project_name,
201 | prefix="img2img",
202 | )
203 |
204 | return [output_images, params_used]
205 |
206 | def _save_images(self, images, metadata, folder_name, prefix):
207 | folder_path = os.path.join(self.output_path, prefix, folder_name)
208 | os.makedirs(folder_path, exist_ok=True)
209 | for idx, image in enumerate(images):
210 | text2img_metadata = PngInfo()
211 | text2img_metadata.add_text(prefix, metadata)
212 | image.save(os.path.join(folder_path, f"{idx}.png"), format="PNG", pnginfo=text2img_metadata)
213 | with open(os.path.join(folder_path, "metadata.txt"), "w") as f:
214 | f.write(metadata)
215 |
216 | def _png_info(self, img):
217 | text2img_md = img.info.get("text2img", "")
218 | img2img_md = img.info.get("img2img", "")
219 | return_text = ""
220 | if len(text2img_md) > 0:
221 | return_text += f"## Text2Img\n{text2img_md}\n"
222 | if len(img2img_md) > 0:
223 | return_text += f"## Img2Img\n{img2img_md}\n"
224 | return return_text
225 |
226 | def app(self):
227 | with gr.Blocks() as demo:
228 | with gr.Blocks():
229 | gr.Markdown("# Diffuzers")
230 | gr.Markdown(f"Text2Img Model: {self.model}")
231 | if self.img2img_model:
232 | gr.Markdown(f"Img2Img Model: {self.img2img_model}")
233 | else:
234 | gr.Markdown(f"Img2Img Model: {self.model}")
235 |
236 | with gr.Tabs():
237 | with gr.TabItem("Text2Image", id="text2image"):
238 | with gr.Row():
239 | text2img_input = self._text2img_input()
240 | with gr.Row():
241 | text2img_output = self._text2img_output()
242 | text2img_input["generate_button"].click(
243 | fn=self._text2img_generate,
244 | inputs=[
245 | text2img_input["prompt"],
246 | text2img_input["negative_prompt"],
247 | text2img_input["scheduler"],
248 | text2img_input["image_size"],
249 | text2img_input["guidance_scale"],
250 | text2img_input["num_images"],
251 | text2img_input["steps"],
252 | text2img_input["seed"],
253 | text2img_input["project_name"],
254 | ],
255 | outputs=[text2img_output["output"], text2img_output["markdown"]],
256 | )
257 | with gr.TabItem("Image2Image", id="img2img"):
258 | with gr.Row():
259 | img2img_input = self._img2img_input()
260 | with gr.Row():
261 | img2img_output = self._img2img_output()
262 | img2img_input["generate_button"].click(
263 | fn=self._img2img_generate,
264 | inputs=[
265 | img2img_input["input_image"],
266 | img2img_input["prompt"],
267 | img2img_input["negative_prompt"],
268 | img2img_input["strength"],
269 | img2img_input["scheduler"],
270 | img2img_input["image_size"],
271 | img2img_input["guidance_scale"],
272 | img2img_input["num_images"],
273 | img2img_input["steps"],
274 | img2img_input["seed"],
275 | img2img_input["project_name"],
276 | ],
277 | outputs=[img2img_output["output"], img2img_output["markdown"]],
278 | )
279 | with gr.TabItem("Inpainting", id="inpainting"):
280 | gr.Markdown("# coming soon!")
281 | with gr.TabItem("ImageInfo", id="imginfo"):
282 | with gr.Column():
283 | img_info_input_file = gr.Image(label="Input image", source="upload", type="pil")
284 | with gr.Column():
285 | img_info_output_md = gr.Markdown()
286 | img_info_input_file.change(
287 | fn=self._png_info,
288 | inputs=[img_info_input_file],
289 | outputs=[img_info_output_md],
290 | )
291 |
292 | return demo
293 |
--------------------------------------------------------------------------------
/diffuzers/image_info.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import streamlit as st
4 | from PIL import Image
5 |
6 |
7 | @dataclass
8 | class ImageInfo:
9 | def app(self):
10 | # upload image
11 | uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
12 | if uploaded_file is not None:
13 | # read image using pil
14 | pil_image = Image.open(uploaded_file)
15 | st.image(uploaded_file, use_column_width=True)
16 | image_info = pil_image.info
17 | # display image info
18 | st.write(image_info)
19 |
--------------------------------------------------------------------------------
/diffuzers/img2img.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import json
3 | from dataclasses import dataclass
4 | from io import BytesIO
5 | from typing import Optional, Union
6 |
7 | import requests
8 | import streamlit as st
9 | import torch
10 | from diffusers import (
11 | AltDiffusionImg2ImgPipeline,
12 | AltDiffusionPipeline,
13 | DiffusionPipeline,
14 | StableDiffusionImg2ImgPipeline,
15 | StableDiffusionPipeline,
16 | )
17 | from loguru import logger
18 | from PIL import Image
19 | from PIL.PngImagePlugin import PngInfo
20 |
21 | from diffuzers import utils
22 |
23 |
24 | @dataclass
25 | class Img2Img:
26 | model: Optional[str] = None
27 | device: Optional[str] = None
28 | output_path: Optional[str] = None
29 | text2img_model: Optional[Union[StableDiffusionPipeline, AltDiffusionPipeline]] = None
30 |
31 | def __str__(self) -> str:
32 | return f"Img2Img(model={self.model}, device={self.device}, output_path={self.output_path})"
33 |
34 | def __post_init__(self):
35 | if self.model is not None:
36 | self.text2img_model = DiffusionPipeline.from_pretrained(
37 | self.model,
38 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
39 | )
40 | components = self.text2img_model.components
41 | print(components)
42 | if isinstance(self.text2img_model, StableDiffusionPipeline):
43 | self.pipeline = StableDiffusionImg2ImgPipeline(**components)
44 | elif isinstance(self.text2img_model, AltDiffusionPipeline):
45 | self.pipeline = AltDiffusionImg2ImgPipeline(**components)
46 | else:
47 | raise ValueError("Model type not supported")
48 |
49 | self.pipeline.to(self.device)
50 | self.pipeline.safety_checker = utils.no_safety_checker
51 | self._compatible_schedulers = self.pipeline.scheduler.compatibles
52 | self.scheduler_config = self.pipeline.scheduler.config
53 | self.compatible_schedulers = {scheduler.__name__: scheduler for scheduler in self._compatible_schedulers}
54 |
55 | if self.device == "mps":
56 | self.pipeline.enable_attention_slicing()
57 | # warmup
58 | url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
59 | response = requests.get(url)
60 | init_image = Image.open(BytesIO(response.content)).convert("RGB")
61 | init_image.thumbnail((768, 768))
62 | prompt = "A fantasy landscape, trending on artstation"
63 | _ = self.pipeline(
64 | prompt=prompt,
65 | image=init_image,
66 | strength=0.75,
67 | guidance_scale=7.5,
68 | num_inference_steps=2,
69 | )
70 |
71 | def _set_scheduler(self, scheduler_name):
72 | scheduler = self.compatible_schedulers[scheduler_name].from_config(self.scheduler_config)
73 | self.pipeline.scheduler = scheduler
74 |
75 | def generate_image(
76 | self, prompt, image, strength, negative_prompt, scheduler, num_images, guidance_scale, steps, seed
77 | ):
78 | self._set_scheduler(scheduler)
79 | logger.info(self.pipeline.scheduler)
80 | if self.device == "mps":
81 | generator = torch.manual_seed(seed)
82 | num_images = 1
83 | else:
84 | generator = torch.Generator(device=self.device).manual_seed(seed)
85 | num_images = int(num_images)
86 | output_images = self.pipeline(
87 | prompt=prompt,
88 | image=image,
89 | strength=strength,
90 | negative_prompt=negative_prompt,
91 | num_inference_steps=steps,
92 | guidance_scale=guidance_scale,
93 | num_images_per_prompt=num_images,
94 | generator=generator,
95 | ).images
96 | torch.cuda.empty_cache()
97 | gc.collect()
98 | metadata = {
99 | "prompt": prompt,
100 | "negative_prompt": negative_prompt,
101 | "scheduler": scheduler,
102 | "num_images": num_images,
103 | "guidance_scale": guidance_scale,
104 | "steps": steps,
105 | "seed": seed,
106 | }
107 | metadata = json.dumps(metadata)
108 | _metadata = PngInfo()
109 | _metadata.add_text("img2img", metadata)
110 |
111 | utils.save_images(
112 | images=output_images,
113 | module="img2img",
114 | metadata=metadata,
115 | output_path=self.output_path,
116 | )
117 | return output_images, _metadata
118 |
119 | def app(self):
120 | available_schedulers = list(self.compatible_schedulers.keys())
121 | # if EulerAncestralDiscreteScheduler is available in available_schedulers, move it to the first position
122 | if "EulerAncestralDiscreteScheduler" in available_schedulers:
123 | available_schedulers.insert(
124 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler"))
125 | )
126 |
127 | input_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
128 | if input_image is not None:
129 | input_image = Image.open(input_image)
130 | st.image(input_image, use_column_width=True)
131 |
132 | # with st.form(key="img2img"):
133 | col1, col2 = st.columns(2)
134 | with col1:
135 | prompt = st.text_area("Prompt", "")
136 | with col2:
137 | negative_prompt = st.text_area("Negative Prompt", "")
138 |
139 | scheduler = st.sidebar.selectbox("Scheduler", available_schedulers, index=0)
140 | guidance_scale = st.sidebar.slider("Guidance scale", 1.0, 40.0, 7.5, 0.5)
141 | strength = st.sidebar.slider("Strength", 0.0, 1.0, 0.8, 0.05)
142 | num_images = st.sidebar.slider("Number of images per prompt", 1, 30, 1, 1)
143 | steps = st.sidebar.slider("Steps", 1, 150, 50, 1)
144 | seed_placeholder = st.sidebar.empty()
145 | seed = seed_placeholder.number_input("Seed", value=42, min_value=1, max_value=999999, step=1)
146 | random_seed = st.sidebar.button("Random seed")
147 | _seed = torch.randint(1, 999999, (1,)).item()
148 | if random_seed:
149 | seed = seed_placeholder.number_input("Seed", value=_seed, min_value=1, max_value=999999, step=1)
150 | # seed = st.sidebar.number_input("Seed", 1, 999999, 1, 1)
151 | sub_col, download_col = st.columns(2)
152 | with sub_col:
153 | submit = st.button("Generate")
154 |
155 | if submit:
156 | with st.spinner("Generating images..."):
157 | output_images, metadata = self.generate_image(
158 | prompt=prompt,
159 | image=input_image,
160 | negative_prompt=negative_prompt,
161 | scheduler=scheduler,
162 | num_images=num_images,
163 | guidance_scale=guidance_scale,
164 | steps=steps,
165 | seed=seed,
166 | strength=strength,
167 | )
168 |
169 | utils.display_and_download_images(output_images, metadata, download_col)
170 |
--------------------------------------------------------------------------------
/diffuzers/inpainting.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import json
3 | import random
4 | from dataclasses import dataclass
5 | from io import BytesIO
6 | from typing import Optional
7 |
8 | import requests
9 | import streamlit as st
10 | import torch
11 | from diffusers import StableDiffusionInpaintPipeline
12 | from loguru import logger
13 | from PIL import Image
14 | from PIL.PngImagePlugin import PngInfo
15 | from streamlit_drawable_canvas import st_canvas
16 |
17 | from diffuzers import utils
18 |
19 |
20 | @dataclass
21 | class Inpainting:
22 | model: Optional[str] = None
23 | device: Optional[str] = None
24 | output_path: Optional[str] = None
25 |
26 | def __str__(self) -> str:
27 | return f"Inpainting(model={self.model}, device={self.device}, output_path={self.output_path})"
28 |
29 | def __post_init__(self):
30 | self.pipeline = StableDiffusionInpaintPipeline.from_pretrained(
31 | self.model,
32 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
33 | use_auth_token=utils.use_auth_token(),
34 | )
35 |
36 | self.pipeline.to(self.device)
37 | self.pipeline.safety_checker = utils.no_safety_checker
38 | self._compatible_schedulers = self.pipeline.scheduler.compatibles
39 | self.scheduler_config = self.pipeline.scheduler.config
40 | self.compatible_schedulers = {scheduler.__name__: scheduler for scheduler in self._compatible_schedulers}
41 |
42 | if self.device == "mps":
43 | self.pipeline.enable_attention_slicing()
44 | # warmup
45 |
46 | def download_image(url):
47 | response = requests.get(url)
48 | return Image.open(BytesIO(response.content)).convert("RGB")
49 |
50 | img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
51 | mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
52 |
53 | init_image = download_image(img_url).resize((512, 512))
54 | mask_image = download_image(mask_url).resize((512, 512))
55 |
56 | prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
57 | _ = self.pipeline(
58 | prompt=prompt,
59 | image=init_image,
60 | mask_image=mask_image,
61 | num_inference_steps=2,
62 | )
63 |
64 | def _set_scheduler(self, scheduler_name):
65 | scheduler = self.compatible_schedulers[scheduler_name].from_config(self.scheduler_config)
66 | self.pipeline.scheduler = scheduler
67 |
68 | def generate_image(
69 | self, prompt, negative_prompt, image, mask, guidance_scale, scheduler, steps, seed, height, width, num_images
70 | ):
71 |
72 | if seed == -1:
73 | # generate random seed
74 | seed = random.randint(0, 999999)
75 |
76 | self._set_scheduler(scheduler)
77 | logger.info(self.pipeline.scheduler)
78 |
79 | if self.device == "mps":
80 | generator = torch.manual_seed(seed)
81 | num_images = 1
82 | else:
83 | generator = torch.Generator(device=self.device).manual_seed(seed)
84 |
85 | output_images = self.pipeline(
86 | prompt=prompt,
87 | negative_prompt=negative_prompt,
88 | image=image,
89 | mask_image=mask,
90 | num_inference_steps=steps,
91 | guidance_scale=guidance_scale,
92 | num_images_per_prompt=num_images,
93 | generator=generator,
94 | height=height,
95 | width=width,
96 | ).images
97 | metadata = {
98 | "prompt": prompt,
99 | "negative_prompt": negative_prompt,
100 | "guidance_scale": guidance_scale,
101 | "scheduler": scheduler,
102 | "steps": steps,
103 | "seed": seed,
104 | }
105 | metadata = json.dumps(metadata)
106 | _metadata = PngInfo()
107 | _metadata.add_text("inpainting", metadata)
108 |
109 | utils.save_images(
110 | images=output_images,
111 | module="inpainting",
112 | metadata=metadata,
113 | output_path=self.output_path,
114 | )
115 |
116 | torch.cuda.empty_cache()
117 | gc.collect()
118 | return output_images, _metadata
119 |
120 | def app(self):
121 | stroke_color = "#FFF"
122 | bg_color = "#000"
123 | col1, col2 = st.columns(2)
124 | # with col1:
125 | with col1:
126 | prompt = st.text_area("Prompt", "", key="inpainting_prompt", help="Prompt for the image generation")
127 | # with col2:
128 | negative_prompt = st.text_area(
129 | "Negative Prompt",
130 | "",
131 | key="inpainting_negative_prompt",
132 | help="The prompt not to guide image generation. Write things that you dont want to see in the image.",
133 | )
134 | with col2:
135 | uploaded_file = st.file_uploader(
136 | "Image:",
137 | type=["png", "jpg", "jpeg"],
138 | help="Image size must match model's image size. Usually: 512 or 768",
139 | key="inpainting_uploaded_file",
140 | )
141 |
142 | # sidebar options
143 | available_schedulers = list(self.compatible_schedulers.keys())
144 | if "EulerAncestralDiscreteScheduler" in available_schedulers:
145 | available_schedulers.insert(
146 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler"))
147 | )
148 | scheduler = st.sidebar.selectbox(
149 | "Scheduler",
150 | available_schedulers,
151 | index=0,
152 | key="inpainting_scheduler",
153 | help="Scheduler to use for generation",
154 | )
155 | guidance_scale = st.sidebar.slider(
156 | "Guidance scale",
157 | 1.0,
158 | 40.0,
159 | 7.5,
160 | 0.5,
161 | key="inpainting_guidance_scale",
162 | help="Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.",
163 | )
164 | num_images = st.sidebar.slider(
165 | "Number of images per prompt",
166 | 1,
167 | 30,
168 | 1,
169 | 1,
170 | key="inpainting_num_images",
171 | help="Number of images you want to generate. More images requires more time and uses more GPU memory.",
172 | )
173 | steps = st.sidebar.slider(
174 | "Steps",
175 | 1,
176 | 150,
177 | 50,
178 | 1,
179 | key="inpainting_steps",
180 | help="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.",
181 | )
182 | seed = st.sidebar.number_input(
183 | "Seed",
184 | value=42,
185 | min_value=-1,
186 | max_value=999999,
187 | step=1,
188 | help="Random seed. Change for different results using same parameters.",
189 | )
190 |
191 | if uploaded_file is not None:
192 | with col2:
193 | drawing_mode = st.selectbox(
194 | "Drawing tool:", ("freedraw", "rect", "circle"), key="inpainting_drawing_mode"
195 | )
196 | stroke_width = st.slider("Stroke width: ", 1, 25, 8, key="inpainting_stroke_width")
197 |
198 | pil_image = Image.open(uploaded_file).convert("RGB")
199 | img_height, img_width = pil_image.size
200 | canvas_result = st_canvas(
201 | fill_color="rgb(255, 255, 255)",
202 | stroke_width=stroke_width,
203 | stroke_color=stroke_color,
204 | background_color=bg_color,
205 | background_image=pil_image,
206 | update_streamlit=True,
207 | drawing_mode=drawing_mode,
208 | height=768,
209 | width=768,
210 | key="inpainting_canvas",
211 | )
212 |
213 | with col1:
214 | submit = st.button("Generate", key="inpainting_submit")
215 | if (
216 | canvas_result.image_data is not None
217 | and pil_image
218 | and len(canvas_result.json_data["objects"]) > 0
219 | and submit
220 | ):
221 | mask_npy = canvas_result.image_data[:, :, 3]
222 | # convert mask npy to PIL image
223 | mask_pil = Image.fromarray(mask_npy).convert("RGB")
224 | # resize mask to match image size
225 | mask_pil = mask_pil.resize((img_width, img_height), resample=Image.LANCZOS)
226 | with st.spinner("Generating..."):
227 | output_images, metadata = self.generate_image(
228 | prompt=prompt,
229 | negative_prompt=negative_prompt,
230 | image=pil_image,
231 | mask=mask_pil,
232 | guidance_scale=guidance_scale,
233 | scheduler=scheduler,
234 | steps=steps,
235 | seed=seed,
236 | height=img_height,
237 | width=img_width,
238 | num_images=num_images,
239 | )
240 |
241 | utils.display_and_download_images(output_images, metadata)
242 |
--------------------------------------------------------------------------------
/diffuzers/interrogator.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass
3 | from typing import Optional
4 |
5 | import streamlit as st
6 | from clip_interrogator import Config, Interrogator
7 | from huggingface_hub import hf_hub_download
8 | from loguru import logger
9 | from PIL import Image
10 |
11 | from diffuzers import utils
12 |
13 |
14 | @dataclass
15 | class ImageInterrogator:
16 | device: Optional[str] = None
17 | output_path: Optional[str] = None
18 |
19 | def inference(self, model, image, mode):
20 | preprocess_files = [
21 | "ViT-H-14_laion2b_s32b_b79k_artists.pkl",
22 | "ViT-H-14_laion2b_s32b_b79k_flavors.pkl",
23 | "ViT-H-14_laion2b_s32b_b79k_mediums.pkl",
24 | "ViT-H-14_laion2b_s32b_b79k_movements.pkl",
25 | "ViT-H-14_laion2b_s32b_b79k_trendings.pkl",
26 | "ViT-L-14_openai_artists.pkl",
27 | "ViT-L-14_openai_flavors.pkl",
28 | "ViT-L-14_openai_mediums.pkl",
29 | "ViT-L-14_openai_movements.pkl",
30 | "ViT-L-14_openai_trendings.pkl",
31 | ]
32 |
33 | logger.info("Downloading preprocessed cache files...")
34 | for file in preprocess_files:
35 | path = hf_hub_download(repo_id="pharma/ci-preprocess", filename=file, cache_dir=utils.cache_folder())
36 | cache_path = os.path.dirname(path)
37 |
38 | config = Config(cache_path=cache_path, clip_model_path=utils.cache_folder(), clip_model_name=model)
39 | pipeline = Interrogator(config)
40 |
41 | pipeline.config.blip_num_beams = 64
42 | pipeline.config.chunk_size = 2048
43 | pipeline.config.flavor_intermediate_count = 2048 if model == "ViT-L-14/openai" else 1024
44 |
45 | if mode == "best":
46 | prompt = pipeline.interrogate(image)
47 | elif mode == "classic":
48 | prompt = pipeline.interrogate_classic(image)
49 | else:
50 | prompt = pipeline.interrogate_fast(image)
51 | return prompt
52 |
53 | def app(self):
54 | # upload image
55 | input_image = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
56 | with st.form(key="image_interrogator"):
57 | clip_model = st.selectbox("CLIP Model", ["ViT-L (Best for SD1.X)", "ViT-H (Best for SD2.X)"])
58 | mode = st.selectbox("Mode", ["Best", "Classic"])
59 | submit = st.form_submit_button("Interrogate")
60 | if input_image is not None:
61 | # read image using pil
62 | pil_image = Image.open(input_image).convert("RGB")
63 | if submit:
64 | with st.spinner("Interrogating image..."):
65 | if clip_model == "ViT-L (Best for SD1.X)":
66 | model = "ViT-L-14/openai"
67 | else:
68 | model = "ViT-H-14/laion2b_s32b_b79k"
69 | prompt = self.inference(model, pil_image, mode.lower())
70 | col1, col2 = st.columns(2)
71 | with col1:
72 | st.image(input_image, use_column_width=True)
73 | with col2:
74 | st.write(prompt)
75 |
--------------------------------------------------------------------------------
/diffuzers/pages/1_Inpainting.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | from diffuzers import utils
4 | from diffuzers.inpainting import Inpainting
5 |
6 |
7 | def app():
8 | utils.create_base_page()
9 | with st.form("inpainting_model_form"):
10 | model = st.text_input(
11 | "Which model do you want to use for inpainting?",
12 | value="runwayml/stable-diffusion-inpainting"
13 | if st.session_state.get("inpainting_model") is None
14 | else st.session_state.inpainting_model,
15 | )
16 | submit = st.form_submit_button("Load model")
17 | if submit:
18 | st.session_state.inpainting_model = model
19 | with st.spinner("Loading model..."):
20 | inpainting = Inpainting(
21 | model=model,
22 | device=st.session_state.device,
23 | output_path=st.session_state.output_path,
24 | )
25 | st.session_state.inpainting = inpainting
26 | if "inpainting" in st.session_state:
27 | st.write(f"Current model: {st.session_state.inpainting}")
28 | st.session_state.inpainting.app()
29 |
30 |
31 | if __name__ == "__main__":
32 | app()
33 |
--------------------------------------------------------------------------------
/diffuzers/pages/2_Utilities.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | from diffuzers import utils
4 | from diffuzers.gfp_gan import GFPGAN
5 | from diffuzers.image_info import ImageInfo
6 | from diffuzers.interrogator import ImageInterrogator
7 | from diffuzers.upscaler import Upscaler
8 |
9 |
10 | def app():
11 | utils.create_base_page()
12 | task = st.selectbox(
13 | "Choose a utility",
14 | [
15 | "ImageInfo",
16 | "SD Upscaler",
17 | "GFPGAN",
18 | "CLIP Interrogator",
19 | ],
20 | )
21 | if task == "ImageInfo":
22 | ImageInfo().app()
23 | elif task == "SD Upscaler":
24 | with st.form("upscaler_model"):
25 | upscaler_model = st.text_input("Model", "stabilityai/stable-diffusion-x4-upscaler")
26 | submit = st.form_submit_button("Load model")
27 | if submit:
28 | with st.spinner("Loading model..."):
29 | ups = Upscaler(
30 | model=upscaler_model,
31 | device=st.session_state.device,
32 | output_path=st.session_state.output_path,
33 | )
34 | st.session_state.ups = ups
35 | if "ups" in st.session_state:
36 | st.write(f"Current model: {st.session_state.ups}")
37 | st.session_state.ups.app()
38 |
39 | elif task == "GFPGAN":
40 | with st.spinner("Loading model..."):
41 | gfpgan = GFPGAN(
42 | device=st.session_state.device,
43 | output_path=st.session_state.output_path,
44 | )
45 | gfpgan.app()
46 | elif task == "CLIP Interrogator":
47 | interrogator = ImageInterrogator(
48 | device=st.session_state.device,
49 | output_path=st.session_state.output_path,
50 | )
51 | interrogator.app()
52 |
53 |
54 | if __name__ == "__main__":
55 | app()
56 |
--------------------------------------------------------------------------------
/diffuzers/pages/3_FAQs.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | from diffuzers import utils
4 |
5 |
6 | def app():
7 | utils.create_base_page()
8 | st.markdown("## FAQs")
9 |
10 |
11 | if __name__ == "__main__":
12 | app()
13 |
--------------------------------------------------------------------------------
/diffuzers/pages/4_Code of Conduct.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | from diffuzers import utils
4 |
5 |
6 | def app():
7 | utils.create_base_page()
8 | st.markdown(utils.CODE_OF_CONDUCT)
9 |
10 |
11 | if __name__ == "__main__":
12 | app()
13 |
--------------------------------------------------------------------------------
/diffuzers/text2img.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import json
3 | from dataclasses import dataclass
4 | from typing import Optional
5 |
6 | import streamlit as st
7 | import torch
8 | from diffusers import DiffusionPipeline
9 | from loguru import logger
10 | from PIL.PngImagePlugin import PngInfo
11 |
12 | from diffuzers import utils
13 |
14 |
15 | @dataclass
16 | class Text2Image:
17 | device: Optional[str] = None
18 | model: Optional[str] = None
19 | output_path: Optional[str] = None
20 |
21 | def __str__(self) -> str:
22 | return f"Text2Image(model={self.model}, device={self.device}, output_path={self.output_path})"
23 |
24 | def __post_init__(self):
25 | self.pipeline = DiffusionPipeline.from_pretrained(
26 | self.model,
27 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
28 | )
29 | self.pipeline.to(self.device)
30 | self.pipeline.safety_checker = utils.no_safety_checker
31 | self._compatible_schedulers = self.pipeline.scheduler.compatibles
32 | self.scheduler_config = self.pipeline.scheduler.config
33 | self.compatible_schedulers = {scheduler.__name__: scheduler for scheduler in self._compatible_schedulers}
34 |
35 | if self.device == "mps":
36 | self.pipeline.enable_attention_slicing()
37 | # warmup
38 | prompt = "a photo of an astronaut riding a horse on mars"
39 | _ = self.pipeline(prompt, num_inference_steps=2)
40 |
41 | def _set_scheduler(self, scheduler_name):
42 | scheduler = self.compatible_schedulers[scheduler_name].from_config(self.scheduler_config)
43 | self.pipeline.scheduler = scheduler
44 |
45 | def generate_image(self, prompt, negative_prompt, scheduler, image_size, num_images, guidance_scale, steps, seed):
46 | self._set_scheduler(scheduler)
47 | logger.info(self.pipeline.scheduler)
48 | if self.device == "mps":
49 | generator = torch.manual_seed(seed)
50 | num_images = 1
51 | else:
52 | generator = torch.Generator(device=self.device).manual_seed(seed)
53 | num_images = int(num_images)
54 | output_images = self.pipeline(
55 | prompt,
56 | negative_prompt=negative_prompt,
57 | width=image_size[1],
58 | height=image_size[0],
59 | num_inference_steps=steps,
60 | guidance_scale=guidance_scale,
61 | num_images_per_prompt=num_images,
62 | generator=generator,
63 | ).images
64 | torch.cuda.empty_cache()
65 | gc.collect()
66 | metadata = {
67 | "prompt": prompt,
68 | "negative_prompt": negative_prompt,
69 | "scheduler": scheduler,
70 | "image_size": image_size,
71 | "num_images": num_images,
72 | "guidance_scale": guidance_scale,
73 | "steps": steps,
74 | "seed": seed,
75 | }
76 | metadata = json.dumps(metadata)
77 | _metadata = PngInfo()
78 | _metadata.add_text("text2img", metadata)
79 |
80 | utils.save_images(
81 | images=output_images,
82 | module="text2img",
83 | metadata=metadata,
84 | output_path=self.output_path,
85 | )
86 |
87 | return output_images, _metadata
88 |
89 | def app(self):
90 | available_schedulers = list(self.compatible_schedulers.keys())
91 | if "EulerAncestralDiscreteScheduler" in available_schedulers:
92 | available_schedulers.insert(
93 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler"))
94 | )
95 | # with st.form(key="text2img"):
96 | col1, col2 = st.columns(2)
97 | with col1:
98 | prompt = st.text_area("Prompt", "Blue elephant")
99 | with col2:
100 | negative_prompt = st.text_area("Negative Prompt", "")
101 | # sidebar options
102 | scheduler = st.sidebar.selectbox("Scheduler", available_schedulers, index=0)
103 | image_height = st.sidebar.slider("Image height", 128, 1024, 512, 128)
104 | image_width = st.sidebar.slider("Image width", 128, 1024, 512, 128)
105 | guidance_scale = st.sidebar.slider("Guidance scale", 1.0, 40.0, 7.5, 0.5)
106 | num_images = st.sidebar.slider("Number of images per prompt", 1, 30, 1, 1)
107 | steps = st.sidebar.slider("Steps", 1, 150, 50, 1)
108 |
109 | seed_placeholder = st.sidebar.empty()
110 | seed = seed_placeholder.number_input("Seed", value=42, min_value=1, max_value=999999, step=1)
111 | random_seed = st.sidebar.button("Random seed")
112 | _seed = torch.randint(1, 999999, (1,)).item()
113 | if random_seed:
114 | seed = seed_placeholder.number_input("Seed", value=_seed, min_value=1, max_value=999999, step=1)
115 |
116 | sub_col, download_col = st.columns(2)
117 | with sub_col:
118 | submit = st.button("Generate")
119 |
120 | if submit:
121 | with st.spinner("Generating images..."):
122 | output_images, metadata = self.generate_image(
123 | prompt=prompt,
124 | negative_prompt=negative_prompt,
125 | scheduler=scheduler,
126 | image_size=(image_height, image_width),
127 | num_images=num_images,
128 | guidance_scale=guidance_scale,
129 | steps=steps,
130 | seed=seed,
131 | )
132 |
133 | utils.display_and_download_images(output_images, metadata, download_col)
134 |
--------------------------------------------------------------------------------
/diffuzers/textual_inversion.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import json
3 | from dataclasses import dataclass
4 | from typing import Optional
5 |
6 | import streamlit as st
7 | import torch
8 | from diffusers import DiffusionPipeline
9 | from loguru import logger
10 | from PIL.PngImagePlugin import PngInfo
11 |
12 | from diffuzers import utils
13 |
14 |
15 | def load_embed(learned_embeds_path, text_encoder, tokenizer, token=None):
16 | loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
17 | if len(loaded_learned_embeds) > 2:
18 | embeds = loaded_learned_embeds["string_to_param"]["*"][-1, :]
19 | else:
20 | # separate token and the embeds
21 | trained_token = list(loaded_learned_embeds.keys())[0]
22 | embeds = loaded_learned_embeds[trained_token]
23 |
24 | # add the token in tokenizer
25 | token = token if token is not None else trained_token
26 | num_added_tokens = tokenizer.add_tokens(token)
27 | i = 1
28 | while num_added_tokens == 0:
29 | logger.warning(f"The tokenizer already contains the token {token}.")
30 | token = f"{token[:-1]}-{i}>"
31 | logger.info(f"Attempting to add the token {token}.")
32 | num_added_tokens = tokenizer.add_tokens(token)
33 | i += 1
34 |
35 | # resize the token embeddings
36 | text_encoder.resize_token_embeddings(len(tokenizer))
37 |
38 | # get the id for the token and assign the embeds
39 | token_id = tokenizer.convert_tokens_to_ids(token)
40 | text_encoder.get_input_embeddings().weight.data[token_id] = embeds
41 | return token
42 |
43 |
44 | @dataclass
45 | class TextualInversion:
46 | model: str
47 | embeddings_url: str
48 | token_identifier: str
49 | device: Optional[str] = None
50 | output_path: Optional[str] = None
51 |
52 | def __str__(self) -> str:
53 | return f"TextualInversion(model={self.model}, embeddings={self.embeddings_url}, token_identifier={self.token_identifier}, device={self.device}, output_path={self.output_path})"
54 |
55 | def __post_init__(self):
56 | self.pipeline = DiffusionPipeline.from_pretrained(
57 | self.model,
58 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
59 | )
60 | self.pipeline.to(self.device)
61 | self.pipeline.safety_checker = utils.no_safety_checker
62 | self._compatible_schedulers = self.pipeline.scheduler.compatibles
63 | self.scheduler_config = self.pipeline.scheduler.config
64 | self.compatible_schedulers = {scheduler.__name__: scheduler for scheduler in self._compatible_schedulers}
65 |
66 | # download the embeddings
67 | self.embeddings_path = utils.download_file(self.embeddings_url)
68 | load_embed(
69 | learned_embeds_path=self.embeddings_path,
70 | text_encoder=self.pipeline.text_encoder,
71 | tokenizer=self.pipeline.tokenizer,
72 | token=self.token_identifier,
73 | )
74 |
75 | if self.device == "mps":
76 | self.pipeline.enable_attention_slicing()
77 | # warmup
78 | prompt = "a photo of an astronaut riding a horse on mars"
79 | _ = self.pipeline(prompt, num_inference_steps=1)
80 |
81 | def _set_scheduler(self, scheduler_name):
82 | scheduler = self.compatible_schedulers[scheduler_name].from_config(self.scheduler_config)
83 | self.pipeline.scheduler = scheduler
84 |
85 | def generate_image(self, prompt, negative_prompt, scheduler, image_size, num_images, guidance_scale, steps, seed):
86 | self._set_scheduler(scheduler)
87 | logger.info(self.pipeline.scheduler)
88 | if self.device == "mps":
89 | generator = torch.manual_seed(seed)
90 | num_images = 1
91 | else:
92 | generator = torch.Generator(device=self.device).manual_seed(seed)
93 | num_images = int(num_images)
94 | output_images = self.pipeline(
95 | prompt,
96 | negative_prompt=negative_prompt,
97 | width=image_size[1],
98 | height=image_size[0],
99 | num_inference_steps=steps,
100 | guidance_scale=guidance_scale,
101 | num_images_per_prompt=num_images,
102 | generator=generator,
103 | ).images
104 | torch.cuda.empty_cache()
105 | gc.collect()
106 | metadata = {
107 | "prompt": prompt,
108 | "negative_prompt": negative_prompt,
109 | "scheduler": scheduler,
110 | "image_size": image_size,
111 | "num_images": num_images,
112 | "guidance_scale": guidance_scale,
113 | "steps": steps,
114 | "seed": seed,
115 | }
116 | metadata = json.dumps(metadata)
117 | _metadata = PngInfo()
118 | _metadata.add_text("textual_inversion", metadata)
119 |
120 | utils.save_images(
121 | images=output_images,
122 | module="textual_inversion",
123 | metadata=metadata,
124 | output_path=self.output_path,
125 | )
126 |
127 | return output_images, _metadata
128 |
129 | def app(self):
130 | available_schedulers = list(self.compatible_schedulers.keys())
131 | if "EulerAncestralDiscreteScheduler" in available_schedulers:
132 | available_schedulers.insert(
133 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler"))
134 | )
135 | col1, col2 = st.columns(2)
136 | with col1:
137 | prompt = st.text_area("Prompt", "Blue elephant")
138 | with col2:
139 | negative_prompt = st.text_area("Negative Prompt", "")
140 | # sidebar options
141 | scheduler = st.sidebar.selectbox("Scheduler", available_schedulers, index=0)
142 | image_height = st.sidebar.slider("Image height", 128, 1024, 512, 128)
143 | image_width = st.sidebar.slider("Image width", 128, 1024, 512, 128)
144 | guidance_scale = st.sidebar.slider("Guidance scale", 1.0, 40.0, 7.5, 0.5)
145 | num_images = st.sidebar.slider("Number of images per prompt", 1, 30, 1, 1)
146 | steps = st.sidebar.slider("Steps", 1, 150, 50, 1)
147 | seed_placeholder = st.sidebar.empty()
148 | seed = seed_placeholder.number_input("Seed", value=42, min_value=1, max_value=999999, step=1)
149 | random_seed = st.sidebar.button("Random seed")
150 | _seed = torch.randint(1, 999999, (1,)).item()
151 | if random_seed:
152 | seed = seed_placeholder.number_input("Seed", value=_seed, min_value=1, max_value=999999, step=1)
153 | sub_col, download_col = st.columns(2)
154 | with sub_col:
155 | submit = st.button("Generate")
156 |
157 | if submit:
158 | with st.spinner("Generating images..."):
159 | output_images, metadata = self.generate_image(
160 | prompt=prompt,
161 | negative_prompt=negative_prompt,
162 | scheduler=scheduler,
163 | image_size=(image_height, image_width),
164 | num_images=num_images,
165 | guidance_scale=guidance_scale,
166 | steps=steps,
167 | seed=seed,
168 | )
169 |
170 | utils.display_and_download_images(output_images, metadata, download_col)
171 |
--------------------------------------------------------------------------------
/diffuzers/upscaler.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import json
3 | from dataclasses import dataclass
4 | from typing import Optional
5 |
6 | import streamlit as st
7 | import torch
8 | from diffusers import StableDiffusionUpscalePipeline
9 | from loguru import logger
10 | from PIL import Image
11 | from PIL.PngImagePlugin import PngInfo
12 |
13 | from diffuzers import utils
14 |
15 |
16 | @dataclass
17 | class Upscaler:
18 | model: str = "stabilityai/stable-diffusion-x4-upscaler"
19 | device: Optional[str] = None
20 | output_path: Optional[str] = None
21 |
22 | def __str__(self) -> str:
23 | return f"Upscaler(model={self.model}, device={self.device}, output_path={self.output_path})"
24 |
25 | def __post_init__(self):
26 | self.pipeline = StableDiffusionUpscalePipeline.from_pretrained(
27 | self.model,
28 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
29 | )
30 | self.pipeline.to(self.device)
31 | self.pipeline.safety_checker = utils.no_safety_checker
32 | self._compatible_schedulers = self.pipeline.scheduler.compatibles
33 | self.scheduler_config = self.pipeline.scheduler.config
34 | self.compatible_schedulers = {scheduler.__name__: scheduler for scheduler in self._compatible_schedulers}
35 |
36 | def _set_scheduler(self, scheduler_name):
37 | scheduler = self.compatible_schedulers[scheduler_name].from_config(self.scheduler_config)
38 | self.pipeline.scheduler = scheduler
39 |
40 | def generate_image(
41 | self, image, prompt, negative_prompt, guidance_scale, noise_level, num_images, eta, scheduler, steps, seed
42 | ):
43 | self._set_scheduler(scheduler)
44 | logger.info(self.pipeline.scheduler)
45 | if self.device == "mps":
46 | generator = torch.manual_seed(seed)
47 | num_images = 1
48 | else:
49 | generator = torch.Generator(device=self.device).manual_seed(seed)
50 | num_images = int(num_images)
51 | output_images = self.pipeline(
52 | image=image,
53 | prompt=prompt,
54 | negative_prompt=negative_prompt,
55 | noise_level=noise_level,
56 | num_inference_steps=steps,
57 | eta=eta,
58 | num_images_per_prompt=num_images,
59 | generator=generator,
60 | guidance_scale=guidance_scale,
61 | ).images
62 |
63 | torch.cuda.empty_cache()
64 | gc.collect()
65 | metadata = {
66 | "prompt": prompt,
67 | "negative_prompt": negative_prompt,
68 | "noise_level": noise_level,
69 | "num_images": num_images,
70 | "eta": eta,
71 | "scheduler": scheduler,
72 | "steps": steps,
73 | "seed": seed,
74 | }
75 |
76 | metadata = json.dumps(metadata)
77 | _metadata = PngInfo()
78 | _metadata.add_text("upscaler", metadata)
79 |
80 | utils.save_images(
81 | images=output_images,
82 | module="upscaler",
83 | metadata=metadata,
84 | output_path=self.output_path,
85 | )
86 |
87 | return output_images, _metadata
88 |
89 | def app(self):
90 | available_schedulers = list(self.compatible_schedulers.keys())
91 | # if EulerAncestralDiscreteScheduler is available in available_schedulers, move it to the first position
92 | if "EulerAncestralDiscreteScheduler" in available_schedulers:
93 | available_schedulers.insert(
94 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler"))
95 | )
96 |
97 | input_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
98 | if input_image is not None:
99 | input_image = Image.open(input_image)
100 | input_image = input_image.convert("RGB").resize((128, 128), resample=Image.LANCZOS)
101 | st.image(input_image, use_column_width=True)
102 |
103 | col1, col2 = st.columns(2)
104 | with col1:
105 | prompt = st.text_area("Prompt (Optional)", "")
106 | with col2:
107 | negative_prompt = st.text_area("Negative Prompt (Optional)", "")
108 |
109 | scheduler = st.sidebar.selectbox("Scheduler", available_schedulers, index=0)
110 | guidance_scale = st.sidebar.slider("Guidance scale", 1.0, 40.0, 7.5, 0.5)
111 | noise_level = st.sidebar.slider("Noise level", 0, 100, 20, 1)
112 | eta = st.sidebar.slider("Eta", 0.0, 1.0, 0.0, 0.1)
113 | num_images = st.sidebar.slider("Number of images per prompt", 1, 30, 1, 1)
114 | steps = st.sidebar.slider("Steps", 1, 150, 50, 1)
115 | seed_placeholder = st.sidebar.empty()
116 | seed = seed_placeholder.number_input("Seed", value=42, min_value=1, max_value=999999, step=1)
117 | random_seed = st.sidebar.button("Random seed")
118 | _seed = torch.randint(1, 999999, (1,)).item()
119 | if random_seed:
120 | seed = seed_placeholder.number_input("Seed", value=_seed, min_value=1, max_value=999999, step=1)
121 | sub_col, download_col = st.columns(2)
122 | with sub_col:
123 | submit = st.button("Generate")
124 |
125 | if submit:
126 | with st.spinner("Generating images..."):
127 | output_images, metadata = self.generate_image(
128 | prompt=prompt,
129 | image=input_image,
130 | negative_prompt=negative_prompt,
131 | scheduler=scheduler,
132 | num_images=num_images,
133 | guidance_scale=guidance_scale,
134 | steps=steps,
135 | seed=seed,
136 | noise_level=noise_level,
137 | eta=eta,
138 | )
139 |
140 | utils.display_and_download_images(output_images, metadata, download_col)
141 |
--------------------------------------------------------------------------------
/diffuzers/utils.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import gc
3 | import io
4 | import os
5 | import tempfile
6 | import zipfile
7 | from datetime import datetime
8 | from threading import Thread
9 |
10 | import requests
11 | import streamlit as st
12 | import torch
13 | from huggingface_hub import HfApi
14 | from huggingface_hub.utils._errors import RepositoryNotFoundError
15 | from huggingface_hub.utils._validators import HFValidationError
16 | from loguru import logger
17 | from PIL.PngImagePlugin import PngInfo
18 | from st_clickable_images import clickable_images
19 |
20 |
21 | no_safety_checker = None
22 |
23 |
24 | CODE_OF_CONDUCT = """
25 | ## Code of conduct
26 | The app should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
27 |
28 | Using the app to generate content that is cruel to individuals is a misuse of this app. One shall not use this app to generate content that is intended to be cruel to individuals, or to generate content that is intended to be cruel to individuals in a way that is not obvious to the viewer.
29 | This includes, but is not limited to:
30 | - Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
31 | - Intentionally promoting or propagating discriminatory content or harmful stereotypes.
32 | - Impersonating individuals without their consent.
33 | - Sexual content without consent of the people who might see it.
34 | - Mis- and disinformation
35 | - Representations of egregious violence and gore
36 | - Sharing of copyrighted or licensed material in violation of its terms of use.
37 | - Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
38 |
39 | By using this app, you agree to the above code of conduct.
40 |
41 | """
42 |
43 |
44 | def use_auth_token():
45 | token_path = os.path.join(os.path.expanduser("~"), ".huggingface", "token")
46 | if os.path.exists(token_path):
47 | return True
48 | if "HF_TOKEN" in os.environ:
49 | return os.environ["HF_TOKEN"]
50 | return False
51 |
52 |
53 | def create_base_page():
54 | st.set_page_config(layout="wide")
55 | st.title("Diffuzers")
56 | st.markdown("Welcome to Diffuzers! A web app for [🤗 Diffusers](https://github.com/huggingface/diffusers)")
57 |
58 |
59 | def download_file(file_url):
60 | r = requests.get(file_url, stream=True)
61 | with tempfile.NamedTemporaryFile(delete=False) as tmp:
62 | for chunk in r.iter_content(chunk_size=1024):
63 | if chunk: # filter out keep-alive new chunks
64 | tmp.write(chunk)
65 | return tmp.name
66 |
67 |
68 | def cache_folder():
69 | _cache_folder = os.path.join(os.path.expanduser("~"), ".diffuzers")
70 | os.makedirs(_cache_folder, exist_ok=True)
71 | return _cache_folder
72 |
73 |
74 | def clear_memory(preserve):
75 | torch.cuda.empty_cache()
76 | gc.collect()
77 | to_clear = ["inpainting", "text2img", "img2text"]
78 | for key in to_clear:
79 | if key not in preserve and key in st.session_state:
80 | del st.session_state[key]
81 |
82 |
83 | def save_to_hub(api, images, module, current_datetime, metadata, output_path):
84 | logger.info(f"Saving images to hub: {output_path}")
85 | _metadata = PngInfo()
86 | _metadata.add_text("text2img", metadata)
87 | for i, img in enumerate(images):
88 | img_byte_arr = io.BytesIO()
89 | img.save(img_byte_arr, format="PNG", pnginfo=_metadata)
90 | img_byte_arr = img_byte_arr.getvalue()
91 | api.upload_file(
92 | path_or_fileobj=img_byte_arr,
93 | path_in_repo=f"{module}/{current_datetime}/{i}.png",
94 | repo_id=output_path,
95 | repo_type="dataset",
96 | )
97 |
98 | api.upload_file(
99 | path_or_fileobj=str.encode(metadata),
100 | path_in_repo=f"{module}/{current_datetime}/metadata.json",
101 | repo_id=output_path,
102 | repo_type="dataset",
103 | )
104 | logger.info(f"Saved images to hub: {output_path}")
105 |
106 |
107 | def save_to_local(images, module, current_datetime, metadata, output_path):
108 | _metadata = PngInfo()
109 | _metadata.add_text("text2img", metadata)
110 | os.makedirs(output_path, exist_ok=True)
111 | os.makedirs(f"{output_path}/{module}", exist_ok=True)
112 | os.makedirs(f"{output_path}/{module}/{current_datetime}", exist_ok=True)
113 |
114 | for i, img in enumerate(images):
115 | img.save(
116 | f"{output_path}/{module}/{current_datetime}/{i}.png",
117 | pnginfo=_metadata,
118 | )
119 |
120 | # save metadata as text file
121 | with open(f"{output_path}/{module}/{current_datetime}/metadata.txt", "w") as f:
122 | f.write(metadata)
123 | logger.info(f"Saved images to {output_path}/{module}/{current_datetime}")
124 |
125 |
126 | def save_images(images, module, metadata, output_path):
127 | if output_path is None:
128 | logger.warning("No output path specified, skipping saving images")
129 | return
130 |
131 | api = HfApi()
132 | dset_info = None
133 | try:
134 | dset_info = api.dataset_info(output_path)
135 | except (HFValidationError, RepositoryNotFoundError):
136 | logger.warning("No valid hugging face repo. Saving locally...")
137 |
138 | current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
139 |
140 | if not dset_info:
141 | save_to_local(images, module, current_datetime, metadata, output_path)
142 | else:
143 | Thread(target=save_to_hub, args=(api, images, module, current_datetime, metadata, output_path)).start()
144 |
145 |
146 | def display_and_download_images(output_images, metadata, download_col=None):
147 | # st.image(output_images, width=128, output_format="PNG")
148 |
149 | with st.spinner("Preparing images for download..."):
150 | # save images to a temporary directory
151 | with tempfile.TemporaryDirectory() as tmpdir:
152 | gallery_images = []
153 | for i, image in enumerate(output_images):
154 | image.save(os.path.join(tmpdir, f"{i + 1}.png"), pnginfo=metadata)
155 | with open(os.path.join(tmpdir, f"{i + 1}.png"), "rb") as img:
156 | encoded = base64.b64encode(img.read()).decode()
157 | gallery_images.append(f"data:image/jpeg;base64,{encoded}")
158 |
159 | # zip the images
160 | zip_path = os.path.join(tmpdir, "images.zip")
161 | with zipfile.ZipFile(zip_path, "w") as zip:
162 | for filename in os.listdir(tmpdir):
163 | if filename.endswith(".png"):
164 | zip.write(os.path.join(tmpdir, filename), filename)
165 |
166 | # convert zip to base64
167 | with open(zip_path, "rb") as f:
168 | encoded = base64.b64encode(f.read()).decode()
169 |
170 | _ = clickable_images(
171 | gallery_images,
172 | titles=[f"Image #{str(i)}" for i in range(len(gallery_images))],
173 | div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
174 | img_style={"margin": "5px", "height": "200px"},
175 | )
176 |
177 | # add download link
178 | st.markdown(
179 | f"""
180 |
181 | Download Images
182 |
183 | """,
184 | unsafe_allow_html=True,
185 | )
186 |
--------------------------------------------------------------------------------
/diffuzers/x2image.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import gc
3 | import json
4 | import os
5 | import random
6 | import tempfile
7 | from dataclasses import dataclass
8 | from io import BytesIO
9 | from typing import Optional
10 |
11 | import requests
12 | import streamlit as st
13 | import torch
14 | from diffusers import (
15 | AltDiffusionImg2ImgPipeline,
16 | AltDiffusionPipeline,
17 | DiffusionPipeline,
18 | StableDiffusionImg2ImgPipeline,
19 | StableDiffusionInstructPix2PixPipeline,
20 | StableDiffusionPipeline,
21 | )
22 | from loguru import logger
23 | from PIL import Image
24 | from PIL.PngImagePlugin import PngInfo
25 | from st_clickable_images import clickable_images
26 |
27 | from diffuzers import utils
28 |
29 |
30 | def load_embed(learned_embeds_path, text_encoder, tokenizer, token=None):
31 | loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
32 | if len(loaded_learned_embeds) > 2:
33 | embeds = loaded_learned_embeds["string_to_param"]["*"][-1, :]
34 | else:
35 | # separate token and the embeds
36 | trained_token = list(loaded_learned_embeds.keys())[0]
37 | embeds = loaded_learned_embeds[trained_token]
38 |
39 | # add the token in tokenizer
40 | token = token if token is not None else trained_token
41 | num_added_tokens = tokenizer.add_tokens(token)
42 | i = 1
43 | while num_added_tokens == 0:
44 | logger.warning(f"The tokenizer already contains the token {token}.")
45 | token = f"{token[:-1]}-{i}>"
46 | logger.info(f"Attempting to add the token {token}.")
47 | num_added_tokens = tokenizer.add_tokens(token)
48 | i += 1
49 |
50 | # resize the token embeddings
51 | text_encoder.resize_token_embeddings(len(tokenizer))
52 |
53 | # get the id for the token and assign the embeds
54 | token_id = tokenizer.convert_tokens_to_ids(token)
55 | text_encoder.get_input_embeddings().weight.data[token_id] = embeds
56 | return token
57 |
58 |
59 | @dataclass
60 | class X2Image:
61 | device: Optional[str] = None
62 | model: Optional[str] = None
63 | output_path: Optional[str] = None
64 | custom_pipeline: Optional[str] = None
65 | embeddings_url: Optional[str] = None
66 | token_identifier: Optional[str] = None
67 |
68 | def __str__(self) -> str:
69 | return f"X2Image(model={self.model}, pipeline={self.custom_pipeline})"
70 |
71 | def __post_init__(self):
72 | self.text2img_pipeline = DiffusionPipeline.from_pretrained(
73 | self.model,
74 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
75 | custom_pipeline=self.custom_pipeline,
76 | use_auth_token=utils.use_auth_token(),
77 | )
78 | components = self.text2img_pipeline.components
79 | self.pix2pix_pipeline = None
80 | if isinstance(self.text2img_pipeline, StableDiffusionPipeline):
81 | self.img2img_pipeline = StableDiffusionImg2ImgPipeline(**components)
82 | self.pix2pix_pipeline = StableDiffusionInstructPix2PixPipeline(**components)
83 | elif isinstance(self.text2img_pipeline, AltDiffusionPipeline):
84 | self.img2img_pipeline = AltDiffusionImg2ImgPipeline(**components)
85 | else:
86 | self.img2img_pipeline = None
87 | logger.error("Model type not supported, img2img pipeline not created")
88 |
89 | self.text2img_pipeline.to(self.device)
90 | self.text2img_pipeline.safety_checker = utils.no_safety_checker
91 | self.img2img_pipeline.to(self.device)
92 | self.img2img_pipeline.safety_checker = utils.no_safety_checker
93 | if self.pix2pix_pipeline is not None:
94 | self.pix2pix_pipeline.to(self.device)
95 | self.pix2pix_pipeline.safety_checker = utils.no_safety_checker
96 |
97 | self.compatible_schedulers = {
98 | scheduler.__name__: scheduler for scheduler in self.text2img_pipeline.scheduler.compatibles
99 | }
100 |
101 | if len(self.embeddings_url) > 0 and len(self.token_identifier) > 0:
102 | # download the embeddings
103 | self.embeddings_path = utils.download_file(self.embeddings_url)
104 | load_embed(
105 | learned_embeds_path=self.embeddings_path,
106 | text_encoder=self.pipeline.text_encoder,
107 | tokenizer=self.pipeline.tokenizer,
108 | token=self.token_identifier,
109 | )
110 |
111 | if self.device == "mps":
112 | self.text2img_pipeline.enable_attention_slicing()
113 | prompt = "a photo of an astronaut riding a horse on mars"
114 | _ = self.text2img_pipeline(prompt, num_inference_steps=2)
115 |
116 | self.img2img_pipeline.enable_attention_slicing()
117 | url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
118 | response = requests.get(url)
119 | init_image = Image.open(BytesIO(response.content)).convert("RGB")
120 | init_image.thumbnail((768, 768))
121 | prompt = "A fantasy landscape, trending on artstation"
122 | _ = self.img2img_pipeline(
123 | prompt=prompt,
124 | image=init_image,
125 | strength=0.75,
126 | guidance_scale=7.5,
127 | num_inference_steps=2,
128 | )
129 | if self.pix2pix_pipeline is not None:
130 | self.pix2pix_pipeline.enable_attention_slicing()
131 | prompt = "turn him into a cyborg"
132 | _ = self.pix2pix_pipeline(prompt, image=init_image, num_inference_steps=2)
133 |
134 | def _set_scheduler(self, pipeline_name, scheduler_name):
135 | if pipeline_name == "text2img":
136 | scheduler_config = self.text2img_pipeline.scheduler.config
137 | elif pipeline_name == "img2img":
138 | scheduler_config = self.img2img_pipeline.scheduler.config
139 | elif pipeline_name == "pix2pix":
140 | scheduler_config = self.pix2pix_pipeline.scheduler.config
141 | else:
142 | raise ValueError(f"Pipeline {pipeline_name} not supported")
143 |
144 | scheduler = self.compatible_schedulers[scheduler_name].from_config(scheduler_config)
145 |
146 | if pipeline_name == "text2img":
147 | self.text2img_pipeline.scheduler = scheduler
148 | elif pipeline_name == "img2img":
149 | self.img2img_pipeline.scheduler = scheduler
150 |
151 | def _pregen(self, pipeline_name, scheduler, num_images, seed):
152 | self._set_scheduler(scheduler_name=scheduler, pipeline_name=pipeline_name)
153 | if self.device == "mps":
154 | generator = torch.manual_seed(seed)
155 | num_images = 1
156 | else:
157 | generator = torch.Generator(device=self.device).manual_seed(seed)
158 | num_images = int(num_images)
159 | return generator, num_images
160 |
161 | def _postgen(self, metadata, output_images, pipeline_name):
162 | torch.cuda.empty_cache()
163 | gc.collect()
164 | metadata = json.dumps(metadata)
165 | _metadata = PngInfo()
166 | _metadata.add_text(pipeline_name, metadata)
167 | utils.save_images(
168 | images=output_images,
169 | module=pipeline_name,
170 | metadata=metadata,
171 | output_path=self.output_path,
172 | )
173 | return output_images, _metadata
174 |
175 | def text2img_generate(
176 | self, prompt, negative_prompt, scheduler, image_size, num_images, guidance_scale, steps, seed
177 | ):
178 |
179 | if seed == -1:
180 | # generate random seed
181 | seed = random.randint(0, 999999)
182 |
183 | generator, num_images = self._pregen(
184 | pipeline_name="text2img",
185 | scheduler=scheduler,
186 | num_images=num_images,
187 | seed=seed,
188 | )
189 | output_images = self.text2img_pipeline(
190 | prompt,
191 | negative_prompt=negative_prompt,
192 | width=image_size[1],
193 | height=image_size[0],
194 | num_inference_steps=steps,
195 | guidance_scale=guidance_scale,
196 | num_images_per_prompt=num_images,
197 | generator=generator,
198 | ).images
199 | metadata = {
200 | "prompt": prompt,
201 | "negative_prompt": negative_prompt,
202 | "scheduler": scheduler,
203 | "image_size": image_size,
204 | "num_images": num_images,
205 | "guidance_scale": guidance_scale,
206 | "steps": steps,
207 | "seed": seed,
208 | }
209 |
210 | output_images, _metadata = self._postgen(
211 | metadata=metadata,
212 | output_images=output_images,
213 | pipeline_name="text2img",
214 | )
215 | return output_images, _metadata
216 |
217 | def img2img_generate(
218 | self, prompt, image, strength, negative_prompt, scheduler, num_images, guidance_scale, steps, seed
219 | ):
220 |
221 | if seed == -1:
222 | # generate random seed
223 | seed = random.randint(0, 999999)
224 |
225 | generator, num_images = self._pregen(
226 | pipeline_name="img2img",
227 | scheduler=scheduler,
228 | num_images=num_images,
229 | seed=seed,
230 | )
231 | output_images = self.img2img_pipeline(
232 | prompt=prompt,
233 | image=image,
234 | strength=strength,
235 | negative_prompt=negative_prompt,
236 | num_inference_steps=steps,
237 | guidance_scale=guidance_scale,
238 | num_images_per_prompt=num_images,
239 | generator=generator,
240 | ).images
241 | metadata = {
242 | "prompt": prompt,
243 | "negative_prompt": negative_prompt,
244 | "scheduler": scheduler,
245 | "num_images": num_images,
246 | "guidance_scale": guidance_scale,
247 | "steps": steps,
248 | "seed": seed,
249 | }
250 | output_images, _metadata = self._postgen(
251 | metadata=metadata,
252 | output_images=output_images,
253 | pipeline_name="img2img",
254 | )
255 | return output_images, _metadata
256 |
257 | def pix2pix_generate(
258 | self, prompt, image, negative_prompt, scheduler, num_images, guidance_scale, image_guidance_scale, steps, seed
259 | ):
260 | if seed == -1:
261 | # generate random seed
262 | seed = random.randint(0, 999999)
263 |
264 | generator, num_images = self._pregen(
265 | pipeline_name="pix2pix",
266 | scheduler=scheduler,
267 | num_images=num_images,
268 | seed=seed,
269 | )
270 | output_images = self.pix2pix_pipeline(
271 | prompt=prompt,
272 | image=image,
273 | negative_prompt=negative_prompt,
274 | num_inference_steps=steps,
275 | guidance_scale=guidance_scale,
276 | image_guidance_scale=image_guidance_scale,
277 | num_images_per_prompt=num_images,
278 | generator=generator,
279 | ).images
280 | metadata = {
281 | "prompt": prompt,
282 | "negative_prompt": negative_prompt,
283 | "scheduler": scheduler,
284 | "num_images": num_images,
285 | "guidance_scale": guidance_scale,
286 | "image_guidance_scale": image_guidance_scale,
287 | "steps": steps,
288 | "seed": seed,
289 | }
290 | output_images, _metadata = self._postgen(
291 | metadata=metadata,
292 | output_images=output_images,
293 | pipeline_name="pix2pix",
294 | )
295 | return output_images, _metadata
296 |
297 | def app(self):
298 | available_schedulers = list(self.compatible_schedulers.keys())
299 | if "EulerAncestralDiscreteScheduler" in available_schedulers:
300 | available_schedulers.insert(
301 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler"))
302 | )
303 | # col3, col4 = st.columns(2)
304 | # with col3:
305 | input_image = st.file_uploader(
306 | "Upload an image to use image2image or pix2pix",
307 | type=["png", "jpg", "jpeg"],
308 | help="Upload an image to use image2image. If left blank, text2image will be used instead.",
309 | )
310 | use_pix2pix = st.checkbox("Use pix2pix", value=False)
311 | if input_image is not None:
312 | input_image = Image.open(input_image)
313 | if use_pix2pix:
314 | pipeline_name = "pix2pix"
315 | else:
316 | pipeline_name = "img2img"
317 | # display image using html
318 | # convert image to base64
319 | # st.markdown(f"
", unsafe_allow_html=True)
320 | # st.image(input_image, use_column_width=True)
321 | with tempfile.TemporaryDirectory() as tmpdir:
322 | gallery_images = []
323 | input_image.save(os.path.join(tmpdir, "img.png"))
324 | with open(os.path.join(tmpdir, "img.png"), "rb") as img:
325 | encoded = base64.b64encode(img.read()).decode()
326 | gallery_images.append(f"data:image/jpeg;base64,{encoded}")
327 |
328 | _ = clickable_images(
329 | gallery_images,
330 | titles=[f"Image #{str(i)}" for i in range(len(gallery_images))],
331 | div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
332 | img_style={"margin": "5px", "height": "200px"},
333 | )
334 | else:
335 | pipeline_name = "text2img"
336 | # prompt = st.text_area("Prompt", "Blue elephant")
337 | # negative_prompt = st.text_area("Negative Prompt", "")
338 | # with col4:
339 | col1, col2 = st.columns(2)
340 | with col1:
341 | prompt = st.text_area("Prompt", "Blue elephant", help="Prompt to guide image generation")
342 | with col2:
343 | negative_prompt = st.text_area(
344 | "Negative Prompt",
345 | "",
346 | help="The prompt not to guide image generation. Write things that you dont want to see in the image.",
347 | )
348 | # sidebar options
349 | if input_image is None:
350 | image_height = st.sidebar.slider(
351 | "Image height", 128, 1024, 512, 128, help="The height in pixels of the generated image."
352 | )
353 | image_width = st.sidebar.slider(
354 | "Image width", 128, 1024, 512, 128, help="The width in pixels of the generated image."
355 | )
356 |
357 | num_images = st.sidebar.slider(
358 | "Number of images per prompt",
359 | 1,
360 | 30,
361 | 1,
362 | 1,
363 | help="Number of images you want to generate. More images requires more time and uses more GPU memory.",
364 | )
365 |
366 | # add section advanced options
367 | st.sidebar.markdown("### Advanced options")
368 | scheduler = st.sidebar.selectbox(
369 | "Scheduler", available_schedulers, index=0, help="Scheduler to use for generation"
370 | )
371 | guidance_scale = st.sidebar.slider(
372 | "Guidance scale",
373 | 1.0,
374 | 40.0,
375 | 7.5,
376 | 0.5,
377 | help="Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.",
378 | )
379 | if use_pix2pix and input_image is not None:
380 | image_guidance_scale = st.sidebar.slider(
381 | "Image guidance scale",
382 | 1.0,
383 | 40.0,
384 | 1.5,
385 | 0.5,
386 | help="Image guidance scale is to push the generated image towards the inital image `image`. Image guidance scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to generate images that are closely linked to the source image `image`, usually at the expense of lower image quality.",
387 | )
388 | if input_image is not None and not use_pix2pix:
389 | strength = st.sidebar.slider(
390 | "Denoising strength",
391 | 0.0,
392 | 1.0,
393 | 0.8,
394 | 0.05,
395 | help="Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.",
396 | )
397 | steps = st.sidebar.slider(
398 | "Steps",
399 | 1,
400 | 150,
401 | 50,
402 | 1,
403 | help="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.",
404 | )
405 | seed = st.sidebar.number_input(
406 | "Seed",
407 | value=42,
408 | min_value=-1,
409 | max_value=999999,
410 | step=1,
411 | help="Random seed. Change for different results using same parameters.",
412 | )
413 |
414 | sub_col, download_col = st.columns(2)
415 | with sub_col:
416 | submit = st.button("Generate")
417 |
418 | if submit:
419 | with st.spinner("Generating images..."):
420 | if pipeline_name == "text2img":
421 | output_images, metadata = self.text2img_generate(
422 | prompt=prompt,
423 | negative_prompt=negative_prompt,
424 | scheduler=scheduler,
425 | image_size=(image_height, image_width),
426 | num_images=num_images,
427 | guidance_scale=guidance_scale,
428 | steps=steps,
429 | seed=seed,
430 | )
431 | elif pipeline_name == "img2img":
432 | output_images, metadata = self.img2img_generate(
433 | prompt=prompt,
434 | image=input_image,
435 | strength=strength,
436 | negative_prompt=negative_prompt,
437 | scheduler=scheduler,
438 | num_images=num_images,
439 | guidance_scale=guidance_scale,
440 | steps=steps,
441 | seed=seed,
442 | )
443 | elif pipeline_name == "pix2pix":
444 | output_images, metadata = self.pix2pix_generate(
445 | prompt=prompt,
446 | image=input_image,
447 | negative_prompt=negative_prompt,
448 | scheduler=scheduler,
449 | num_images=num_images,
450 | guidance_scale=guidance_scale,
451 | image_guidance_scale=image_guidance_scale,
452 | steps=steps,
453 | seed=seed,
454 | )
455 | utils.display_and_download_images(output_images, metadata, download_col)
456 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | # import os
14 | # import sys
15 | # sys.path.insert(0, os.path.abspath('.'))
16 |
17 |
18 | # -- Project information -----------------------------------------------------
19 |
20 | project = "diffuzers: webapp and api for 🤗 diffusers"
21 | copyright = "2023, abhishek thakur"
22 | author = "abhishek thakur"
23 |
24 |
25 | # -- General configuration ---------------------------------------------------
26 |
27 | # Add any Sphinx extension module names here, as strings. They can be
28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
29 | # ones.
30 | extensions = []
31 |
32 | # Add any paths that contain templates here, relative to this directory.
33 | templates_path = ["_templates"]
34 |
35 | # List of patterns, relative to source directory, that match files and
36 | # directories to ignore when looking for source files.
37 | # This pattern also affects html_static_path and html_extra_path.
38 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
39 |
40 |
41 | # -- Options for HTML output -------------------------------------------------
42 |
43 | # The theme to use for HTML and HTML Help pages. See the documentation for
44 | # a list of builtin themes.
45 | #
46 | html_theme = "alabaster"
47 |
48 | # Add any paths that contain custom static files (such as style sheets) here,
49 | # relative to this directory. They are copied after the builtin static files,
50 | # so a file named "default.css" will overwrite the builtin "default.css".
51 | html_static_path = ["_static"]
52 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. diffuzers: webapp and api for 🤗 diffusers documentation master file, created by
2 | sphinx-quickstart on Thu Jan 12 14:10:58 2023.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to diffuzers' documentation!
7 | ======================================================================
8 |
9 | Diffuzers offers web app and also api for 🤗 diffusers. Installation is very simple.
10 | You can install via pip:
11 |
12 | .. code-block:: bash
13 |
14 | pip install diffuzers
15 |
16 | .. toctree::
17 | :maxdepth: 2
18 | :caption: Contents:
19 |
20 |
21 |
22 | Indices and tables
23 | ==================
24 |
25 | * :ref:`genindex`
26 | * :ref:`modindex`
27 | * :ref:`search`
28 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.https://www.sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.15.0
2 | basicsr>=1.4.2
3 | diffusers>=0.12.1
4 | facexlib>=0.2.5
5 | fairscale==0.4.4
6 | fastapi==0.88.0
7 | gfpgan>=1.3.7
8 | huggingface_hub>=0.11.1
9 | loguru==0.6.0
10 | open_clip_torch==2.9.1
11 | opencv-python
12 | protobuf==3.20
13 | pyngrok==5.2.1
14 | python-multipart==0.0.5
15 | realesrgan>=0.2.5
16 | streamlit==1.16.0
17 | streamlit-drawable-canvas==0.9.0
18 | st-clickable-images==0.0.3
19 | timm==0.4.12
20 | torch>=1.12.0
21 | torchvision>=0.13.0
22 | transformers==4.25.1
23 | uvicorn==0.15.0
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | version = attr: diffuzers.__version__
3 |
4 | [isort]
5 | ensure_newline_before_comments = True
6 | force_grid_wrap = 0
7 | include_trailing_comma = True
8 | line_length = 119
9 | lines_after_imports = 2
10 | multi_line_output = 3
11 | use_parentheses = True
12 |
13 | [flake8]
14 | ignore = E203, E501, W503
15 | max-line-length = 119
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Lint as: python3
2 | # pylint: enable=line-too-long
3 | """Hugging Face Competitions
4 | """
5 | from setuptools import find_packages, setup
6 |
7 |
8 | with open("README.md") as f:
9 | long_description = f.read()
10 |
11 | QUALITY_REQUIRE = [
12 | "black~=22.0",
13 | "isort==5.8.0",
14 | "flake8==3.9.2",
15 | "mypy==0.901",
16 | ]
17 |
18 | TEST_REQUIRE = ["pytest", "pytest-cov"]
19 |
20 | EXTRAS_REQUIRE = {
21 | "dev": QUALITY_REQUIRE,
22 | "quality": QUALITY_REQUIRE,
23 | "test": TEST_REQUIRE,
24 | "docs": [
25 | "recommonmark",
26 | "sphinx==3.1.2",
27 | "sphinx-markdown-tables",
28 | "sphinx-rtd-theme==0.4.3",
29 | "sphinx-copybutton",
30 | ],
31 | }
32 |
33 | with open("requirements.txt") as f:
34 | INSTALL_REQUIRES = f.read().splitlines()
35 |
36 | setup(
37 | name="diffuzers",
38 | description="diffuzers",
39 | long_description=long_description,
40 | long_description_content_type="text/markdown",
41 | author="Abhishek Thakur",
42 | url="https://github.com/abhishekkrthakur/diffuzers",
43 | packages=find_packages("."),
44 | entry_points={"console_scripts": ["diffuzers=diffuzers.cli.main:main"]},
45 | install_requires=INSTALL_REQUIRES,
46 | extras_require=EXTRAS_REQUIRE,
47 | python_requires=">=3.7",
48 | classifiers=[
49 | "Intended Audience :: Developers",
50 | "Intended Audience :: Education",
51 | "Intended Audience :: Science/Research",
52 | "License :: OSI Approved :: Apache Software License",
53 | "Operating System :: OS Independent",
54 | "Programming Language :: Python :: 3.8",
55 | "Programming Language :: Python :: 3.9",
56 | "Programming Language :: Python :: 3.10",
57 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
58 | ],
59 | keywords="diffuzers diffusers",
60 | include_package_data=True,
61 | )
62 |
--------------------------------------------------------------------------------
/static/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abhishekkrthakur/diffuzers/71f710a1940c22637561809387b1c8622be1a48b/static/.keep
--------------------------------------------------------------------------------
/static/screenshot.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abhishekkrthakur/diffuzers/71f710a1940c22637561809387b1c8622be1a48b/static/screenshot.jpeg
--------------------------------------------------------------------------------
/static/screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abhishekkrthakur/diffuzers/71f710a1940c22637561809387b1c8622be1a48b/static/screenshot.png
--------------------------------------------------------------------------------
/static/screenshot_st.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abhishekkrthakur/diffuzers/71f710a1940c22637561809387b1c8622be1a48b/static/screenshot_st.png
--------------------------------------------------------------------------------