├── .github
└── FUNDING.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── Makefile
├── OTHER_LICENSES
├── README.md
├── diffuzers.ipynb
├── docs
├── Makefile
├── conf.py
├── index.rst
└── make.bat
├── requirements.txt
├── setup.cfg
├── setup.py
├── stablefusion
├── Home.py
├── __init__.py
├── api
│ ├── __init__.py
│ ├── main.py
│ ├── schemas.py
│ └── utils.py
├── cli
│ ├── __init__.py
│ ├── main.py
│ ├── run_api.py
│ └── run_app.py
├── data
│ ├── artists.txt
│ ├── flavors.txt
│ ├── mediums.txt
│ └── movements.txt
├── model_list.txt
├── models
│ ├── ckpt_interference
│ │ └── t
│ ├── cpkt_models
│ │ └── t
│ ├── diffusion_models
│ │ └── t.txt
│ ├── realesrgan
│ │ ├── inference_realesrgan.py
│ │ ├── options
│ │ │ ├── finetune_realesrgan_x4plus.yml
│ │ │ ├── finetune_realesrgan_x4plus_pairdata.yml
│ │ │ ├── train_realesrgan_x2plus.yml
│ │ │ ├── train_realesrgan_x4plus.yml
│ │ │ ├── train_realesrnet_x2plus.yml
│ │ │ └── train_realesrnet_x4plus.yml
│ │ ├── realesrgan
│ │ │ ├── __init__.py
│ │ │ ├── archs
│ │ │ │ ├── __init__.py
│ │ │ │ ├── discriminator_arch.py
│ │ │ │ └── srvgg_arch.py
│ │ │ ├── data
│ │ │ │ ├── __init__.py
│ │ │ │ ├── realesrgan_dataset.py
│ │ │ │ └── realesrgan_paired_dataset.py
│ │ │ ├── models
│ │ │ │ ├── __init__.py
│ │ │ │ ├── realesrgan_model.py
│ │ │ │ └── realesrnet_model.py
│ │ │ ├── train.py
│ │ │ └── utils.py
│ │ └── weights
│ │ │ └── README.md
│ └── safetensors_models
│ │ └── t
├── pages
│ ├── 10_Utilities.py
│ ├── 1_Text2Image.py
│ ├── 2_Image2Image.py
│ ├── 3_Inpainting.py
│ ├── 4_ControlNet.py
│ ├── 5_OpenPose_Editor.py
│ ├── 6_Textual Inversion.py
│ ├── 7_Upscaler.py
│ ├── 8_Convertor.py
│ └── 9_Train.py
├── scripts
│ ├── blip.py
│ ├── ckpt_to_diffusion.py
│ ├── clip_interrogator.py
│ ├── controlnet.py
│ ├── dreambooth
│ │ ├── README.md
│ │ ├── convert_weights_to_ckpt.py
│ │ ├── requirements_flax.txt
│ │ ├── train_dreambooth.py
│ │ ├── train_dreambooth_flax.py
│ │ └── train_dreambooth_lora.py
│ ├── gfp_gan.py
│ ├── gradio_app.py
│ ├── image_info.py
│ ├── img2img.py
│ ├── inpainting.py
│ ├── interrogator.py
│ ├── model_adding.py
│ ├── model_removing.py
│ ├── pose_editor.py
│ ├── pose_html.py
│ ├── safetensors_to_diffusion.py
│ ├── text2img.py
│ ├── textual_inversion.py
│ ├── upscaler.py
│ └── x2image.py
└── utils.py
└── static
├── .keep
├── Screenshot1.png
├── Screenshot2.png
├── Screenshot3.png
└── screenshot4.png
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | github: everydaycodings
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # local stuff
2 | .vscode/
3 | examples/.logs/
4 | *.bin
5 | *.csv
6 | input/
7 | logs/
8 | gfpgan/
9 | *.pkl
10 | *.pt
11 | *.pth
12 | abhishek/
13 | diffout/
14 | datasets/*
15 | experiments/*
16 | results/*
17 | tb_logger/*
18 | wandb/*
19 | tmp/*
20 | weights/*
21 | # Byte-compiled / optimized / DLL files
22 | __pycache__/
23 | *.py[cod]
24 | *$py.class
25 | .pypirc
26 | # C extensions
27 | *.so
28 | 7_Train.py
29 |
30 | # Distribution / packaging
31 | .Python
32 | build/
33 | develop-eggs/
34 | dist/
35 | downloads/
36 | eggs/
37 | .eggs/
38 | lib/
39 | lib64/
40 | parts/
41 | sdist/
42 | var/
43 | pip-wheel-metadata/
44 | share/python-wheels/
45 | *.egg-info/
46 | .installed.cfg
47 | *.egg
48 | MANIFEST
49 | src/
50 |
51 | # PyInstaller
52 | # Usually these files are written by a python script from a template
53 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
54 | *.manifest
55 | *.spec
56 |
57 | # Installer logs
58 | pip-log.txt
59 | pip-delete-this-directory.txt
60 |
61 | # Unit test / coverage reports
62 | htmlcov/
63 | .tox/
64 | .nox/
65 | .coverage
66 | .coverage.*
67 | .cache
68 | nosetests.xml
69 | coverage.xml
70 | *.cover
71 | *.py,cover
72 | .hypothesis/
73 | .pytest_cache/
74 |
75 | # Translations
76 | *.mo
77 | *.pot
78 |
79 | # Django stuff:
80 | *.log
81 | local_settings.py
82 | db.sqlite3
83 | db.sqlite3-journal
84 |
85 | # Flask stuff:
86 | instance/
87 | .webassets-cache
88 |
89 | # Scrapy stuff:
90 | .scrapy
91 |
92 | # Sphinx documentation
93 | docs/_build/
94 |
95 | # PyBuilder
96 | target/
97 |
98 | # Jupyter Notebook
99 | .ipynb_checkpoints
100 |
101 | # IPython
102 | profile_default/
103 | ipython_config.py
104 |
105 | # pyenv
106 | .python-version
107 |
108 | # pipenv
109 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
110 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
111 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
112 | # install all needed dependencies.
113 | #Pipfile.lock
114 |
115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
116 | __pypackages__/
117 |
118 | # Celery stuff
119 | celerybeat-schedule
120 | celerybeat.pid
121 |
122 | # SageMath parsed files
123 | *.sage.py
124 |
125 | # Environments
126 | .env
127 | .venv
128 | env/
129 | venv/
130 | ENV/
131 | env.bak/
132 | venv.bak/
133 |
134 | # Spyder project settings
135 | .spyderproject
136 | .spyproject
137 |
138 | # Rope project settings
139 | .ropeproject
140 |
141 | # mkdocs documentation
142 | /site
143 |
144 | # mypy
145 | .mypy_cache/
146 | .dmypy.json
147 | dmypy.json
148 |
149 | # Pyre type checker
150 | .pyre/
151 | test.ipynb
152 | # vscode
153 | .vscode/
154 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include stablefusion/pages/*.py
2 | include stablefusion/data/*.txt
3 | include stablefusion/*.txt
4 | recursive-include stablefusion/scripts *
5 | recursive-include stablefusion/models *
6 | include OTHER_LICENSES
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: quality style
2 |
3 | quality:
4 | python -m black --check --line-length 119 --target-version py38 .
5 | python -m isort --check-only .
6 | python -m flake8 --max-line-length 119 .
7 |
8 | style:
9 | python -m black --line-length 119 --target-version py38 .
10 | python -m isort .
--------------------------------------------------------------------------------
/OTHER_LICENSES:
--------------------------------------------------------------------------------
1 | clip_interrogator licence
2 | for diffuzes/clip_interrogator.py
3 | for diffuzes/data/*.txt
4 | -------------------------
5 | MIT License
6 |
7 | Copyright (c) 2022 pharmapsychotic
8 |
9 | Permission is hereby granted, free of charge, to any person obtaining a copy
10 | of this software and associated documentation files (the "Software"), to deal
11 | in the Software without restriction, including without limitation the rights
12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 | copies of the Software, and to permit persons to whom the Software is
14 | furnished to do so, subject to the following conditions:
15 |
16 | The above copyright notice and this permission notice shall be included in all
17 | copies or substantial portions of the Software.
18 |
19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 | SOFTWARE.
26 |
27 | blip licence
28 | for diffuzes/blip.py
29 | ------------
30 | Copyright (c) 2022, Salesforce.com, Inc.
31 | All rights reserved.
32 |
33 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
34 |
35 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
36 |
37 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
38 |
39 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
40 |
41 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # StableFusion
2 |
3 | A Web ui for **Stable Diffusion Models**.
4 |
5 | ## Update (Version 0.1.4)
6 |
7 | In this version, we've added two major features to the project:
8 |
9 | - **ControlNet Option**
10 | - **OpenPose Editor**
11 |
12 | We've also made the following improvements and bug fixes:
13 |
14 | - **Improved File Arrangement**: We've rearranged the files in the project to make it more organized and convenient for users to navigate.
15 | - **Fixed Random Seed Generator**: We've fixed a bug in the random seed generator that was causing incorrect results. This fix ensures that the project operates more accurately and reliably.
16 |
17 | We hope these updates will improve the overall functionality and user experience of the project. As always, please feel free to reach out to us if you have any questions or feedback.
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 | 
26 | 
27 |
28 | If something doesnt work as expected, or if you need some features which are not available, then create request using [github issues](https://github.com/NeuralRealm/StableFusion/issues)
29 |
30 |
31 | ## Features available in the app:
32 |
33 | - text to image
34 | - image to image
35 | - Inpainting
36 | - instruct pix2pix
37 | - textual inversion
38 | - ControlNet
39 | - OpenPose Editor
40 | - image info
41 | - Upscale Your Image
42 | - clip interrogator
43 | - Convert ckpt file to diffusers
44 | - Convert safetensors file to diffusers
45 | - Add your own diffusers model
46 | - more coming soon!
47 |
48 |
49 |
50 | ## Installation
51 |
52 | To install bleeding edge version of StableFusion, clone the repo and install it using pip.
53 |
54 | ```bash
55 | git clone https://github.com/NeuralRealm/StableFusion
56 | cd StableFusion
57 | pip install -e .
58 | ```
59 |
60 | Installation using pip:
61 |
62 | ```bash
63 | pip install stablefusion
64 | ```
65 |
66 | ## Usage
67 |
68 | ### Web App
69 | To run the web app, run the following command:
70 |
71 | For Local Host
72 | ```bash
73 | stablefusion app
74 | ```
75 | or
76 |
77 | For Public Shareable Link
78 | ```bash
79 | stablefusion app --port 10000 --ngrok_key YourNgrokAuthtoken --share
80 | ```
81 |
82 | ## All CLI Options for running the app:
83 |
84 | ```bash
85 | ❯ stablefusion app --help
86 | usage: stablefusion [] app [-h] [--output OUTPUT] [--share] [--port PORT] [--host HOST]
87 | [--device DEVICE] [--ngrok_key NGROK_KEY]
88 |
89 | ✨ Run stablefusion app
90 |
91 | optional arguments:
92 | -h, --help show this help message and exit
93 | --output OUTPUT Output path is optional, but if provided, all generations will automatically be saved to this
94 | path.
95 | --share Share the app
96 | --port PORT Port to run the app on
97 | --host HOST Host to run the app on
98 | --device DEVICE Device to use, e.g. cpu, cuda, cuda:0, mps (for m1 mac) etc.
99 | --ngrok_key NGROK_KEY
100 | Ngrok key to use for sharing the app. Only required if you want to share the app
101 | ```
102 |
103 |
104 | ## Using private models from huggingface hub
105 |
106 | If you want to use private models from huggingface hub, then you need to login using `huggingface-cli login` command.
107 |
108 | Note: You can also save your generations directly to huggingface hub if your output path points to a huggingface hub dataset repo and you have access to push to that repository. Thus, you will end up saving a lot of disk space.
109 |
110 | ## Acknowledgements
111 |
112 | I would like to express my gratitude to the following individuals and organizations for sharing their code, which formed the basis of the implementation used in this project:
113 |
114 | - [Tencent ARC](https://github.com/TencentARC) for their code for the [GFPGAN](https://github.com/TencentARC/GFPGAN) package, which was used for image super-resolution.
115 | - [LexKoin](https://github.com/LexKoin) for their code for the [Real-ESRGAN-UpScale](https://github.com/LexKoin/Real-ESRGAN-UpScale) package, which was used for image enhancement.
116 | - [Hugging Face](https://github.com/huggingface) for their code for the [diffusers](https://github.com/huggingface/diffusers) package, which was used for optimizing the model's parameters.
117 | - [Abhishek Thakur](https://github.com/abhishekkrthakur) for sharing his code for the [diffuzers](https://github.com/abhishekkrthakur/diffuzers) package, which was also used for optimizing the model's parameters.
118 |
119 | I am grateful for their contributions to the open source community, which made this project possible.
120 |
121 | ## Contributing
122 |
123 | StableFusion is an open-source project, and we welcome contributions from the community. Whether you're a developer, designer, or user, there are many ways you can help make this project better. Here are a few ways you can get involved:
124 |
125 | - **Report issues:** If you find a bug or have a feature request, please open an issue on our [GitHub repository](https://github.com/NeuralRealm/StableFusion/issues). We appreciate detailed bug reports and constructive feedback.
126 | - **Submit pull requests:** If you're interested in contributing code, we welcome pull requests for bug fixes, new features, and documentation improvements.
127 | - **Spread the word:** If you enjoy using StableFusion, please help us spread the word by sharing it with your friends, colleagues, and social media networks. We appreciate any support you can give us!
128 |
129 | ### We believe that open-source software is the future of technology, and we're excited to have you join us in making StableFusion a success. Thank you for your support!
130 |
--------------------------------------------------------------------------------
/diffuzers.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "!pip3 install stablefusion -q"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 2,
15 | "metadata": {},
16 | "outputs": [
17 | {
18 | "name": "stdout",
19 | "output_type": "stream",
20 | "text": [
21 | "/bin/bash: line 1: diffuzers: command not found\n"
22 | ]
23 | }
24 | ],
25 | "source": [
26 | "!stablefusion app --port 10000 --ngrok_key YOUR_NGROK_AUTHTOKEN --share"
27 | ]
28 | }
29 | ],
30 | "metadata": {
31 | "kernelspec": {
32 | "display_name": "diffuzers",
33 | "language": "python",
34 | "name": "python3"
35 | },
36 | "language_info": {
37 | "codemirror_mode": {
38 | "name": "ipython",
39 | "version": 3
40 | },
41 | "file_extension": ".py",
42 | "mimetype": "text/x-python",
43 | "name": "python",
44 | "nbconvert_exporter": "python",
45 | "pygments_lexer": "ipython3",
46 | "version": "3.9.16"
47 | },
48 | "orig_nbformat": 4,
49 | "vscode": {
50 | "interpreter": {
51 | "hash": "e667c4130092153eeed1faae4ad5e1fb640895671448a3853c00b657af26211d"
52 | }
53 | }
54 | },
55 | "nbformat": 4,
56 | "nbformat_minor": 2
57 | }
58 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | # import os
14 | # import sys
15 | # sys.path.insert(0, os.path.abspath('.'))
16 |
17 |
18 | # -- Project information -----------------------------------------------------
19 |
20 | project = "StableFusion: webapp and api for 🤗 diffusers"
21 | copyright = "2023, NeuralRealm"
22 | author = "NeuralRealm"
23 |
24 |
25 | # -- General configuration ---------------------------------------------------
26 |
27 | # Add any Sphinx extension module names here, as strings. They can be
28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
29 | # ones.
30 | extensions = []
31 |
32 | # Add any paths that contain templates here, relative to this directory.
33 | templates_path = ["_templates"]
34 |
35 | # List of patterns, relative to source directory, that match files and
36 | # directories to ignore when looking for source files.
37 | # This pattern also affects html_static_path and html_extra_path.
38 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
39 |
40 |
41 | # -- Options for HTML output -------------------------------------------------
42 |
43 | # The theme to use for HTML and HTML Help pages. See the documentation for
44 | # a list of builtin themes.
45 | #
46 | html_theme = "alabaster"
47 |
48 | # Add any paths that contain custom static files (such as style sheets) here,
49 | # relative to this directory. They are copied after the builtin static files,
50 | # so a file named "default.css" will overwrite the builtin "default.css".
51 | html_static_path = ["_static"]
52 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | Welcome to StableFusion' documentation!
2 | ======================================================================
3 |
4 | Diffuzers offers web app and also api for 🤗 diffusers. Installation is very simple.
5 | You can install via pip:
6 |
7 | .. code-block:: bash
8 |
9 | pip install stablefusion
10 |
11 | .. toctree::
12 | :maxdepth: 2
13 | :caption: Contents:
14 |
15 |
16 |
17 | Indices and tables
18 | ==================
19 |
20 | * :ref:`genindex`
21 | * :ref:`modindex`
22 | * :ref:`search`
23 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.https://www.sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.15.0
2 | basicsr>=1.4.2
3 | diffusers>=0.14.0
4 | facexlib>=0.2.5
5 | fairscale==0.4.4
6 | fastapi==0.88.0
7 | gfpgan>=1.3.7
8 | huggingface_hub>=0.11.1
9 | loguru==0.6.0
10 | open_clip_torch==2.9.1
11 | opencv-python
12 | protobuf==3.20
13 | pyngrok==5.2.1
14 | python-multipart==0.0.5
15 | realesrgan>=0.2.5
16 | streamlit==1.16.0
17 | streamlit-drawable-canvas==0.9.0
18 | st-clickable-images==0.0.3
19 | timm==0.4.12
20 | torch>=1.12.0
21 | torchvision>=0.13.0
22 | transformers==4.25.1
23 | uvicorn==0.15.0
24 | OmegaConf
25 | tqdl
26 | xformers
27 | bitsandbytes
28 | safetensors
29 | controlnet_aux
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | version = attr: stablefusion.__version__
3 |
4 | [isort]
5 | ensure_newline_before_comments = True
6 | force_grid_wrap = 0
7 | include_trailing_comma = True
8 | line_length = 119
9 | lines_after_imports = 2
10 | multi_line_output = 3
11 | use_parentheses = True
12 |
13 | [flake8]
14 | ignore = E203, E501, W503
15 | max-line-length = 119
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Lint as: python3
2 | # pylint: enable=line-too-long
3 |
4 | from setuptools import find_packages, setup
5 |
6 |
7 | with open("README.md" , encoding='utf-8') as f:
8 | long_description = f.read()
9 |
10 | QUALITY_REQUIRE = [
11 | "black~=22.0",
12 | "isort==5.8.0",
13 | "flake8==3.9.2",
14 | "mypy==0.901",
15 | ]
16 |
17 | TEST_REQUIRE = ["pytest", "pytest-cov"]
18 |
19 | EXTRAS_REQUIRE = {
20 | "dev": QUALITY_REQUIRE,
21 | "quality": QUALITY_REQUIRE,
22 | "test": TEST_REQUIRE,
23 | "docs": [
24 | "recommonmark",
25 | "sphinx==3.1.2",
26 | "sphinx-markdown-tables",
27 | "sphinx-rtd-theme==0.4.3",
28 | "sphinx-copybutton",
29 | ],
30 | }
31 |
32 | with open("requirements.txt") as f:
33 | INSTALL_REQUIRES = f.read().splitlines()
34 |
35 | setup(
36 | name="stablefusion",
37 | description="StableFusion",
38 | long_description=long_description,
39 | long_description_content_type="text/markdown",
40 | author="NeuralRealm",
41 | url="https://github.com/NeuralRealm/StableFusion",
42 | packages=find_packages("."),
43 | entry_points={"console_scripts": ["stablefusion=stablefusion.cli.main:main"]},
44 | install_requires=INSTALL_REQUIRES,
45 | extras_require=EXTRAS_REQUIRE,
46 | python_requires=">=3.7",
47 | classifiers=[
48 | "Intended Audience :: Developers",
49 | "Intended Audience :: Education",
50 | "Intended Audience :: Science/Research",
51 | "License :: OSI Approved :: Apache Software License",
52 | "Operating System :: OS Independent",
53 | "Programming Language :: Python :: 3.8",
54 | "Programming Language :: Python :: 3.9",
55 | "Programming Language :: Python :: 3.10",
56 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
57 | ],
58 | keywords="stablefusion",
59 | include_package_data=True,
60 | )
61 |
--------------------------------------------------------------------------------
/stablefusion/Home.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import streamlit as st
4 | from loguru import logger
5 |
6 | from stablefusion import utils
7 | from stablefusion.scripts.x2image import X2Image
8 | import ast
9 | import os
10 |
11 | def read_model_list():
12 |
13 | try:
14 | with open('{}/model_list.txt'.format(os.path.dirname(__file__)), 'r') as f:
15 | contents = f.read()
16 | except:
17 | with open('stablefusion/model_list.txt', 'r') as f:
18 | contents = f.read()
19 |
20 | model_list = ast.literal_eval(contents)
21 |
22 | return model_list
23 |
24 | def parse_args():
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument(
27 | "--output",
28 | type=str,
29 | required=False,
30 | default=None,
31 | help="Output path",
32 | )
33 | parser.add_argument(
34 | "--device",
35 | type=str,
36 | required=True,
37 | help="Device to use, e.g. cpu, cuda, cuda:0, mps etc.",
38 | )
39 | return parser.parse_args()
40 |
41 |
42 | def x2img_app():
43 | with st.form("x2img_model_form"):
44 | col1, col2 = st.columns(2)
45 | with col1:
46 | model = st.selectbox(
47 | "Which model do you want to use?",
48 | options=read_model_list(),
49 | )
50 | with col2:
51 | custom_pipeline = st.selectbox(
52 | "Custom pipeline",
53 | options=[
54 | "Vanilla",
55 | "Long Prompt Weighting",
56 | ],
57 | index=0 if st.session_state.get("x2img_custom_pipeline") in (None, "Vanilla") else 1,
58 | )
59 |
60 | with st.expander("Textual Inversion (Optional)"):
61 | token_identifier = st.text_input(
62 | "Token identifier",
63 | placeholder=""
64 | if st.session_state.get("textual_inversion_token_identifier") is None
65 | else st.session_state.textual_inversion_token_identifier,
66 | )
67 | embeddings = st.text_input(
68 | "Embeddings",
69 | placeholder="https://huggingface.co/sd-concepts-library/axe-tattoo/resolve/main/learned_embeds.bin"
70 | if st.session_state.get("textual_inversion_embeddings") is None
71 | else st.session_state.textual_inversion_embeddings,
72 | )
73 | submit = st.form_submit_button("Load model")
74 |
75 | if submit:
76 | st.session_state.x2img_model = model
77 | st.session_state.x2img_custom_pipeline = custom_pipeline
78 | st.session_state.textual_inversion_token_identifier = token_identifier
79 | st.session_state.textual_inversion_embeddings = embeddings
80 | cpipe = "lpw_stable_diffusion" if custom_pipeline == "Long Prompt Weighting" else None
81 | with st.spinner("Loading model..."):
82 | x2img = X2Image(
83 | model=model,
84 | device=st.session_state.device,
85 | output_path=st.session_state.output_path,
86 | custom_pipeline=cpipe,
87 | token_identifier=token_identifier,
88 | embeddings_url=embeddings,
89 | )
90 | st.session_state.x2img = x2img
91 | if "x2img" in st.session_state:
92 | st.write(f"Current model: {st.session_state.x2img}")
93 | st.session_state.x2img.app()
94 |
95 |
96 | def run_app():
97 | utils.create_base_page()
98 | x2img_app()
99 |
100 |
101 | if __name__ == "__main__":
102 | args = parse_args()
103 | logger.info(f"Args: {args}")
104 | logger.info(st.session_state)
105 | st.session_state.device = args.device
106 | st.session_state.output_path = args.output
107 | run_app()
108 |
--------------------------------------------------------------------------------
/stablefusion/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from loguru import logger
4 |
5 |
6 | logger.configure(handlers=[dict(sink=sys.stderr, format="> {level:<7} {message}")])
7 |
8 | __version__ = "0.1.4"
9 |
--------------------------------------------------------------------------------
/stablefusion/api/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuralRealm/StableFusion/9268d686f49b45194a2251c6728e74191f298654/stablefusion/api/__init__.py
--------------------------------------------------------------------------------
/stablefusion/api/main.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 |
4 | from fastapi import Depends, FastAPI, File, UploadFile
5 | from loguru import logger
6 | from PIL import Image
7 | from starlette.middleware.cors import CORSMiddleware
8 |
9 | from stablefusion.api.schemas import Img2ImgParams, ImgResponse, InpaintingParams, InstructPix2PixParams, Text2ImgParams
10 | from stablefusion.api.utils import convert_to_b64_list
11 | from stablefusion.scripts.inpainting import Inpainting
12 | from stablefusion.scripts.x2image import X2Image
13 |
14 |
15 | app = FastAPI(
16 | title="StableFusion api",
17 | license_info={
18 | "name": "Apache 2.0",
19 | "url": "https://www.apache.org/licenses/LICENSE-2.0.html",
20 | },
21 | )
22 | app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
23 |
24 |
25 | @app.on_event("startup")
26 | async def startup_event():
27 |
28 | x2img_model = os.environ.get("X2IMG_MODEL")
29 | x2img_pipeline = os.environ.get("X2IMG_PIPELINE")
30 | inpainting_model = os.environ.get("INPAINTING_MODEL")
31 | device = os.environ.get("DEVICE")
32 | output_path = os.environ.get("OUTPUT_PATH")
33 | ti_identifier = os.environ.get("TOKEN_IDENTIFIER", "")
34 | ti_embeddings_url = os.environ.get("TOKEN_EMBEDDINGS_URL", "")
35 | logger.info("@@@@@ Starting Diffuzes API @@@@@ ")
36 | logger.info(f"Text2Image Model: {x2img_model}")
37 | logger.info(f"Text2Image Pipeline: {x2img_pipeline if x2img_pipeline is not None else 'Vanilla'}")
38 | logger.info(f"Inpainting Model: {inpainting_model}")
39 | logger.info(f"Device: {device}")
40 | logger.info(f"Output Path: {output_path}")
41 | logger.info(f"Token Identifier: {ti_identifier}")
42 | logger.info(f"Token Embeddings URL: {ti_embeddings_url}")
43 |
44 | logger.info("Loading x2img model...")
45 | if x2img_model is not None:
46 | app.state.x2img_model = X2Image(
47 | model=x2img_model,
48 | device=device,
49 | output_path=output_path,
50 | custom_pipeline=x2img_pipeline,
51 | token_identifier=ti_identifier,
52 | embeddings_url=ti_embeddings_url,
53 | )
54 | else:
55 | app.state.x2img_model = None
56 | logger.info("Loading inpainting model...")
57 | if inpainting_model is not None:
58 | app.state.inpainting_model = Inpainting(
59 | model=inpainting_model,
60 | device=device,
61 | output_path=output_path,
62 | )
63 | logger.info("API is ready to use!")
64 |
65 |
66 | @app.post("/text2img")
67 | async def text2img(params: Text2ImgParams) -> ImgResponse:
68 | logger.info(f"Params: {params}")
69 | if app.state.x2img_model is None:
70 | return {"error": "x2img model is not loaded"}
71 |
72 | images, _ = app.state.x2img_model.text2img_generate(
73 | params.prompt,
74 | num_images=params.num_images,
75 | steps=params.steps,
76 | seed=params.seed,
77 | negative_prompt=params.negative_prompt,
78 | scheduler=params.scheduler,
79 | image_size=(params.image_height, params.image_width),
80 | guidance_scale=params.guidance_scale,
81 | )
82 | base64images = convert_to_b64_list(images)
83 | return ImgResponse(images=base64images, metadata=params.dict())
84 |
85 |
86 | @app.post("/img2img")
87 | async def img2img(params: Img2ImgParams = Depends(), image: UploadFile = File(...)) -> ImgResponse:
88 | if app.state.x2img_model is None:
89 | return {"error": "x2img model is not loaded"}
90 | image = Image.open(io.BytesIO(image.file.read()))
91 | images, _ = app.state.x2img_model.img2img_generate(
92 | image=image,
93 | prompt=params.prompt,
94 | negative_prompt=params.negative_prompt,
95 | num_images=params.num_images,
96 | steps=params.steps,
97 | seed=params.seed,
98 | scheduler=params.scheduler,
99 | guidance_scale=params.guidance_scale,
100 | strength=params.strength,
101 | )
102 | base64images = convert_to_b64_list(images)
103 | return ImgResponse(images=base64images, metadata=params.dict())
104 |
105 |
106 | @app.post("/instruct-pix2pix")
107 | async def instruct_pix2pix(params: InstructPix2PixParams = Depends(), image: UploadFile = File(...)) -> ImgResponse:
108 | if app.state.x2img_model is None:
109 | return {"error": "x2img model is not loaded"}
110 | image = Image.open(io.BytesIO(image.file.read()))
111 | images, _ = app.state.x2img_model.pix2pix_generate(
112 | image=image,
113 | prompt=params.prompt,
114 | negative_prompt=params.negative_prompt,
115 | num_images=params.num_images,
116 | steps=params.steps,
117 | seed=params.seed,
118 | scheduler=params.scheduler,
119 | guidance_scale=params.guidance_scale,
120 | image_guidance_scale=params.image_guidance_scale,
121 | )
122 | base64images = convert_to_b64_list(images)
123 | return ImgResponse(images=base64images, metadata=params.dict())
124 |
125 |
126 | @app.post("/inpainting")
127 | async def inpainting(
128 | params: InpaintingParams = Depends(), image: UploadFile = File(...), mask: UploadFile = File(...)
129 | ) -> ImgResponse:
130 | if app.state.inpainting_model is None:
131 | return {"error": "inpainting model is not loaded"}
132 | image = Image.open(io.BytesIO(image.file.read()))
133 | mask = Image.open(io.BytesIO(mask.file.read()))
134 | images, _ = app.state.inpainting_model.generate_image(
135 | image=image,
136 | mask=mask,
137 | prompt=params.prompt,
138 | negative_prompt=params.negative_prompt,
139 | scheduler=params.scheduler,
140 | height=params.image_height,
141 | width=params.image_width,
142 | num_images=params.num_images,
143 | guidance_scale=params.guidance_scale,
144 | steps=params.steps,
145 | seed=params.seed,
146 | )
147 | base64images = convert_to_b64_list(images)
148 | return ImgResponse(images=base64images, metadata=params.dict())
149 |
150 |
151 | @app.get("/")
152 | def read_root():
153 | return {"Hello": "World"}
154 |
--------------------------------------------------------------------------------
/stablefusion/api/schemas.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 |
3 | from pydantic import BaseModel, Field
4 |
5 |
6 | class Text2ImgParams(BaseModel):
7 | prompt: str = Field(..., description="Text prompt for the model")
8 | negative_prompt: str = Field(None, description="Negative text prompt for the model")
9 | scheduler: str = Field("EulerAncestralDiscreteScheduler", description="Scheduler to use for the model")
10 | image_height: int = Field(512, description="Image height")
11 | image_width: int = Field(512, description="Image width")
12 | num_images: int = Field(1, description="Number of images to generate")
13 | guidance_scale: float = Field(7, description="Guidance scale")
14 | steps: int = Field(50, description="Number of steps to run the model for")
15 | seed: int = Field(42, description="Seed for the model")
16 |
17 |
18 | class Img2ImgParams(BaseModel):
19 | prompt: str = Field(..., description="Text prompt for the model")
20 | negative_prompt: str = Field(None, description="Negative text prompt for the model")
21 | scheduler: str = Field("EulerAncestralDiscreteScheduler", description="Scheduler to use for the model")
22 | strength: float = Field(0.7, description="Strength")
23 | num_images: int = Field(1, description="Number of images to generate")
24 | guidance_scale: float = Field(7, description="Guidance scale")
25 | steps: int = Field(50, description="Number of steps to run the model for")
26 | seed: int = Field(42, description="Seed for the model")
27 |
28 |
29 | class InstructPix2PixParams(BaseModel):
30 | prompt: str = Field(..., description="Text prompt for the model")
31 | negative_prompt: str = Field(None, description="Negative text prompt for the model")
32 | scheduler: str = Field("EulerAncestralDiscreteScheduler", description="Scheduler to use for the model")
33 | num_images: int = Field(1, description="Number of images to generate")
34 | guidance_scale: float = Field(7, description="Guidance scale")
35 | image_guidance_scale: float = Field(1.5, description="Image guidance scale")
36 | steps: int = Field(50, description="Number of steps to run the model for")
37 | seed: int = Field(42, description="Seed for the model")
38 |
39 |
40 | class ImgResponse(BaseModel):
41 | images: List[str] = Field(..., description="List of images in base64 format")
42 | metadata: Dict = Field(..., description="Metadata")
43 |
44 |
45 | class InpaintingParams(BaseModel):
46 | prompt: str = Field(..., description="Text prompt for the model")
47 | negative_prompt: str = Field(None, description="Negative text prompt for the model")
48 | scheduler: str = Field("EulerAncestralDiscreteScheduler", description="Scheduler to use for the model")
49 | image_height: int = Field(512, description="Image height")
50 | image_width: int = Field(512, description="Image width")
51 | num_images: int = Field(1, description="Number of images to generate")
52 | guidance_scale: float = Field(7, description="Guidance scale")
53 | steps: int = Field(50, description="Number of steps to run the model for")
54 | seed: int = Field(42, description="Seed for the model")
55 |
--------------------------------------------------------------------------------
/stablefusion/api/utils.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import io
3 |
4 |
5 | def convert_to_b64_list(images):
6 | base64images = []
7 | for image in images:
8 | buf = io.BytesIO()
9 | image.save(buf, format="PNG")
10 | byte_im = base64.b64encode(buf.getvalue())
11 | base64images.append(byte_im)
12 | return base64images
13 |
--------------------------------------------------------------------------------
/stablefusion/cli/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from argparse import ArgumentParser
3 |
4 |
5 | class BaseStableFusionCommand(ABC):
6 | @staticmethod
7 | @abstractmethod
8 | def register_subcommand(parser: ArgumentParser):
9 | raise NotImplementedError()
10 |
11 | @abstractmethod
12 | def run(self):
13 | raise NotImplementedError()
14 |
--------------------------------------------------------------------------------
/stablefusion/cli/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from .. import __version__
4 | from .run_api import RunStableFusionAPICommand
5 | from .run_app import RunStableFusionAppCommand
6 |
7 |
8 | def main():
9 | parser = argparse.ArgumentParser(
10 | "StableFusion CLI",
11 | usage="stablefusion []",
12 | epilog="For more information about a command, run: `stablefusion --help`",
13 | )
14 | parser.add_argument("--version", "-v", help="Display stablefusion version", action="store_true")
15 | commands_parser = parser.add_subparsers(help="commands")
16 |
17 | # Register commands
18 | RunStableFusionAppCommand.register_subcommand(commands_parser)
19 | #RunStableFusionAPICommand.register_subcommand(commands_parser)
20 |
21 | args = parser.parse_args()
22 |
23 | if args.version:
24 | print(__version__)
25 | exit(0)
26 |
27 | if not hasattr(args, "func"):
28 | parser.print_help()
29 | exit(1)
30 |
31 | command = args.func(args)
32 | command.run()
33 |
34 |
35 | if __name__ == "__main__":
36 | main()
37 |
--------------------------------------------------------------------------------
/stablefusion/cli/run_api.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from argparse import ArgumentParser
3 |
4 | import torch
5 |
6 | from . import BaseStableFusionCommand
7 |
8 |
9 | def run_api_command_factory(args):
10 | return RunStableFusionAPICommand(
11 | args.output,
12 | args.port,
13 | args.host,
14 | args.device,
15 | args.workers,
16 | args.ssl_certfile,
17 | args.ssl_keyfile,
18 | )
19 |
20 |
21 | class RunStableFusionAPICommand(BaseStableFusionCommand):
22 | @staticmethod
23 | def register_subcommand(parser: ArgumentParser):
24 | run_api_parser = parser.add_parser(
25 | "api",
26 | description="✨ Run StableFusion api",
27 | )
28 | run_api_parser.add_argument(
29 | "--output",
30 | type=str,
31 | required=False,
32 | help="Output path is optional, but if provided, all generations will automatically be saved to this path.",
33 | )
34 | run_api_parser.add_argument(
35 | "--port",
36 | type=int,
37 | default=10000,
38 | help="Port to run the app on",
39 | required=False,
40 | )
41 | run_api_parser.add_argument(
42 | "--host",
43 | type=str,
44 | default="127.0.0.1",
45 | help="Host to run the app on",
46 | required=False,
47 | )
48 | run_api_parser.add_argument(
49 | "--device",
50 | type=str,
51 | required=False,
52 | help="Device to use, e.g. cpu, cuda, cuda:0, mps (for m1 mac) etc.",
53 | )
54 | run_api_parser.add_argument(
55 | "--workers",
56 | type=int,
57 | required=False,
58 | default=1,
59 | help="Number of workers to use",
60 | )
61 | run_api_parser.add_argument(
62 | "--ssl_certfile",
63 | type=str,
64 | required=False,
65 | help="the path to your ssl cert",
66 | )
67 | run_api_parser.add_argument(
68 | "--ssl_keyfile",
69 | type=str,
70 | required=False,
71 | help="the path to your ssl key",
72 | )
73 | run_api_parser.set_defaults(func=run_api_command_factory)
74 |
75 | def __init__(self, output, port, host, device, workers, ssl_certfile, ssl_keyfile):
76 | self.output = output
77 | self.port = port
78 | self.host = host
79 | self.device = device
80 | self.workers = workers
81 | self.ssl_certfile = ssl_certfile
82 | self.ssl_keyfile = ssl_keyfile
83 |
84 | if self.device is None:
85 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
86 |
87 | self.port = str(self.port)
88 | self.workers = str(self.workers)
89 |
90 | def run(self):
91 | cmd = [
92 | "uvicorn",
93 | "stablefusion.api.main:app",
94 | "--host",
95 | self.host,
96 | "--port",
97 | self.port,
98 | "--workers",
99 | self.workers,
100 | ]
101 | if self.ssl_certfile is not None:
102 | cmd.extend(["--ssl-certfile", self.ssl_certfile])
103 | if self.ssl_keyfile is not None:
104 | cmd.extend(["--ssl-keyfile", self.ssl_keyfile])
105 |
106 | proc = subprocess.Popen(
107 | cmd,
108 | stdout=subprocess.PIPE,
109 | stderr=subprocess.STDOUT,
110 | shell=False,
111 | universal_newlines=True,
112 | bufsize=1,
113 | )
114 | with proc as p:
115 | try:
116 | for line in p.stdout:
117 | print(line, end="")
118 | except KeyboardInterrupt:
119 | print("Killing api")
120 | p.kill()
121 | p.wait()
122 | raise
123 |
--------------------------------------------------------------------------------
/stablefusion/cli/run_app.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from argparse import ArgumentParser
3 |
4 | import torch
5 | from pyngrok import ngrok
6 |
7 | from . import BaseStableFusionCommand
8 |
9 |
10 | def run_app_command_factory(args):
11 | return RunStableFusionAppCommand(
12 | args.output,
13 | args.share,
14 | args.port,
15 | args.host,
16 | args.device,
17 | args.ngrok_key,
18 | )
19 |
20 |
21 | class RunStableFusionAppCommand(BaseStableFusionCommand):
22 | @staticmethod
23 | def register_subcommand(parser: ArgumentParser):
24 | run_app_parser = parser.add_parser(
25 | "app",
26 | description="✨ Run stablefusion app",
27 | )
28 | run_app_parser.add_argument(
29 | "--output",
30 | type=str,
31 | required=False,
32 | help="Output path is optional, but if provided, all generations will automatically be saved to this path.",
33 | )
34 | run_app_parser.add_argument(
35 | "--share",
36 | action="store_true",
37 | help="Share the app",
38 | )
39 | run_app_parser.add_argument(
40 | "--port",
41 | type=int,
42 | default=10000,
43 | help="Port to run the app on",
44 | required=False,
45 | )
46 | run_app_parser.add_argument(
47 | "--host",
48 | type=str,
49 | default="127.0.0.1",
50 | help="Host to run the app on",
51 | required=False,
52 | )
53 | run_app_parser.add_argument(
54 | "--device",
55 | type=str,
56 | required=False,
57 | help="Device to use, e.g. cpu, cuda, cuda:0, mps (for m1 mac) etc.",
58 | )
59 | run_app_parser.add_argument(
60 | "--ngrok_key",
61 | type=str,
62 | required=False,
63 | help="Ngrok key to use for sharing the app. Only required if you want to share the app",
64 | )
65 |
66 | run_app_parser.set_defaults(func=run_app_command_factory)
67 |
68 | def __init__(self, output, share, port, host, device, ngrok_key):
69 | self.output = output
70 | self.share = share
71 | self.port = port
72 | self.host = host
73 | self.device = device
74 | self.ngrok_key = ngrok_key
75 |
76 | if self.device is None:
77 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
78 |
79 | if self.share:
80 | if self.ngrok_key is None:
81 | raise ValueError(
82 | "ngrok key is required if you want to share the app. Get it for free from https://dashboard.ngrok.com/get-started/your-authtoken"
83 | )
84 |
85 | def run(self):
86 | # from ..app import stablefusion
87 |
88 | # print(self.share)
89 | # app = stablefusion(self.model, self.output).app()
90 | # app.launch(show_api=False, share=self.share, server_port=self.port, server_name=self.host)
91 | import os
92 |
93 | dirname = os.path.dirname(__file__)
94 | filename = os.path.join(dirname, "..", "Home.py")
95 | cmd = [
96 | "streamlit",
97 | "run",
98 | filename,
99 | "--browser.gatherUsageStats",
100 | "false",
101 | "--browser.serverAddress",
102 | self.host,
103 | "--server.port",
104 | str(self.port),
105 | "--theme.base",
106 | "dark",
107 | "--",
108 | "--device",
109 | self.device,
110 | ]
111 | if self.output is not None:
112 | cmd.extend(["--output", self.output])
113 |
114 | if self.share:
115 | ngrok.set_auth_token(self.ngrok_key)
116 | public_url = ngrok.connect(self.port).public_url
117 | print(f"Sharing app at {public_url}")
118 |
119 | proc = subprocess.Popen(
120 | cmd,
121 | stdout=subprocess.PIPE,
122 | stderr=subprocess.STDOUT,
123 | shell=False,
124 | universal_newlines=True,
125 | bufsize=1,
126 | )
127 | with proc as p:
128 | try:
129 | for line in p.stdout:
130 | print(line, end="")
131 | except KeyboardInterrupt:
132 | print("Killing streamlit app")
133 | p.kill()
134 | if self.share:
135 | print("Killing ngrok tunnel")
136 | ngrok.kill()
137 | p.wait()
138 | raise
139 |
--------------------------------------------------------------------------------
/stablefusion/data/mediums.txt:
--------------------------------------------------------------------------------
1 | a 3D render
2 | a black and white photo
3 | a bronze sculpture
4 | a cartoon
5 | a cave painting
6 | a character portrait
7 | a charcoal drawing
8 | a child's drawing
9 | a color pencil sketch
10 | a colorized photo
11 | a comic book panel
12 | a computer rendering
13 | a cross stitch
14 | a cubist painting
15 | a detailed drawing
16 | a detailed matte painting
17 | a detailed painting
18 | a diagram
19 | a digital painting
20 | a digital rendering
21 | a drawing
22 | a fine art painting
23 | a flemish Baroque
24 | a gouache
25 | a hologram
26 | a hyperrealistic painting
27 | a jigsaw puzzle
28 | a low poly render
29 | a macro photograph
30 | a manga drawing
31 | a marble sculpture
32 | a matte painting
33 | a microscopic photo
34 | a mid-nineteenth century engraving
35 | a minimalist painting
36 | a mosaic
37 | a painting
38 | a pastel
39 | a pencil sketch
40 | a photo
41 | a photocopy
42 | a photorealistic painting
43 | a picture
44 | a pointillism painting
45 | a polaroid photo
46 | a pop art painting
47 | a portrait
48 | a poster
49 | a raytraced image
50 | a renaissance painting
51 | a screenprint
52 | a screenshot
53 | a silk screen
54 | a sketch
55 | a statue
56 | a still life
57 | a stipple
58 | a stock photo
59 | a storybook illustration
60 | a surrealist painting
61 | a surrealist sculpture
62 | a tattoo
63 | a tilt shift photo
64 | a watercolor painting
65 | a wireframe diagram
66 | a woodcut
67 | an abstract drawing
68 | an abstract painting
69 | an abstract sculpture
70 | an acrylic painting
71 | an airbrush painting
72 | an album cover
73 | an ambient occlusion render
74 | an anime drawing
75 | an art deco painting
76 | an art deco sculpture
77 | an engraving
78 | an etching
79 | an illustration of
80 | an impressionist painting
81 | an ink drawing
82 | an oil on canvas painting
83 | an oil painting
84 | an ultrafine detailed painting
85 | chalk art
86 | computer graphics
87 | concept art
88 | cyberpunk art
89 | digital art
90 | egyptian art
91 | graffiti art
92 | lineart
93 | pixel art
94 | poster art
95 | vector art
96 |
--------------------------------------------------------------------------------
/stablefusion/data/movements.txt:
--------------------------------------------------------------------------------
1 | abstract art
2 | abstract expressionism
3 | abstract illusionism
4 | academic art
5 | action painting
6 | aestheticism
7 | afrofuturism
8 | altermodern
9 | american barbizon school
10 | american impressionism
11 | american realism
12 | american romanticism
13 | american scene painting
14 | analytical art
15 | antipodeans
16 | arabesque
17 | arbeitsrat für kunst
18 | art & language
19 | art brut
20 | art deco
21 | art informel
22 | art nouveau
23 | art photography
24 | arte povera
25 | arts and crafts movement
26 | ascii art
27 | ashcan school
28 | assemblage
29 | australian tonalism
30 | auto-destructive art
31 | barbizon school
32 | baroque
33 | bauhaus
34 | bengal school of art
35 | berlin secession
36 | black arts movement
37 | brutalism
38 | classical realism
39 | cloisonnism
40 | cobra
41 | color field
42 | computer art
43 | conceptual art
44 | concrete art
45 | constructivism
46 | context art
47 | crayon art
48 | crystal cubism
49 | cubism
50 | cubo-futurism
51 | cynical realism
52 | dada
53 | danube school
54 | dau-al-set
55 | de stijl
56 | deconstructivism
57 | digital art
58 | ecological art
59 | environmental art
60 | excessivism
61 | expressionism
62 | fantastic realism
63 | fantasy art
64 | fauvism
65 | feminist art
66 | figuration libre
67 | figurative art
68 | figurativism
69 | fine art
70 | fluxus
71 | folk art
72 | funk art
73 | furry art
74 | futurism
75 | generative art
76 | geometric abstract art
77 | german romanticism
78 | gothic art
79 | graffiti
80 | gutai group
81 | happening
82 | harlem renaissance
83 | heidelberg school
84 | holography
85 | hudson river school
86 | hurufiyya
87 | hypermodernism
88 | hyperrealism
89 | impressionism
90 | incoherents
91 | institutional critique
92 | interactive art
93 | international gothic
94 | international typographic style
95 | kinetic art
96 | kinetic pointillism
97 | kitsch movement
98 | land art
99 | les automatistes
100 | les nabis
101 | letterism
102 | light and space
103 | lowbrow
104 | lyco art
105 | lyrical abstraction
106 | magic realism
107 | magical realism
108 | mail art
109 | mannerism
110 | massurrealism
111 | maximalism
112 | metaphysical painting
113 | mingei
114 | minimalism
115 | modern european ink painting
116 | modernism
117 | modular constructivism
118 | naive art
119 | naturalism
120 | neo-dada
121 | neo-expressionism
122 | neo-fauvism
123 | neo-figurative
124 | neo-primitivism
125 | neo-romanticism
126 | neoclassicism
127 | neogeo
128 | neoism
129 | neoplasticism
130 | net art
131 | new objectivity
132 | new sculpture
133 | northwest school
134 | nuclear art
135 | objective abstraction
136 | op art
137 | optical illusion
138 | orphism
139 | panfuturism
140 | paris school
141 | photorealism
142 | pixel art
143 | plasticien
144 | plein air
145 | pointillism
146 | pop art
147 | pop surrealism
148 | post-impressionism
149 | postminimalism
150 | pre-raphaelitism
151 | precisionism
152 | primitivism
153 | private press
154 | process art
155 | psychedelic art
156 | purism
157 | qajar art
158 | quito school
159 | rasquache
160 | rayonism
161 | realism
162 | regionalism
163 | remodernism
164 | renaissance
165 | retrofuturism
166 | rococo
167 | romanesque
168 | romanticism
169 | samikshavad
170 | serial art
171 | shin hanga
172 | shock art
173 | socialist realism
174 | sots art
175 | space art
176 | street art
177 | stuckism
178 | sumatraism
179 | superflat
180 | suprematism
181 | surrealism
182 | symbolism
183 | synchromism
184 | synthetism
185 | sōsaku hanga
186 | tachisme
187 | temporary art
188 | tonalism
189 | toyism
190 | transgressive art
191 | ukiyo-e
192 | underground comix
193 | unilalianism
194 | vancouver school
195 | vanitas
196 | verdadism
197 | video art
198 | viennese actionism
199 | visual art
200 | vorticism
201 |
--------------------------------------------------------------------------------
/stablefusion/model_list.txt:
--------------------------------------------------------------------------------
1 | ['runwayml/stable-diffusion-v1-5', 'stabilityai/stable-diffusion-2-1', 'Linaqruf/anything-v3.0', 'Envvi/Inkpunk-Diffusion', 'prompthero/openjourney', 'darkstorm2150/Protogen_v2.2_Official_Release', 'darkstorm2150/Protogen_x3.4_Official_Release']
--------------------------------------------------------------------------------
/stablefusion/models/ckpt_interference/t:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuralRealm/StableFusion/9268d686f49b45194a2251c6728e74191f298654/stablefusion/models/ckpt_interference/t
--------------------------------------------------------------------------------
/stablefusion/models/cpkt_models/t:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuralRealm/StableFusion/9268d686f49b45194a2251c6728e74191f298654/stablefusion/models/cpkt_models/t
--------------------------------------------------------------------------------
/stablefusion/models/diffusion_models/t.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuralRealm/StableFusion/9268d686f49b45194a2251c6728e74191f298654/stablefusion/models/diffusion_models/t.txt
--------------------------------------------------------------------------------
/stablefusion/models/realesrgan/inference_realesrgan.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import glob
4 | import os
5 | from basicsr.archs.rrdbnet_arch import RRDBNet
6 | from basicsr.utils.download_util import load_file_from_url
7 | import tempfile
8 | from realesrgan import RealESRGANer
9 | from realesrgan.archs.srvgg_arch import SRVGGNetCompact
10 | import numpy as np
11 | from stablefusion import utils
12 | import streamlit as st
13 | from PIL import Image
14 | from io import BytesIO
15 | import base64
16 | import datetime
17 |
18 | def main(model_name, denoise_strength, tile, tile_pad, pre_pad, fp32, gpu_id, face_enhance, outscale, input_image, model_path):
19 | # determine models according to model names
20 | model_name = model_name.split('.')[0]
21 | if model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
22 | model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
23 | netscale = 4
24 | file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
25 | elif model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
26 | model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
27 | netscale = 4
28 | file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
29 | elif model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
30 | model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
31 | netscale = 4
32 | file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
33 | elif model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
34 | model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
35 | netscale = 2
36 | file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
37 | elif model_name == 'realesr-animevideov3': # x4 VGG-style model (XS size)
38 | model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
39 | netscale = 4
40 | file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth']
41 | elif model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
42 | model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
43 | netscale = 4
44 | file_url = [
45 | 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
46 | 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
47 | ]
48 |
49 | # determine model paths
50 | if model_path is not None:
51 | model_path = model_path
52 | else:
53 | model_path = os.path.join('weights', model_name + '.pth')
54 | if not os.path.isfile(model_path):
55 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
56 | for url in file_url:
57 | # model_path will be updated
58 | model_path = load_file_from_url(
59 | url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
60 |
61 | # use dni to control the denoise strength
62 | dni_weight = None
63 | if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
64 | wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
65 | model_path = [model_path, wdn_model_path]
66 | dni_weight = [denoise_strength, 1 - denoise_strength]
67 |
68 | # restorer
69 | upsampler = RealESRGANer(
70 | scale=netscale,
71 | model_path=model_path,
72 | dni_weight=dni_weight,
73 | model=model,
74 | tile=tile,
75 | tile_pad=tile_pad,
76 | pre_pad=pre_pad,
77 | half=not fp32,
78 | gpu_id=gpu_id)
79 |
80 | if face_enhance: # Use GFPGAN for face enhancement
81 | from gfpgan import GFPGANer
82 | face_enhancer = GFPGANer(
83 | model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
84 | upscale=outscale,
85 | arch='clean',
86 | channel_multiplier=2,
87 | bg_upsampler=upsampler)
88 |
89 |
90 |
91 | image_array = np.asarray(bytearray(input_image.read()), dtype=np.uint8)
92 |
93 | img = cv2.imdecode(image_array, cv2.IMREAD_UNCHANGED)
94 |
95 |
96 | try:
97 | if face_enhance:
98 | _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
99 | else:
100 | output, _ = upsampler.enhance(img, outscale=outscale)
101 |
102 |
103 | img = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
104 | st.image(img)
105 |
106 | output_img = Image.fromarray(img)
107 | buffered = BytesIO()
108 | output_img.save(buffered, format="PNG")
109 | img_str = base64.b64encode(buffered.getvalue()).decode()
110 | now = datetime.datetime.now()
111 | formatted_date_time = now.strftime("%Y-%m-%d_%H_%M_%S")
112 | href = f'Download Image
'
113 | st.markdown(href, unsafe_allow_html=True)
114 |
115 |
116 | except RuntimeError as error:
117 | print('Error', error)
118 | print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
119 |
120 |
121 |
122 | if __name__ == '__main__':
123 | main()
124 |
--------------------------------------------------------------------------------
/stablefusion/models/realesrgan/options/finetune_realesrgan_x4plus.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: finetune_RealESRGANx4plus_400k
3 | model_type: RealESRGANModel
4 | scale: 4
5 | num_gpu: auto
6 | manual_seed: 0
7 |
8 | # ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
9 | # USM the ground-truth
10 | l1_gt_usm: True
11 | percep_gt_usm: True
12 | gan_gt_usm: False
13 |
14 | # the first degradation process
15 | resize_prob: [0.2, 0.7, 0.1] # up, down, keep
16 | resize_range: [0.15, 1.5]
17 | gaussian_noise_prob: 0.5
18 | noise_range: [1, 30]
19 | poisson_scale_range: [0.05, 3]
20 | gray_noise_prob: 0.4
21 | jpeg_range: [30, 95]
22 |
23 | # the second degradation process
24 | second_blur_prob: 0.8
25 | resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
26 | resize_range2: [0.3, 1.2]
27 | gaussian_noise_prob2: 0.5
28 | noise_range2: [1, 25]
29 | poisson_scale_range2: [0.05, 2.5]
30 | gray_noise_prob2: 0.4
31 | jpeg_range2: [30, 95]
32 |
33 | gt_size: 256
34 | queue_size: 180
35 |
36 | # dataset and data loader settings
37 | datasets:
38 | train:
39 | name: DF2K+OST
40 | type: RealESRGANDataset
41 | dataroot_gt: datasets/DF2K
42 | meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
43 | io_backend:
44 | type: disk
45 |
46 | blur_kernel_size: 21
47 | kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
48 | kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
49 | sinc_prob: 0.1
50 | blur_sigma: [0.2, 3]
51 | betag_range: [0.5, 4]
52 | betap_range: [1, 2]
53 |
54 | blur_kernel_size2: 21
55 | kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
56 | kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
57 | sinc_prob2: 0.1
58 | blur_sigma2: [0.2, 1.5]
59 | betag_range2: [0.5, 4]
60 | betap_range2: [1, 2]
61 |
62 | final_sinc_prob: 0.8
63 |
64 | gt_size: 256
65 | use_hflip: True
66 | use_rot: False
67 |
68 | # data loader
69 | use_shuffle: true
70 | num_worker_per_gpu: 5
71 | batch_size_per_gpu: 12
72 | dataset_enlarge_ratio: 1
73 | prefetch_mode: ~
74 |
75 | # Uncomment these for validation
76 | # val:
77 | # name: validation
78 | # type: PairedImageDataset
79 | # dataroot_gt: path_to_gt
80 | # dataroot_lq: path_to_lq
81 | # io_backend:
82 | # type: disk
83 |
84 | # network structures
85 | network_g:
86 | type: RRDBNet
87 | num_in_ch: 3
88 | num_out_ch: 3
89 | num_feat: 64
90 | num_block: 23
91 | num_grow_ch: 32
92 |
93 | network_d:
94 | type: UNetDiscriminatorSN
95 | num_in_ch: 3
96 | num_feat: 64
97 | skip_connection: True
98 |
99 | # path
100 | path:
101 | # use the pre-trained Real-ESRNet model
102 | pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
103 | param_key_g: params_ema
104 | strict_load_g: true
105 | pretrain_network_d: experiments/pretrained_models/RealESRGAN_x4plus_netD.pth
106 | param_key_d: params
107 | strict_load_d: true
108 | resume_state: ~
109 |
110 | # training settings
111 | train:
112 | ema_decay: 0.999
113 | optim_g:
114 | type: Adam
115 | lr: !!float 1e-4
116 | weight_decay: 0
117 | betas: [0.9, 0.99]
118 | optim_d:
119 | type: Adam
120 | lr: !!float 1e-4
121 | weight_decay: 0
122 | betas: [0.9, 0.99]
123 |
124 | scheduler:
125 | type: MultiStepLR
126 | milestones: [400000]
127 | gamma: 0.5
128 |
129 | total_iter: 400000
130 | warmup_iter: -1 # no warm up
131 |
132 | # losses
133 | pixel_opt:
134 | type: L1Loss
135 | loss_weight: 1.0
136 | reduction: mean
137 | # perceptual loss (content and style losses)
138 | perceptual_opt:
139 | type: PerceptualLoss
140 | layer_weights:
141 | # before relu
142 | 'conv1_2': 0.1
143 | 'conv2_2': 0.1
144 | 'conv3_4': 1
145 | 'conv4_4': 1
146 | 'conv5_4': 1
147 | vgg_type: vgg19
148 | use_input_norm: true
149 | perceptual_weight: !!float 1.0
150 | style_weight: 0
151 | range_norm: false
152 | criterion: l1
153 | # gan loss
154 | gan_opt:
155 | type: GANLoss
156 | gan_type: vanilla
157 | real_label_val: 1.0
158 | fake_label_val: 0.0
159 | loss_weight: !!float 1e-1
160 |
161 | net_d_iters: 1
162 | net_d_init_iters: 0
163 |
164 | # Uncomment these for validation
165 | # validation settings
166 | # val:
167 | # val_freq: !!float 5e3
168 | # save_img: True
169 |
170 | # metrics:
171 | # psnr: # metric name
172 | # type: calculate_psnr
173 | # crop_border: 4
174 | # test_y_channel: false
175 |
176 | # logging settings
177 | logger:
178 | print_freq: 100
179 | save_checkpoint_freq: !!float 5e3
180 | use_tb_logger: true
181 | wandb:
182 | project: ~
183 | resume_id: ~
184 |
185 | # dist training settings
186 | dist_params:
187 | backend: nccl
188 | port: 29500
189 |
--------------------------------------------------------------------------------
/stablefusion/models/realesrgan/options/finetune_realesrgan_x4plus_pairdata.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: finetune_RealESRGANx4plus_400k_pairdata
3 | model_type: RealESRGANModel
4 | scale: 4
5 | num_gpu: auto
6 | manual_seed: 0
7 |
8 | # USM the ground-truth
9 | l1_gt_usm: True
10 | percep_gt_usm: True
11 | gan_gt_usm: False
12 |
13 | high_order_degradation: False # do not use the high-order degradation generation process
14 |
15 | # dataset and data loader settings
16 | datasets:
17 | train:
18 | name: DIV2K
19 | type: RealESRGANPairedDataset
20 | dataroot_gt: datasets/DF2K
21 | dataroot_lq: datasets/DF2K
22 | meta_info: datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt
23 | io_backend:
24 | type: disk
25 |
26 | gt_size: 256
27 | use_hflip: True
28 | use_rot: False
29 |
30 | # data loader
31 | use_shuffle: true
32 | num_worker_per_gpu: 5
33 | batch_size_per_gpu: 12
34 | dataset_enlarge_ratio: 1
35 | prefetch_mode: ~
36 |
37 | # Uncomment these for validation
38 | # val:
39 | # name: validation
40 | # type: PairedImageDataset
41 | # dataroot_gt: path_to_gt
42 | # dataroot_lq: path_to_lq
43 | # io_backend:
44 | # type: disk
45 |
46 | # network structures
47 | network_g:
48 | type: RRDBNet
49 | num_in_ch: 3
50 | num_out_ch: 3
51 | num_feat: 64
52 | num_block: 23
53 | num_grow_ch: 32
54 |
55 | network_d:
56 | type: UNetDiscriminatorSN
57 | num_in_ch: 3
58 | num_feat: 64
59 | skip_connection: True
60 |
61 | # path
62 | path:
63 | # use the pre-trained Real-ESRNet model
64 | pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
65 | param_key_g: params_ema
66 | strict_load_g: true
67 | pretrain_network_d: experiments/pretrained_models/RealESRGAN_x4plus_netD.pth
68 | param_key_d: params
69 | strict_load_d: true
70 | resume_state: ~
71 |
72 | # training settings
73 | train:
74 | ema_decay: 0.999
75 | optim_g:
76 | type: Adam
77 | lr: !!float 1e-4
78 | weight_decay: 0
79 | betas: [0.9, 0.99]
80 | optim_d:
81 | type: Adam
82 | lr: !!float 1e-4
83 | weight_decay: 0
84 | betas: [0.9, 0.99]
85 |
86 | scheduler:
87 | type: MultiStepLR
88 | milestones: [400000]
89 | gamma: 0.5
90 |
91 | total_iter: 400000
92 | warmup_iter: -1 # no warm up
93 |
94 | # losses
95 | pixel_opt:
96 | type: L1Loss
97 | loss_weight: 1.0
98 | reduction: mean
99 | # perceptual loss (content and style losses)
100 | perceptual_opt:
101 | type: PerceptualLoss
102 | layer_weights:
103 | # before relu
104 | 'conv1_2': 0.1
105 | 'conv2_2': 0.1
106 | 'conv3_4': 1
107 | 'conv4_4': 1
108 | 'conv5_4': 1
109 | vgg_type: vgg19
110 | use_input_norm: true
111 | perceptual_weight: !!float 1.0
112 | style_weight: 0
113 | range_norm: false
114 | criterion: l1
115 | # gan loss
116 | gan_opt:
117 | type: GANLoss
118 | gan_type: vanilla
119 | real_label_val: 1.0
120 | fake_label_val: 0.0
121 | loss_weight: !!float 1e-1
122 |
123 | net_d_iters: 1
124 | net_d_init_iters: 0
125 |
126 | # Uncomment these for validation
127 | # validation settings
128 | # val:
129 | # val_freq: !!float 5e3
130 | # save_img: True
131 |
132 | # metrics:
133 | # psnr: # metric name
134 | # type: calculate_psnr
135 | # crop_border: 4
136 | # test_y_channel: false
137 |
138 | # logging settings
139 | logger:
140 | print_freq: 100
141 | save_checkpoint_freq: !!float 5e3
142 | use_tb_logger: true
143 | wandb:
144 | project: ~
145 | resume_id: ~
146 |
147 | # dist training settings
148 | dist_params:
149 | backend: nccl
150 | port: 29500
151 |
--------------------------------------------------------------------------------
/stablefusion/models/realesrgan/options/train_realesrgan_x2plus.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: train_RealESRGANx2plus_400k_B12G4
3 | model_type: RealESRGANModel
4 | scale: 2
5 | num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
6 | manual_seed: 0
7 |
8 | # ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
9 | # USM the ground-truth
10 | l1_gt_usm: True
11 | percep_gt_usm: True
12 | gan_gt_usm: False
13 |
14 | # the first degradation process
15 | resize_prob: [0.2, 0.7, 0.1] # up, down, keep
16 | resize_range: [0.15, 1.5]
17 | gaussian_noise_prob: 0.5
18 | noise_range: [1, 30]
19 | poisson_scale_range: [0.05, 3]
20 | gray_noise_prob: 0.4
21 | jpeg_range: [30, 95]
22 |
23 | # the second degradation process
24 | second_blur_prob: 0.8
25 | resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
26 | resize_range2: [0.3, 1.2]
27 | gaussian_noise_prob2: 0.5
28 | noise_range2: [1, 25]
29 | poisson_scale_range2: [0.05, 2.5]
30 | gray_noise_prob2: 0.4
31 | jpeg_range2: [30, 95]
32 |
33 | gt_size: 256
34 | queue_size: 180
35 |
36 | # dataset and data loader settings
37 | datasets:
38 | train:
39 | name: DF2K+OST
40 | type: RealESRGANDataset
41 | dataroot_gt: datasets/DF2K
42 | meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
43 | io_backend:
44 | type: disk
45 |
46 | blur_kernel_size: 21
47 | kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
48 | kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
49 | sinc_prob: 0.1
50 | blur_sigma: [0.2, 3]
51 | betag_range: [0.5, 4]
52 | betap_range: [1, 2]
53 |
54 | blur_kernel_size2: 21
55 | kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
56 | kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
57 | sinc_prob2: 0.1
58 | blur_sigma2: [0.2, 1.5]
59 | betag_range2: [0.5, 4]
60 | betap_range2: [1, 2]
61 |
62 | final_sinc_prob: 0.8
63 |
64 | gt_size: 256
65 | use_hflip: True
66 | use_rot: False
67 |
68 | # data loader
69 | use_shuffle: true
70 | num_worker_per_gpu: 5
71 | batch_size_per_gpu: 12
72 | dataset_enlarge_ratio: 1
73 | prefetch_mode: ~
74 |
75 | # Uncomment these for validation
76 | # val:
77 | # name: validation
78 | # type: PairedImageDataset
79 | # dataroot_gt: path_to_gt
80 | # dataroot_lq: path_to_lq
81 | # io_backend:
82 | # type: disk
83 |
84 | # network structures
85 | network_g:
86 | type: RRDBNet
87 | num_in_ch: 3
88 | num_out_ch: 3
89 | num_feat: 64
90 | num_block: 23
91 | num_grow_ch: 32
92 | scale: 2
93 |
94 | network_d:
95 | type: UNetDiscriminatorSN
96 | num_in_ch: 3
97 | num_feat: 64
98 | skip_connection: True
99 |
100 | # path
101 | path:
102 | # use the pre-trained Real-ESRNet model
103 | pretrain_network_g: experiments/pretrained_models/RealESRNet_x2plus.pth
104 | param_key_g: params_ema
105 | strict_load_g: true
106 | resume_state: ~
107 |
108 | # training settings
109 | train:
110 | ema_decay: 0.999
111 | optim_g:
112 | type: Adam
113 | lr: !!float 1e-4
114 | weight_decay: 0
115 | betas: [0.9, 0.99]
116 | optim_d:
117 | type: Adam
118 | lr: !!float 1e-4
119 | weight_decay: 0
120 | betas: [0.9, 0.99]
121 |
122 | scheduler:
123 | type: MultiStepLR
124 | milestones: [400000]
125 | gamma: 0.5
126 |
127 | total_iter: 400000
128 | warmup_iter: -1 # no warm up
129 |
130 | # losses
131 | pixel_opt:
132 | type: L1Loss
133 | loss_weight: 1.0
134 | reduction: mean
135 | # perceptual loss (content and style losses)
136 | perceptual_opt:
137 | type: PerceptualLoss
138 | layer_weights:
139 | # before relu
140 | 'conv1_2': 0.1
141 | 'conv2_2': 0.1
142 | 'conv3_4': 1
143 | 'conv4_4': 1
144 | 'conv5_4': 1
145 | vgg_type: vgg19
146 | use_input_norm: true
147 | perceptual_weight: !!float 1.0
148 | style_weight: 0
149 | range_norm: false
150 | criterion: l1
151 | # gan loss
152 | gan_opt:
153 | type: GANLoss
154 | gan_type: vanilla
155 | real_label_val: 1.0
156 | fake_label_val: 0.0
157 | loss_weight: !!float 1e-1
158 |
159 | net_d_iters: 1
160 | net_d_init_iters: 0
161 |
162 | # Uncomment these for validation
163 | # validation settings
164 | # val:
165 | # val_freq: !!float 5e3
166 | # save_img: True
167 |
168 | # metrics:
169 | # psnr: # metric name
170 | # type: calculate_psnr
171 | # crop_border: 4
172 | # test_y_channel: false
173 |
174 | # logging settings
175 | logger:
176 | print_freq: 100
177 | save_checkpoint_freq: !!float 5e3
178 | use_tb_logger: true
179 | wandb:
180 | project: ~
181 | resume_id: ~
182 |
183 | # dist training settings
184 | dist_params:
185 | backend: nccl
186 | port: 29500
187 |
--------------------------------------------------------------------------------
/stablefusion/models/realesrgan/options/train_realesrgan_x4plus.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: train_RealESRGANx4plus_400k_B12G4
3 | model_type: RealESRGANModel
4 | scale: 4
5 | num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
6 | manual_seed: 0
7 |
8 | # ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
9 | # USM the ground-truth
10 | l1_gt_usm: True
11 | percep_gt_usm: True
12 | gan_gt_usm: False
13 |
14 | # the first degradation process
15 | resize_prob: [0.2, 0.7, 0.1] # up, down, keep
16 | resize_range: [0.15, 1.5]
17 | gaussian_noise_prob: 0.5
18 | noise_range: [1, 30]
19 | poisson_scale_range: [0.05, 3]
20 | gray_noise_prob: 0.4
21 | jpeg_range: [30, 95]
22 |
23 | # the second degradation process
24 | second_blur_prob: 0.8
25 | resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
26 | resize_range2: [0.3, 1.2]
27 | gaussian_noise_prob2: 0.5
28 | noise_range2: [1, 25]
29 | poisson_scale_range2: [0.05, 2.5]
30 | gray_noise_prob2: 0.4
31 | jpeg_range2: [30, 95]
32 |
33 | gt_size: 256
34 | queue_size: 180
35 |
36 | # dataset and data loader settings
37 | datasets:
38 | train:
39 | name: DF2K+OST
40 | type: RealESRGANDataset
41 | dataroot_gt: datasets/DF2K
42 | meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
43 | io_backend:
44 | type: disk
45 |
46 | blur_kernel_size: 21
47 | kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
48 | kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
49 | sinc_prob: 0.1
50 | blur_sigma: [0.2, 3]
51 | betag_range: [0.5, 4]
52 | betap_range: [1, 2]
53 |
54 | blur_kernel_size2: 21
55 | kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
56 | kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
57 | sinc_prob2: 0.1
58 | blur_sigma2: [0.2, 1.5]
59 | betag_range2: [0.5, 4]
60 | betap_range2: [1, 2]
61 |
62 | final_sinc_prob: 0.8
63 |
64 | gt_size: 256
65 | use_hflip: True
66 | use_rot: False
67 |
68 | # data loader
69 | use_shuffle: true
70 | num_worker_per_gpu: 5
71 | batch_size_per_gpu: 12
72 | dataset_enlarge_ratio: 1
73 | prefetch_mode: ~
74 |
75 | # Uncomment these for validation
76 | # val:
77 | # name: validation
78 | # type: PairedImageDataset
79 | # dataroot_gt: path_to_gt
80 | # dataroot_lq: path_to_lq
81 | # io_backend:
82 | # type: disk
83 |
84 | # network structures
85 | network_g:
86 | type: RRDBNet
87 | num_in_ch: 3
88 | num_out_ch: 3
89 | num_feat: 64
90 | num_block: 23
91 | num_grow_ch: 32
92 |
93 | network_d:
94 | type: UNetDiscriminatorSN
95 | num_in_ch: 3
96 | num_feat: 64
97 | skip_connection: True
98 |
99 | # path
100 | path:
101 | # use the pre-trained Real-ESRNet model
102 | pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
103 | param_key_g: params_ema
104 | strict_load_g: true
105 | resume_state: ~
106 |
107 | # training settings
108 | train:
109 | ema_decay: 0.999
110 | optim_g:
111 | type: Adam
112 | lr: !!float 1e-4
113 | weight_decay: 0
114 | betas: [0.9, 0.99]
115 | optim_d:
116 | type: Adam
117 | lr: !!float 1e-4
118 | weight_decay: 0
119 | betas: [0.9, 0.99]
120 |
121 | scheduler:
122 | type: MultiStepLR
123 | milestones: [400000]
124 | gamma: 0.5
125 |
126 | total_iter: 400000
127 | warmup_iter: -1 # no warm up
128 |
129 | # losses
130 | pixel_opt:
131 | type: L1Loss
132 | loss_weight: 1.0
133 | reduction: mean
134 | # perceptual loss (content and style losses)
135 | perceptual_opt:
136 | type: PerceptualLoss
137 | layer_weights:
138 | # before relu
139 | 'conv1_2': 0.1
140 | 'conv2_2': 0.1
141 | 'conv3_4': 1
142 | 'conv4_4': 1
143 | 'conv5_4': 1
144 | vgg_type: vgg19
145 | use_input_norm: true
146 | perceptual_weight: !!float 1.0
147 | style_weight: 0
148 | range_norm: false
149 | criterion: l1
150 | # gan loss
151 | gan_opt:
152 | type: GANLoss
153 | gan_type: vanilla
154 | real_label_val: 1.0
155 | fake_label_val: 0.0
156 | loss_weight: !!float 1e-1
157 |
158 | net_d_iters: 1
159 | net_d_init_iters: 0
160 |
161 | # Uncomment these for validation
162 | # validation settings
163 | # val:
164 | # val_freq: !!float 5e3
165 | # save_img: True
166 |
167 | # metrics:
168 | # psnr: # metric name
169 | # type: calculate_psnr
170 | # crop_border: 4
171 | # test_y_channel: false
172 |
173 | # logging settings
174 | logger:
175 | print_freq: 100
176 | save_checkpoint_freq: !!float 5e3
177 | use_tb_logger: true
178 | wandb:
179 | project: ~
180 | resume_id: ~
181 |
182 | # dist training settings
183 | dist_params:
184 | backend: nccl
185 | port: 29500
186 |
--------------------------------------------------------------------------------
/stablefusion/models/realesrgan/options/train_realesrnet_x2plus.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: train_RealESRNetx2plus_1000k_B12G4
3 | model_type: RealESRNetModel
4 | scale: 2
5 | num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
6 | manual_seed: 0
7 |
8 | # ----------------- options for synthesizing training data in RealESRNetModel ----------------- #
9 | gt_usm: True # USM the ground-truth
10 |
11 | # the first degradation process
12 | resize_prob: [0.2, 0.7, 0.1] # up, down, keep
13 | resize_range: [0.15, 1.5]
14 | gaussian_noise_prob: 0.5
15 | noise_range: [1, 30]
16 | poisson_scale_range: [0.05, 3]
17 | gray_noise_prob: 0.4
18 | jpeg_range: [30, 95]
19 |
20 | # the second degradation process
21 | second_blur_prob: 0.8
22 | resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
23 | resize_range2: [0.3, 1.2]
24 | gaussian_noise_prob2: 0.5
25 | noise_range2: [1, 25]
26 | poisson_scale_range2: [0.05, 2.5]
27 | gray_noise_prob2: 0.4
28 | jpeg_range2: [30, 95]
29 |
30 | gt_size: 256
31 | queue_size: 180
32 |
33 | # dataset and data loader settings
34 | datasets:
35 | train:
36 | name: DF2K+OST
37 | type: RealESRGANDataset
38 | dataroot_gt: datasets/DF2K
39 | meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
40 | io_backend:
41 | type: disk
42 |
43 | blur_kernel_size: 21
44 | kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
45 | kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
46 | sinc_prob: 0.1
47 | blur_sigma: [0.2, 3]
48 | betag_range: [0.5, 4]
49 | betap_range: [1, 2]
50 |
51 | blur_kernel_size2: 21
52 | kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
53 | kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
54 | sinc_prob2: 0.1
55 | blur_sigma2: [0.2, 1.5]
56 | betag_range2: [0.5, 4]
57 | betap_range2: [1, 2]
58 |
59 | final_sinc_prob: 0.8
60 |
61 | gt_size: 256
62 | use_hflip: True
63 | use_rot: False
64 |
65 | # data loader
66 | use_shuffle: true
67 | num_worker_per_gpu: 5
68 | batch_size_per_gpu: 12
69 | dataset_enlarge_ratio: 1
70 | prefetch_mode: ~
71 |
72 | # Uncomment these for validation
73 | # val:
74 | # name: validation
75 | # type: PairedImageDataset
76 | # dataroot_gt: path_to_gt
77 | # dataroot_lq: path_to_lq
78 | # io_backend:
79 | # type: disk
80 |
81 | # network structures
82 | network_g:
83 | type: RRDBNet
84 | num_in_ch: 3
85 | num_out_ch: 3
86 | num_feat: 64
87 | num_block: 23
88 | num_grow_ch: 32
89 | scale: 2
90 |
91 | # path
92 | path:
93 | pretrain_network_g: experiments/pretrained_models/RealESRGAN_x4plus.pth
94 | param_key_g: params_ema
95 | strict_load_g: False
96 | resume_state: ~
97 |
98 | # training settings
99 | train:
100 | ema_decay: 0.999
101 | optim_g:
102 | type: Adam
103 | lr: !!float 2e-4
104 | weight_decay: 0
105 | betas: [0.9, 0.99]
106 |
107 | scheduler:
108 | type: MultiStepLR
109 | milestones: [1000000]
110 | gamma: 0.5
111 |
112 | total_iter: 1000000
113 | warmup_iter: -1 # no warm up
114 |
115 | # losses
116 | pixel_opt:
117 | type: L1Loss
118 | loss_weight: 1.0
119 | reduction: mean
120 |
121 | # Uncomment these for validation
122 | # validation settings
123 | # val:
124 | # val_freq: !!float 5e3
125 | # save_img: True
126 |
127 | # metrics:
128 | # psnr: # metric name
129 | # type: calculate_psnr
130 | # crop_border: 4
131 | # test_y_channel: false
132 |
133 | # logging settings
134 | logger:
135 | print_freq: 100
136 | save_checkpoint_freq: !!float 5e3
137 | use_tb_logger: true
138 | wandb:
139 | project: ~
140 | resume_id: ~
141 |
142 | # dist training settings
143 | dist_params:
144 | backend: nccl
145 | port: 29500
146 |
--------------------------------------------------------------------------------
/stablefusion/models/realesrgan/options/train_realesrnet_x4plus.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: train_RealESRNetx4plus_1000k_B12G4
3 | model_type: RealESRNetModel
4 | scale: 4
5 | num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
6 | manual_seed: 0
7 |
8 | # ----------------- options for synthesizing training data in RealESRNetModel ----------------- #
9 | gt_usm: True # USM the ground-truth
10 |
11 | # the first degradation process
12 | resize_prob: [0.2, 0.7, 0.1] # up, down, keep
13 | resize_range: [0.15, 1.5]
14 | gaussian_noise_prob: 0.5
15 | noise_range: [1, 30]
16 | poisson_scale_range: [0.05, 3]
17 | gray_noise_prob: 0.4
18 | jpeg_range: [30, 95]
19 |
20 | # the second degradation process
21 | second_blur_prob: 0.8
22 | resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
23 | resize_range2: [0.3, 1.2]
24 | gaussian_noise_prob2: 0.5
25 | noise_range2: [1, 25]
26 | poisson_scale_range2: [0.05, 2.5]
27 | gray_noise_prob2: 0.4
28 | jpeg_range2: [30, 95]
29 |
30 | gt_size: 256
31 | queue_size: 180
32 |
33 | # dataset and data loader settings
34 | datasets:
35 | train:
36 | name: DF2K+OST
37 | type: RealESRGANDataset
38 | dataroot_gt: datasets/DF2K
39 | meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
40 | io_backend:
41 | type: disk
42 |
43 | blur_kernel_size: 21
44 | kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
45 | kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
46 | sinc_prob: 0.1
47 | blur_sigma: [0.2, 3]
48 | betag_range: [0.5, 4]
49 | betap_range: [1, 2]
50 |
51 | blur_kernel_size2: 21
52 | kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
53 | kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
54 | sinc_prob2: 0.1
55 | blur_sigma2: [0.2, 1.5]
56 | betag_range2: [0.5, 4]
57 | betap_range2: [1, 2]
58 |
59 | final_sinc_prob: 0.8
60 |
61 | gt_size: 256
62 | use_hflip: True
63 | use_rot: False
64 |
65 | # data loader
66 | use_shuffle: true
67 | num_worker_per_gpu: 5
68 | batch_size_per_gpu: 12
69 | dataset_enlarge_ratio: 1
70 | prefetch_mode: ~
71 |
72 | # Uncomment these for validation
73 | # val:
74 | # name: validation
75 | # type: PairedImageDataset
76 | # dataroot_gt: path_to_gt
77 | # dataroot_lq: path_to_lq
78 | # io_backend:
79 | # type: disk
80 |
81 | # network structures
82 | network_g:
83 | type: RRDBNet
84 | num_in_ch: 3
85 | num_out_ch: 3
86 | num_feat: 64
87 | num_block: 23
88 | num_grow_ch: 32
89 |
90 | # path
91 | path:
92 | pretrain_network_g: experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth
93 | param_key_g: params_ema
94 | strict_load_g: true
95 | resume_state: ~
96 |
97 | # training settings
98 | train:
99 | ema_decay: 0.999
100 | optim_g:
101 | type: Adam
102 | lr: !!float 2e-4
103 | weight_decay: 0
104 | betas: [0.9, 0.99]
105 |
106 | scheduler:
107 | type: MultiStepLR
108 | milestones: [1000000]
109 | gamma: 0.5
110 |
111 | total_iter: 1000000
112 | warmup_iter: -1 # no warm up
113 |
114 | # losses
115 | pixel_opt:
116 | type: L1Loss
117 | loss_weight: 1.0
118 | reduction: mean
119 |
120 | # Uncomment these for validation
121 | # validation settings
122 | # val:
123 | # val_freq: !!float 5e3
124 | # save_img: True
125 |
126 | # metrics:
127 | # psnr: # metric name
128 | # type: calculate_psnr
129 | # crop_border: 4
130 | # test_y_channel: false
131 |
132 | # logging settings
133 | logger:
134 | print_freq: 100
135 | save_checkpoint_freq: !!float 5e3
136 | use_tb_logger: true
137 | wandb:
138 | project: ~
139 | resume_id: ~
140 |
141 | # dist training settings
142 | dist_params:
143 | backend: nccl
144 | port: 29500
145 |
--------------------------------------------------------------------------------
/stablefusion/models/realesrgan/realesrgan/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | from .archs import *
3 | from .data import *
4 | from .models import *
5 | from .utils import *
6 | from .version import *
7 |
--------------------------------------------------------------------------------
/stablefusion/models/realesrgan/realesrgan/archs/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from basicsr.utils import scandir
3 | from os import path as osp
4 |
5 | # automatically scan and import arch modules for registry
6 | # scan all the files that end with '_arch.py' under the archs folder
7 | arch_folder = osp.dirname(osp.abspath(__file__))
8 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
9 | # import all the arch modules
10 | _arch_modules = [importlib.import_module(f'realesrgan.archs.{file_name}') for file_name in arch_filenames]
11 |
--------------------------------------------------------------------------------
/stablefusion/models/realesrgan/realesrgan/archs/discriminator_arch.py:
--------------------------------------------------------------------------------
1 | from basicsr.utils.registry import ARCH_REGISTRY
2 | from torch import nn as nn
3 | from torch.nn import functional as F
4 | from torch.nn.utils import spectral_norm
5 |
6 |
7 | @ARCH_REGISTRY.register()
8 | class UNetDiscriminatorSN(nn.Module):
9 | """Defines a U-Net discriminator with spectral normalization (SN)
10 |
11 | It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
12 |
13 | Arg:
14 | num_in_ch (int): Channel number of inputs. Default: 3.
15 | num_feat (int): Channel number of base intermediate features. Default: 64.
16 | skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
17 | """
18 |
19 | def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
20 | super(UNetDiscriminatorSN, self).__init__()
21 | self.skip_connection = skip_connection
22 | norm = spectral_norm
23 | # the first convolution
24 | self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
25 | # downsample
26 | self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
27 | self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
28 | self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
29 | # upsample
30 | self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
31 | self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
32 | self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
33 | # extra convolutions
34 | self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
35 | self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
36 | self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
37 |
38 | def forward(self, x):
39 | # downsample
40 | x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
41 | x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
42 | x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
43 | x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
44 |
45 | # upsample
46 | x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
47 | x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
48 |
49 | if self.skip_connection:
50 | x4 = x4 + x2
51 | x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
52 | x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
53 |
54 | if self.skip_connection:
55 | x5 = x5 + x1
56 | x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
57 | x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
58 |
59 | if self.skip_connection:
60 | x6 = x6 + x0
61 |
62 | # extra convolutions
63 | out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
64 | out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
65 | out = self.conv9(out)
66 |
67 | return out
68 |
--------------------------------------------------------------------------------
/stablefusion/models/realesrgan/realesrgan/archs/srvgg_arch.py:
--------------------------------------------------------------------------------
1 | from basicsr.utils.registry import ARCH_REGISTRY
2 | from torch import nn as nn
3 | from torch.nn import functional as F
4 |
5 |
6 | @ARCH_REGISTRY.register()
7 | class SRVGGNetCompact(nn.Module):
8 | """A compact VGG-style network structure for super-resolution.
9 |
10 | It is a compact network structure, which performs upsampling in the last layer and no convolution is
11 | conducted on the HR feature space.
12 |
13 | Args:
14 | num_in_ch (int): Channel number of inputs. Default: 3.
15 | num_out_ch (int): Channel number of outputs. Default: 3.
16 | num_feat (int): Channel number of intermediate features. Default: 64.
17 | num_conv (int): Number of convolution layers in the body network. Default: 16.
18 | upscale (int): Upsampling factor. Default: 4.
19 | act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
20 | """
21 |
22 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
23 | super(SRVGGNetCompact, self).__init__()
24 | self.num_in_ch = num_in_ch
25 | self.num_out_ch = num_out_ch
26 | self.num_feat = num_feat
27 | self.num_conv = num_conv
28 | self.upscale = upscale
29 | self.act_type = act_type
30 |
31 | self.body = nn.ModuleList()
32 | # the first conv
33 | self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
34 | # the first activation
35 | if act_type == 'relu':
36 | activation = nn.ReLU(inplace=True)
37 | elif act_type == 'prelu':
38 | activation = nn.PReLU(num_parameters=num_feat)
39 | elif act_type == 'leakyrelu':
40 | activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
41 | self.body.append(activation)
42 |
43 | # the body structure
44 | for _ in range(num_conv):
45 | self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
46 | # activation
47 | if act_type == 'relu':
48 | activation = nn.ReLU(inplace=True)
49 | elif act_type == 'prelu':
50 | activation = nn.PReLU(num_parameters=num_feat)
51 | elif act_type == 'leakyrelu':
52 | activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
53 | self.body.append(activation)
54 |
55 | # the last conv
56 | self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
57 | # upsample
58 | self.upsampler = nn.PixelShuffle(upscale)
59 |
60 | def forward(self, x):
61 | out = x
62 | for i in range(0, len(self.body)):
63 | out = self.body[i](out)
64 |
65 | out = self.upsampler(out)
66 | # add the nearest upsampled image, so that the network learns the residual
67 | base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
68 | out += base
69 | return out
70 |
--------------------------------------------------------------------------------
/stablefusion/models/realesrgan/realesrgan/data/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from basicsr.utils import scandir
3 | from os import path as osp
4 |
5 | # automatically scan and import dataset modules for registry
6 | # scan all the files that end with '_dataset.py' under the data folder
7 | data_folder = osp.dirname(osp.abspath(__file__))
8 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
9 | # import all the dataset modules
10 | _dataset_modules = [importlib.import_module(f'realesrgan.data.{file_name}') for file_name in dataset_filenames]
11 |
--------------------------------------------------------------------------------
/stablefusion/models/realesrgan/realesrgan/data/realesrgan_dataset.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import numpy as np
4 | import os
5 | import os.path as osp
6 | import random
7 | import time
8 | import torch
9 | from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
10 | from basicsr.data.transforms import augment
11 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
12 | from basicsr.utils.registry import DATASET_REGISTRY
13 | from torch.utils import data as data
14 |
15 |
16 | @DATASET_REGISTRY.register()
17 | class RealESRGANDataset(data.Dataset):
18 | """Dataset used for Real-ESRGAN model:
19 | Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
20 |
21 | It loads gt (Ground-Truth) images, and augments them.
22 | It also generates blur kernels and sinc kernels for generating low-quality images.
23 | Note that the low-quality images are processed in tensors on GPUS for faster processing.
24 |
25 | Args:
26 | opt (dict): Config for train datasets. It contains the following keys:
27 | dataroot_gt (str): Data root path for gt.
28 | meta_info (str): Path for meta information file.
29 | io_backend (dict): IO backend type and other kwarg.
30 | use_hflip (bool): Use horizontal flips.
31 | use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
32 | Please see more options in the codes.
33 | """
34 |
35 | def __init__(self, opt):
36 | super(RealESRGANDataset, self).__init__()
37 | self.opt = opt
38 | self.file_client = None
39 | self.io_backend_opt = opt['io_backend']
40 | self.gt_folder = opt['dataroot_gt']
41 |
42 | # file client (lmdb io backend)
43 | if self.io_backend_opt['type'] == 'lmdb':
44 | self.io_backend_opt['db_paths'] = [self.gt_folder]
45 | self.io_backend_opt['client_keys'] = ['gt']
46 | if not self.gt_folder.endswith('.lmdb'):
47 | raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
48 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
49 | self.paths = [line.split('.')[0] for line in fin]
50 | else:
51 | # disk backend with meta_info
52 | # Each line in the meta_info describes the relative path to an image
53 | with open(self.opt['meta_info']) as fin:
54 | paths = [line.strip().split(' ')[0] for line in fin]
55 | self.paths = [os.path.join(self.gt_folder, v) for v in paths]
56 |
57 | # blur settings for the first degradation
58 | self.blur_kernel_size = opt['blur_kernel_size']
59 | self.kernel_list = opt['kernel_list']
60 | self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
61 | self.blur_sigma = opt['blur_sigma']
62 | self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
63 | self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
64 | self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
65 |
66 | # blur settings for the second degradation
67 | self.blur_kernel_size2 = opt['blur_kernel_size2']
68 | self.kernel_list2 = opt['kernel_list2']
69 | self.kernel_prob2 = opt['kernel_prob2']
70 | self.blur_sigma2 = opt['blur_sigma2']
71 | self.betag_range2 = opt['betag_range2']
72 | self.betap_range2 = opt['betap_range2']
73 | self.sinc_prob2 = opt['sinc_prob2']
74 |
75 | # a final sinc filter
76 | self.final_sinc_prob = opt['final_sinc_prob']
77 |
78 | self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
79 | # TODO: kernel range is now hard-coded, should be in the configure file
80 | self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
81 | self.pulse_tensor[10, 10] = 1
82 |
83 | def __getitem__(self, index):
84 | if self.file_client is None:
85 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
86 |
87 | # -------------------------------- Load gt images -------------------------------- #
88 | # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
89 | gt_path = self.paths[index]
90 | # avoid errors caused by high latency in reading files
91 | retry = 3
92 | while retry > 0:
93 | try:
94 | img_bytes = self.file_client.get(gt_path, 'gt')
95 | except (IOError, OSError) as e:
96 | logger = get_root_logger()
97 | logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
98 | # change another file to read
99 | index = random.randint(0, self.__len__())
100 | gt_path = self.paths[index]
101 | time.sleep(1) # sleep 1s for occasional server congestion
102 | else:
103 | break
104 | finally:
105 | retry -= 1
106 | img_gt = imfrombytes(img_bytes, float32=True)
107 |
108 | # -------------------- Do augmentation for training: flip, rotation -------------------- #
109 | img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
110 |
111 | # crop or pad to 400
112 | # TODO: 400 is hard-coded. You may change it accordingly
113 | h, w = img_gt.shape[0:2]
114 | crop_pad_size = 400
115 | # pad
116 | if h < crop_pad_size or w < crop_pad_size:
117 | pad_h = max(0, crop_pad_size - h)
118 | pad_w = max(0, crop_pad_size - w)
119 | img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
120 | # crop
121 | if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
122 | h, w = img_gt.shape[0:2]
123 | # randomly choose top and left coordinates
124 | top = random.randint(0, h - crop_pad_size)
125 | left = random.randint(0, w - crop_pad_size)
126 | img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
127 |
128 | # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
129 | kernel_size = random.choice(self.kernel_range)
130 | if np.random.uniform() < self.opt['sinc_prob']:
131 | # this sinc filter setting is for kernels ranging from [7, 21]
132 | if kernel_size < 13:
133 | omega_c = np.random.uniform(np.pi / 3, np.pi)
134 | else:
135 | omega_c = np.random.uniform(np.pi / 5, np.pi)
136 | kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
137 | else:
138 | kernel = random_mixed_kernels(
139 | self.kernel_list,
140 | self.kernel_prob,
141 | kernel_size,
142 | self.blur_sigma,
143 | self.blur_sigma, [-math.pi, math.pi],
144 | self.betag_range,
145 | self.betap_range,
146 | noise_range=None)
147 | # pad kernel
148 | pad_size = (21 - kernel_size) // 2
149 | kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
150 |
151 | # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
152 | kernel_size = random.choice(self.kernel_range)
153 | if np.random.uniform() < self.opt['sinc_prob2']:
154 | if kernel_size < 13:
155 | omega_c = np.random.uniform(np.pi / 3, np.pi)
156 | else:
157 | omega_c = np.random.uniform(np.pi / 5, np.pi)
158 | kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
159 | else:
160 | kernel2 = random_mixed_kernels(
161 | self.kernel_list2,
162 | self.kernel_prob2,
163 | kernel_size,
164 | self.blur_sigma2,
165 | self.blur_sigma2, [-math.pi, math.pi],
166 | self.betag_range2,
167 | self.betap_range2,
168 | noise_range=None)
169 |
170 | # pad kernel
171 | pad_size = (21 - kernel_size) // 2
172 | kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
173 |
174 | # ------------------------------------- the final sinc kernel ------------------------------------- #
175 | if np.random.uniform() < self.opt['final_sinc_prob']:
176 | kernel_size = random.choice(self.kernel_range)
177 | omega_c = np.random.uniform(np.pi / 3, np.pi)
178 | sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
179 | sinc_kernel = torch.FloatTensor(sinc_kernel)
180 | else:
181 | sinc_kernel = self.pulse_tensor
182 |
183 | # BGR to RGB, HWC to CHW, numpy to tensor
184 | img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
185 | kernel = torch.FloatTensor(kernel)
186 | kernel2 = torch.FloatTensor(kernel2)
187 |
188 | return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
189 | return return_d
190 |
191 | def __len__(self):
192 | return len(self.paths)
193 |
--------------------------------------------------------------------------------
/stablefusion/models/realesrgan/realesrgan/data/realesrgan_paired_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
3 | from basicsr.data.transforms import augment, paired_random_crop
4 | from basicsr.utils import FileClient, imfrombytes, img2tensor
5 | from basicsr.utils.registry import DATASET_REGISTRY
6 | from torch.utils import data as data
7 | from torchvision.transforms.functional import normalize
8 |
9 |
10 | @DATASET_REGISTRY.register()
11 | class RealESRGANPairedDataset(data.Dataset):
12 | """Paired image dataset for image restoration.
13 |
14 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
15 |
16 | There are three modes:
17 | 1. 'lmdb': Use lmdb files.
18 | If opt['io_backend'] == lmdb.
19 | 2. 'meta_info': Use meta information file to generate paths.
20 | If opt['io_backend'] != lmdb and opt['meta_info'] is not None.
21 | 3. 'folder': Scan folders to generate paths.
22 | The rest.
23 |
24 | Args:
25 | opt (dict): Config for train datasets. It contains the following keys:
26 | dataroot_gt (str): Data root path for gt.
27 | dataroot_lq (str): Data root path for lq.
28 | meta_info (str): Path for meta information file.
29 | io_backend (dict): IO backend type and other kwarg.
30 | filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
31 | Default: '{}'.
32 | gt_size (int): Cropped patched size for gt patches.
33 | use_hflip (bool): Use horizontal flips.
34 | use_rot (bool): Use rotation (use vertical flip and transposing h
35 | and w for implementation).
36 |
37 | scale (bool): Scale, which will be added automatically.
38 | phase (str): 'train' or 'val'.
39 | """
40 |
41 | def __init__(self, opt):
42 | super(RealESRGANPairedDataset, self).__init__()
43 | self.opt = opt
44 | self.file_client = None
45 | self.io_backend_opt = opt['io_backend']
46 | # mean and std for normalizing the input images
47 | self.mean = opt['mean'] if 'mean' in opt else None
48 | self.std = opt['std'] if 'std' in opt else None
49 |
50 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
51 | self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
52 |
53 | # file client (lmdb io backend)
54 | if self.io_backend_opt['type'] == 'lmdb':
55 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
56 | self.io_backend_opt['client_keys'] = ['lq', 'gt']
57 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
58 | elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
59 | # disk backend with meta_info
60 | # Each line in the meta_info describes the relative path to an image
61 | with open(self.opt['meta_info']) as fin:
62 | paths = [line.strip() for line in fin]
63 | self.paths = []
64 | for path in paths:
65 | gt_path, lq_path = path.split(', ')
66 | gt_path = os.path.join(self.gt_folder, gt_path)
67 | lq_path = os.path.join(self.lq_folder, lq_path)
68 | self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
69 | else:
70 | # disk backend
71 | # it will scan the whole folder to get meta info
72 | # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
73 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
74 |
75 | def __getitem__(self, index):
76 | if self.file_client is None:
77 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
78 |
79 | scale = self.opt['scale']
80 |
81 | # Load gt and lq images. Dimension order: HWC; channel order: BGR;
82 | # image range: [0, 1], float32.
83 | gt_path = self.paths[index]['gt_path']
84 | img_bytes = self.file_client.get(gt_path, 'gt')
85 | img_gt = imfrombytes(img_bytes, float32=True)
86 | lq_path = self.paths[index]['lq_path']
87 | img_bytes = self.file_client.get(lq_path, 'lq')
88 | img_lq = imfrombytes(img_bytes, float32=True)
89 |
90 | # augmentation for training
91 | if self.opt['phase'] == 'train':
92 | gt_size = self.opt['gt_size']
93 | # random crop
94 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
95 | # flip, rotation
96 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
97 |
98 | # BGR to RGB, HWC to CHW, numpy to tensor
99 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
100 | # normalize
101 | if self.mean is not None or self.std is not None:
102 | normalize(img_lq, self.mean, self.std, inplace=True)
103 | normalize(img_gt, self.mean, self.std, inplace=True)
104 |
105 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
106 |
107 | def __len__(self):
108 | return len(self.paths)
109 |
--------------------------------------------------------------------------------
/stablefusion/models/realesrgan/realesrgan/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from basicsr.utils import scandir
3 | from os import path as osp
4 |
5 | # automatically scan and import model modules for registry
6 | # scan all the files that end with '_model.py' under the model folder
7 | model_folder = osp.dirname(osp.abspath(__file__))
8 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
9 | # import all the model modules
10 | _model_modules = [importlib.import_module(f'realesrgan.models.{file_name}') for file_name in model_filenames]
11 |
--------------------------------------------------------------------------------
/stablefusion/models/realesrgan/realesrgan/models/realesrnet_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import torch
4 | from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
5 | from basicsr.data.transforms import paired_random_crop
6 | from basicsr.models.sr_model import SRModel
7 | from basicsr.utils import DiffJPEG, USMSharp
8 | from basicsr.utils.img_process_util import filter2D
9 | from basicsr.utils.registry import MODEL_REGISTRY
10 | from torch.nn import functional as F
11 |
12 |
13 | @MODEL_REGISTRY.register()
14 | class RealESRNetModel(SRModel):
15 | """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
16 |
17 | It is trained without GAN losses.
18 | It mainly performs:
19 | 1. randomly synthesize LQ images in GPU tensors
20 | 2. optimize the networks with GAN training.
21 | """
22 |
23 | def __init__(self, opt):
24 | super(RealESRNetModel, self).__init__(opt)
25 | self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
26 | self.usm_sharpener = USMSharp().cuda() # do usm sharpening
27 | self.queue_size = opt.get('queue_size', 180)
28 |
29 | @torch.no_grad()
30 | def _dequeue_and_enqueue(self):
31 | """It is the training pair pool for increasing the diversity in a batch.
32 |
33 | Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
34 | batch could not have different resize scaling factors. Therefore, we employ this training pair pool
35 | to increase the degradation diversity in a batch.
36 | """
37 | # initialize
38 | b, c, h, w = self.lq.size()
39 | if not hasattr(self, 'queue_lr'):
40 | assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
41 | self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
42 | _, c, h, w = self.gt.size()
43 | self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
44 | self.queue_ptr = 0
45 | if self.queue_ptr == self.queue_size: # the pool is full
46 | # do dequeue and enqueue
47 | # shuffle
48 | idx = torch.randperm(self.queue_size)
49 | self.queue_lr = self.queue_lr[idx]
50 | self.queue_gt = self.queue_gt[idx]
51 | # get first b samples
52 | lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
53 | gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
54 | # update the queue
55 | self.queue_lr[0:b, :, :, :] = self.lq.clone()
56 | self.queue_gt[0:b, :, :, :] = self.gt.clone()
57 |
58 | self.lq = lq_dequeue
59 | self.gt = gt_dequeue
60 | else:
61 | # only do enqueue
62 | self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
63 | self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
64 | self.queue_ptr = self.queue_ptr + b
65 |
66 | @torch.no_grad()
67 | def feed_data(self, data):
68 | """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
69 | """
70 | if self.is_train and self.opt.get('high_order_degradation', True):
71 | # training data synthesis
72 | self.gt = data['gt'].to(self.device)
73 | # USM sharpen the GT images
74 | if self.opt['gt_usm'] is True:
75 | self.gt = self.usm_sharpener(self.gt)
76 |
77 | self.kernel1 = data['kernel1'].to(self.device)
78 | self.kernel2 = data['kernel2'].to(self.device)
79 | self.sinc_kernel = data['sinc_kernel'].to(self.device)
80 |
81 | ori_h, ori_w = self.gt.size()[2:4]
82 |
83 | # ----------------------- The first degradation process ----------------------- #
84 | # blur
85 | out = filter2D(self.gt, self.kernel1)
86 | # random resize
87 | updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
88 | if updown_type == 'up':
89 | scale = np.random.uniform(1, self.opt['resize_range'][1])
90 | elif updown_type == 'down':
91 | scale = np.random.uniform(self.opt['resize_range'][0], 1)
92 | else:
93 | scale = 1
94 | mode = random.choice(['area', 'bilinear', 'bicubic'])
95 | out = F.interpolate(out, scale_factor=scale, mode=mode)
96 | # add noise
97 | gray_noise_prob = self.opt['gray_noise_prob']
98 | if np.random.uniform() < self.opt['gaussian_noise_prob']:
99 | out = random_add_gaussian_noise_pt(
100 | out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
101 | else:
102 | out = random_add_poisson_noise_pt(
103 | out,
104 | scale_range=self.opt['poisson_scale_range'],
105 | gray_prob=gray_noise_prob,
106 | clip=True,
107 | rounds=False)
108 | # JPEG compression
109 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
110 | out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
111 | out = self.jpeger(out, quality=jpeg_p)
112 |
113 | # ----------------------- The second degradation process ----------------------- #
114 | # blur
115 | if np.random.uniform() < self.opt['second_blur_prob']:
116 | out = filter2D(out, self.kernel2)
117 | # random resize
118 | updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
119 | if updown_type == 'up':
120 | scale = np.random.uniform(1, self.opt['resize_range2'][1])
121 | elif updown_type == 'down':
122 | scale = np.random.uniform(self.opt['resize_range2'][0], 1)
123 | else:
124 | scale = 1
125 | mode = random.choice(['area', 'bilinear', 'bicubic'])
126 | out = F.interpolate(
127 | out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
128 | # add noise
129 | gray_noise_prob = self.opt['gray_noise_prob2']
130 | if np.random.uniform() < self.opt['gaussian_noise_prob2']:
131 | out = random_add_gaussian_noise_pt(
132 | out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
133 | else:
134 | out = random_add_poisson_noise_pt(
135 | out,
136 | scale_range=self.opt['poisson_scale_range2'],
137 | gray_prob=gray_noise_prob,
138 | clip=True,
139 | rounds=False)
140 |
141 | # JPEG compression + the final sinc filter
142 | # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
143 | # as one operation.
144 | # We consider two orders:
145 | # 1. [resize back + sinc filter] + JPEG compression
146 | # 2. JPEG compression + [resize back + sinc filter]
147 | # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
148 | if np.random.uniform() < 0.5:
149 | # resize back + the final sinc filter
150 | mode = random.choice(['area', 'bilinear', 'bicubic'])
151 | out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
152 | out = filter2D(out, self.sinc_kernel)
153 | # JPEG compression
154 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
155 | out = torch.clamp(out, 0, 1)
156 | out = self.jpeger(out, quality=jpeg_p)
157 | else:
158 | # JPEG compression
159 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
160 | out = torch.clamp(out, 0, 1)
161 | out = self.jpeger(out, quality=jpeg_p)
162 | # resize back + the final sinc filter
163 | mode = random.choice(['area', 'bilinear', 'bicubic'])
164 | out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
165 | out = filter2D(out, self.sinc_kernel)
166 |
167 | # clamp and round
168 | self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
169 |
170 | # random crop
171 | gt_size = self.opt['gt_size']
172 | self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])
173 |
174 | # training pair pool
175 | self._dequeue_and_enqueue()
176 | self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
177 | else:
178 | # for paired training or validation
179 | self.lq = data['lq'].to(self.device)
180 | if 'gt' in data:
181 | self.gt = data['gt'].to(self.device)
182 | self.gt_usm = self.usm_sharpener(self.gt)
183 |
184 | def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
185 | # do not use the synthetic process during validation
186 | self.is_train = False
187 | super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
188 | self.is_train = True
189 |
--------------------------------------------------------------------------------
/stablefusion/models/realesrgan/realesrgan/train.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | import os.path as osp
3 | from basicsr.train import train_pipeline
4 |
5 | import realesrgan.archs
6 | import realesrgan.data
7 | import realesrgan.models
8 |
9 | if __name__ == '__main__':
10 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
11 | train_pipeline(root_path)
12 |
--------------------------------------------------------------------------------
/stablefusion/models/realesrgan/weights/README.md:
--------------------------------------------------------------------------------
1 | # Weights
2 |
3 | Put the downloaded weights to this folder.
4 |
--------------------------------------------------------------------------------
/stablefusion/models/safetensors_models/t:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuralRealm/StableFusion/9268d686f49b45194a2251c6728e74191f298654/stablefusion/models/safetensors_models/t
--------------------------------------------------------------------------------
/stablefusion/pages/10_Utilities.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | from stablefusion import utils
4 | from stablefusion.scripts.gfp_gan import GFPGAN
5 | from stablefusion.scripts.image_info import ImageInfo
6 | from stablefusion.scripts.interrogator import ImageInterrogator
7 | from stablefusion.scripts.upscaler import Upscaler
8 | from stablefusion.scripts.model_adding import ModelAdding
9 | from stablefusion.scripts.model_removing import ModelRemoving
10 |
11 |
12 | def app():
13 | utils.create_base_page()
14 | task = st.selectbox(
15 | "Choose a utility",
16 | [
17 | "ImageInfo",
18 | "Model Adding",
19 | "Model Removing",
20 | "SD Upscaler",
21 | "GFPGAN",
22 | "CLIP Interrogator",
23 | ],
24 | )
25 | if task == "ImageInfo":
26 | ImageInfo().app()
27 | elif task == "Model Adding":
28 | ModelAdding().app()
29 | elif task == "Model Removing":
30 | ModelRemoving().app()
31 | elif task == "SD Upscaler":
32 | with st.form("upscaler_model"):
33 | upscaler_model = st.text_input("Model", "stabilityai/stable-diffusion-x4-upscaler")
34 | submit = st.form_submit_button("Load model")
35 | if submit:
36 | with st.spinner("Loading model..."):
37 | ups = Upscaler(
38 | model=upscaler_model,
39 | device=st.session_state.device,
40 | output_path=st.session_state.output_path,
41 | )
42 | st.session_state.ups = ups
43 | if "ups" in st.session_state:
44 | st.write(f"Current model: {st.session_state.ups}")
45 | st.session_state.ups.app()
46 |
47 | elif task == "GFPGAN":
48 | with st.spinner("Loading model..."):
49 | gfpgan = GFPGAN(
50 | device=st.session_state.device,
51 | output_path=st.session_state.output_path,
52 | )
53 | gfpgan.app()
54 | elif task == "CLIP Interrogator":
55 | interrogator = ImageInterrogator(
56 | device=st.session_state.device,
57 | output_path=st.session_state.output_path,
58 | )
59 | interrogator.app()
60 |
61 |
62 | if __name__ == "__main__":
63 | app()
64 |
--------------------------------------------------------------------------------
/stablefusion/pages/1_Text2Image.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | from stablefusion.scripts.img2img import Img2Img
4 | from stablefusion.scripts.text2img import Text2Image
5 | from stablefusion.Home import read_model_list
6 | from stablefusion import utils
7 |
8 | def app():
9 | utils.create_base_page()
10 | if "img2img" in st.session_state and "text2img" not in st.session_state:
11 | text2img = Text2Image(
12 | model=st.session_state.img2img.model,
13 | device=st.session_state.device,
14 | output_path=st.session_state.output_path,
15 | )
16 | st.session_state.text2img = text2img
17 | with st.form("text2img_model"):
18 | model = st.selectbox(
19 | "Which model do you want to use?",
20 | options=read_model_list(),
21 | )
22 | # submit_col, _, clear_col = st.columns(3)
23 | # with submit_col:
24 | submit = st.form_submit_button("Load model")
25 | # with clear_col:
26 | # clear = st.form_submit_button("Clear memory")
27 | # if clear:
28 | # clear_memory(preserve="text2img")
29 | if submit:
30 | with st.spinner("Loading model..."):
31 | text2img = Text2Image(
32 | model=model,
33 | device=st.session_state.device,
34 | output_path=st.session_state.output_path,
35 | )
36 | st.session_state.text2img = text2img
37 | img2img = Img2Img(
38 | model=None,
39 | device=st.session_state.device,
40 | output_path=st.session_state.output_path,
41 | text2img_model=text2img.pipeline,
42 | )
43 | st.session_state.img2img = img2img
44 | if "text2img" in st.session_state:
45 | st.write(f"Current model: {st.session_state.text2img}")
46 | st.session_state.text2img.app()
47 |
48 |
49 | if __name__ == "__main__":
50 | app()
51 |
--------------------------------------------------------------------------------
/stablefusion/pages/2_Image2Image.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | from stablefusion.scripts.img2img import Img2Img
4 | from stablefusion.scripts.text2img import Text2Image
5 | from stablefusion.Home import read_model_list
6 | from stablefusion import utils
7 |
8 | def app():
9 | utils.create_base_page()
10 | if "text2img" in st.session_state and "img2img" not in st.session_state:
11 | img2img = Img2Img(
12 | model=None,
13 | device=st.session_state.device,
14 | output_path=st.session_state.output_path,
15 | text2img_model=st.session_state.text2img.pipeline,
16 | )
17 | st.session_state.img2img = img2img
18 | with st.form("img2img_model"):
19 | model = st.selectbox(
20 | "Which model do you want to use?",
21 | options=read_model_list(),
22 | )
23 | submit = st.form_submit_button("Load model")
24 | if submit:
25 | with st.spinner("Loading model..."):
26 | img2img = Img2Img(
27 | model=model,
28 | device=st.session_state.device,
29 | output_path=st.session_state.output_path,
30 | )
31 | st.session_state.img2img = img2img
32 | text2img = Text2Image(
33 | model=model,
34 | device=st.session_state.device,
35 | output_path=st.session_state.output_path,
36 | )
37 | st.session_state.text2img = text2img
38 | if "img2img" in st.session_state:
39 | st.write(f"Current model: {st.session_state.img2img}")
40 | st.session_state.img2img.app()
41 |
42 |
43 | if __name__ == "__main__":
44 | app()
45 |
--------------------------------------------------------------------------------
/stablefusion/pages/3_Inpainting.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | from stablefusion import utils
4 | from stablefusion.scripts.inpainting import Inpainting
5 | from stablefusion.Home import read_model_list
6 |
7 | def app():
8 | utils.create_base_page()
9 | with st.form("inpainting_model_form"):
10 | model = st.selectbox(
11 | "Which model do you want to use for inpainting?",
12 | options=read_model_list()
13 | )
14 | pipeline = st.selectbox(label="Select Your Pipeline: ", options=["StableDiffusionInpaintPipelineLegacy" ,"StableDiffusionInpaintPipeline"])
15 | submit = st.form_submit_button("Load model")
16 | if submit:
17 | st.session_state.inpainting_model = model
18 | with st.spinner("Loading model..."):
19 | inpainting = Inpainting(
20 | model=model,
21 | device=st.session_state.device,
22 | output_path=st.session_state.output_path,
23 | pipeline_select=pipeline
24 | )
25 | st.session_state.inpainting = inpainting
26 | if "inpainting" in st.session_state:
27 | st.write(f"Current model: {st.session_state.inpainting}")
28 | st.session_state.inpainting.app()
29 |
30 |
31 | if __name__ == "__main__":
32 | app()
--------------------------------------------------------------------------------
/stablefusion/pages/4_ControlNet.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | from stablefusion import utils
4 | from stablefusion.scripts.controlnet import Controlnet
5 | from stablefusion.Home import read_model_list
6 |
7 | control_net_model_list = ["lllyasviel/sd-controlnet-canny",
8 | "lllyasviel/sd-controlnet-hed",
9 | "lllyasviel/sd-controlnet-normal",
10 | "lllyasviel/sd-controlnet-scribble",
11 | "lllyasviel/sd-controlnet-depth",
12 | "lllyasviel/sd-controlnet-mlsd",
13 | "lllyasviel/sd-controlnet-openpose",
14 | ]
15 |
16 | processer_list = [ "Canny",
17 | "Hed",
18 | "Normal",
19 | "Scribble",
20 | "Depth",
21 | "Mlsd",
22 | "OpenPose",
23 | ]
24 |
25 |
26 | import streamlit as st
27 |
28 | from stablefusion import utils
29 | from stablefusion.scripts.gfp_gan import GFPGAN
30 | from stablefusion.scripts.image_info import ImageInfo
31 | from stablefusion.scripts.interrogator import ImageInterrogator
32 | from stablefusion.scripts.upscaler import Upscaler
33 | from stablefusion.scripts.model_adding import ModelAdding
34 | from stablefusion.scripts.model_removing import ModelRemoving
35 |
36 |
37 | def app():
38 | utils.create_base_page()
39 | with st.form("inpainting_model_form"):
40 | base_model = st.selectbox(
41 | "Base model For your ControlNet: ",
42 | options=read_model_list()
43 | )
44 | controlnet_model = st.selectbox(label="Choose Your ControlNet: ", options=control_net_model_list)
45 | processer = st.selectbox(label="Choose Your Processer: ", options=processer_list)
46 | submit = st.form_submit_button("Load model")
47 | if submit:
48 | st.session_state.controlnet_models = base_model
49 | with st.spinner("Loading model..."):
50 | controlnet = Controlnet(
51 | model=base_model,
52 | device=st.session_state.device,
53 | output_path=st.session_state.output_path,
54 | controlnet_model=controlnet_model,
55 | processer = processer
56 | )
57 | st.session_state.controlnet = controlnet
58 | if "controlnet" in st.session_state:
59 | st.write(f"Current model: {st.session_state.controlnet}")
60 | st.session_state.controlnet.app()
61 |
62 |
63 | if __name__ == "__main__":
64 | app()
--------------------------------------------------------------------------------
/stablefusion/pages/5_OpenPose_Editor.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | from stablefusion import utils
4 | from stablefusion.scripts.pose_editor import OpenPoseEditor
5 | from stablefusion.Home import read_model_list
6 |
7 | def app():
8 | utils.create_base_page()
9 | with st.form("openpose_editor_form"):
10 | model = st.selectbox(
11 | "Which model do you want to use for OpenPose?",
12 | options=read_model_list()
13 | )
14 | submit = st.form_submit_button("Load model")
15 | if submit:
16 | st.session_state.openpose_editor = model
17 | with st.spinner("Loading model..."):
18 | openpose = OpenPoseEditor(
19 | model=model,
20 | device=st.session_state.device,
21 | output_path=st.session_state.output_path,
22 | )
23 | st.session_state.openpose_editor = openpose
24 | if "openpose_editor" in st.session_state:
25 | st.write(f"Current model: {st.session_state.openpose_editor}")
26 | st.session_state.openpose_editor.app()
27 |
28 |
29 | if __name__ == "__main__":
30 | app()
--------------------------------------------------------------------------------
/stablefusion/pages/6_Textual Inversion.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | from stablefusion.scripts.textual_inversion import TextualInversion
4 | from stablefusion import utils
5 | from stablefusion.Home import read_model_list
6 |
7 | def app():
8 | utils.create_base_page()
9 | with st.form("textual_inversion_form"):
10 | model = st.selectbox(
11 | "Which base model do you want to use?",
12 | options=read_model_list(),
13 | )
14 | token_identifier = st.text_input(
15 | "Token identifier",
16 | value=""
17 | if st.session_state.get("textual_inversion_token_identifier") is None
18 | else st.session_state.textual_inversion_token_identifier,
19 | )
20 | embeddings = st.text_input(
21 | "Embeddings",
22 | value="https://huggingface.co/sd-concepts-library/axe-tattoo/resolve/main/learned_embeds.bin"
23 | if st.session_state.get("textual_inversion_embeddings") is None
24 | else st.session_state.textual_inversion_embeddings,
25 | )
26 | # st.file_uploader("Embeddings", type=["pt", "bin"])
27 | submit = st.form_submit_button("Load model")
28 | if submit:
29 | st.session_state.textual_inversion_model = model
30 | st.session_state.textual_inversion_token_identifier = token_identifier
31 | st.session_state.textual_inversion_embeddings = embeddings
32 | with st.spinner("Loading model..."):
33 | textual_inversion = TextualInversion(
34 | model=model,
35 | token_identifier=token_identifier,
36 | embeddings_url=embeddings,
37 | device=st.session_state.device,
38 | output_path=st.session_state.output_path,
39 | )
40 | st.session_state.textual_inversion = textual_inversion
41 | if "textual_inversion" in st.session_state:
42 | st.write(f"Current model: {st.session_state.textual_inversion}")
43 | st.session_state.textual_inversion.app()
44 |
45 |
46 | if __name__ == "__main__":
47 | app()
48 |
--------------------------------------------------------------------------------
/stablefusion/pages/7_Upscaler.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | from stablefusion import utils
4 | from stablefusion.models.realesrgan import inference_realesrgan
5 | from stablefusion.scripts.upscaler import Upscaler
6 | from stablefusion.scripts.gfp_gan import GFPGAN
7 |
8 |
9 | def app():
10 | utils.create_base_page()
11 | task = st.selectbox(
12 | label="Choose Your Upsacler",
13 | options=[
14 | "RealESRGAN",
15 | "SD Upscaler",
16 | "GFPGAN",
17 | ],
18 | )
19 |
20 | if task == "RealESRGAN":
21 |
22 | model_name = st.selectbox(
23 | label="Choose Your Model",
24 | options=[
25 | "RealESRGAN_x4plus",
26 | "RealESRNet_x4plus",
27 | "RealESRGAN_x4plus_anime_6B",
28 | "RealESRGAN_x2plus",
29 | "realesr-animevideov3",
30 | "realesr-general-x4v3"
31 | ],
32 | )
33 |
34 | input_image = st.file_uploader(label="Upload The Picture", type=["png", "jpg", "jpeg"])
35 |
36 | col1, col2 = st.columns(2)
37 | with col1:
38 | denoise_strength = st.slider(label="Select your Denoise Strength", min_value=0.0 , max_value=1.0, value=0.5, step=0.1, help="Denoise strength. 0 for weak denoise (keep noise), 1 for strong denoise ability.")
39 | outscale = st.slider(label="Select your final Upsampling scale", min_value=0, max_value=4, value=4, step=1, help="The final upsampling scale of the image")
40 |
41 | with col2:
42 | face_enhance = st.selectbox(label="Do you want to Inhance Face", options=[True, False], help="Use GFPGAN to enhance face")
43 | alpha_upsampler = st.selectbox(label="The upsampler for the alpha channels", options=["realesrgan", "bicubic"])
44 |
45 | if st.button("Start Upscaling"):
46 | with st.spinner("Upscaling The Image..."):
47 | inference_realesrgan.main(model_name=model_name, outscale=outscale, denoise_strength=denoise_strength, face_enhance=face_enhance, tile=0, tile_pad=10, pre_pad=0, fp32="fp32", gpu_id=None, input_image=input_image, model_path=None)
48 |
49 |
50 | elif task == "SD Upscaler":
51 | with st.form("upscaler_model"):
52 | upscaler_model = st.text_input("Model", "stabilityai/stable-diffusion-x4-upscaler")
53 | submit = st.form_submit_button("Load model")
54 | if submit:
55 | with st.spinner("Loading model..."):
56 | ups = Upscaler(
57 | model=upscaler_model,
58 | device=st.session_state.device,
59 | output_path=st.session_state.output_path,
60 | )
61 | st.session_state.ups = ups
62 | if "ups" in st.session_state:
63 | st.write(f"Current model: {st.session_state.ups}")
64 | st.session_state.ups.app()
65 |
66 | elif task == "GFPGAN":
67 | with st.spinner("Loading model..."):
68 | gfpgan = GFPGAN(
69 | device=st.session_state.device,
70 | output_path=st.session_state.output_path,
71 | )
72 | gfpgan.app()
73 |
74 | if __name__ == "__main__":
75 | app()
76 |
--------------------------------------------------------------------------------
/stablefusion/pages/8_Convertor.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | from stablefusion import utils
4 | from stablefusion.scripts.gfp_gan import GFPGAN
5 | from stablefusion.scripts.image_info import ImageInfo
6 | from stablefusion.scripts.interrogator import ImageInterrogator
7 | from stablefusion.scripts.upscaler import Upscaler
8 | from stablefusion.scripts.model_adding import ModelAdding
9 | from stablefusion.scripts.model_removing import ModelRemoving
10 | from stablefusion.scripts.ckpt_to_diffusion import convert_ckpt_to_diffusion
11 | from stablefusion.scripts.safetensors_to_diffusion import convert_safetensor_to_diffusers
12 |
13 |
14 | def app():
15 | utils.create_base_page()
16 | task = st.selectbox(
17 | "Choose a Convertor",
18 | [
19 | "CKPT to Diffusers",
20 | "Safetensors to Diffusers",
21 | ],
22 | )
23 | if task == "CKPT to Diffusers":
24 |
25 | with st.form("Convert Your CKPT to Diffusers"):
26 | ckpt_model_name = st.text_input(label="Name of the CKPT Model", value="NameOfYourModel.ckpt")
27 | ckpt_model = st.text_input(label="Download Link of CKPT Model path: ", value="https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.0-pruned-fp16.ckpt", help="Path to the checkpoint to convert.")
28 | overwrite_mode = st.selectbox("Do You want to overwrite file: ", options=[False, True], help="If the ckpt files with that same name is present do you want to overwrite the file or want to use that file instead of downloading it again?")
29 | st.subheader("Advance Settings")
30 | st.text("Don't Change anything if you are not sure about it.")
31 |
32 | config_file = st.text_input(label="Enter The Config File: ", value=None, help="The YAML config file corresponding to the original architecture.")
33 |
34 | col1, col2 = st.columns(2)
35 |
36 | with col1:
37 | num_in_channels = st.text_input(label="Enter The Number Of Channels", value=None, help="The number of input channels. If `None` number of input channels will be automatically inferred.")
38 | scheduler_type = st.selectbox(label="Select the Scheduler: ", options=['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm'], help="Type of scheduler to use.")
39 | pipeline_type = st.text_input(label="Enter the pipeline type: ", value=None, help="The pipeline type. If `None` pipeline will be automatically inferred.")
40 |
41 | with col2:
42 | image_size = st.selectbox(label="Image Size", options=[None, "512", "768"], help="The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2 Base. Use 768 for Stable Diffusion v2.")
43 | prediction_type = st.selectbox(label="Select your prediction type", options=[None, "epsilon", "v-prediction"])
44 | extract_ema = st.selectbox(label="Extract the EMA", options=[True, False], help="Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning.")
45 |
46 | device = None
47 |
48 | submit = st.form_submit_button("Start Converting")
49 |
50 | if submit:
51 | with st.spinner("Converting..."):
52 | convert_ckpt_to_diffusion(device=device, checkpoint_link=ckpt_model, checkpoint_name=ckpt_model_name, num_in_channels=num_in_channels, scheduler_type=scheduler_type, pipeline_type=pipeline_type, extract_ema=extract_ema, dump_path=" ", image_size=image_size, original_config_file=config_file, prediction_type=prediction_type, overwrite_file=overwrite_mode)
53 |
54 |
55 | elif task == "Safetensors to Diffusers":
56 |
57 | with st.form("Convert Your Safetensors to Diffusers"):
58 | stabletensor_model_name = st.text_input(label="Name of the Safetensors Model", value="NameOfYourModel.safetensors")
59 | stabletensor_model = st.text_input(label="Download Link of Safetensors Model path: ", value="https://civitai.com/api/download/models/4007?type=Model&format=SafeTensor", help="Path to the checkpoint to convert.")
60 | overwrite_mode = st.selectbox("Do You want to overwrite file: ", options=[False, True], help="If the Stabletensor files with that same name is present do you want to overwrite the file or want to use that file instead of downloading it again?")
61 | st.subheader("Advance Settings")
62 | st.text("Don't Change anything if you are not sure about it.")
63 |
64 | config_file = st.text_input(label="Enter The Config File: ", value=None, help="The YAML config file corresponding to the original architecture.")
65 |
66 | col1, col2 = st.columns(2)
67 |
68 | with col1:
69 | num_in_channels = st.text_input(label="Enter The Number Of Channels", value=None, help="The number of input channels. If `None` number of input channels will be automatically inferred.")
70 | scheduler_type = st.selectbox(label="Select the Scheduler: ", options=['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm'], help="Type of scheduler to use.")
71 | pipeline_type = st.text_input(label="Enter the pipeline type: ", value=None, help="The pipeline type. If `None` pipeline will be automatically inferred.")
72 |
73 | with col2:
74 | image_size = st.selectbox(label="Image Size", options=[None, "512", "768"], help="The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2 Base. Use 768 for Stable Diffusion v2.")
75 | prediction_type = st.selectbox(label="Select your prediction type", options=[None, "epsilon", "v-prediction"])
76 | extract_ema = st.selectbox(label="Extract the EMA", options=[True, False], help="Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning.")
77 |
78 | device = None
79 |
80 | submit = st.form_submit_button("Start Converting")
81 |
82 | if submit:
83 | with st.spinner("Converting..."):
84 | if config_file == "None":
85 | config_file = None
86 |
87 | if num_in_channels == "None":
88 | num_in_channels = None
89 |
90 | if pipeline_type == "None":
91 | pipeline_type = None
92 | convert_safetensor_to_diffusers(original_config_file=config_file, image_size=image_size, prediction_type=prediction_type, pipeline_type=pipeline_type, extract_ema=extract_ema, scheduler_type=scheduler_type, num_in_channels=num_in_channels, upcast_attention=False, from_safetensors=True, device=device, stable_unclip=None, stable_unclip_prior=None, clip_stats_path=None, controlnet=None, to_safetensors=None, checkpoint_name=stabletensor_model_name, checkpoint_link=stabletensor_model, overwrite=overwrite_mode)
93 |
94 |
95 |
96 | if __name__ == "__main__":
97 | app()
--------------------------------------------------------------------------------
/stablefusion/pages/9_Train.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from stablefusion import utils
3 | from stablefusion.scripts.dreambooth import train_dreambooth
4 | from stablefusion.Home import read_model_list
5 | import json
6 | import os
7 |
8 |
9 | def dump_concept_file(instance_prompt, class_prompt, instance_data_dir, class_data_dir):
10 |
11 | concepts_list = [
12 | {
13 | "instance_prompt": instance_prompt,
14 | "class_prompt": class_prompt,
15 | "instance_data_dir": "{}/trainner_assets/instance_data/{}".format(utils.base_path(), instance_data_dir),
16 | "class_data_dir": "{}/trainner_assets/class_data/{}".format(utils.base_path(), class_data_dir)
17 | },
18 | ]
19 | for c in concepts_list:
20 | os.makedirs(c["instance_data_dir"], exist_ok=True)
21 |
22 | with open("{}/trainner_assets/concepts_list.json".format(utils.base_path()), "w") as f:
23 | json.dump(concepts_list, f, indent=4)
24 |
25 |
26 | def app():
27 | utils.create_base_page()
28 | task = st.selectbox(
29 | "Choose a Convertor",
30 | [
31 | "Dreambooth",
32 | ],
33 | )
34 | if task == "Dreambooth":
35 |
36 | with st.form("Train With Dreambooth"):
37 |
38 | col1, col2 = st.columns(2)
39 | with col1:
40 | base_model = st.selectbox(label="Name of the Base Model", options=read_model_list(), help="Path to pretrained model or model identifier from huggingface.co/models.")
41 | instance_prompt = st.text_input(label="The prompt with identifier specifying the instance", value="photo of elon musk person", help="replace elon musk name with your object or person name")
42 | class_prompt = st.text_input(label="The prompt to specify images in the same class as provided instance images.", value="photo of a person", help="replace person with object if you are taining for any kind of object")
43 | instance_data_dir_name = st.text_input(label="The prompt with identifier specifying the instance(folder): ", value="elon musk", help="A folder containing the training data of instance images")
44 | class_data_dir_name = st.text_input(label="A Name of the training data of class images(folder): ", value="person", help="A folder containing the training data of class images.")
45 | output_dir_name = st.text_input(label="Name of your Trainned model(folder): ", value="elon musk", help="The output directory name where the model predictions and checkpoints will be written.")
46 | seed = st.number_input(label="Enter The Seed", value=1337, step=1, help="A seed for reproducible training.")
47 |
48 | with col2:
49 | resolution = st.number_input(label="Enter The Image Resolution: ", value=512, step=1, help="The resolution for input images, all the images in the train/validation dataset will be resized to this resolution")
50 | train_batch_size = st.number_input(label="Train Batch Size", value=4, step=1, help="Batch size (per device) for sampling images.")
51 | learning_rate = st.number_input(label="Learning Rate", value=6, help="Initial learning rate (after the potential warmup period) to use. float(1) / float(10**8) for 1e-8")
52 | num_class_images = st.number_input(label="Number of Class Images", value=100, step=1, help="Minimal class images for prior preservation loss. If there are not enough images already present in class_data_dir, additional images will be sampled with class_prompt.")
53 | sample_batch_size = st.number_input(label="Sample Batch Size: ", value=4, help="Batch size (per device) for sampling images.")
54 | max_train_steps = st.number_input(label="Max Train Steps: ", value=20, help="Total number of training steps to perform.")
55 | save_interval = st.number_input(label="Save Interal Steps: ", value=10000, help="Save weights every N steps.")
56 |
57 | st.subheader("Advance Settings")
58 | st.text("Don't Change anything if you are not sure about it.")
59 | col3, col4 = st.columns(2)
60 | with col3:
61 | revision = st.text_input(label="Revision of pretrained model identifier: ", value=None)
62 | prior_loss_weight_condition = st.selectbox(label="Want to use prior_loss_weight: ", options=[True, False])
63 | if prior_loss_weight_condition:
64 | prior_loss_weight = st.slider(label="Prior Loss Weight: ",value=1.0, max_value=10.0, help="The weight of prior preservation loss.")
65 | else:
66 | prior_loss_weight = 0
67 |
68 | train_text_encoder = st.selectbox("Whether to train the text encoder", options=[True, False])
69 | log_interval = st.number_input(label="Save Log Interval: ", value=10, help="Log every N steps.")
70 | tokenizer_name = st.text_input("Enter the tokenizer name", value=None, help="Pretrained tokenizer name or path if not the same as model_name")
71 |
72 | with col4:
73 | use_8bit_adam = st.selectbox(label="Want to use 8bit adam: ", options=[True, False], help="Whether or not to use 8-bit Adam from bitsandbytes.")
74 | gradient_accumulation_steps = st.slider(label="Gradient Accumulation steps", value=1, help="Number of updates steps to accumulate before performing a backward/update pass.")
75 | lr_scheduler = st.selectbox(label="Select Your Learning Schedular: ", options=["constant", "cosine", "cosine_with_restarts", "polynomial","linear", "constant_with_warmup"])
76 | lr_warmup_steps = st.number_input(label="Learning Rate, Warmup Steps: ", value=50, step=1, help="Number of steps for the warmup in the lr scheduler.")
77 | mixed_precision = st.selectbox(label="Whether to use mixed precision", options=["no", "fp16", "bf16"], help="Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.")
78 |
79 |
80 | submit = st.form_submit_button("Start Trainning")
81 |
82 | if submit:
83 | with st.spinner("Trainning..."):
84 |
85 | output_dir = "{}/trainner_assets/stable_diffusion_weights/{}".format(utils.base_path(), output_dir_name)
86 |
87 | learning_rate = 1 * 10**(-learning_rate)
88 |
89 | dump_concept_file(instance_prompt=instance_prompt, class_prompt=class_prompt, instance_data_dir=instance_data_dir_name, class_data_dir=class_data_dir_name)
90 | train_dreambooth.main(pretrained_model_name_or_path=base_model, revision=revision, output_dir=output_dir, with_prior_preservation=prior_loss_weight, prior_loss_weight=prior_loss_weight, seed=seed, resolution=resolution, train_batch_size=train_batch_size, train_text_encoder=train_text_encoder, mixed_precision=mixed_precision, use_8bit_adam=use_8bit_adam, gradient_accumulation_steps=gradient_accumulation_steps, learning_rate=learning_rate, lr_scheduler=lr_scheduler, lr_warmup_steps=lr_warmup_steps, num_class_images=num_class_images, sample_batch_size=sample_batch_size, max_train_steps=max_train_steps, save_interval=save_interval, save_sample_prompt=None, log_interval=log_interval, tokenizer_name=tokenizer_name)
91 |
92 |
93 |
94 | if __name__ == "__main__":
95 | app()
96 |
--------------------------------------------------------------------------------
/stablefusion/scripts/dreambooth/requirements_flax.txt:
--------------------------------------------------------------------------------
1 | transformers>=4.25.1
2 | flax
3 | optax
4 | torch
5 | torchvision
6 | ftfy
7 | tensorboard
8 | Jinja2
9 |
--------------------------------------------------------------------------------
/stablefusion/scripts/gfp_gan.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import os
3 | import shutil
4 | from dataclasses import dataclass
5 | from io import BytesIO
6 | from typing import Optional
7 | import datetime
8 | import cv2
9 | import numpy as np
10 | import streamlit as st
11 | import torch
12 | from basicsr.archs.srvgg_arch import SRVGGNetCompact
13 | from gfpgan.utils import GFPGANer
14 | from loguru import logger
15 | from PIL import Image
16 | from realesrgan.utils import RealESRGANer
17 |
18 | from stablefusion import utils
19 |
20 |
21 | @dataclass
22 | class GFPGAN:
23 | device: Optional[str] = None
24 | output_path: Optional[str] = None
25 |
26 | def __str__(self) -> str:
27 | return f"GFPGAN(device={self.device}, output_path={self.output_path})"
28 |
29 | def __post_init__(self):
30 | files = {
31 | "realesr-general-x4v3.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
32 | "v1.2": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth",
33 | "v1.3": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
34 | "v1.4": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
35 | "RestoreFormer": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth",
36 | "CodeFormer": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth",
37 | }
38 | _ = utils.cache_folder()
39 | self.model_paths = {}
40 | for file_key, file in files.items():
41 | logger.info(f"Downloading {file_key} from {file}")
42 | basename = os.path.basename(file)
43 | output_path = os.path.join(utils.cache_folder(), basename)
44 | if os.path.exists(output_path):
45 | self.model_paths[file_key] = output_path
46 | continue
47 | temp_file = utils.download_file(file)
48 | shutil.move(temp_file, output_path)
49 | self.model_paths[file_key] = output_path
50 |
51 | self.model = SRVGGNetCompact(
52 | num_in_ch=3,
53 | num_out_ch=3,
54 | num_feat=64,
55 | num_conv=32,
56 | upscale=4,
57 | act_type="prelu",
58 | )
59 | model_path = os.path.join(utils.cache_folder(), self.model_paths["realesr-general-x4v3.pth"])
60 | half = True if torch.cuda.is_available() else False
61 | self.upsampler = RealESRGANer(
62 | scale=4,
63 | model_path=model_path,
64 | model=self.model,
65 | tile=0,
66 | tile_pad=10,
67 | pre_pad=0,
68 | half=half,
69 | )
70 |
71 | def inference(self, img, version, scale):
72 | # taken from: https://huggingface.co/spaces/Xintao/GFPGAN/blob/main/app.py
73 | # weight /= 100
74 | if scale > 4:
75 | scale = 4 # avoid too large scale value
76 |
77 | file_bytes = np.asarray(bytearray(img.read()), dtype=np.uint8)
78 | img = cv2.imdecode(file_bytes, 1)
79 | # img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
80 | if len(img.shape) == 2: # for gray inputs
81 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
82 |
83 | h, w = img.shape[0:2]
84 | if h < 300:
85 | img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
86 |
87 | if version == "v1.2":
88 | face_enhancer = GFPGANer(
89 | model_path=self.model_paths["v1.2"],
90 | upscale=2,
91 | arch="clean",
92 | channel_multiplier=2,
93 | bg_upsampler=self.upsampler,
94 | )
95 | elif version == "v1.3":
96 | face_enhancer = GFPGANer(
97 | model_path=self.model_paths["v1.3"],
98 | upscale=2,
99 | arch="clean",
100 | channel_multiplier=2,
101 | bg_upsampler=self.upsampler,
102 | )
103 | elif version == "v1.4":
104 | face_enhancer = GFPGANer(
105 | model_path=self.model_paths["v1.4"],
106 | upscale=2,
107 | arch="clean",
108 | channel_multiplier=2,
109 | bg_upsampler=self.upsampler,
110 | )
111 | elif version == "RestoreFormer":
112 | face_enhancer = GFPGANer(
113 | model_path=self.model_paths["RestoreFormer"],
114 | upscale=2,
115 | arch="RestoreFormer",
116 | channel_multiplier=2,
117 | bg_upsampler=self.upsampler,
118 | )
119 | # elif version == 'CodeFormer':
120 | # face_enhancer = GFPGANer(
121 | # model_path='CodeFormer.pth', upscale=2, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler)
122 | try:
123 | # _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True, weight=weight)
124 | _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
125 | except RuntimeError as error:
126 | logger.error("Error", error)
127 |
128 | try:
129 | if scale != 2:
130 | interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
131 | h, w = img.shape[0:2]
132 | output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
133 | except Exception as error:
134 | logger.error("wrong scale input.", error)
135 |
136 | output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
137 | return output
138 |
139 | def app(self):
140 | input_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
141 | if input_image is not None:
142 | st.image(input_image, width=512)
143 | with st.form(key="gfpgan"):
144 | version = st.selectbox("GFPGAN version", ["v1.2", "v1.3", "v1.4", "RestoreFormer"])
145 | scale = st.slider("Scale", 2, 4, 4, 1)
146 | submit = st.form_submit_button("Upscale")
147 | if submit:
148 | if input_image is not None:
149 | with st.spinner("Upscaling image..."):
150 | output_img = self.inference(input_image, version, scale)
151 | st.image(output_img, width=512)
152 | # add image download button
153 | output_img = Image.fromarray(output_img)
154 | buffered = BytesIO()
155 | output_img.save(buffered, format="PNG")
156 | img_str = base64.b64encode(buffered.getvalue()).decode()
157 | now = datetime.datetime.now()
158 | formatted_date_time = now.strftime("%Y-%m-%d_%H_%M_%S")
159 | href = f'Download Image
'
160 | st.markdown(href, unsafe_allow_html=True)
161 |
--------------------------------------------------------------------------------
/stablefusion/scripts/image_info.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import streamlit as st
4 | from PIL import Image
5 |
6 |
7 | @dataclass
8 | class ImageInfo:
9 | def app(self):
10 | # upload image
11 | uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
12 | if uploaded_file is not None:
13 | # read image using pil
14 | pil_image = Image.open(uploaded_file)
15 | st.image(uploaded_file, use_column_width=True)
16 | image_info = pil_image.info
17 | # display image info
18 | st.write(image_info)
19 |
--------------------------------------------------------------------------------
/stablefusion/scripts/img2img.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import json
3 | from dataclasses import dataclass
4 | from io import BytesIO
5 | from typing import Optional, Union
6 | import random
7 | import requests
8 | import streamlit as st
9 | import torch
10 | from diffusers import (
11 | AltDiffusionImg2ImgPipeline,
12 | AltDiffusionPipeline,
13 | DiffusionPipeline,
14 | StableDiffusionImg2ImgPipeline,
15 | StableDiffusionPipeline,
16 | )
17 | from loguru import logger
18 | from PIL import Image
19 | from PIL.PngImagePlugin import PngInfo
20 |
21 | from stablefusion import utils
22 |
23 |
24 | @dataclass
25 | class Img2Img:
26 | model: Optional[str] = None
27 | device: Optional[str] = None
28 | output_path: Optional[str] = None
29 | text2img_model: Optional[Union[StableDiffusionPipeline, AltDiffusionPipeline]] = None
30 |
31 | def __str__(self) -> str:
32 | return f"Img2Img(model={self.model}, device={self.device}, output_path={self.output_path})"
33 |
34 | def __post_init__(self):
35 | if self.model is not None:
36 | self.text2img_model = DiffusionPipeline.from_pretrained(
37 | self.model,
38 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
39 | )
40 | components = self.text2img_model.components
41 | print(components)
42 | if isinstance(self.text2img_model, StableDiffusionPipeline):
43 | self.pipeline = StableDiffusionImg2ImgPipeline(**components)
44 | elif isinstance(self.text2img_model, AltDiffusionPipeline):
45 | self.pipeline = AltDiffusionImg2ImgPipeline(**components)
46 | else:
47 | raise ValueError("Model type not supported")
48 |
49 | self.pipeline.to(self.device)
50 | self.pipeline.safety_checker = utils.no_safety_checker
51 | self._compatible_schedulers = self.pipeline.scheduler.compatibles
52 | self.scheduler_config = self.pipeline.scheduler.config
53 | self.compatible_schedulers = {scheduler.__name__: scheduler for scheduler in self._compatible_schedulers}
54 |
55 | if self.device == "mps":
56 | self.pipeline.enable_attention_slicing()
57 | # warmup
58 | url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
59 | response = requests.get(url)
60 | init_image = Image.open(BytesIO(response.content)).convert("RGB")
61 | init_image.thumbnail((768, 768))
62 | prompt = "A fantasy landscape, trending on artstation"
63 | _ = self.pipeline(
64 | prompt=prompt,
65 | image=init_image,
66 | strength=0.75,
67 | guidance_scale=7.5,
68 | num_inference_steps=2,
69 | )
70 |
71 | def _set_scheduler(self, scheduler_name):
72 | scheduler = self.compatible_schedulers[scheduler_name].from_config(self.scheduler_config)
73 | self.pipeline.scheduler = scheduler
74 |
75 | def generate_image(
76 | self, prompt, image, strength, negative_prompt, scheduler, num_images, guidance_scale, steps, seed
77 | ):
78 | self._set_scheduler(scheduler)
79 | logger.info(self.pipeline.scheduler)
80 | if self.device == "mps":
81 | generator = torch.manual_seed(seed)
82 | num_images = 1
83 | else:
84 | generator = torch.Generator(device=self.device).manual_seed(seed)
85 | num_images = int(num_images)
86 | output_images = self.pipeline(
87 | prompt=prompt,
88 | image=image,
89 | strength=strength,
90 | negative_prompt=negative_prompt,
91 | num_inference_steps=steps,
92 | guidance_scale=guidance_scale,
93 | num_images_per_prompt=num_images,
94 | generator=generator,
95 | ).images
96 | torch.cuda.empty_cache()
97 | gc.collect()
98 | metadata = {
99 | "prompt": prompt,
100 | "negative_prompt": negative_prompt,
101 | "scheduler": scheduler,
102 | "num_images": num_images,
103 | "guidance_scale": guidance_scale,
104 | "steps": steps,
105 | "seed": seed,
106 | }
107 | metadata = json.dumps(metadata)
108 | _metadata = PngInfo()
109 | _metadata.add_text("img2img", metadata)
110 |
111 | utils.save_images(
112 | images=output_images,
113 | module="img2img",
114 | metadata=metadata,
115 | output_path=self.output_path,
116 | )
117 | return output_images, _metadata
118 |
119 | def app(self):
120 | available_schedulers = list(self.compatible_schedulers.keys())
121 | # if EulerAncestralDiscreteScheduler is available in available_schedulers, move it to the first position
122 | if "EulerAncestralDiscreteScheduler" in available_schedulers:
123 | available_schedulers.insert(
124 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler"))
125 | )
126 |
127 | input_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
128 | if input_image is not None:
129 | input_image = Image.open(input_image)
130 | st.image(input_image, use_column_width=True)
131 |
132 | # with st.form(key="img2img"):
133 | col1, col2 = st.columns(2)
134 | with col1:
135 | prompt = st.text_area("Prompt", "", help="Prompt to guide image generation")
136 | with col2:
137 | negative_prompt = st.text_area("Negative Prompt", "", help="The prompt not to guide image generation. Write things that you dont want to see in the image.")
138 |
139 | scheduler = st.sidebar.selectbox("Scheduler", available_schedulers, index=0, help="Scheduler(Sampler) to use for generation")
140 | guidance_scale = st.sidebar.slider("Guidance scale", 1.0, 40.0, 7.5, 0.5, help="Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.")
141 | strength = st.sidebar.slider("Denoise Strength", 0.0, 1.0, 0.5, 0.05)
142 | num_images = st.sidebar.slider("Number of images per prompt", 1, 30, 1, 1, help="Number of images you want to generate. More images requires more time and uses more GPU memory.")
143 | steps = st.sidebar.slider("Steps", 1, 150, 50, 1, help="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.")
144 | seed_choice = st.sidebar.selectbox("Do you want a random seed", options=["Yes", "No"])
145 | if seed_choice == "Yes":
146 | seed = random.randint(0, 9999999)
147 | else:
148 | seed = st.sidebar.number_input(
149 | "Seed",
150 | value=42,
151 | step=1,
152 | help="Random seed. Change for different results using same parameters.",
153 | )
154 | # seed = st.sidebar.number_input("Seed", 1, 999999, 1, 1)
155 | sub_col, download_col = st.columns(2)
156 | with sub_col:
157 | submit = st.button("Generate")
158 |
159 | if submit:
160 | with st.spinner("Generating images..."):
161 | output_images, metadata = self.generate_image(
162 | prompt=prompt,
163 | image=input_image,
164 | negative_prompt=negative_prompt,
165 | scheduler=scheduler,
166 | num_images=num_images,
167 | guidance_scale=guidance_scale,
168 | steps=steps,
169 | seed=seed,
170 | strength=strength,
171 | )
172 |
173 | utils.display_and_download_images(output_images, metadata, download_col)
174 |
--------------------------------------------------------------------------------
/stablefusion/scripts/interrogator.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass
3 | from typing import Optional
4 |
5 | import streamlit as st
6 | from stablefusion.scripts.clip_interrogator import Config, Interrogator
7 | from huggingface_hub import hf_hub_download
8 | from loguru import logger
9 | from PIL import Image
10 |
11 | from stablefusion import utils
12 |
13 |
14 | @dataclass
15 | class ImageInterrogator:
16 | device: Optional[str] = None
17 | output_path: Optional[str] = None
18 |
19 | def inference(self, model, image, mode):
20 | preprocess_files = [
21 | "ViT-H-14_laion2b_s32b_b79k_artists.pkl",
22 | "ViT-H-14_laion2b_s32b_b79k_flavors.pkl",
23 | "ViT-H-14_laion2b_s32b_b79k_mediums.pkl",
24 | "ViT-H-14_laion2b_s32b_b79k_movements.pkl",
25 | "ViT-H-14_laion2b_s32b_b79k_trendings.pkl",
26 | "ViT-L-14_openai_artists.pkl",
27 | "ViT-L-14_openai_flavors.pkl",
28 | "ViT-L-14_openai_mediums.pkl",
29 | "ViT-L-14_openai_movements.pkl",
30 | "ViT-L-14_openai_trendings.pkl",
31 | ]
32 |
33 | logger.info("Downloading preprocessed cache files...")
34 | for file in preprocess_files:
35 | path = hf_hub_download(repo_id="pharma/ci-preprocess", filename=file, cache_dir=utils.cache_folder())
36 | cache_path = os.path.dirname(path)
37 |
38 | config = Config(cache_path=cache_path, clip_model_path=utils.cache_folder(), clip_model_name=model)
39 | pipeline = Interrogator(config)
40 |
41 | pipeline.config.blip_num_beams = 64
42 | pipeline.config.chunk_size = 2048
43 | pipeline.config.flavor_intermediate_count = 2048 if model == "ViT-L-14/openai" else 1024
44 |
45 | if mode == "best":
46 | prompt = pipeline.interrogate(image)
47 | elif mode == "classic":
48 | prompt = pipeline.interrogate_classic(image)
49 | else:
50 | prompt = pipeline.interrogate_fast(image)
51 | return prompt
52 |
53 | def app(self):
54 | # upload image
55 | input_image = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
56 | with st.form(key="image_interrogator"):
57 | clip_model = st.selectbox("CLIP Model", ["ViT-L (Best for SD1.X)", "ViT-H (Best for SD2.X)"])
58 | mode = st.selectbox("Mode", ["Best", "Classic"])
59 | submit = st.form_submit_button("Interrogate")
60 | if input_image is not None:
61 | # read image using pil
62 | pil_image = Image.open(input_image).convert("RGB")
63 | if submit:
64 | with st.spinner("Interrogating image..."):
65 | if clip_model == "ViT-L (Best for SD1.X)":
66 | model = "ViT-L-14/openai"
67 | else:
68 | model = "ViT-H-14/laion2b_s32b_b79k"
69 | prompt = self.inference(model, pil_image, mode.lower())
70 | col1, col2 = st.columns(2)
71 | with col1:
72 | st.image(input_image, use_column_width=True)
73 | with col2:
74 | st.write(prompt)
75 |
--------------------------------------------------------------------------------
/stablefusion/scripts/model_adding.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import streamlit as st
3 | import ast
4 | import os
5 | from stablefusion.utils import base_path
6 |
7 | base_path = base_path()
8 |
9 | @dataclass
10 | class ModelAdding:
11 |
12 | def read_model_list(self):
13 |
14 | try:
15 | with open('{}/model_list.txt'.format(base_path), 'r') as f:
16 | contents = f.read()
17 | except:
18 | with open('stablefusion/model_list.txt', 'r') as f:
19 | contents = f.read()
20 | model_list = ast.literal_eval(contents)
21 |
22 | return model_list
23 |
24 | def write_model_list(self, model_list):
25 |
26 | try:
27 | with open('{}/model_list.txt'.format(base_path), 'w') as f:
28 | f.write(model_list)
29 | except:
30 | with open('stablefusion/model_list.txt', 'w') as f:
31 | f.write(model_list)
32 |
33 | def check_models(self, model_name):
34 |
35 | model_list = self.read_model_list()
36 |
37 | if model_name in model_list:
38 | st.warning("{} already present in the list".format(model_name))
39 |
40 | else:
41 | model_list.append(model_name)
42 | self.write_model_list(model_list=str(model_list))
43 | st.success("Succefully added {} into your list".format(model_name))
44 |
45 |
46 | def app(self):
47 | # upload image
48 | model_name = st.text_input(label="Enter The Model Name", value="runwayml/stable-diffusion-v1-5")
49 |
50 | if st.button("Apply Changes"):
51 | self.check_models(model_name=model_name)
52 |
53 |
--------------------------------------------------------------------------------
/stablefusion/scripts/model_removing.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import streamlit as st
3 | import ast
4 | import os
5 | from stablefusion.utils import base_path
6 |
7 |
8 | base_path = base_path()
9 |
10 | @dataclass
11 | class ModelRemoving:
12 |
13 | def read_model_list(self):
14 |
15 | try:
16 | with open('{}/model_list.txt'.format(base_path), 'r') as f:
17 | contents = f.read()
18 | except:
19 | with open('stablefusion/model_list.txt', 'r') as f:
20 | contents = f.read()
21 | model_list = ast.literal_eval(contents)
22 |
23 | return model_list
24 |
25 | def write_model_list(self, model_list):
26 |
27 | try:
28 | with open('{}/model_list.txt'.format(base_path), 'w') as f:
29 | f.write(model_list)
30 | except:
31 | with open('stablefusion/model_list.txt', 'w') as f:
32 | f.write(model_list)
33 |
34 | def check_models(self, model_name):
35 |
36 | model_list = self.read_model_list()
37 |
38 | if model_name not in model_list:
39 | st.warning("{} not present in the list".format(model_name))
40 |
41 | else:
42 | model_list.remove(model_name)
43 | self.write_model_list(model_list=str(model_list))
44 | st.success("Succefully Removed {} into your list".format(model_name))
45 |
46 |
47 | def app(self):
48 | # upload image
49 | model_name = st.selectbox(label="Enter The Model Name", options=self.read_model_list())
50 |
51 | if st.button("Apply Changes"):
52 | self.check_models(model_name=model_name)
53 |
54 |
--------------------------------------------------------------------------------
/stablefusion/scripts/pose_editor.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import json
3 | from dataclasses import dataclass
4 | from io import BytesIO
5 | from typing import Optional, Union
6 | import random
7 | import requests
8 | import streamlit as st
9 | import torch
10 | from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
11 | from loguru import logger
12 | from PIL import Image
13 | from PIL.PngImagePlugin import PngInfo
14 | from controlnet_aux import OpenposeDetector, HEDdetector, MLSDdetector
15 | from stablefusion import utils
16 | import cv2
17 | from PIL import Image
18 | import numpy as np
19 | from streamlit_drawable_canvas import st_canvas
20 | from transformers import pipeline
21 | from stablefusion.scripts.pose_html import html_component
22 |
23 |
24 | @dataclass
25 | class OpenPoseEditor:
26 | model: Optional[str] = None
27 | device: Optional[str] = None
28 | output_path: Optional[str] = None
29 |
30 | def __str__(self) -> str:
31 | return f"BaseModel(model={self.model}, device={self.device}, output_path={self.output_path})"
32 |
33 |
34 | def __post_init__(self):
35 | self.controlnet = ControlNetModel.from_pretrained(
36 | "lllyasviel/sd-controlnet-openpose",
37 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
38 | use_auth_token=utils.use_auth_token(),
39 | )
40 | self.pipeline = StableDiffusionControlNetPipeline.from_pretrained(
41 | self.model,
42 | controlnet=self.controlnet,
43 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
44 | use_auth_token=utils.use_auth_token(),
45 | )
46 |
47 | self.pipeline.to(self.device)
48 | self.pipeline.safety_checker = utils.no_safety_checker
49 | self._compatible_schedulers = self.pipeline.scheduler.compatibles
50 | self.scheduler_config = self.pipeline.scheduler.config
51 | self.compatible_schedulers = {scheduler.__name__: scheduler for scheduler in self._compatible_schedulers}
52 |
53 | if self.device == "mps":
54 | self.pipeline.enable_attention_slicing()
55 | # warmup
56 | url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
57 | response = requests.get(url)
58 | init_image = Image.open(BytesIO(response.content)).convert("RGB")
59 | init_image.thumbnail((768, 768))
60 | prompt = "A fantasy landscape, trending on artstation"
61 | _ = self.pipeline(
62 | prompt=prompt,
63 | image=init_image,
64 | strength=0.75,
65 | guidance_scale=7.5,
66 | num_inference_steps=2,
67 | )
68 |
69 | def _set_scheduler(self, scheduler_name):
70 | scheduler = self.compatible_schedulers[scheduler_name].from_config(self.scheduler_config)
71 | self.pipeline.scheduler = scheduler
72 |
73 | def generate_image(
74 | self, prompt, image, negative_prompt, scheduler, num_images, guidance_scale, steps, seed, height, width
75 | ):
76 | self._set_scheduler(scheduler)
77 | logger.info(self.pipeline.scheduler)
78 | if self.device == "mps":
79 | generator = torch.manual_seed(seed)
80 | num_images = 1
81 | else:
82 | generator = torch.Generator(device=self.device).manual_seed(seed)
83 | num_images = int(num_images)
84 | output_images = self.pipeline(
85 | prompt=prompt,
86 | image=image,
87 | negative_prompt=negative_prompt,
88 | num_inference_steps=steps,
89 | guidance_scale=guidance_scale,
90 | num_images_per_prompt=num_images,
91 | generator=generator,
92 | height=height,
93 | width=width
94 | ).images
95 | torch.cuda.empty_cache()
96 | gc.collect()
97 | metadata = {
98 | "prompt": prompt,
99 | "negative_prompt": negative_prompt,
100 | "scheduler": scheduler,
101 | "num_images": num_images,
102 | "guidance_scale": guidance_scale,
103 | "steps": steps,
104 | "seed": seed,
105 | }
106 | metadata = json.dumps(metadata)
107 | _metadata = PngInfo()
108 | _metadata.add_text("img2img", metadata)
109 |
110 | utils.save_images(
111 | images=output_images,
112 | module="controlnet",
113 | metadata=metadata,
114 | output_path=self.output_path,
115 | )
116 | return output_images, _metadata
117 |
118 | def app(self):
119 | available_schedulers = list(self.compatible_schedulers.keys())
120 | # if EulerAncestralDiscreteScheduler is available in available_schedulers, move it to the first position
121 | if "EulerAncestralDiscreteScheduler" in available_schedulers:
122 | available_schedulers.insert(
123 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler"))
124 | )
125 |
126 | html_component()
127 |
128 | input_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
129 |
130 | if input_image is not None:
131 |
132 | input_image = Image.open(input_image).convert("RGB")
133 |
134 | st.image(input_image, use_column_width=True)
135 |
136 |
137 | # with st.form(key="img2img"):
138 | col1, col2 = st.columns(2)
139 | with col1:
140 | prompt = st.text_area("Prompt", "", help="Prompt to guide image generation")
141 | with col2:
142 | negative_prompt = st.text_area("Negative Prompt", "", help="The prompt not to guide image generation. Write things that you dont want to see in the image.")
143 |
144 | scheduler = st.sidebar.selectbox("Scheduler", available_schedulers, index=0, help="Scheduler(Sampler) to use for generation")
145 | image_height = st.sidebar.slider("Image height", 128, 1024, 512, 128, help="The height in pixels of the generated image.")
146 | image_width = st.sidebar.slider("Image width", 128, 1024, 512, 128, help="The width in pixels of the generated image.")
147 | guidance_scale = st.sidebar.slider("Guidance scale", 1.0, 40.0, 7.5, 0.5, help="Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.")
148 | num_images = st.sidebar.slider("Number of images per prompt", 1, 30, 1, 1, help="Number of images you want to generate. More images requires more time and uses more GPU memory.")
149 | steps = st.sidebar.slider("Steps", 1, 150, 50, 1, help="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.")
150 | seed_choice = st.sidebar.selectbox("Do you want a random seed", options=["Yes", "No"])
151 | if seed_choice == "Yes":
152 | seed = random.randint(0, 9999999)
153 | else:
154 | seed = st.sidebar.number_input(
155 | "Seed",
156 | value=42,
157 | step=1,
158 | help="Random seed. Change for different results using same parameters.",
159 | )
160 | sub_col, download_col = st.columns(2)
161 | with sub_col:
162 | submit = st.button("Generate")
163 |
164 | if submit:
165 | with st.spinner("Generating images..."):
166 | output_images, metadata = self.generate_image(
167 | prompt=prompt,
168 | image=input_image,
169 | negative_prompt=negative_prompt,
170 | scheduler=scheduler,
171 | num_images=num_images,
172 | guidance_scale=guidance_scale,
173 | steps=steps,
174 | seed=seed,
175 | height=image_height,
176 | width=image_width
177 | )
178 |
179 | utils.display_and_download_images(output_images, metadata, download_col)
180 |
--------------------------------------------------------------------------------
/stablefusion/scripts/safetensors_to_diffusion.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Conversion script for the LDM checkpoints. """
16 | import ast
17 | import argparse
18 | import os
19 | import streamlit as st
20 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import load_pipeline_from_original_stable_diffusion_ckpt
21 | import requests
22 | from stablefusion.utils import base_path
23 |
24 | current_path = base_path()
25 |
26 | def convert_to_diffusers_model(checkpoint_path, original_config_file, image_size, prediction_type, pipeline_type, extract_ema, scheduler_type, num_in_channels, upcast_attention, from_safetensors, device, stable_unclip, stable_unclip_prior, clip_stats_path, controlnet, checkpoint_name, to_safetensors):
27 |
28 | pipe = load_pipeline_from_original_stable_diffusion_ckpt(
29 | checkpoint_path=checkpoint_path,
30 | original_config_file=original_config_file,
31 | image_size=image_size,
32 | prediction_type=prediction_type,
33 | model_type=pipeline_type,
34 | extract_ema=extract_ema,
35 | scheduler_type=scheduler_type,
36 | num_in_channels=num_in_channels,
37 | upcast_attention=upcast_attention,
38 | from_safetensors=from_safetensors,
39 | device=device,
40 | stable_unclip=stable_unclip,
41 | stable_unclip_prior=stable_unclip_prior,
42 | clip_stats_path=clip_stats_path,
43 | controlnet=controlnet,
44 | )
45 |
46 | dump_path = "{}/models/diffusion_models/{}".format(current_path, checkpoint_name.split(".")[0])
47 |
48 | if controlnet:
49 | # only save the controlnet model
50 | pipe.controlnet.save_pretrained(dump_path, safe_serialization=to_safetensors)
51 | else:
52 | pipe.save_pretrained(dump_path, safe_serialization=to_safetensors)
53 |
54 | st.success("Your Model Has Been Created!")
55 | append_created_model(model_name=str(checkpoint_name).split(".")[0])
56 |
57 |
58 | def download_ckpt_model(checkpoint_link, checkpoint_name):
59 |
60 | try:
61 | print("Started Downloading the model...")
62 | response = requests.get(checkpoint_link, stream=True)
63 |
64 | with open("{}/models/safetensors_models/{}".format(current_path, checkpoint_name), "wb") as f:
65 | f.write(response.content)
66 |
67 | print("Model Downloading Completed!")
68 |
69 | except:
70 | print("Started Downloading the model...")
71 | response = requests.get(checkpoint_link, stream=True)
72 |
73 | with open("stablefusion/models/safetensors_models/{}".format(checkpoint_name), "wb") as f:
74 | f.write(response.content)
75 |
76 | print("Model Downloading Completed!")
77 |
78 | def read_model_file():
79 |
80 | try:
81 | with open('{}/model_list.txt'.format(current_path), 'r') as f:
82 | contents = f.read()
83 | except:
84 | with open('stablefusion/model_list.txt', 'r') as f:
85 | contents = f.read()
86 |
87 | model_list = ast.literal_eval(contents)
88 |
89 | return model_list
90 |
91 |
92 | def write_model_list(model_list):
93 |
94 | try:
95 | with open('{}/model_list.txt'.format(current_path), 'w') as f:
96 | f.write(model_list)
97 | except:
98 | with open('stablefusion/model_list.txt', 'w') as f:
99 | f.write(model_list)
100 |
101 |
102 | def append_created_model(model_name):
103 | model_list = read_model_file()
104 | try:
105 | apending_list = os.listdir("{}/models/diffusion_models".format(current_path))
106 | except:
107 | apending_list = os.listdir("stablefusion/models/diffusion_models")
108 |
109 | for working_item in apending_list:
110 | if str(working_item).split(".")[-1] == "txt":
111 | pass
112 | else:
113 | if model_name not in model_list:
114 | try:
115 | model_list.append("{}/models/diffusion_models/{}".format(current_path, model_name))
116 | except:
117 | model_list.append("stablefusion/models/diffusion_models/{}".format(model_name))
118 |
119 | write_model_list(model_list=str(model_list))
120 | st.success("Model Added to your List Now you can use this model at your Home Page.")
121 |
122 |
123 | def convert_safetensor_to_diffusers(original_config_file, image_size, prediction_type, pipeline_type, extract_ema, scheduler_type, num_in_channels, upcast_attention, from_safetensors, device, stable_unclip, stable_unclip_prior, clip_stats_path, controlnet, to_safetensors, checkpoint_name, checkpoint_link, overwrite):
124 |
125 | checkpoint_path = "{}/models/safetensors_models/{}".format(current_path, checkpoint_name)
126 |
127 | try:
128 | custom_diffusion_model = "{}/models/diffusion_models/{}".format(current_path, str(checkpoint_name).split(".")[0])
129 | except:
130 | custom_diffusion_model = "stablefusion/models/diffusion_models/{}".format(str(checkpoint_name).split(".")[0])
131 |
132 | if overwrite is False:
133 | if not os.path.isfile(checkpoint_path):
134 |
135 | download_ckpt_model(checkpoint_link=checkpoint_link, checkpoint_name=checkpoint_name)
136 |
137 | else:
138 | st.warning("Using {}".format(checkpoint_path))
139 | else:
140 |
141 | download_ckpt_model(checkpoint_link=checkpoint_link, checkpoint_name=checkpoint_name)
142 |
143 |
144 | if overwrite is False:
145 |
146 | if not os.path.exists(custom_diffusion_model):
147 |
148 | convert_to_diffusers_model(checkpoint_name=checkpoint_name, original_config_file=original_config_file, image_size=image_size, prediction_type=prediction_type, extract_ema=extract_ema, scheduler_type=scheduler_type, num_in_channels=num_in_channels, upcast_attention=upcast_attention, from_safetensors=from_safetensors, device=device, stable_unclip=stable_unclip, stable_unclip_prior=stable_unclip_prior, clip_stats_path=clip_stats_path, controlnet=controlnet, to_safetensors=to_safetensors, checkpoint_path=checkpoint_path, pipeline_type=pipeline_type)
149 |
150 | else:
151 | st.warning("Model {} is already present".format(custom_diffusion_model))
152 |
153 | else:
154 | convert_to_diffusers_model(checkpoint_name=checkpoint_name, original_config_file=original_config_file, image_size=image_size, prediction_type=prediction_type, extract_ema=extract_ema, scheduler_type=scheduler_type, num_in_channels=num_in_channels, upcast_attention=upcast_attention, from_safetensors=from_safetensors, device=device, stable_unclip=stable_unclip, stable_unclip_prior=stable_unclip_prior, clip_stats_path=clip_stats_path, controlnet=controlnet, to_safetensors=to_safetensors, checkpoint_path=checkpoint_path, pipeline_type=pipeline_type)
--------------------------------------------------------------------------------
/stablefusion/scripts/text2img.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import json
3 | from dataclasses import dataclass
4 | from typing import Optional
5 | import random
6 | import streamlit as st
7 | import torch
8 | from diffusers import DiffusionPipeline
9 | from loguru import logger
10 | from PIL.PngImagePlugin import PngInfo
11 |
12 | from stablefusion import utils
13 |
14 |
15 | @dataclass
16 | class Text2Image:
17 | device: Optional[str] = None
18 | model: Optional[str] = None
19 | output_path: Optional[str] = None
20 |
21 | def __str__(self) -> str:
22 | return f"Text2Image(model={self.model}, device={self.device}, output_path={self.output_path})"
23 |
24 | def __post_init__(self):
25 | self.pipeline = DiffusionPipeline.from_pretrained(
26 | self.model,
27 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
28 | )
29 | self.pipeline.to(self.device)
30 | self.pipeline.safety_checker = utils.no_safety_checker
31 | self._compatible_schedulers = self.pipeline.scheduler.compatibles
32 | self.scheduler_config = self.pipeline.scheduler.config
33 | self.compatible_schedulers = {scheduler.__name__: scheduler for scheduler in self._compatible_schedulers}
34 |
35 | if self.device == "mps":
36 | self.pipeline.enable_attention_slicing()
37 | # warmup
38 | prompt = "a photo of an astronaut riding a horse on mars"
39 | _ = self.pipeline(prompt, num_inference_steps=2)
40 |
41 | def _set_scheduler(self, scheduler_name):
42 | scheduler = self.compatible_schedulers[scheduler_name].from_config(self.scheduler_config)
43 | self.pipeline.scheduler = scheduler
44 |
45 | def generate_image(self, prompt, negative_prompt, scheduler, image_size, num_images, guidance_scale, steps, seed):
46 | self._set_scheduler(scheduler)
47 | logger.info(self.pipeline.scheduler)
48 | if self.device == "mps":
49 | generator = torch.manual_seed(seed)
50 | num_images = 1
51 | else:
52 | generator = torch.Generator(device=self.device).manual_seed(seed)
53 | num_images = int(num_images)
54 | output_images = self.pipeline(
55 | prompt,
56 | negative_prompt=negative_prompt,
57 | width=image_size[1],
58 | height=image_size[0],
59 | num_inference_steps=steps,
60 | guidance_scale=guidance_scale,
61 | num_images_per_prompt=num_images,
62 | generator=generator,
63 | ).images
64 | torch.cuda.empty_cache()
65 | gc.collect()
66 | metadata = {
67 | "prompt": prompt,
68 | "negative_prompt": negative_prompt,
69 | "scheduler": scheduler,
70 | "image_size": image_size,
71 | "num_images": num_images,
72 | "guidance_scale": guidance_scale,
73 | "steps": steps,
74 | "seed": seed,
75 | }
76 | metadata = json.dumps(metadata)
77 | _metadata = PngInfo()
78 | _metadata.add_text("text2img", metadata)
79 |
80 | utils.save_images(
81 | images=output_images,
82 | module="text2img",
83 | metadata=metadata,
84 | output_path=self.output_path,
85 | )
86 |
87 | return output_images, _metadata
88 |
89 | def app(self):
90 | available_schedulers = list(self.compatible_schedulers.keys())
91 | if "EulerAncestralDiscreteScheduler" in available_schedulers:
92 | available_schedulers.insert(
93 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler"))
94 | )
95 | # with st.form(key="text2img"):
96 | col1, col2 = st.columns(2)
97 | with col1:
98 | prompt = st.text_area("Prompt", "Blue elephant", help="Prompt to guide image generation")
99 | with col2:
100 | negative_prompt = st.text_area("Negative Prompt", "", help="The prompt not to guide image generation. Write things that you dont want to see in the image.")
101 | # sidebar options
102 | scheduler = st.sidebar.selectbox("Scheduler", available_schedulers, index=0, help="Scheduler(Sampler) to use for generation")
103 | image_height = st.sidebar.slider("Image height", 128, 1024, 512, 128, help="The height in pixels of the generated image.")
104 | image_width = st.sidebar.slider("Image width", 128, 1024, 512, 128, help="The width in pixels of the generated image.")
105 | guidance_scale = st.sidebar.slider("Guidance scale", 1.0, 40.0, 7.5, 0.5, help="Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.")
106 | num_images = st.sidebar.slider("Number of images per prompt", 1, 30, 1, 1, help="Number of images you want to generate. More images requires more time and uses more GPU memory.")
107 | steps = st.sidebar.slider("Steps", 1, 150, 50, 1, help="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.")
108 |
109 | seed_choice = st.sidebar.selectbox("Do you want a random seed", options=["Yes", "No"])
110 | if seed_choice == "Yes":
111 | seed = random.randint(0, 9999999)
112 | else:
113 | seed = st.sidebar.number_input(
114 | "Seed",
115 | value=42,
116 | step=1,
117 | help="Random seed. Change for different results using same parameters.",
118 | )
119 |
120 | sub_col, download_col = st.columns(2)
121 | with sub_col:
122 | submit = st.button("Generate")
123 |
124 | if submit:
125 | with st.spinner("Generating images..."):
126 | output_images, metadata = self.generate_image(
127 | prompt=prompt,
128 | negative_prompt=negative_prompt,
129 | scheduler=scheduler,
130 | image_size=(image_height, image_width),
131 | num_images=num_images,
132 | guidance_scale=guidance_scale,
133 | steps=steps,
134 | seed=seed,
135 | )
136 |
137 | utils.display_and_download_images(output_images, metadata, download_col)
138 |
--------------------------------------------------------------------------------
/stablefusion/scripts/textual_inversion.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import json
3 | from dataclasses import dataclass
4 | from typing import Optional
5 | import random
6 | import streamlit as st
7 | import torch
8 | from diffusers import DiffusionPipeline
9 | from loguru import logger
10 | from PIL.PngImagePlugin import PngInfo
11 |
12 | from stablefusion import utils
13 |
14 |
15 | def load_embed(learned_embeds_path, text_encoder, tokenizer, token=None):
16 | loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
17 | if len(loaded_learned_embeds) > 2:
18 | embeds = loaded_learned_embeds["string_to_param"]["*"][-1, :]
19 | else:
20 | # separate token and the embeds
21 | trained_token = list(loaded_learned_embeds.keys())[0]
22 | embeds = loaded_learned_embeds[trained_token]
23 |
24 | # add the token in tokenizer
25 | token = token if token is not None else trained_token
26 | num_added_tokens = tokenizer.add_tokens(token)
27 | i = 1
28 | while num_added_tokens == 0:
29 | logger.warning(f"The tokenizer already contains the token {token}.")
30 | token = f"{token[:-1]}-{i}>"
31 | logger.info(f"Attempting to add the token {token}.")
32 | num_added_tokens = tokenizer.add_tokens(token)
33 | i += 1
34 |
35 | # resize the token embeddings
36 | text_encoder.resize_token_embeddings(len(tokenizer))
37 |
38 | # get the id for the token and assign the embeds
39 | token_id = tokenizer.convert_tokens_to_ids(token)
40 | text_encoder.get_input_embeddings().weight.data[token_id] = embeds
41 | return token
42 |
43 |
44 | @dataclass
45 | class TextualInversion:
46 | model: str
47 | embeddings_url: str
48 | token_identifier: str
49 | device: Optional[str] = None
50 | output_path: Optional[str] = None
51 |
52 | def __str__(self) -> str:
53 | return f"TextualInversion(model={self.model}, embeddings={self.embeddings_url}, token_identifier={self.token_identifier}, device={self.device}, output_path={self.output_path})"
54 |
55 | def __post_init__(self):
56 | self.pipeline = DiffusionPipeline.from_pretrained(
57 | self.model,
58 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
59 | )
60 | self.pipeline.to(self.device)
61 | self.pipeline.safety_checker = utils.no_safety_checker
62 | self._compatible_schedulers = self.pipeline.scheduler.compatibles
63 | self.scheduler_config = self.pipeline.scheduler.config
64 | self.compatible_schedulers = {scheduler.__name__: scheduler for scheduler in self._compatible_schedulers}
65 |
66 | # download the embeddings
67 | self.embeddings_path = utils.download_file(self.embeddings_url)
68 | load_embed(
69 | learned_embeds_path=self.embeddings_path,
70 | text_encoder=self.pipeline.text_encoder,
71 | tokenizer=self.pipeline.tokenizer,
72 | token=self.token_identifier,
73 | )
74 |
75 | if self.device == "mps":
76 | self.pipeline.enable_attention_slicing()
77 | # warmup
78 | prompt = "a photo of an astronaut riding a horse on mars"
79 | _ = self.pipeline(prompt, num_inference_steps=1)
80 |
81 | def _set_scheduler(self, scheduler_name):
82 | scheduler = self.compatible_schedulers[scheduler_name].from_config(self.scheduler_config)
83 | self.pipeline.scheduler = scheduler
84 |
85 | def generate_image(self, prompt, negative_prompt, scheduler, image_size, num_images, guidance_scale, steps, seed):
86 | self._set_scheduler(scheduler)
87 | logger.info(self.pipeline.scheduler)
88 | if self.device == "mps":
89 | generator = torch.manual_seed(seed)
90 | num_images = 1
91 | else:
92 | generator = torch.Generator(device=self.device).manual_seed(seed)
93 | num_images = int(num_images)
94 | output_images = self.pipeline(
95 | prompt,
96 | negative_prompt=negative_prompt,
97 | width=image_size[1],
98 | height=image_size[0],
99 | num_inference_steps=steps,
100 | guidance_scale=guidance_scale,
101 | num_images_per_prompt=num_images,
102 | generator=generator,
103 | ).images
104 | torch.cuda.empty_cache()
105 | gc.collect()
106 | metadata = {
107 | "prompt": prompt,
108 | "negative_prompt": negative_prompt,
109 | "scheduler": scheduler,
110 | "image_size": image_size,
111 | "num_images": num_images,
112 | "guidance_scale": guidance_scale,
113 | "steps": steps,
114 | "seed": seed,
115 | }
116 | metadata = json.dumps(metadata)
117 | _metadata = PngInfo()
118 | _metadata.add_text("textual_inversion", metadata)
119 |
120 | utils.save_images(
121 | images=output_images,
122 | module="textual_inversion",
123 | metadata=metadata,
124 | output_path=self.output_path,
125 | )
126 |
127 | return output_images, _metadata
128 |
129 | def app(self):
130 | available_schedulers = list(self.compatible_schedulers.keys())
131 | if "EulerAncestralDiscreteScheduler" in available_schedulers:
132 | available_schedulers.insert(
133 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler"))
134 | )
135 | col1, col2 = st.columns(2)
136 | with col1:
137 | prompt = st.text_area("Prompt", "Blue elephant")
138 | with col2:
139 | negative_prompt = st.text_area("Negative Prompt", "")
140 | # sidebar options
141 | scheduler = st.sidebar.selectbox("Scheduler", available_schedulers, index=0)
142 | image_height = st.sidebar.slider("Image height", 128, 1024, 512, 128)
143 | image_width = st.sidebar.slider("Image width", 128, 1024, 512, 128)
144 | guidance_scale = st.sidebar.slider("Guidance scale", 1.0, 40.0, 7.5, 0.5)
145 | num_images = st.sidebar.slider("Number of images per prompt", 1, 30, 1, 1)
146 | steps = st.sidebar.slider("Steps", 1, 150, 50, 1)
147 | seed_choice = st.sidebar.selectbox("Do you want a random seed", options=["Yes", "No"])
148 | if seed_choice == "Yes":
149 | seed = random.randint(0, 9999999)
150 | else:
151 | seed = st.sidebar.number_input(
152 | "Seed",
153 | value=42,
154 | step=1,
155 | help="Random seed. Change for different results using same parameters.",
156 | )
157 | sub_col, download_col = st.columns(2)
158 | with sub_col:
159 | submit = st.button("Generate")
160 |
161 | if submit:
162 | with st.spinner("Generating images..."):
163 | output_images, metadata = self.generate_image(
164 | prompt=prompt,
165 | negative_prompt=negative_prompt,
166 | scheduler=scheduler,
167 | image_size=(image_height, image_width),
168 | num_images=num_images,
169 | guidance_scale=guidance_scale,
170 | steps=steps,
171 | seed=seed,
172 | )
173 |
174 | utils.display_and_download_images(output_images, metadata, download_col)
175 |
--------------------------------------------------------------------------------
/stablefusion/scripts/upscaler.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import json
3 | from dataclasses import dataclass
4 | from typing import Optional
5 |
6 | import streamlit as st
7 | import torch
8 | from diffusers import StableDiffusionUpscalePipeline
9 | from loguru import logger
10 | from PIL import Image
11 | from PIL.PngImagePlugin import PngInfo
12 |
13 | from stablefusion import utils
14 |
15 | @dataclass
16 | class Upscaler:
17 | model: str = "stabilityai/stable-diffusion-x4-upscaler"
18 | device: Optional[str] = None
19 | output_path: Optional[str] = None
20 |
21 | def __str__(self) -> str:
22 | return f"Upscaler(model={self.model}, device={self.device}, output_path={self.output_path})"
23 |
24 | def __post_init__(self):
25 | self.pipeline = StableDiffusionUpscalePipeline.from_pretrained(
26 | self.model,
27 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
28 | )
29 | self.pipeline.to(self.device)
30 | self.pipeline.safety_checker = utils.no_safety_checker
31 | self._compatible_schedulers = self.pipeline.scheduler.compatibles
32 | self.scheduler_config = self.pipeline.scheduler.config
33 | self.compatible_schedulers = {scheduler.__name__: scheduler for scheduler in self._compatible_schedulers}
34 |
35 | def _set_scheduler(self, scheduler_name):
36 | scheduler = self.compatible_schedulers[scheduler_name].from_config(self.scheduler_config)
37 | self.pipeline.scheduler = scheduler
38 |
39 | def generate_image(
40 | self, image, prompt, negative_prompt, guidance_scale, noise_level, num_images, eta, scheduler, steps, seed
41 | ):
42 | self._set_scheduler(scheduler)
43 | logger.info(self.pipeline.scheduler)
44 | if self.device == "mps":
45 | generator = torch.manual_seed(seed)
46 | num_images = 1
47 | else:
48 | generator = torch.Generator(device=self.device).manual_seed(seed)
49 | num_images = int(num_images)
50 | output_images = self.pipeline(
51 | image=image,
52 | prompt=prompt,
53 | negative_prompt=negative_prompt,
54 | noise_level=noise_level,
55 | num_inference_steps=steps,
56 | eta=eta,
57 | num_images_per_prompt=num_images,
58 | generator=generator,
59 | guidance_scale=guidance_scale,
60 | ).images
61 |
62 | torch.cuda.empty_cache()
63 | gc.collect()
64 | metadata = {
65 | "prompt": prompt,
66 | "negative_prompt": negative_prompt,
67 | "noise_level": noise_level,
68 | "num_images": num_images,
69 | "eta": eta,
70 | "scheduler": scheduler,
71 | "steps": steps,
72 | "seed": seed,
73 | }
74 |
75 | metadata = json.dumps(metadata)
76 | _metadata = PngInfo()
77 | _metadata.add_text("upscaler", metadata)
78 |
79 | utils.save_images(
80 | images=output_images,
81 | module="upscaler",
82 | metadata=metadata,
83 | output_path=self.output_path,
84 | )
85 |
86 | return output_images, _metadata
87 |
88 | def app(self):
89 | available_schedulers = list(self.compatible_schedulers.keys())
90 | # if EulerAncestralDiscreteScheduler is available in available_schedulers, move it to the first position
91 | if "EulerAncestralDiscreteScheduler" in available_schedulers:
92 | available_schedulers.insert(
93 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler"))
94 | )
95 |
96 | input_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
97 | if input_image is not None:
98 | input_image = Image.open(input_image)
99 | input_image = input_image.convert("RGB").resize((128, 128), resample=Image.LANCZOS)
100 | st.image(input_image, use_column_width=True)
101 |
102 | col1, col2 = st.columns(2)
103 | with col1:
104 | prompt = st.text_area("Prompt (Optional)", "")
105 | with col2:
106 | negative_prompt = st.text_area("Negative Prompt (Optional)", "")
107 |
108 | scheduler = st.sidebar.selectbox("Scheduler", available_schedulers, index=0)
109 | guidance_scale = st.sidebar.slider("Guidance scale", 1.0, 40.0, 7.5, 0.5)
110 | noise_level = st.sidebar.slider("Noise level", 0, 100, 20, 1)
111 | eta = st.sidebar.slider("Eta", 0.0, 1.0, 0.0, 0.1)
112 | num_images = st.sidebar.slider("Number of images per prompt", 1, 30, 1, 1)
113 | steps = st.sidebar.slider("Steps", 1, 150, 50, 1)
114 | seed_placeholder = st.sidebar.empty()
115 | seed = seed_placeholder.number_input("Seed", value=42, min_value=1, max_value=999999, step=1)
116 | random_seed = st.sidebar.button("Random seed")
117 | _seed = torch.randint(1, 999999, (1,)).item()
118 | if random_seed:
119 | seed = seed_placeholder.number_input("Seed", value=_seed, min_value=1, max_value=999999, step=1)
120 | sub_col, download_col = st.columns(2)
121 | with sub_col:
122 | submit = st.button("Generate")
123 |
124 | if submit:
125 | with st.spinner("Generating images..."):
126 | output_images, metadata = self.generate_image(
127 | prompt=prompt,
128 | image=input_image,
129 | negative_prompt=negative_prompt,
130 | scheduler=scheduler,
131 | num_images=num_images,
132 | guidance_scale=guidance_scale,
133 | steps=steps,
134 | seed=seed,
135 | noise_level=noise_level,
136 | eta=eta,
137 | )
138 |
139 | utils.display_and_download_images(output_images, metadata, download_col)
140 |
--------------------------------------------------------------------------------
/stablefusion/utils.py:
--------------------------------------------------------------------------------
1 | import base64
2 | from PIL import Image
3 | import gc
4 | import io
5 | import os
6 | import tempfile
7 | import zipfile
8 | from datetime import datetime
9 | from threading import Thread
10 | import cv2
11 | import requests
12 | import streamlit as st
13 | import torch
14 | from huggingface_hub import HfApi
15 | from huggingface_hub.utils._errors import RepositoryNotFoundError
16 | from huggingface_hub.utils._validators import HFValidationError
17 | from loguru import logger
18 | from PIL.PngImagePlugin import PngInfo
19 | from st_clickable_images import clickable_images
20 | import os
21 |
22 |
23 | no_safety_checker = None
24 |
25 |
26 | CODE_OF_CONDUCT = """
27 | ## Code of conduct
28 | The app should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
29 |
30 | Using the app to generate content that is cruel to individuals is a misuse of this app. One shall not use this app to generate content that is intended to be cruel to individuals, or to generate content that is intended to be cruel to individuals in a way that is not obvious to the viewer.
31 | This includes, but is not limited to:
32 | - Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
33 | - Intentionally promoting or propagating discriminatory content or harmful stereotypes.
34 | - Impersonating individuals without their consent.
35 | - Sexual content without consent of the people who might see it.
36 | - Mis- and disinformation
37 | - Representations of egregious violence and gore
38 | - Sharing of copyrighted or licensed material in violation of its terms of use.
39 | - Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
40 |
41 | By using this app, you agree to the above code of conduct.
42 |
43 | """
44 |
45 | def use_auth_token():
46 | token_path = os.path.join(os.path.expanduser("~"), ".huggingface", "token")
47 | if os.path.exists(token_path):
48 | return True
49 | return False
50 |
51 |
52 | def create_base_page():
53 | st.set_page_config(layout="wide",
54 | menu_items={
55 | "Get Help": "https://github.com/NeuralRealm/StableFusion/discussions",
56 | "Report a bug": "https://github.com/NeuralRealm/StableFusion/issues",
57 | 'About': "# Stable Fusion\nWelcome to StableFusion! Our AI-powered web app offers a user-friendly interface for transforming text into images and generating new images with customizable styles and formats. Built with Diffusion, Python, and Streamlit, our app makes it easy to create stunning visuals with just a few clicks.\n### Getting Started\nTo get started with StableFusion, simply visit our website and follow the on-screen instructions. You can input your text or upload an image, select your preferred style and output format, and let our AI do the rest.\n### About NeuralRealm\nStableFusion was developed by NeuralRealm, an organization dedicated to advancing the field of artificial intelligence. We are grateful for their contributions and proud to offer this user-friendly web app to our users\n### Contributions\nWe welcome contributions from the community to help improve StableFusion. If you have any feedback, suggestions, or would like to contribute code, please visit our [GitHub repository](https://github.com/NeuralRealm/StableFusion).\n### License\nStable Fusion is licensed under the [GPL-3.0 license](https://github.com/NeuralRealm/StableFusion/blob/master/LICENSE). See the LICENSE file for more information."
58 | }
59 | )
60 | st.title("StableFusion")
61 | st.markdown("Welcome to **StableFusion**! A web app for **Stable Diffusion Models**")
62 |
63 | def download_file(file_url):
64 | r = requests.get(file_url, stream=True)
65 | with tempfile.NamedTemporaryFile(delete=False) as tmp:
66 | for chunk in r.iter_content(chunk_size=1024):
67 | if chunk: # filter out keep-alive new chunks
68 | tmp.write(chunk)
69 | return tmp.name
70 |
71 |
72 | def base_path():
73 |
74 | return os.path.dirname(__file__)
75 |
76 | def cache_folder():
77 | _cache_folder = os.path.join(os.path.expanduser("~"), ".stablefusion")
78 | os.makedirs(_cache_folder, exist_ok=True)
79 | return _cache_folder
80 |
81 |
82 | def clear_memory(preserve):
83 | torch.cuda.empty_cache()
84 | gc.collect()
85 | to_clear = ["inpainting", "text2img", "img2text"]
86 | for key in to_clear:
87 | if key not in preserve and key in st.session_state:
88 | del st.session_state[key]
89 |
90 |
91 | def save_to_hub(api, images, module, current_datetime, metadata, output_path):
92 | logger.info(f"Saving images to hub: {output_path}")
93 | _metadata = PngInfo()
94 | _metadata.add_text("text2img", metadata)
95 | for i, img in enumerate(images):
96 | img_byte_arr = io.BytesIO()
97 | img.save(img_byte_arr, format="PNG", pnginfo=_metadata)
98 | img_byte_arr = img_byte_arr.getvalue()
99 | api.upload_file(
100 | path_or_fileobj=img_byte_arr,
101 | path_in_repo=f"{module}/{current_datetime}/{i}.png",
102 | repo_id=output_path,
103 | repo_type="dataset",
104 | )
105 |
106 | api.upload_file(
107 | path_or_fileobj=str.encode(metadata),
108 | path_in_repo=f"{module}/{current_datetime}/metadata.json",
109 | repo_id=output_path,
110 | repo_type="dataset",
111 | )
112 | logger.info(f"Saved images to hub: {output_path}")
113 |
114 |
115 | def save_to_local(images, module, current_datetime, metadata, output_path):
116 | _metadata = PngInfo()
117 | _metadata.add_text("text2img", metadata)
118 | os.makedirs(output_path, exist_ok=True)
119 | os.makedirs(f"{output_path}/{module}", exist_ok=True)
120 | os.makedirs(f"{output_path}/{module}/{current_datetime}", exist_ok=True)
121 |
122 | for i, img in enumerate(images):
123 | img.save(
124 | f"{output_path}/{module}/{current_datetime}/{i}.png",
125 | pnginfo=_metadata,
126 | )
127 |
128 | # save metadata as text file
129 | with open(f"{output_path}/{module}/{current_datetime}/metadata.txt", "w") as f:
130 | f.write(metadata)
131 | logger.info(f"Saved images to {output_path}/{module}/{current_datetime}")
132 |
133 |
134 | def save_images(images, module, metadata, output_path):
135 | if output_path is None:
136 | logger.warning("No output path specified, skipping saving images")
137 | return
138 |
139 | api = HfApi()
140 | dset_info = None
141 | try:
142 | dset_info = api.dataset_info(output_path)
143 | except (HFValidationError, RepositoryNotFoundError):
144 | logger.warning("No valid hugging face repo. Saving locally...")
145 |
146 | current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
147 |
148 | if not dset_info:
149 | save_to_local(images, module, current_datetime, metadata, output_path)
150 | else:
151 | Thread(target=save_to_hub, args=(api, images, module, current_datetime, metadata, output_path)).start()
152 |
153 |
154 | def display_and_download_images(output_images, metadata, download_col=None):
155 | # st.image(output_images, width=128, output_format="PNG")
156 |
157 | with st.spinner("Preparing images for download..."):
158 | # save images to a temporary directory
159 | with tempfile.TemporaryDirectory() as tmpdir:
160 | gallery_images = []
161 | for i, image in enumerate(output_images):
162 |
163 | image.save(os.path.join(tmpdir, f"{i + 1}.png"), pnginfo=metadata)
164 | with open(os.path.join(tmpdir, f"{i + 1}.png"), "rb") as img:
165 | encoded = base64.b64encode(img.read()).decode()
166 | gallery_images.append(f"data:image/jpeg;base64,{encoded}")
167 |
168 | # zip the images
169 | zip_path = os.path.join(tmpdir, "images.zip")
170 | with zipfile.ZipFile(zip_path, "w") as zip:
171 | for filename in os.listdir(tmpdir):
172 | if filename.endswith(".png"):
173 | zip.write(os.path.join(tmpdir, filename), filename)
174 |
175 | # convert zip to base64
176 | with open(zip_path, "rb") as f:
177 | encoded = base64.b64encode(f.read()).decode()
178 |
179 | _ = clickable_images(
180 | gallery_images,
181 | titles=[f"Image #{str(i)}" for i in range(len(gallery_images))],
182 | div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
183 | img_style={"margin": "5px", "height": "512px", "width": "512px"},
184 | )
185 |
186 | now = datetime.now()
187 | formatted_date_time = now.strftime("%Y-%m-%d_%H_%M_%S")
188 |
189 | # add download link
190 | st.markdown(
191 | f"""
192 |
193 | Download Images
194 |
195 | """,
196 | unsafe_allow_html=True,
197 | )
198 |
199 |
--------------------------------------------------------------------------------
/static/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuralRealm/StableFusion/9268d686f49b45194a2251c6728e74191f298654/static/.keep
--------------------------------------------------------------------------------
/static/Screenshot1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuralRealm/StableFusion/9268d686f49b45194a2251c6728e74191f298654/static/Screenshot1.png
--------------------------------------------------------------------------------
/static/Screenshot2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuralRealm/StableFusion/9268d686f49b45194a2251c6728e74191f298654/static/Screenshot2.png
--------------------------------------------------------------------------------
/static/Screenshot3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuralRealm/StableFusion/9268d686f49b45194a2251c6728e74191f298654/static/Screenshot3.png
--------------------------------------------------------------------------------
/static/screenshot4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuralRealm/StableFusion/9268d686f49b45194a2251c6728e74191f298654/static/screenshot4.png
--------------------------------------------------------------------------------