├── .github └── FUNDING.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── Makefile ├── OTHER_LICENSES ├── README.md ├── diffuzers.ipynb ├── docs ├── Makefile ├── conf.py ├── index.rst └── make.bat ├── requirements.txt ├── setup.cfg ├── setup.py ├── stablefusion ├── Home.py ├── __init__.py ├── api │ ├── __init__.py │ ├── main.py │ ├── schemas.py │ └── utils.py ├── cli │ ├── __init__.py │ ├── main.py │ ├── run_api.py │ └── run_app.py ├── data │ ├── artists.txt │ ├── flavors.txt │ ├── mediums.txt │ └── movements.txt ├── model_list.txt ├── models │ ├── ckpt_interference │ │ └── t │ ├── cpkt_models │ │ └── t │ ├── diffusion_models │ │ └── t.txt │ ├── realesrgan │ │ ├── inference_realesrgan.py │ │ ├── options │ │ │ ├── finetune_realesrgan_x4plus.yml │ │ │ ├── finetune_realesrgan_x4plus_pairdata.yml │ │ │ ├── train_realesrgan_x2plus.yml │ │ │ ├── train_realesrgan_x4plus.yml │ │ │ ├── train_realesrnet_x2plus.yml │ │ │ └── train_realesrnet_x4plus.yml │ │ ├── realesrgan │ │ │ ├── __init__.py │ │ │ ├── archs │ │ │ │ ├── __init__.py │ │ │ │ ├── discriminator_arch.py │ │ │ │ └── srvgg_arch.py │ │ │ ├── data │ │ │ │ ├── __init__.py │ │ │ │ ├── realesrgan_dataset.py │ │ │ │ └── realesrgan_paired_dataset.py │ │ │ ├── models │ │ │ │ ├── __init__.py │ │ │ │ ├── realesrgan_model.py │ │ │ │ └── realesrnet_model.py │ │ │ ├── train.py │ │ │ └── utils.py │ │ └── weights │ │ │ └── README.md │ └── safetensors_models │ │ └── t ├── pages │ ├── 10_Utilities.py │ ├── 1_Text2Image.py │ ├── 2_Image2Image.py │ ├── 3_Inpainting.py │ ├── 4_ControlNet.py │ ├── 5_OpenPose_Editor.py │ ├── 6_Textual Inversion.py │ ├── 7_Upscaler.py │ ├── 8_Convertor.py │ └── 9_Train.py ├── scripts │ ├── blip.py │ ├── ckpt_to_diffusion.py │ ├── clip_interrogator.py │ ├── controlnet.py │ ├── dreambooth │ │ ├── README.md │ │ ├── convert_weights_to_ckpt.py │ │ ├── requirements_flax.txt │ │ ├── train_dreambooth.py │ │ ├── train_dreambooth_flax.py │ │ └── train_dreambooth_lora.py │ ├── gfp_gan.py │ ├── gradio_app.py │ ├── image_info.py │ ├── img2img.py │ ├── inpainting.py │ ├── interrogator.py │ ├── model_adding.py │ ├── model_removing.py │ ├── pose_editor.py │ ├── pose_html.py │ ├── safetensors_to_diffusion.py │ ├── text2img.py │ ├── textual_inversion.py │ ├── upscaler.py │ └── x2image.py └── utils.py └── static ├── .keep ├── Screenshot1.png ├── Screenshot2.png ├── Screenshot3.png └── screenshot4.png /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: everydaycodings -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # local stuff 2 | .vscode/ 3 | examples/.logs/ 4 | *.bin 5 | *.csv 6 | input/ 7 | logs/ 8 | gfpgan/ 9 | *.pkl 10 | *.pt 11 | *.pth 12 | abhishek/ 13 | diffout/ 14 | datasets/* 15 | experiments/* 16 | results/* 17 | tb_logger/* 18 | wandb/* 19 | tmp/* 20 | weights/* 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | *.py[cod] 24 | *$py.class 25 | .pypirc 26 | # C extensions 27 | *.so 28 | 7_Train.py 29 | 30 | # Distribution / packaging 31 | .Python 32 | build/ 33 | develop-eggs/ 34 | dist/ 35 | downloads/ 36 | eggs/ 37 | .eggs/ 38 | lib/ 39 | lib64/ 40 | parts/ 41 | sdist/ 42 | var/ 43 | pip-wheel-metadata/ 44 | share/python-wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | MANIFEST 49 | src/ 50 | 51 | # PyInstaller 52 | # Usually these files are written by a python script from a template 53 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 54 | *.manifest 55 | *.spec 56 | 57 | # Installer logs 58 | pip-log.txt 59 | pip-delete-this-directory.txt 60 | 61 | # Unit test / coverage reports 62 | htmlcov/ 63 | .tox/ 64 | .nox/ 65 | .coverage 66 | .coverage.* 67 | .cache 68 | nosetests.xml 69 | coverage.xml 70 | *.cover 71 | *.py,cover 72 | .hypothesis/ 73 | .pytest_cache/ 74 | 75 | # Translations 76 | *.mo 77 | *.pot 78 | 79 | # Django stuff: 80 | *.log 81 | local_settings.py 82 | db.sqlite3 83 | db.sqlite3-journal 84 | 85 | # Flask stuff: 86 | instance/ 87 | .webassets-cache 88 | 89 | # Scrapy stuff: 90 | .scrapy 91 | 92 | # Sphinx documentation 93 | docs/_build/ 94 | 95 | # PyBuilder 96 | target/ 97 | 98 | # Jupyter Notebook 99 | .ipynb_checkpoints 100 | 101 | # IPython 102 | profile_default/ 103 | ipython_config.py 104 | 105 | # pyenv 106 | .python-version 107 | 108 | # pipenv 109 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 110 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 111 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 112 | # install all needed dependencies. 113 | #Pipfile.lock 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | test.ipynb 152 | # vscode 153 | .vscode/ 154 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include stablefusion/pages/*.py 2 | include stablefusion/data/*.txt 3 | include stablefusion/*.txt 4 | recursive-include stablefusion/scripts * 5 | recursive-include stablefusion/models * 6 | include OTHER_LICENSES -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: quality style 2 | 3 | quality: 4 | python -m black --check --line-length 119 --target-version py38 . 5 | python -m isort --check-only . 6 | python -m flake8 --max-line-length 119 . 7 | 8 | style: 9 | python -m black --line-length 119 --target-version py38 . 10 | python -m isort . -------------------------------------------------------------------------------- /OTHER_LICENSES: -------------------------------------------------------------------------------- 1 | clip_interrogator licence 2 | for diffuzes/clip_interrogator.py 3 | for diffuzes/data/*.txt 4 | ------------------------- 5 | MIT License 6 | 7 | Copyright (c) 2022 pharmapsychotic 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in all 17 | copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | SOFTWARE. 26 | 27 | blip licence 28 | for diffuzes/blip.py 29 | ------------ 30 | Copyright (c) 2022, Salesforce.com, Inc. 31 | All rights reserved. 32 | 33 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 34 | 35 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 36 | 37 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 38 | 39 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 40 | 41 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StableFusion 2 | 3 | A Web ui for **Stable Diffusion Models**. 4 | 5 | ## Update (Version 0.1.4) 6 | 7 | In this version, we've added two major features to the project: 8 | 9 | - **ControlNet Option** 10 | - **OpenPose Editor** 11 | 12 | We've also made the following improvements and bug fixes: 13 | 14 | - **Improved File Arrangement**: We've rearranged the files in the project to make it more organized and convenient for users to navigate. 15 | - **Fixed Random Seed Generator**: We've fixed a bug in the random seed generator that was causing incorrect results. This fix ensures that the project operates more accurately and reliably. 16 | 17 | We hope these updates will improve the overall functionality and user experience of the project. As always, please feel free to reach out to us if you have any questions or feedback. 18 | 19 | 20 | 21 | Open In Colab 22 | 23 | 24 | 25 | ![image](https://raw.githubusercontent.com/NeuralRealm/StableFusion/master/static/Screenshot1.png) 26 | ![image](https://raw.githubusercontent.com/NeuralRealm/StableFusion/master/static/Screenshot2.png) 27 | 28 | If something doesnt work as expected, or if you need some features which are not available, then create request using [github issues](https://github.com/NeuralRealm/StableFusion/issues) 29 | 30 | 31 | ## Features available in the app: 32 | 33 | - text to image 34 | - image to image 35 | - Inpainting 36 | - instruct pix2pix 37 | - textual inversion 38 | - ControlNet 39 | - OpenPose Editor 40 | - image info 41 | - Upscale Your Image 42 | - clip interrogator 43 | - Convert ckpt file to diffusers 44 | - Convert safetensors file to diffusers 45 | - Add your own diffusers model 46 | - more coming soon! 47 | 48 | 49 | 50 | ## Installation 51 | 52 | To install bleeding edge version of StableFusion, clone the repo and install it using pip. 53 | 54 | ```bash 55 | git clone https://github.com/NeuralRealm/StableFusion 56 | cd StableFusion 57 | pip install -e . 58 | ``` 59 | 60 | Installation using pip: 61 | 62 | ```bash 63 | pip install stablefusion 64 | ``` 65 | 66 | ## Usage 67 | 68 | ### Web App 69 | To run the web app, run the following command: 70 | 71 | For Local Host 72 | ```bash 73 | stablefusion app 74 | ``` 75 | or 76 | 77 | For Public Shareable Link 78 | ```bash 79 | stablefusion app --port 10000 --ngrok_key YourNgrokAuthtoken --share 80 | ``` 81 | 82 | ## All CLI Options for running the app: 83 | 84 | ```bash 85 | ❯ stablefusion app --help 86 | usage: stablefusion [] app [-h] [--output OUTPUT] [--share] [--port PORT] [--host HOST] 87 | [--device DEVICE] [--ngrok_key NGROK_KEY] 88 | 89 | ✨ Run stablefusion app 90 | 91 | optional arguments: 92 | -h, --help show this help message and exit 93 | --output OUTPUT Output path is optional, but if provided, all generations will automatically be saved to this 94 | path. 95 | --share Share the app 96 | --port PORT Port to run the app on 97 | --host HOST Host to run the app on 98 | --device DEVICE Device to use, e.g. cpu, cuda, cuda:0, mps (for m1 mac) etc. 99 | --ngrok_key NGROK_KEY 100 | Ngrok key to use for sharing the app. Only required if you want to share the app 101 | ``` 102 | 103 | 104 | ## Using private models from huggingface hub 105 | 106 | If you want to use private models from huggingface hub, then you need to login using `huggingface-cli login` command. 107 | 108 | Note: You can also save your generations directly to huggingface hub if your output path points to a huggingface hub dataset repo and you have access to push to that repository. Thus, you will end up saving a lot of disk space. 109 | 110 | ## Acknowledgements 111 | 112 | I would like to express my gratitude to the following individuals and organizations for sharing their code, which formed the basis of the implementation used in this project: 113 | 114 | - [Tencent ARC](https://github.com/TencentARC) for their code for the [GFPGAN](https://github.com/TencentARC/GFPGAN) package, which was used for image super-resolution. 115 | - [LexKoin](https://github.com/LexKoin) for their code for the [Real-ESRGAN-UpScale](https://github.com/LexKoin/Real-ESRGAN-UpScale) package, which was used for image enhancement. 116 | - [Hugging Face](https://github.com/huggingface) for their code for the [diffusers](https://github.com/huggingface/diffusers) package, which was used for optimizing the model's parameters. 117 | - [Abhishek Thakur](https://github.com/abhishekkrthakur) for sharing his code for the [diffuzers](https://github.com/abhishekkrthakur/diffuzers) package, which was also used for optimizing the model's parameters. 118 | 119 | I am grateful for their contributions to the open source community, which made this project possible. 120 | 121 | ## Contributing 122 | 123 | StableFusion is an open-source project, and we welcome contributions from the community. Whether you're a developer, designer, or user, there are many ways you can help make this project better. Here are a few ways you can get involved: 124 | 125 | - **Report issues:** If you find a bug or have a feature request, please open an issue on our [GitHub repository](https://github.com/NeuralRealm/StableFusion/issues). We appreciate detailed bug reports and constructive feedback. 126 | - **Submit pull requests:** If you're interested in contributing code, we welcome pull requests for bug fixes, new features, and documentation improvements. 127 | - **Spread the word:** If you enjoy using StableFusion, please help us spread the word by sharing it with your friends, colleagues, and social media networks. We appreciate any support you can give us! 128 | 129 | ### We believe that open-source software is the future of technology, and we're excited to have you join us in making StableFusion a success. Thank you for your support! 130 | -------------------------------------------------------------------------------- /diffuzers.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "!pip3 install stablefusion -q" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "/bin/bash: line 1: diffuzers: command not found\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "!stablefusion app --port 10000 --ngrok_key YOUR_NGROK_AUTHTOKEN --share" 27 | ] 28 | } 29 | ], 30 | "metadata": { 31 | "kernelspec": { 32 | "display_name": "diffuzers", 33 | "language": "python", 34 | "name": "python3" 35 | }, 36 | "language_info": { 37 | "codemirror_mode": { 38 | "name": "ipython", 39 | "version": 3 40 | }, 41 | "file_extension": ".py", 42 | "mimetype": "text/x-python", 43 | "name": "python", 44 | "nbconvert_exporter": "python", 45 | "pygments_lexer": "ipython3", 46 | "version": "3.9.16" 47 | }, 48 | "orig_nbformat": 4, 49 | "vscode": { 50 | "interpreter": { 51 | "hash": "e667c4130092153eeed1faae4ad5e1fb640895671448a3853c00b657af26211d" 52 | } 53 | } 54 | }, 55 | "nbformat": 4, 56 | "nbformat_minor": 2 57 | } 58 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = "StableFusion: webapp and api for 🤗 diffusers" 21 | copyright = "2023, NeuralRealm" 22 | author = "NeuralRealm" 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = [] 31 | 32 | # Add any paths that contain templates here, relative to this directory. 33 | templates_path = ["_templates"] 34 | 35 | # List of patterns, relative to source directory, that match files and 36 | # directories to ignore when looking for source files. 37 | # This pattern also affects html_static_path and html_extra_path. 38 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 39 | 40 | 41 | # -- Options for HTML output ------------------------------------------------- 42 | 43 | # The theme to use for HTML and HTML Help pages. See the documentation for 44 | # a list of builtin themes. 45 | # 46 | html_theme = "alabaster" 47 | 48 | # Add any paths that contain custom static files (such as style sheets) here, 49 | # relative to this directory. They are copied after the builtin static files, 50 | # so a file named "default.css" will overwrite the builtin "default.css". 51 | html_static_path = ["_static"] 52 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to StableFusion' documentation! 2 | ====================================================================== 3 | 4 | Diffuzers offers web app and also api for 🤗 diffusers. Installation is very simple. 5 | You can install via pip: 6 | 7 | .. code-block:: bash 8 | 9 | pip install stablefusion 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | :caption: Contents: 14 | 15 | 16 | 17 | Indices and tables 18 | ================== 19 | 20 | * :ref:`genindex` 21 | * :ref:`modindex` 22 | * :ref:`search` 23 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.15.0 2 | basicsr>=1.4.2 3 | diffusers>=0.14.0 4 | facexlib>=0.2.5 5 | fairscale==0.4.4 6 | fastapi==0.88.0 7 | gfpgan>=1.3.7 8 | huggingface_hub>=0.11.1 9 | loguru==0.6.0 10 | open_clip_torch==2.9.1 11 | opencv-python 12 | protobuf==3.20 13 | pyngrok==5.2.1 14 | python-multipart==0.0.5 15 | realesrgan>=0.2.5 16 | streamlit==1.16.0 17 | streamlit-drawable-canvas==0.9.0 18 | st-clickable-images==0.0.3 19 | timm==0.4.12 20 | torch>=1.12.0 21 | torchvision>=0.13.0 22 | transformers==4.25.1 23 | uvicorn==0.15.0 24 | OmegaConf 25 | tqdl 26 | xformers 27 | bitsandbytes 28 | safetensors 29 | controlnet_aux -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | version = attr: stablefusion.__version__ 3 | 4 | [isort] 5 | ensure_newline_before_comments = True 6 | force_grid_wrap = 0 7 | include_trailing_comma = True 8 | line_length = 119 9 | lines_after_imports = 2 10 | multi_line_output = 3 11 | use_parentheses = True 12 | 13 | [flake8] 14 | ignore = E203, E501, W503 15 | max-line-length = 119 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # pylint: enable=line-too-long 3 | 4 | from setuptools import find_packages, setup 5 | 6 | 7 | with open("README.md" , encoding='utf-8') as f: 8 | long_description = f.read() 9 | 10 | QUALITY_REQUIRE = [ 11 | "black~=22.0", 12 | "isort==5.8.0", 13 | "flake8==3.9.2", 14 | "mypy==0.901", 15 | ] 16 | 17 | TEST_REQUIRE = ["pytest", "pytest-cov"] 18 | 19 | EXTRAS_REQUIRE = { 20 | "dev": QUALITY_REQUIRE, 21 | "quality": QUALITY_REQUIRE, 22 | "test": TEST_REQUIRE, 23 | "docs": [ 24 | "recommonmark", 25 | "sphinx==3.1.2", 26 | "sphinx-markdown-tables", 27 | "sphinx-rtd-theme==0.4.3", 28 | "sphinx-copybutton", 29 | ], 30 | } 31 | 32 | with open("requirements.txt") as f: 33 | INSTALL_REQUIRES = f.read().splitlines() 34 | 35 | setup( 36 | name="stablefusion", 37 | description="StableFusion", 38 | long_description=long_description, 39 | long_description_content_type="text/markdown", 40 | author="NeuralRealm", 41 | url="https://github.com/NeuralRealm/StableFusion", 42 | packages=find_packages("."), 43 | entry_points={"console_scripts": ["stablefusion=stablefusion.cli.main:main"]}, 44 | install_requires=INSTALL_REQUIRES, 45 | extras_require=EXTRAS_REQUIRE, 46 | python_requires=">=3.7", 47 | classifiers=[ 48 | "Intended Audience :: Developers", 49 | "Intended Audience :: Education", 50 | "Intended Audience :: Science/Research", 51 | "License :: OSI Approved :: Apache Software License", 52 | "Operating System :: OS Independent", 53 | "Programming Language :: Python :: 3.8", 54 | "Programming Language :: Python :: 3.9", 55 | "Programming Language :: Python :: 3.10", 56 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 57 | ], 58 | keywords="stablefusion", 59 | include_package_data=True, 60 | ) 61 | -------------------------------------------------------------------------------- /stablefusion/Home.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import streamlit as st 4 | from loguru import logger 5 | 6 | from stablefusion import utils 7 | from stablefusion.scripts.x2image import X2Image 8 | import ast 9 | import os 10 | 11 | def read_model_list(): 12 | 13 | try: 14 | with open('{}/model_list.txt'.format(os.path.dirname(__file__)), 'r') as f: 15 | contents = f.read() 16 | except: 17 | with open('stablefusion/model_list.txt', 'r') as f: 18 | contents = f.read() 19 | 20 | model_list = ast.literal_eval(contents) 21 | 22 | return model_list 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument( 27 | "--output", 28 | type=str, 29 | required=False, 30 | default=None, 31 | help="Output path", 32 | ) 33 | parser.add_argument( 34 | "--device", 35 | type=str, 36 | required=True, 37 | help="Device to use, e.g. cpu, cuda, cuda:0, mps etc.", 38 | ) 39 | return parser.parse_args() 40 | 41 | 42 | def x2img_app(): 43 | with st.form("x2img_model_form"): 44 | col1, col2 = st.columns(2) 45 | with col1: 46 | model = st.selectbox( 47 | "Which model do you want to use?", 48 | options=read_model_list(), 49 | ) 50 | with col2: 51 | custom_pipeline = st.selectbox( 52 | "Custom pipeline", 53 | options=[ 54 | "Vanilla", 55 | "Long Prompt Weighting", 56 | ], 57 | index=0 if st.session_state.get("x2img_custom_pipeline") in (None, "Vanilla") else 1, 58 | ) 59 | 60 | with st.expander("Textual Inversion (Optional)"): 61 | token_identifier = st.text_input( 62 | "Token identifier", 63 | placeholder="" 64 | if st.session_state.get("textual_inversion_token_identifier") is None 65 | else st.session_state.textual_inversion_token_identifier, 66 | ) 67 | embeddings = st.text_input( 68 | "Embeddings", 69 | placeholder="https://huggingface.co/sd-concepts-library/axe-tattoo/resolve/main/learned_embeds.bin" 70 | if st.session_state.get("textual_inversion_embeddings") is None 71 | else st.session_state.textual_inversion_embeddings, 72 | ) 73 | submit = st.form_submit_button("Load model") 74 | 75 | if submit: 76 | st.session_state.x2img_model = model 77 | st.session_state.x2img_custom_pipeline = custom_pipeline 78 | st.session_state.textual_inversion_token_identifier = token_identifier 79 | st.session_state.textual_inversion_embeddings = embeddings 80 | cpipe = "lpw_stable_diffusion" if custom_pipeline == "Long Prompt Weighting" else None 81 | with st.spinner("Loading model..."): 82 | x2img = X2Image( 83 | model=model, 84 | device=st.session_state.device, 85 | output_path=st.session_state.output_path, 86 | custom_pipeline=cpipe, 87 | token_identifier=token_identifier, 88 | embeddings_url=embeddings, 89 | ) 90 | st.session_state.x2img = x2img 91 | if "x2img" in st.session_state: 92 | st.write(f"Current model: {st.session_state.x2img}") 93 | st.session_state.x2img.app() 94 | 95 | 96 | def run_app(): 97 | utils.create_base_page() 98 | x2img_app() 99 | 100 | 101 | if __name__ == "__main__": 102 | args = parse_args() 103 | logger.info(f"Args: {args}") 104 | logger.info(st.session_state) 105 | st.session_state.device = args.device 106 | st.session_state.output_path = args.output 107 | run_app() 108 | -------------------------------------------------------------------------------- /stablefusion/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from loguru import logger 4 | 5 | 6 | logger.configure(handlers=[dict(sink=sys.stderr, format="> {level:<7} {message}")]) 7 | 8 | __version__ = "0.1.4" 9 | -------------------------------------------------------------------------------- /stablefusion/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralRealm/StableFusion/9268d686f49b45194a2251c6728e74191f298654/stablefusion/api/__init__.py -------------------------------------------------------------------------------- /stablefusion/api/main.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | 4 | from fastapi import Depends, FastAPI, File, UploadFile 5 | from loguru import logger 6 | from PIL import Image 7 | from starlette.middleware.cors import CORSMiddleware 8 | 9 | from stablefusion.api.schemas import Img2ImgParams, ImgResponse, InpaintingParams, InstructPix2PixParams, Text2ImgParams 10 | from stablefusion.api.utils import convert_to_b64_list 11 | from stablefusion.scripts.inpainting import Inpainting 12 | from stablefusion.scripts.x2image import X2Image 13 | 14 | 15 | app = FastAPI( 16 | title="StableFusion api", 17 | license_info={ 18 | "name": "Apache 2.0", 19 | "url": "https://www.apache.org/licenses/LICENSE-2.0.html", 20 | }, 21 | ) 22 | app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) 23 | 24 | 25 | @app.on_event("startup") 26 | async def startup_event(): 27 | 28 | x2img_model = os.environ.get("X2IMG_MODEL") 29 | x2img_pipeline = os.environ.get("X2IMG_PIPELINE") 30 | inpainting_model = os.environ.get("INPAINTING_MODEL") 31 | device = os.environ.get("DEVICE") 32 | output_path = os.environ.get("OUTPUT_PATH") 33 | ti_identifier = os.environ.get("TOKEN_IDENTIFIER", "") 34 | ti_embeddings_url = os.environ.get("TOKEN_EMBEDDINGS_URL", "") 35 | logger.info("@@@@@ Starting Diffuzes API @@@@@ ") 36 | logger.info(f"Text2Image Model: {x2img_model}") 37 | logger.info(f"Text2Image Pipeline: {x2img_pipeline if x2img_pipeline is not None else 'Vanilla'}") 38 | logger.info(f"Inpainting Model: {inpainting_model}") 39 | logger.info(f"Device: {device}") 40 | logger.info(f"Output Path: {output_path}") 41 | logger.info(f"Token Identifier: {ti_identifier}") 42 | logger.info(f"Token Embeddings URL: {ti_embeddings_url}") 43 | 44 | logger.info("Loading x2img model...") 45 | if x2img_model is not None: 46 | app.state.x2img_model = X2Image( 47 | model=x2img_model, 48 | device=device, 49 | output_path=output_path, 50 | custom_pipeline=x2img_pipeline, 51 | token_identifier=ti_identifier, 52 | embeddings_url=ti_embeddings_url, 53 | ) 54 | else: 55 | app.state.x2img_model = None 56 | logger.info("Loading inpainting model...") 57 | if inpainting_model is not None: 58 | app.state.inpainting_model = Inpainting( 59 | model=inpainting_model, 60 | device=device, 61 | output_path=output_path, 62 | ) 63 | logger.info("API is ready to use!") 64 | 65 | 66 | @app.post("/text2img") 67 | async def text2img(params: Text2ImgParams) -> ImgResponse: 68 | logger.info(f"Params: {params}") 69 | if app.state.x2img_model is None: 70 | return {"error": "x2img model is not loaded"} 71 | 72 | images, _ = app.state.x2img_model.text2img_generate( 73 | params.prompt, 74 | num_images=params.num_images, 75 | steps=params.steps, 76 | seed=params.seed, 77 | negative_prompt=params.negative_prompt, 78 | scheduler=params.scheduler, 79 | image_size=(params.image_height, params.image_width), 80 | guidance_scale=params.guidance_scale, 81 | ) 82 | base64images = convert_to_b64_list(images) 83 | return ImgResponse(images=base64images, metadata=params.dict()) 84 | 85 | 86 | @app.post("/img2img") 87 | async def img2img(params: Img2ImgParams = Depends(), image: UploadFile = File(...)) -> ImgResponse: 88 | if app.state.x2img_model is None: 89 | return {"error": "x2img model is not loaded"} 90 | image = Image.open(io.BytesIO(image.file.read())) 91 | images, _ = app.state.x2img_model.img2img_generate( 92 | image=image, 93 | prompt=params.prompt, 94 | negative_prompt=params.negative_prompt, 95 | num_images=params.num_images, 96 | steps=params.steps, 97 | seed=params.seed, 98 | scheduler=params.scheduler, 99 | guidance_scale=params.guidance_scale, 100 | strength=params.strength, 101 | ) 102 | base64images = convert_to_b64_list(images) 103 | return ImgResponse(images=base64images, metadata=params.dict()) 104 | 105 | 106 | @app.post("/instruct-pix2pix") 107 | async def instruct_pix2pix(params: InstructPix2PixParams = Depends(), image: UploadFile = File(...)) -> ImgResponse: 108 | if app.state.x2img_model is None: 109 | return {"error": "x2img model is not loaded"} 110 | image = Image.open(io.BytesIO(image.file.read())) 111 | images, _ = app.state.x2img_model.pix2pix_generate( 112 | image=image, 113 | prompt=params.prompt, 114 | negative_prompt=params.negative_prompt, 115 | num_images=params.num_images, 116 | steps=params.steps, 117 | seed=params.seed, 118 | scheduler=params.scheduler, 119 | guidance_scale=params.guidance_scale, 120 | image_guidance_scale=params.image_guidance_scale, 121 | ) 122 | base64images = convert_to_b64_list(images) 123 | return ImgResponse(images=base64images, metadata=params.dict()) 124 | 125 | 126 | @app.post("/inpainting") 127 | async def inpainting( 128 | params: InpaintingParams = Depends(), image: UploadFile = File(...), mask: UploadFile = File(...) 129 | ) -> ImgResponse: 130 | if app.state.inpainting_model is None: 131 | return {"error": "inpainting model is not loaded"} 132 | image = Image.open(io.BytesIO(image.file.read())) 133 | mask = Image.open(io.BytesIO(mask.file.read())) 134 | images, _ = app.state.inpainting_model.generate_image( 135 | image=image, 136 | mask=mask, 137 | prompt=params.prompt, 138 | negative_prompt=params.negative_prompt, 139 | scheduler=params.scheduler, 140 | height=params.image_height, 141 | width=params.image_width, 142 | num_images=params.num_images, 143 | guidance_scale=params.guidance_scale, 144 | steps=params.steps, 145 | seed=params.seed, 146 | ) 147 | base64images = convert_to_b64_list(images) 148 | return ImgResponse(images=base64images, metadata=params.dict()) 149 | 150 | 151 | @app.get("/") 152 | def read_root(): 153 | return {"Hello": "World"} 154 | -------------------------------------------------------------------------------- /stablefusion/api/schemas.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | 6 | class Text2ImgParams(BaseModel): 7 | prompt: str = Field(..., description="Text prompt for the model") 8 | negative_prompt: str = Field(None, description="Negative text prompt for the model") 9 | scheduler: str = Field("EulerAncestralDiscreteScheduler", description="Scheduler to use for the model") 10 | image_height: int = Field(512, description="Image height") 11 | image_width: int = Field(512, description="Image width") 12 | num_images: int = Field(1, description="Number of images to generate") 13 | guidance_scale: float = Field(7, description="Guidance scale") 14 | steps: int = Field(50, description="Number of steps to run the model for") 15 | seed: int = Field(42, description="Seed for the model") 16 | 17 | 18 | class Img2ImgParams(BaseModel): 19 | prompt: str = Field(..., description="Text prompt for the model") 20 | negative_prompt: str = Field(None, description="Negative text prompt for the model") 21 | scheduler: str = Field("EulerAncestralDiscreteScheduler", description="Scheduler to use for the model") 22 | strength: float = Field(0.7, description="Strength") 23 | num_images: int = Field(1, description="Number of images to generate") 24 | guidance_scale: float = Field(7, description="Guidance scale") 25 | steps: int = Field(50, description="Number of steps to run the model for") 26 | seed: int = Field(42, description="Seed for the model") 27 | 28 | 29 | class InstructPix2PixParams(BaseModel): 30 | prompt: str = Field(..., description="Text prompt for the model") 31 | negative_prompt: str = Field(None, description="Negative text prompt for the model") 32 | scheduler: str = Field("EulerAncestralDiscreteScheduler", description="Scheduler to use for the model") 33 | num_images: int = Field(1, description="Number of images to generate") 34 | guidance_scale: float = Field(7, description="Guidance scale") 35 | image_guidance_scale: float = Field(1.5, description="Image guidance scale") 36 | steps: int = Field(50, description="Number of steps to run the model for") 37 | seed: int = Field(42, description="Seed for the model") 38 | 39 | 40 | class ImgResponse(BaseModel): 41 | images: List[str] = Field(..., description="List of images in base64 format") 42 | metadata: Dict = Field(..., description="Metadata") 43 | 44 | 45 | class InpaintingParams(BaseModel): 46 | prompt: str = Field(..., description="Text prompt for the model") 47 | negative_prompt: str = Field(None, description="Negative text prompt for the model") 48 | scheduler: str = Field("EulerAncestralDiscreteScheduler", description="Scheduler to use for the model") 49 | image_height: int = Field(512, description="Image height") 50 | image_width: int = Field(512, description="Image width") 51 | num_images: int = Field(1, description="Number of images to generate") 52 | guidance_scale: float = Field(7, description="Guidance scale") 53 | steps: int = Field(50, description="Number of steps to run the model for") 54 | seed: int = Field(42, description="Seed for the model") 55 | -------------------------------------------------------------------------------- /stablefusion/api/utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | 4 | 5 | def convert_to_b64_list(images): 6 | base64images = [] 7 | for image in images: 8 | buf = io.BytesIO() 9 | image.save(buf, format="PNG") 10 | byte_im = base64.b64encode(buf.getvalue()) 11 | base64images.append(byte_im) 12 | return base64images 13 | -------------------------------------------------------------------------------- /stablefusion/cli/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from argparse import ArgumentParser 3 | 4 | 5 | class BaseStableFusionCommand(ABC): 6 | @staticmethod 7 | @abstractmethod 8 | def register_subcommand(parser: ArgumentParser): 9 | raise NotImplementedError() 10 | 11 | @abstractmethod 12 | def run(self): 13 | raise NotImplementedError() 14 | -------------------------------------------------------------------------------- /stablefusion/cli/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from .. import __version__ 4 | from .run_api import RunStableFusionAPICommand 5 | from .run_app import RunStableFusionAppCommand 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser( 10 | "StableFusion CLI", 11 | usage="stablefusion []", 12 | epilog="For more information about a command, run: `stablefusion --help`", 13 | ) 14 | parser.add_argument("--version", "-v", help="Display stablefusion version", action="store_true") 15 | commands_parser = parser.add_subparsers(help="commands") 16 | 17 | # Register commands 18 | RunStableFusionAppCommand.register_subcommand(commands_parser) 19 | #RunStableFusionAPICommand.register_subcommand(commands_parser) 20 | 21 | args = parser.parse_args() 22 | 23 | if args.version: 24 | print(__version__) 25 | exit(0) 26 | 27 | if not hasattr(args, "func"): 28 | parser.print_help() 29 | exit(1) 30 | 31 | command = args.func(args) 32 | command.run() 33 | 34 | 35 | if __name__ == "__main__": 36 | main() 37 | -------------------------------------------------------------------------------- /stablefusion/cli/run_api.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from argparse import ArgumentParser 3 | 4 | import torch 5 | 6 | from . import BaseStableFusionCommand 7 | 8 | 9 | def run_api_command_factory(args): 10 | return RunStableFusionAPICommand( 11 | args.output, 12 | args.port, 13 | args.host, 14 | args.device, 15 | args.workers, 16 | args.ssl_certfile, 17 | args.ssl_keyfile, 18 | ) 19 | 20 | 21 | class RunStableFusionAPICommand(BaseStableFusionCommand): 22 | @staticmethod 23 | def register_subcommand(parser: ArgumentParser): 24 | run_api_parser = parser.add_parser( 25 | "api", 26 | description="✨ Run StableFusion api", 27 | ) 28 | run_api_parser.add_argument( 29 | "--output", 30 | type=str, 31 | required=False, 32 | help="Output path is optional, but if provided, all generations will automatically be saved to this path.", 33 | ) 34 | run_api_parser.add_argument( 35 | "--port", 36 | type=int, 37 | default=10000, 38 | help="Port to run the app on", 39 | required=False, 40 | ) 41 | run_api_parser.add_argument( 42 | "--host", 43 | type=str, 44 | default="127.0.0.1", 45 | help="Host to run the app on", 46 | required=False, 47 | ) 48 | run_api_parser.add_argument( 49 | "--device", 50 | type=str, 51 | required=False, 52 | help="Device to use, e.g. cpu, cuda, cuda:0, mps (for m1 mac) etc.", 53 | ) 54 | run_api_parser.add_argument( 55 | "--workers", 56 | type=int, 57 | required=False, 58 | default=1, 59 | help="Number of workers to use", 60 | ) 61 | run_api_parser.add_argument( 62 | "--ssl_certfile", 63 | type=str, 64 | required=False, 65 | help="the path to your ssl cert", 66 | ) 67 | run_api_parser.add_argument( 68 | "--ssl_keyfile", 69 | type=str, 70 | required=False, 71 | help="the path to your ssl key", 72 | ) 73 | run_api_parser.set_defaults(func=run_api_command_factory) 74 | 75 | def __init__(self, output, port, host, device, workers, ssl_certfile, ssl_keyfile): 76 | self.output = output 77 | self.port = port 78 | self.host = host 79 | self.device = device 80 | self.workers = workers 81 | self.ssl_certfile = ssl_certfile 82 | self.ssl_keyfile = ssl_keyfile 83 | 84 | if self.device is None: 85 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 86 | 87 | self.port = str(self.port) 88 | self.workers = str(self.workers) 89 | 90 | def run(self): 91 | cmd = [ 92 | "uvicorn", 93 | "stablefusion.api.main:app", 94 | "--host", 95 | self.host, 96 | "--port", 97 | self.port, 98 | "--workers", 99 | self.workers, 100 | ] 101 | if self.ssl_certfile is not None: 102 | cmd.extend(["--ssl-certfile", self.ssl_certfile]) 103 | if self.ssl_keyfile is not None: 104 | cmd.extend(["--ssl-keyfile", self.ssl_keyfile]) 105 | 106 | proc = subprocess.Popen( 107 | cmd, 108 | stdout=subprocess.PIPE, 109 | stderr=subprocess.STDOUT, 110 | shell=False, 111 | universal_newlines=True, 112 | bufsize=1, 113 | ) 114 | with proc as p: 115 | try: 116 | for line in p.stdout: 117 | print(line, end="") 118 | except KeyboardInterrupt: 119 | print("Killing api") 120 | p.kill() 121 | p.wait() 122 | raise 123 | -------------------------------------------------------------------------------- /stablefusion/cli/run_app.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from argparse import ArgumentParser 3 | 4 | import torch 5 | from pyngrok import ngrok 6 | 7 | from . import BaseStableFusionCommand 8 | 9 | 10 | def run_app_command_factory(args): 11 | return RunStableFusionAppCommand( 12 | args.output, 13 | args.share, 14 | args.port, 15 | args.host, 16 | args.device, 17 | args.ngrok_key, 18 | ) 19 | 20 | 21 | class RunStableFusionAppCommand(BaseStableFusionCommand): 22 | @staticmethod 23 | def register_subcommand(parser: ArgumentParser): 24 | run_app_parser = parser.add_parser( 25 | "app", 26 | description="✨ Run stablefusion app", 27 | ) 28 | run_app_parser.add_argument( 29 | "--output", 30 | type=str, 31 | required=False, 32 | help="Output path is optional, but if provided, all generations will automatically be saved to this path.", 33 | ) 34 | run_app_parser.add_argument( 35 | "--share", 36 | action="store_true", 37 | help="Share the app", 38 | ) 39 | run_app_parser.add_argument( 40 | "--port", 41 | type=int, 42 | default=10000, 43 | help="Port to run the app on", 44 | required=False, 45 | ) 46 | run_app_parser.add_argument( 47 | "--host", 48 | type=str, 49 | default="127.0.0.1", 50 | help="Host to run the app on", 51 | required=False, 52 | ) 53 | run_app_parser.add_argument( 54 | "--device", 55 | type=str, 56 | required=False, 57 | help="Device to use, e.g. cpu, cuda, cuda:0, mps (for m1 mac) etc.", 58 | ) 59 | run_app_parser.add_argument( 60 | "--ngrok_key", 61 | type=str, 62 | required=False, 63 | help="Ngrok key to use for sharing the app. Only required if you want to share the app", 64 | ) 65 | 66 | run_app_parser.set_defaults(func=run_app_command_factory) 67 | 68 | def __init__(self, output, share, port, host, device, ngrok_key): 69 | self.output = output 70 | self.share = share 71 | self.port = port 72 | self.host = host 73 | self.device = device 74 | self.ngrok_key = ngrok_key 75 | 76 | if self.device is None: 77 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 78 | 79 | if self.share: 80 | if self.ngrok_key is None: 81 | raise ValueError( 82 | "ngrok key is required if you want to share the app. Get it for free from https://dashboard.ngrok.com/get-started/your-authtoken" 83 | ) 84 | 85 | def run(self): 86 | # from ..app import stablefusion 87 | 88 | # print(self.share) 89 | # app = stablefusion(self.model, self.output).app() 90 | # app.launch(show_api=False, share=self.share, server_port=self.port, server_name=self.host) 91 | import os 92 | 93 | dirname = os.path.dirname(__file__) 94 | filename = os.path.join(dirname, "..", "Home.py") 95 | cmd = [ 96 | "streamlit", 97 | "run", 98 | filename, 99 | "--browser.gatherUsageStats", 100 | "false", 101 | "--browser.serverAddress", 102 | self.host, 103 | "--server.port", 104 | str(self.port), 105 | "--theme.base", 106 | "dark", 107 | "--", 108 | "--device", 109 | self.device, 110 | ] 111 | if self.output is not None: 112 | cmd.extend(["--output", self.output]) 113 | 114 | if self.share: 115 | ngrok.set_auth_token(self.ngrok_key) 116 | public_url = ngrok.connect(self.port).public_url 117 | print(f"Sharing app at {public_url}") 118 | 119 | proc = subprocess.Popen( 120 | cmd, 121 | stdout=subprocess.PIPE, 122 | stderr=subprocess.STDOUT, 123 | shell=False, 124 | universal_newlines=True, 125 | bufsize=1, 126 | ) 127 | with proc as p: 128 | try: 129 | for line in p.stdout: 130 | print(line, end="") 131 | except KeyboardInterrupt: 132 | print("Killing streamlit app") 133 | p.kill() 134 | if self.share: 135 | print("Killing ngrok tunnel") 136 | ngrok.kill() 137 | p.wait() 138 | raise 139 | -------------------------------------------------------------------------------- /stablefusion/data/mediums.txt: -------------------------------------------------------------------------------- 1 | a 3D render 2 | a black and white photo 3 | a bronze sculpture 4 | a cartoon 5 | a cave painting 6 | a character portrait 7 | a charcoal drawing 8 | a child's drawing 9 | a color pencil sketch 10 | a colorized photo 11 | a comic book panel 12 | a computer rendering 13 | a cross stitch 14 | a cubist painting 15 | a detailed drawing 16 | a detailed matte painting 17 | a detailed painting 18 | a diagram 19 | a digital painting 20 | a digital rendering 21 | a drawing 22 | a fine art painting 23 | a flemish Baroque 24 | a gouache 25 | a hologram 26 | a hyperrealistic painting 27 | a jigsaw puzzle 28 | a low poly render 29 | a macro photograph 30 | a manga drawing 31 | a marble sculpture 32 | a matte painting 33 | a microscopic photo 34 | a mid-nineteenth century engraving 35 | a minimalist painting 36 | a mosaic 37 | a painting 38 | a pastel 39 | a pencil sketch 40 | a photo 41 | a photocopy 42 | a photorealistic painting 43 | a picture 44 | a pointillism painting 45 | a polaroid photo 46 | a pop art painting 47 | a portrait 48 | a poster 49 | a raytraced image 50 | a renaissance painting 51 | a screenprint 52 | a screenshot 53 | a silk screen 54 | a sketch 55 | a statue 56 | a still life 57 | a stipple 58 | a stock photo 59 | a storybook illustration 60 | a surrealist painting 61 | a surrealist sculpture 62 | a tattoo 63 | a tilt shift photo 64 | a watercolor painting 65 | a wireframe diagram 66 | a woodcut 67 | an abstract drawing 68 | an abstract painting 69 | an abstract sculpture 70 | an acrylic painting 71 | an airbrush painting 72 | an album cover 73 | an ambient occlusion render 74 | an anime drawing 75 | an art deco painting 76 | an art deco sculpture 77 | an engraving 78 | an etching 79 | an illustration of 80 | an impressionist painting 81 | an ink drawing 82 | an oil on canvas painting 83 | an oil painting 84 | an ultrafine detailed painting 85 | chalk art 86 | computer graphics 87 | concept art 88 | cyberpunk art 89 | digital art 90 | egyptian art 91 | graffiti art 92 | lineart 93 | pixel art 94 | poster art 95 | vector art 96 | -------------------------------------------------------------------------------- /stablefusion/data/movements.txt: -------------------------------------------------------------------------------- 1 | abstract art 2 | abstract expressionism 3 | abstract illusionism 4 | academic art 5 | action painting 6 | aestheticism 7 | afrofuturism 8 | altermodern 9 | american barbizon school 10 | american impressionism 11 | american realism 12 | american romanticism 13 | american scene painting 14 | analytical art 15 | antipodeans 16 | arabesque 17 | arbeitsrat für kunst 18 | art & language 19 | art brut 20 | art deco 21 | art informel 22 | art nouveau 23 | art photography 24 | arte povera 25 | arts and crafts movement 26 | ascii art 27 | ashcan school 28 | assemblage 29 | australian tonalism 30 | auto-destructive art 31 | barbizon school 32 | baroque 33 | bauhaus 34 | bengal school of art 35 | berlin secession 36 | black arts movement 37 | brutalism 38 | classical realism 39 | cloisonnism 40 | cobra 41 | color field 42 | computer art 43 | conceptual art 44 | concrete art 45 | constructivism 46 | context art 47 | crayon art 48 | crystal cubism 49 | cubism 50 | cubo-futurism 51 | cynical realism 52 | dada 53 | danube school 54 | dau-al-set 55 | de stijl 56 | deconstructivism 57 | digital art 58 | ecological art 59 | environmental art 60 | excessivism 61 | expressionism 62 | fantastic realism 63 | fantasy art 64 | fauvism 65 | feminist art 66 | figuration libre 67 | figurative art 68 | figurativism 69 | fine art 70 | fluxus 71 | folk art 72 | funk art 73 | furry art 74 | futurism 75 | generative art 76 | geometric abstract art 77 | german romanticism 78 | gothic art 79 | graffiti 80 | gutai group 81 | happening 82 | harlem renaissance 83 | heidelberg school 84 | holography 85 | hudson river school 86 | hurufiyya 87 | hypermodernism 88 | hyperrealism 89 | impressionism 90 | incoherents 91 | institutional critique 92 | interactive art 93 | international gothic 94 | international typographic style 95 | kinetic art 96 | kinetic pointillism 97 | kitsch movement 98 | land art 99 | les automatistes 100 | les nabis 101 | letterism 102 | light and space 103 | lowbrow 104 | lyco art 105 | lyrical abstraction 106 | magic realism 107 | magical realism 108 | mail art 109 | mannerism 110 | massurrealism 111 | maximalism 112 | metaphysical painting 113 | mingei 114 | minimalism 115 | modern european ink painting 116 | modernism 117 | modular constructivism 118 | naive art 119 | naturalism 120 | neo-dada 121 | neo-expressionism 122 | neo-fauvism 123 | neo-figurative 124 | neo-primitivism 125 | neo-romanticism 126 | neoclassicism 127 | neogeo 128 | neoism 129 | neoplasticism 130 | net art 131 | new objectivity 132 | new sculpture 133 | northwest school 134 | nuclear art 135 | objective abstraction 136 | op art 137 | optical illusion 138 | orphism 139 | panfuturism 140 | paris school 141 | photorealism 142 | pixel art 143 | plasticien 144 | plein air 145 | pointillism 146 | pop art 147 | pop surrealism 148 | post-impressionism 149 | postminimalism 150 | pre-raphaelitism 151 | precisionism 152 | primitivism 153 | private press 154 | process art 155 | psychedelic art 156 | purism 157 | qajar art 158 | quito school 159 | rasquache 160 | rayonism 161 | realism 162 | regionalism 163 | remodernism 164 | renaissance 165 | retrofuturism 166 | rococo 167 | romanesque 168 | romanticism 169 | samikshavad 170 | serial art 171 | shin hanga 172 | shock art 173 | socialist realism 174 | sots art 175 | space art 176 | street art 177 | stuckism 178 | sumatraism 179 | superflat 180 | suprematism 181 | surrealism 182 | symbolism 183 | synchromism 184 | synthetism 185 | sōsaku hanga 186 | tachisme 187 | temporary art 188 | tonalism 189 | toyism 190 | transgressive art 191 | ukiyo-e 192 | underground comix 193 | unilalianism 194 | vancouver school 195 | vanitas 196 | verdadism 197 | video art 198 | viennese actionism 199 | visual art 200 | vorticism 201 | -------------------------------------------------------------------------------- /stablefusion/model_list.txt: -------------------------------------------------------------------------------- 1 | ['runwayml/stable-diffusion-v1-5', 'stabilityai/stable-diffusion-2-1', 'Linaqruf/anything-v3.0', 'Envvi/Inkpunk-Diffusion', 'prompthero/openjourney', 'darkstorm2150/Protogen_v2.2_Official_Release', 'darkstorm2150/Protogen_x3.4_Official_Release'] -------------------------------------------------------------------------------- /stablefusion/models/ckpt_interference/t: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralRealm/StableFusion/9268d686f49b45194a2251c6728e74191f298654/stablefusion/models/ckpt_interference/t -------------------------------------------------------------------------------- /stablefusion/models/cpkt_models/t: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralRealm/StableFusion/9268d686f49b45194a2251c6728e74191f298654/stablefusion/models/cpkt_models/t -------------------------------------------------------------------------------- /stablefusion/models/diffusion_models/t.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralRealm/StableFusion/9268d686f49b45194a2251c6728e74191f298654/stablefusion/models/diffusion_models/t.txt -------------------------------------------------------------------------------- /stablefusion/models/realesrgan/inference_realesrgan.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import glob 4 | import os 5 | from basicsr.archs.rrdbnet_arch import RRDBNet 6 | from basicsr.utils.download_util import load_file_from_url 7 | import tempfile 8 | from realesrgan import RealESRGANer 9 | from realesrgan.archs.srvgg_arch import SRVGGNetCompact 10 | import numpy as np 11 | from stablefusion import utils 12 | import streamlit as st 13 | from PIL import Image 14 | from io import BytesIO 15 | import base64 16 | import datetime 17 | 18 | def main(model_name, denoise_strength, tile, tile_pad, pre_pad, fp32, gpu_id, face_enhance, outscale, input_image, model_path): 19 | # determine models according to model names 20 | model_name = model_name.split('.')[0] 21 | if model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model 22 | model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) 23 | netscale = 4 24 | file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'] 25 | elif model_name == 'RealESRNet_x4plus': # x4 RRDBNet model 26 | model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) 27 | netscale = 4 28 | file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth'] 29 | elif model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks 30 | model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) 31 | netscale = 4 32 | file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'] 33 | elif model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model 34 | model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) 35 | netscale = 2 36 | file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'] 37 | elif model_name == 'realesr-animevideov3': # x4 VGG-style model (XS size) 38 | model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') 39 | netscale = 4 40 | file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth'] 41 | elif model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size) 42 | model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') 43 | netscale = 4 44 | file_url = [ 45 | 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth', 46 | 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth' 47 | ] 48 | 49 | # determine model paths 50 | if model_path is not None: 51 | model_path = model_path 52 | else: 53 | model_path = os.path.join('weights', model_name + '.pth') 54 | if not os.path.isfile(model_path): 55 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 56 | for url in file_url: 57 | # model_path will be updated 58 | model_path = load_file_from_url( 59 | url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None) 60 | 61 | # use dni to control the denoise strength 62 | dni_weight = None 63 | if model_name == 'realesr-general-x4v3' and denoise_strength != 1: 64 | wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3') 65 | model_path = [model_path, wdn_model_path] 66 | dni_weight = [denoise_strength, 1 - denoise_strength] 67 | 68 | # restorer 69 | upsampler = RealESRGANer( 70 | scale=netscale, 71 | model_path=model_path, 72 | dni_weight=dni_weight, 73 | model=model, 74 | tile=tile, 75 | tile_pad=tile_pad, 76 | pre_pad=pre_pad, 77 | half=not fp32, 78 | gpu_id=gpu_id) 79 | 80 | if face_enhance: # Use GFPGAN for face enhancement 81 | from gfpgan import GFPGANer 82 | face_enhancer = GFPGANer( 83 | model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', 84 | upscale=outscale, 85 | arch='clean', 86 | channel_multiplier=2, 87 | bg_upsampler=upsampler) 88 | 89 | 90 | 91 | image_array = np.asarray(bytearray(input_image.read()), dtype=np.uint8) 92 | 93 | img = cv2.imdecode(image_array, cv2.IMREAD_UNCHANGED) 94 | 95 | 96 | try: 97 | if face_enhance: 98 | _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) 99 | else: 100 | output, _ = upsampler.enhance(img, outscale=outscale) 101 | 102 | 103 | img = cv2.cvtColor(output, cv2.COLOR_BGR2RGB) 104 | st.image(img) 105 | 106 | output_img = Image.fromarray(img) 107 | buffered = BytesIO() 108 | output_img.save(buffered, format="PNG") 109 | img_str = base64.b64encode(buffered.getvalue()).decode() 110 | now = datetime.datetime.now() 111 | formatted_date_time = now.strftime("%Y-%m-%d_%H_%M_%S") 112 | href = f'

Download Image

' 113 | st.markdown(href, unsafe_allow_html=True) 114 | 115 | 116 | except RuntimeError as error: 117 | print('Error', error) 118 | print('If you encounter CUDA out of memory, try to set --tile with a smaller number.') 119 | 120 | 121 | 122 | if __name__ == '__main__': 123 | main() 124 | -------------------------------------------------------------------------------- /stablefusion/models/realesrgan/options/finetune_realesrgan_x4plus.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: finetune_RealESRGANx4plus_400k 3 | model_type: RealESRGANModel 4 | scale: 4 5 | num_gpu: auto 6 | manual_seed: 0 7 | 8 | # ----------------- options for synthesizing training data in RealESRGANModel ----------------- # 9 | # USM the ground-truth 10 | l1_gt_usm: True 11 | percep_gt_usm: True 12 | gan_gt_usm: False 13 | 14 | # the first degradation process 15 | resize_prob: [0.2, 0.7, 0.1] # up, down, keep 16 | resize_range: [0.15, 1.5] 17 | gaussian_noise_prob: 0.5 18 | noise_range: [1, 30] 19 | poisson_scale_range: [0.05, 3] 20 | gray_noise_prob: 0.4 21 | jpeg_range: [30, 95] 22 | 23 | # the second degradation process 24 | second_blur_prob: 0.8 25 | resize_prob2: [0.3, 0.4, 0.3] # up, down, keep 26 | resize_range2: [0.3, 1.2] 27 | gaussian_noise_prob2: 0.5 28 | noise_range2: [1, 25] 29 | poisson_scale_range2: [0.05, 2.5] 30 | gray_noise_prob2: 0.4 31 | jpeg_range2: [30, 95] 32 | 33 | gt_size: 256 34 | queue_size: 180 35 | 36 | # dataset and data loader settings 37 | datasets: 38 | train: 39 | name: DF2K+OST 40 | type: RealESRGANDataset 41 | dataroot_gt: datasets/DF2K 42 | meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt 43 | io_backend: 44 | type: disk 45 | 46 | blur_kernel_size: 21 47 | kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] 48 | kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] 49 | sinc_prob: 0.1 50 | blur_sigma: [0.2, 3] 51 | betag_range: [0.5, 4] 52 | betap_range: [1, 2] 53 | 54 | blur_kernel_size2: 21 55 | kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] 56 | kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] 57 | sinc_prob2: 0.1 58 | blur_sigma2: [0.2, 1.5] 59 | betag_range2: [0.5, 4] 60 | betap_range2: [1, 2] 61 | 62 | final_sinc_prob: 0.8 63 | 64 | gt_size: 256 65 | use_hflip: True 66 | use_rot: False 67 | 68 | # data loader 69 | use_shuffle: true 70 | num_worker_per_gpu: 5 71 | batch_size_per_gpu: 12 72 | dataset_enlarge_ratio: 1 73 | prefetch_mode: ~ 74 | 75 | # Uncomment these for validation 76 | # val: 77 | # name: validation 78 | # type: PairedImageDataset 79 | # dataroot_gt: path_to_gt 80 | # dataroot_lq: path_to_lq 81 | # io_backend: 82 | # type: disk 83 | 84 | # network structures 85 | network_g: 86 | type: RRDBNet 87 | num_in_ch: 3 88 | num_out_ch: 3 89 | num_feat: 64 90 | num_block: 23 91 | num_grow_ch: 32 92 | 93 | network_d: 94 | type: UNetDiscriminatorSN 95 | num_in_ch: 3 96 | num_feat: 64 97 | skip_connection: True 98 | 99 | # path 100 | path: 101 | # use the pre-trained Real-ESRNet model 102 | pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth 103 | param_key_g: params_ema 104 | strict_load_g: true 105 | pretrain_network_d: experiments/pretrained_models/RealESRGAN_x4plus_netD.pth 106 | param_key_d: params 107 | strict_load_d: true 108 | resume_state: ~ 109 | 110 | # training settings 111 | train: 112 | ema_decay: 0.999 113 | optim_g: 114 | type: Adam 115 | lr: !!float 1e-4 116 | weight_decay: 0 117 | betas: [0.9, 0.99] 118 | optim_d: 119 | type: Adam 120 | lr: !!float 1e-4 121 | weight_decay: 0 122 | betas: [0.9, 0.99] 123 | 124 | scheduler: 125 | type: MultiStepLR 126 | milestones: [400000] 127 | gamma: 0.5 128 | 129 | total_iter: 400000 130 | warmup_iter: -1 # no warm up 131 | 132 | # losses 133 | pixel_opt: 134 | type: L1Loss 135 | loss_weight: 1.0 136 | reduction: mean 137 | # perceptual loss (content and style losses) 138 | perceptual_opt: 139 | type: PerceptualLoss 140 | layer_weights: 141 | # before relu 142 | 'conv1_2': 0.1 143 | 'conv2_2': 0.1 144 | 'conv3_4': 1 145 | 'conv4_4': 1 146 | 'conv5_4': 1 147 | vgg_type: vgg19 148 | use_input_norm: true 149 | perceptual_weight: !!float 1.0 150 | style_weight: 0 151 | range_norm: false 152 | criterion: l1 153 | # gan loss 154 | gan_opt: 155 | type: GANLoss 156 | gan_type: vanilla 157 | real_label_val: 1.0 158 | fake_label_val: 0.0 159 | loss_weight: !!float 1e-1 160 | 161 | net_d_iters: 1 162 | net_d_init_iters: 0 163 | 164 | # Uncomment these for validation 165 | # validation settings 166 | # val: 167 | # val_freq: !!float 5e3 168 | # save_img: True 169 | 170 | # metrics: 171 | # psnr: # metric name 172 | # type: calculate_psnr 173 | # crop_border: 4 174 | # test_y_channel: false 175 | 176 | # logging settings 177 | logger: 178 | print_freq: 100 179 | save_checkpoint_freq: !!float 5e3 180 | use_tb_logger: true 181 | wandb: 182 | project: ~ 183 | resume_id: ~ 184 | 185 | # dist training settings 186 | dist_params: 187 | backend: nccl 188 | port: 29500 189 | -------------------------------------------------------------------------------- /stablefusion/models/realesrgan/options/finetune_realesrgan_x4plus_pairdata.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: finetune_RealESRGANx4plus_400k_pairdata 3 | model_type: RealESRGANModel 4 | scale: 4 5 | num_gpu: auto 6 | manual_seed: 0 7 | 8 | # USM the ground-truth 9 | l1_gt_usm: True 10 | percep_gt_usm: True 11 | gan_gt_usm: False 12 | 13 | high_order_degradation: False # do not use the high-order degradation generation process 14 | 15 | # dataset and data loader settings 16 | datasets: 17 | train: 18 | name: DIV2K 19 | type: RealESRGANPairedDataset 20 | dataroot_gt: datasets/DF2K 21 | dataroot_lq: datasets/DF2K 22 | meta_info: datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt 23 | io_backend: 24 | type: disk 25 | 26 | gt_size: 256 27 | use_hflip: True 28 | use_rot: False 29 | 30 | # data loader 31 | use_shuffle: true 32 | num_worker_per_gpu: 5 33 | batch_size_per_gpu: 12 34 | dataset_enlarge_ratio: 1 35 | prefetch_mode: ~ 36 | 37 | # Uncomment these for validation 38 | # val: 39 | # name: validation 40 | # type: PairedImageDataset 41 | # dataroot_gt: path_to_gt 42 | # dataroot_lq: path_to_lq 43 | # io_backend: 44 | # type: disk 45 | 46 | # network structures 47 | network_g: 48 | type: RRDBNet 49 | num_in_ch: 3 50 | num_out_ch: 3 51 | num_feat: 64 52 | num_block: 23 53 | num_grow_ch: 32 54 | 55 | network_d: 56 | type: UNetDiscriminatorSN 57 | num_in_ch: 3 58 | num_feat: 64 59 | skip_connection: True 60 | 61 | # path 62 | path: 63 | # use the pre-trained Real-ESRNet model 64 | pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth 65 | param_key_g: params_ema 66 | strict_load_g: true 67 | pretrain_network_d: experiments/pretrained_models/RealESRGAN_x4plus_netD.pth 68 | param_key_d: params 69 | strict_load_d: true 70 | resume_state: ~ 71 | 72 | # training settings 73 | train: 74 | ema_decay: 0.999 75 | optim_g: 76 | type: Adam 77 | lr: !!float 1e-4 78 | weight_decay: 0 79 | betas: [0.9, 0.99] 80 | optim_d: 81 | type: Adam 82 | lr: !!float 1e-4 83 | weight_decay: 0 84 | betas: [0.9, 0.99] 85 | 86 | scheduler: 87 | type: MultiStepLR 88 | milestones: [400000] 89 | gamma: 0.5 90 | 91 | total_iter: 400000 92 | warmup_iter: -1 # no warm up 93 | 94 | # losses 95 | pixel_opt: 96 | type: L1Loss 97 | loss_weight: 1.0 98 | reduction: mean 99 | # perceptual loss (content and style losses) 100 | perceptual_opt: 101 | type: PerceptualLoss 102 | layer_weights: 103 | # before relu 104 | 'conv1_2': 0.1 105 | 'conv2_2': 0.1 106 | 'conv3_4': 1 107 | 'conv4_4': 1 108 | 'conv5_4': 1 109 | vgg_type: vgg19 110 | use_input_norm: true 111 | perceptual_weight: !!float 1.0 112 | style_weight: 0 113 | range_norm: false 114 | criterion: l1 115 | # gan loss 116 | gan_opt: 117 | type: GANLoss 118 | gan_type: vanilla 119 | real_label_val: 1.0 120 | fake_label_val: 0.0 121 | loss_weight: !!float 1e-1 122 | 123 | net_d_iters: 1 124 | net_d_init_iters: 0 125 | 126 | # Uncomment these for validation 127 | # validation settings 128 | # val: 129 | # val_freq: !!float 5e3 130 | # save_img: True 131 | 132 | # metrics: 133 | # psnr: # metric name 134 | # type: calculate_psnr 135 | # crop_border: 4 136 | # test_y_channel: false 137 | 138 | # logging settings 139 | logger: 140 | print_freq: 100 141 | save_checkpoint_freq: !!float 5e3 142 | use_tb_logger: true 143 | wandb: 144 | project: ~ 145 | resume_id: ~ 146 | 147 | # dist training settings 148 | dist_params: 149 | backend: nccl 150 | port: 29500 151 | -------------------------------------------------------------------------------- /stablefusion/models/realesrgan/options/train_realesrgan_x2plus.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: train_RealESRGANx2plus_400k_B12G4 3 | model_type: RealESRGANModel 4 | scale: 2 5 | num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs 6 | manual_seed: 0 7 | 8 | # ----------------- options for synthesizing training data in RealESRGANModel ----------------- # 9 | # USM the ground-truth 10 | l1_gt_usm: True 11 | percep_gt_usm: True 12 | gan_gt_usm: False 13 | 14 | # the first degradation process 15 | resize_prob: [0.2, 0.7, 0.1] # up, down, keep 16 | resize_range: [0.15, 1.5] 17 | gaussian_noise_prob: 0.5 18 | noise_range: [1, 30] 19 | poisson_scale_range: [0.05, 3] 20 | gray_noise_prob: 0.4 21 | jpeg_range: [30, 95] 22 | 23 | # the second degradation process 24 | second_blur_prob: 0.8 25 | resize_prob2: [0.3, 0.4, 0.3] # up, down, keep 26 | resize_range2: [0.3, 1.2] 27 | gaussian_noise_prob2: 0.5 28 | noise_range2: [1, 25] 29 | poisson_scale_range2: [0.05, 2.5] 30 | gray_noise_prob2: 0.4 31 | jpeg_range2: [30, 95] 32 | 33 | gt_size: 256 34 | queue_size: 180 35 | 36 | # dataset and data loader settings 37 | datasets: 38 | train: 39 | name: DF2K+OST 40 | type: RealESRGANDataset 41 | dataroot_gt: datasets/DF2K 42 | meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt 43 | io_backend: 44 | type: disk 45 | 46 | blur_kernel_size: 21 47 | kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] 48 | kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] 49 | sinc_prob: 0.1 50 | blur_sigma: [0.2, 3] 51 | betag_range: [0.5, 4] 52 | betap_range: [1, 2] 53 | 54 | blur_kernel_size2: 21 55 | kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] 56 | kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] 57 | sinc_prob2: 0.1 58 | blur_sigma2: [0.2, 1.5] 59 | betag_range2: [0.5, 4] 60 | betap_range2: [1, 2] 61 | 62 | final_sinc_prob: 0.8 63 | 64 | gt_size: 256 65 | use_hflip: True 66 | use_rot: False 67 | 68 | # data loader 69 | use_shuffle: true 70 | num_worker_per_gpu: 5 71 | batch_size_per_gpu: 12 72 | dataset_enlarge_ratio: 1 73 | prefetch_mode: ~ 74 | 75 | # Uncomment these for validation 76 | # val: 77 | # name: validation 78 | # type: PairedImageDataset 79 | # dataroot_gt: path_to_gt 80 | # dataroot_lq: path_to_lq 81 | # io_backend: 82 | # type: disk 83 | 84 | # network structures 85 | network_g: 86 | type: RRDBNet 87 | num_in_ch: 3 88 | num_out_ch: 3 89 | num_feat: 64 90 | num_block: 23 91 | num_grow_ch: 32 92 | scale: 2 93 | 94 | network_d: 95 | type: UNetDiscriminatorSN 96 | num_in_ch: 3 97 | num_feat: 64 98 | skip_connection: True 99 | 100 | # path 101 | path: 102 | # use the pre-trained Real-ESRNet model 103 | pretrain_network_g: experiments/pretrained_models/RealESRNet_x2plus.pth 104 | param_key_g: params_ema 105 | strict_load_g: true 106 | resume_state: ~ 107 | 108 | # training settings 109 | train: 110 | ema_decay: 0.999 111 | optim_g: 112 | type: Adam 113 | lr: !!float 1e-4 114 | weight_decay: 0 115 | betas: [0.9, 0.99] 116 | optim_d: 117 | type: Adam 118 | lr: !!float 1e-4 119 | weight_decay: 0 120 | betas: [0.9, 0.99] 121 | 122 | scheduler: 123 | type: MultiStepLR 124 | milestones: [400000] 125 | gamma: 0.5 126 | 127 | total_iter: 400000 128 | warmup_iter: -1 # no warm up 129 | 130 | # losses 131 | pixel_opt: 132 | type: L1Loss 133 | loss_weight: 1.0 134 | reduction: mean 135 | # perceptual loss (content and style losses) 136 | perceptual_opt: 137 | type: PerceptualLoss 138 | layer_weights: 139 | # before relu 140 | 'conv1_2': 0.1 141 | 'conv2_2': 0.1 142 | 'conv3_4': 1 143 | 'conv4_4': 1 144 | 'conv5_4': 1 145 | vgg_type: vgg19 146 | use_input_norm: true 147 | perceptual_weight: !!float 1.0 148 | style_weight: 0 149 | range_norm: false 150 | criterion: l1 151 | # gan loss 152 | gan_opt: 153 | type: GANLoss 154 | gan_type: vanilla 155 | real_label_val: 1.0 156 | fake_label_val: 0.0 157 | loss_weight: !!float 1e-1 158 | 159 | net_d_iters: 1 160 | net_d_init_iters: 0 161 | 162 | # Uncomment these for validation 163 | # validation settings 164 | # val: 165 | # val_freq: !!float 5e3 166 | # save_img: True 167 | 168 | # metrics: 169 | # psnr: # metric name 170 | # type: calculate_psnr 171 | # crop_border: 4 172 | # test_y_channel: false 173 | 174 | # logging settings 175 | logger: 176 | print_freq: 100 177 | save_checkpoint_freq: !!float 5e3 178 | use_tb_logger: true 179 | wandb: 180 | project: ~ 181 | resume_id: ~ 182 | 183 | # dist training settings 184 | dist_params: 185 | backend: nccl 186 | port: 29500 187 | -------------------------------------------------------------------------------- /stablefusion/models/realesrgan/options/train_realesrgan_x4plus.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: train_RealESRGANx4plus_400k_B12G4 3 | model_type: RealESRGANModel 4 | scale: 4 5 | num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs 6 | manual_seed: 0 7 | 8 | # ----------------- options for synthesizing training data in RealESRGANModel ----------------- # 9 | # USM the ground-truth 10 | l1_gt_usm: True 11 | percep_gt_usm: True 12 | gan_gt_usm: False 13 | 14 | # the first degradation process 15 | resize_prob: [0.2, 0.7, 0.1] # up, down, keep 16 | resize_range: [0.15, 1.5] 17 | gaussian_noise_prob: 0.5 18 | noise_range: [1, 30] 19 | poisson_scale_range: [0.05, 3] 20 | gray_noise_prob: 0.4 21 | jpeg_range: [30, 95] 22 | 23 | # the second degradation process 24 | second_blur_prob: 0.8 25 | resize_prob2: [0.3, 0.4, 0.3] # up, down, keep 26 | resize_range2: [0.3, 1.2] 27 | gaussian_noise_prob2: 0.5 28 | noise_range2: [1, 25] 29 | poisson_scale_range2: [0.05, 2.5] 30 | gray_noise_prob2: 0.4 31 | jpeg_range2: [30, 95] 32 | 33 | gt_size: 256 34 | queue_size: 180 35 | 36 | # dataset and data loader settings 37 | datasets: 38 | train: 39 | name: DF2K+OST 40 | type: RealESRGANDataset 41 | dataroot_gt: datasets/DF2K 42 | meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt 43 | io_backend: 44 | type: disk 45 | 46 | blur_kernel_size: 21 47 | kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] 48 | kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] 49 | sinc_prob: 0.1 50 | blur_sigma: [0.2, 3] 51 | betag_range: [0.5, 4] 52 | betap_range: [1, 2] 53 | 54 | blur_kernel_size2: 21 55 | kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] 56 | kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] 57 | sinc_prob2: 0.1 58 | blur_sigma2: [0.2, 1.5] 59 | betag_range2: [0.5, 4] 60 | betap_range2: [1, 2] 61 | 62 | final_sinc_prob: 0.8 63 | 64 | gt_size: 256 65 | use_hflip: True 66 | use_rot: False 67 | 68 | # data loader 69 | use_shuffle: true 70 | num_worker_per_gpu: 5 71 | batch_size_per_gpu: 12 72 | dataset_enlarge_ratio: 1 73 | prefetch_mode: ~ 74 | 75 | # Uncomment these for validation 76 | # val: 77 | # name: validation 78 | # type: PairedImageDataset 79 | # dataroot_gt: path_to_gt 80 | # dataroot_lq: path_to_lq 81 | # io_backend: 82 | # type: disk 83 | 84 | # network structures 85 | network_g: 86 | type: RRDBNet 87 | num_in_ch: 3 88 | num_out_ch: 3 89 | num_feat: 64 90 | num_block: 23 91 | num_grow_ch: 32 92 | 93 | network_d: 94 | type: UNetDiscriminatorSN 95 | num_in_ch: 3 96 | num_feat: 64 97 | skip_connection: True 98 | 99 | # path 100 | path: 101 | # use the pre-trained Real-ESRNet model 102 | pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth 103 | param_key_g: params_ema 104 | strict_load_g: true 105 | resume_state: ~ 106 | 107 | # training settings 108 | train: 109 | ema_decay: 0.999 110 | optim_g: 111 | type: Adam 112 | lr: !!float 1e-4 113 | weight_decay: 0 114 | betas: [0.9, 0.99] 115 | optim_d: 116 | type: Adam 117 | lr: !!float 1e-4 118 | weight_decay: 0 119 | betas: [0.9, 0.99] 120 | 121 | scheduler: 122 | type: MultiStepLR 123 | milestones: [400000] 124 | gamma: 0.5 125 | 126 | total_iter: 400000 127 | warmup_iter: -1 # no warm up 128 | 129 | # losses 130 | pixel_opt: 131 | type: L1Loss 132 | loss_weight: 1.0 133 | reduction: mean 134 | # perceptual loss (content and style losses) 135 | perceptual_opt: 136 | type: PerceptualLoss 137 | layer_weights: 138 | # before relu 139 | 'conv1_2': 0.1 140 | 'conv2_2': 0.1 141 | 'conv3_4': 1 142 | 'conv4_4': 1 143 | 'conv5_4': 1 144 | vgg_type: vgg19 145 | use_input_norm: true 146 | perceptual_weight: !!float 1.0 147 | style_weight: 0 148 | range_norm: false 149 | criterion: l1 150 | # gan loss 151 | gan_opt: 152 | type: GANLoss 153 | gan_type: vanilla 154 | real_label_val: 1.0 155 | fake_label_val: 0.0 156 | loss_weight: !!float 1e-1 157 | 158 | net_d_iters: 1 159 | net_d_init_iters: 0 160 | 161 | # Uncomment these for validation 162 | # validation settings 163 | # val: 164 | # val_freq: !!float 5e3 165 | # save_img: True 166 | 167 | # metrics: 168 | # psnr: # metric name 169 | # type: calculate_psnr 170 | # crop_border: 4 171 | # test_y_channel: false 172 | 173 | # logging settings 174 | logger: 175 | print_freq: 100 176 | save_checkpoint_freq: !!float 5e3 177 | use_tb_logger: true 178 | wandb: 179 | project: ~ 180 | resume_id: ~ 181 | 182 | # dist training settings 183 | dist_params: 184 | backend: nccl 185 | port: 29500 186 | -------------------------------------------------------------------------------- /stablefusion/models/realesrgan/options/train_realesrnet_x2plus.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: train_RealESRNetx2plus_1000k_B12G4 3 | model_type: RealESRNetModel 4 | scale: 2 5 | num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs 6 | manual_seed: 0 7 | 8 | # ----------------- options for synthesizing training data in RealESRNetModel ----------------- # 9 | gt_usm: True # USM the ground-truth 10 | 11 | # the first degradation process 12 | resize_prob: [0.2, 0.7, 0.1] # up, down, keep 13 | resize_range: [0.15, 1.5] 14 | gaussian_noise_prob: 0.5 15 | noise_range: [1, 30] 16 | poisson_scale_range: [0.05, 3] 17 | gray_noise_prob: 0.4 18 | jpeg_range: [30, 95] 19 | 20 | # the second degradation process 21 | second_blur_prob: 0.8 22 | resize_prob2: [0.3, 0.4, 0.3] # up, down, keep 23 | resize_range2: [0.3, 1.2] 24 | gaussian_noise_prob2: 0.5 25 | noise_range2: [1, 25] 26 | poisson_scale_range2: [0.05, 2.5] 27 | gray_noise_prob2: 0.4 28 | jpeg_range2: [30, 95] 29 | 30 | gt_size: 256 31 | queue_size: 180 32 | 33 | # dataset and data loader settings 34 | datasets: 35 | train: 36 | name: DF2K+OST 37 | type: RealESRGANDataset 38 | dataroot_gt: datasets/DF2K 39 | meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt 40 | io_backend: 41 | type: disk 42 | 43 | blur_kernel_size: 21 44 | kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] 45 | kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] 46 | sinc_prob: 0.1 47 | blur_sigma: [0.2, 3] 48 | betag_range: [0.5, 4] 49 | betap_range: [1, 2] 50 | 51 | blur_kernel_size2: 21 52 | kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] 53 | kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] 54 | sinc_prob2: 0.1 55 | blur_sigma2: [0.2, 1.5] 56 | betag_range2: [0.5, 4] 57 | betap_range2: [1, 2] 58 | 59 | final_sinc_prob: 0.8 60 | 61 | gt_size: 256 62 | use_hflip: True 63 | use_rot: False 64 | 65 | # data loader 66 | use_shuffle: true 67 | num_worker_per_gpu: 5 68 | batch_size_per_gpu: 12 69 | dataset_enlarge_ratio: 1 70 | prefetch_mode: ~ 71 | 72 | # Uncomment these for validation 73 | # val: 74 | # name: validation 75 | # type: PairedImageDataset 76 | # dataroot_gt: path_to_gt 77 | # dataroot_lq: path_to_lq 78 | # io_backend: 79 | # type: disk 80 | 81 | # network structures 82 | network_g: 83 | type: RRDBNet 84 | num_in_ch: 3 85 | num_out_ch: 3 86 | num_feat: 64 87 | num_block: 23 88 | num_grow_ch: 32 89 | scale: 2 90 | 91 | # path 92 | path: 93 | pretrain_network_g: experiments/pretrained_models/RealESRGAN_x4plus.pth 94 | param_key_g: params_ema 95 | strict_load_g: False 96 | resume_state: ~ 97 | 98 | # training settings 99 | train: 100 | ema_decay: 0.999 101 | optim_g: 102 | type: Adam 103 | lr: !!float 2e-4 104 | weight_decay: 0 105 | betas: [0.9, 0.99] 106 | 107 | scheduler: 108 | type: MultiStepLR 109 | milestones: [1000000] 110 | gamma: 0.5 111 | 112 | total_iter: 1000000 113 | warmup_iter: -1 # no warm up 114 | 115 | # losses 116 | pixel_opt: 117 | type: L1Loss 118 | loss_weight: 1.0 119 | reduction: mean 120 | 121 | # Uncomment these for validation 122 | # validation settings 123 | # val: 124 | # val_freq: !!float 5e3 125 | # save_img: True 126 | 127 | # metrics: 128 | # psnr: # metric name 129 | # type: calculate_psnr 130 | # crop_border: 4 131 | # test_y_channel: false 132 | 133 | # logging settings 134 | logger: 135 | print_freq: 100 136 | save_checkpoint_freq: !!float 5e3 137 | use_tb_logger: true 138 | wandb: 139 | project: ~ 140 | resume_id: ~ 141 | 142 | # dist training settings 143 | dist_params: 144 | backend: nccl 145 | port: 29500 146 | -------------------------------------------------------------------------------- /stablefusion/models/realesrgan/options/train_realesrnet_x4plus.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: train_RealESRNetx4plus_1000k_B12G4 3 | model_type: RealESRNetModel 4 | scale: 4 5 | num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs 6 | manual_seed: 0 7 | 8 | # ----------------- options for synthesizing training data in RealESRNetModel ----------------- # 9 | gt_usm: True # USM the ground-truth 10 | 11 | # the first degradation process 12 | resize_prob: [0.2, 0.7, 0.1] # up, down, keep 13 | resize_range: [0.15, 1.5] 14 | gaussian_noise_prob: 0.5 15 | noise_range: [1, 30] 16 | poisson_scale_range: [0.05, 3] 17 | gray_noise_prob: 0.4 18 | jpeg_range: [30, 95] 19 | 20 | # the second degradation process 21 | second_blur_prob: 0.8 22 | resize_prob2: [0.3, 0.4, 0.3] # up, down, keep 23 | resize_range2: [0.3, 1.2] 24 | gaussian_noise_prob2: 0.5 25 | noise_range2: [1, 25] 26 | poisson_scale_range2: [0.05, 2.5] 27 | gray_noise_prob2: 0.4 28 | jpeg_range2: [30, 95] 29 | 30 | gt_size: 256 31 | queue_size: 180 32 | 33 | # dataset and data loader settings 34 | datasets: 35 | train: 36 | name: DF2K+OST 37 | type: RealESRGANDataset 38 | dataroot_gt: datasets/DF2K 39 | meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt 40 | io_backend: 41 | type: disk 42 | 43 | blur_kernel_size: 21 44 | kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] 45 | kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] 46 | sinc_prob: 0.1 47 | blur_sigma: [0.2, 3] 48 | betag_range: [0.5, 4] 49 | betap_range: [1, 2] 50 | 51 | blur_kernel_size2: 21 52 | kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] 53 | kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] 54 | sinc_prob2: 0.1 55 | blur_sigma2: [0.2, 1.5] 56 | betag_range2: [0.5, 4] 57 | betap_range2: [1, 2] 58 | 59 | final_sinc_prob: 0.8 60 | 61 | gt_size: 256 62 | use_hflip: True 63 | use_rot: False 64 | 65 | # data loader 66 | use_shuffle: true 67 | num_worker_per_gpu: 5 68 | batch_size_per_gpu: 12 69 | dataset_enlarge_ratio: 1 70 | prefetch_mode: ~ 71 | 72 | # Uncomment these for validation 73 | # val: 74 | # name: validation 75 | # type: PairedImageDataset 76 | # dataroot_gt: path_to_gt 77 | # dataroot_lq: path_to_lq 78 | # io_backend: 79 | # type: disk 80 | 81 | # network structures 82 | network_g: 83 | type: RRDBNet 84 | num_in_ch: 3 85 | num_out_ch: 3 86 | num_feat: 64 87 | num_block: 23 88 | num_grow_ch: 32 89 | 90 | # path 91 | path: 92 | pretrain_network_g: experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth 93 | param_key_g: params_ema 94 | strict_load_g: true 95 | resume_state: ~ 96 | 97 | # training settings 98 | train: 99 | ema_decay: 0.999 100 | optim_g: 101 | type: Adam 102 | lr: !!float 2e-4 103 | weight_decay: 0 104 | betas: [0.9, 0.99] 105 | 106 | scheduler: 107 | type: MultiStepLR 108 | milestones: [1000000] 109 | gamma: 0.5 110 | 111 | total_iter: 1000000 112 | warmup_iter: -1 # no warm up 113 | 114 | # losses 115 | pixel_opt: 116 | type: L1Loss 117 | loss_weight: 1.0 118 | reduction: mean 119 | 120 | # Uncomment these for validation 121 | # validation settings 122 | # val: 123 | # val_freq: !!float 5e3 124 | # save_img: True 125 | 126 | # metrics: 127 | # psnr: # metric name 128 | # type: calculate_psnr 129 | # crop_border: 4 130 | # test_y_channel: false 131 | 132 | # logging settings 133 | logger: 134 | print_freq: 100 135 | save_checkpoint_freq: !!float 5e3 136 | use_tb_logger: true 137 | wandb: 138 | project: ~ 139 | resume_id: ~ 140 | 141 | # dist training settings 142 | dist_params: 143 | backend: nccl 144 | port: 29500 145 | -------------------------------------------------------------------------------- /stablefusion/models/realesrgan/realesrgan/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .archs import * 3 | from .data import * 4 | from .models import * 5 | from .utils import * 6 | from .version import * 7 | -------------------------------------------------------------------------------- /stablefusion/models/realesrgan/realesrgan/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from basicsr.utils import scandir 3 | from os import path as osp 4 | 5 | # automatically scan and import arch modules for registry 6 | # scan all the files that end with '_arch.py' under the archs folder 7 | arch_folder = osp.dirname(osp.abspath(__file__)) 8 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 9 | # import all the arch modules 10 | _arch_modules = [importlib.import_module(f'realesrgan.archs.{file_name}') for file_name in arch_filenames] 11 | -------------------------------------------------------------------------------- /stablefusion/models/realesrgan/realesrgan/archs/discriminator_arch.py: -------------------------------------------------------------------------------- 1 | from basicsr.utils.registry import ARCH_REGISTRY 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | from torch.nn.utils import spectral_norm 5 | 6 | 7 | @ARCH_REGISTRY.register() 8 | class UNetDiscriminatorSN(nn.Module): 9 | """Defines a U-Net discriminator with spectral normalization (SN) 10 | 11 | It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. 12 | 13 | Arg: 14 | num_in_ch (int): Channel number of inputs. Default: 3. 15 | num_feat (int): Channel number of base intermediate features. Default: 64. 16 | skip_connection (bool): Whether to use skip connections between U-Net. Default: True. 17 | """ 18 | 19 | def __init__(self, num_in_ch, num_feat=64, skip_connection=True): 20 | super(UNetDiscriminatorSN, self).__init__() 21 | self.skip_connection = skip_connection 22 | norm = spectral_norm 23 | # the first convolution 24 | self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) 25 | # downsample 26 | self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) 27 | self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) 28 | self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) 29 | # upsample 30 | self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) 31 | self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) 32 | self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) 33 | # extra convolutions 34 | self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) 35 | self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) 36 | self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) 37 | 38 | def forward(self, x): 39 | # downsample 40 | x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) 41 | x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) 42 | x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) 43 | x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) 44 | 45 | # upsample 46 | x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) 47 | x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) 48 | 49 | if self.skip_connection: 50 | x4 = x4 + x2 51 | x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) 52 | x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) 53 | 54 | if self.skip_connection: 55 | x5 = x5 + x1 56 | x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) 57 | x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) 58 | 59 | if self.skip_connection: 60 | x6 = x6 + x0 61 | 62 | # extra convolutions 63 | out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) 64 | out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) 65 | out = self.conv9(out) 66 | 67 | return out 68 | -------------------------------------------------------------------------------- /stablefusion/models/realesrgan/realesrgan/archs/srvgg_arch.py: -------------------------------------------------------------------------------- 1 | from basicsr.utils.registry import ARCH_REGISTRY 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | @ARCH_REGISTRY.register() 7 | class SRVGGNetCompact(nn.Module): 8 | """A compact VGG-style network structure for super-resolution. 9 | 10 | It is a compact network structure, which performs upsampling in the last layer and no convolution is 11 | conducted on the HR feature space. 12 | 13 | Args: 14 | num_in_ch (int): Channel number of inputs. Default: 3. 15 | num_out_ch (int): Channel number of outputs. Default: 3. 16 | num_feat (int): Channel number of intermediate features. Default: 64. 17 | num_conv (int): Number of convolution layers in the body network. Default: 16. 18 | upscale (int): Upsampling factor. Default: 4. 19 | act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu. 20 | """ 21 | 22 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): 23 | super(SRVGGNetCompact, self).__init__() 24 | self.num_in_ch = num_in_ch 25 | self.num_out_ch = num_out_ch 26 | self.num_feat = num_feat 27 | self.num_conv = num_conv 28 | self.upscale = upscale 29 | self.act_type = act_type 30 | 31 | self.body = nn.ModuleList() 32 | # the first conv 33 | self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) 34 | # the first activation 35 | if act_type == 'relu': 36 | activation = nn.ReLU(inplace=True) 37 | elif act_type == 'prelu': 38 | activation = nn.PReLU(num_parameters=num_feat) 39 | elif act_type == 'leakyrelu': 40 | activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 41 | self.body.append(activation) 42 | 43 | # the body structure 44 | for _ in range(num_conv): 45 | self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) 46 | # activation 47 | if act_type == 'relu': 48 | activation = nn.ReLU(inplace=True) 49 | elif act_type == 'prelu': 50 | activation = nn.PReLU(num_parameters=num_feat) 51 | elif act_type == 'leakyrelu': 52 | activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 53 | self.body.append(activation) 54 | 55 | # the last conv 56 | self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) 57 | # upsample 58 | self.upsampler = nn.PixelShuffle(upscale) 59 | 60 | def forward(self, x): 61 | out = x 62 | for i in range(0, len(self.body)): 63 | out = self.body[i](out) 64 | 65 | out = self.upsampler(out) 66 | # add the nearest upsampled image, so that the network learns the residual 67 | base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') 68 | out += base 69 | return out 70 | -------------------------------------------------------------------------------- /stablefusion/models/realesrgan/realesrgan/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from basicsr.utils import scandir 3 | from os import path as osp 4 | 5 | # automatically scan and import dataset modules for registry 6 | # scan all the files that end with '_dataset.py' under the data folder 7 | data_folder = osp.dirname(osp.abspath(__file__)) 8 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 9 | # import all the dataset modules 10 | _dataset_modules = [importlib.import_module(f'realesrgan.data.{file_name}') for file_name in dataset_filenames] 11 | -------------------------------------------------------------------------------- /stablefusion/models/realesrgan/realesrgan/data/realesrgan_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import os.path as osp 6 | import random 7 | import time 8 | import torch 9 | from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels 10 | from basicsr.data.transforms import augment 11 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor 12 | from basicsr.utils.registry import DATASET_REGISTRY 13 | from torch.utils import data as data 14 | 15 | 16 | @DATASET_REGISTRY.register() 17 | class RealESRGANDataset(data.Dataset): 18 | """Dataset used for Real-ESRGAN model: 19 | Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. 20 | 21 | It loads gt (Ground-Truth) images, and augments them. 22 | It also generates blur kernels and sinc kernels for generating low-quality images. 23 | Note that the low-quality images are processed in tensors on GPUS for faster processing. 24 | 25 | Args: 26 | opt (dict): Config for train datasets. It contains the following keys: 27 | dataroot_gt (str): Data root path for gt. 28 | meta_info (str): Path for meta information file. 29 | io_backend (dict): IO backend type and other kwarg. 30 | use_hflip (bool): Use horizontal flips. 31 | use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). 32 | Please see more options in the codes. 33 | """ 34 | 35 | def __init__(self, opt): 36 | super(RealESRGANDataset, self).__init__() 37 | self.opt = opt 38 | self.file_client = None 39 | self.io_backend_opt = opt['io_backend'] 40 | self.gt_folder = opt['dataroot_gt'] 41 | 42 | # file client (lmdb io backend) 43 | if self.io_backend_opt['type'] == 'lmdb': 44 | self.io_backend_opt['db_paths'] = [self.gt_folder] 45 | self.io_backend_opt['client_keys'] = ['gt'] 46 | if not self.gt_folder.endswith('.lmdb'): 47 | raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") 48 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: 49 | self.paths = [line.split('.')[0] for line in fin] 50 | else: 51 | # disk backend with meta_info 52 | # Each line in the meta_info describes the relative path to an image 53 | with open(self.opt['meta_info']) as fin: 54 | paths = [line.strip().split(' ')[0] for line in fin] 55 | self.paths = [os.path.join(self.gt_folder, v) for v in paths] 56 | 57 | # blur settings for the first degradation 58 | self.blur_kernel_size = opt['blur_kernel_size'] 59 | self.kernel_list = opt['kernel_list'] 60 | self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability 61 | self.blur_sigma = opt['blur_sigma'] 62 | self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels 63 | self.betap_range = opt['betap_range'] # betap used in plateau blur kernels 64 | self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters 65 | 66 | # blur settings for the second degradation 67 | self.blur_kernel_size2 = opt['blur_kernel_size2'] 68 | self.kernel_list2 = opt['kernel_list2'] 69 | self.kernel_prob2 = opt['kernel_prob2'] 70 | self.blur_sigma2 = opt['blur_sigma2'] 71 | self.betag_range2 = opt['betag_range2'] 72 | self.betap_range2 = opt['betap_range2'] 73 | self.sinc_prob2 = opt['sinc_prob2'] 74 | 75 | # a final sinc filter 76 | self.final_sinc_prob = opt['final_sinc_prob'] 77 | 78 | self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 79 | # TODO: kernel range is now hard-coded, should be in the configure file 80 | self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect 81 | self.pulse_tensor[10, 10] = 1 82 | 83 | def __getitem__(self, index): 84 | if self.file_client is None: 85 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 86 | 87 | # -------------------------------- Load gt images -------------------------------- # 88 | # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. 89 | gt_path = self.paths[index] 90 | # avoid errors caused by high latency in reading files 91 | retry = 3 92 | while retry > 0: 93 | try: 94 | img_bytes = self.file_client.get(gt_path, 'gt') 95 | except (IOError, OSError) as e: 96 | logger = get_root_logger() 97 | logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}') 98 | # change another file to read 99 | index = random.randint(0, self.__len__()) 100 | gt_path = self.paths[index] 101 | time.sleep(1) # sleep 1s for occasional server congestion 102 | else: 103 | break 104 | finally: 105 | retry -= 1 106 | img_gt = imfrombytes(img_bytes, float32=True) 107 | 108 | # -------------------- Do augmentation for training: flip, rotation -------------------- # 109 | img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot']) 110 | 111 | # crop or pad to 400 112 | # TODO: 400 is hard-coded. You may change it accordingly 113 | h, w = img_gt.shape[0:2] 114 | crop_pad_size = 400 115 | # pad 116 | if h < crop_pad_size or w < crop_pad_size: 117 | pad_h = max(0, crop_pad_size - h) 118 | pad_w = max(0, crop_pad_size - w) 119 | img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101) 120 | # crop 121 | if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size: 122 | h, w = img_gt.shape[0:2] 123 | # randomly choose top and left coordinates 124 | top = random.randint(0, h - crop_pad_size) 125 | left = random.randint(0, w - crop_pad_size) 126 | img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...] 127 | 128 | # ------------------------ Generate kernels (used in the first degradation) ------------------------ # 129 | kernel_size = random.choice(self.kernel_range) 130 | if np.random.uniform() < self.opt['sinc_prob']: 131 | # this sinc filter setting is for kernels ranging from [7, 21] 132 | if kernel_size < 13: 133 | omega_c = np.random.uniform(np.pi / 3, np.pi) 134 | else: 135 | omega_c = np.random.uniform(np.pi / 5, np.pi) 136 | kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) 137 | else: 138 | kernel = random_mixed_kernels( 139 | self.kernel_list, 140 | self.kernel_prob, 141 | kernel_size, 142 | self.blur_sigma, 143 | self.blur_sigma, [-math.pi, math.pi], 144 | self.betag_range, 145 | self.betap_range, 146 | noise_range=None) 147 | # pad kernel 148 | pad_size = (21 - kernel_size) // 2 149 | kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) 150 | 151 | # ------------------------ Generate kernels (used in the second degradation) ------------------------ # 152 | kernel_size = random.choice(self.kernel_range) 153 | if np.random.uniform() < self.opt['sinc_prob2']: 154 | if kernel_size < 13: 155 | omega_c = np.random.uniform(np.pi / 3, np.pi) 156 | else: 157 | omega_c = np.random.uniform(np.pi / 5, np.pi) 158 | kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) 159 | else: 160 | kernel2 = random_mixed_kernels( 161 | self.kernel_list2, 162 | self.kernel_prob2, 163 | kernel_size, 164 | self.blur_sigma2, 165 | self.blur_sigma2, [-math.pi, math.pi], 166 | self.betag_range2, 167 | self.betap_range2, 168 | noise_range=None) 169 | 170 | # pad kernel 171 | pad_size = (21 - kernel_size) // 2 172 | kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) 173 | 174 | # ------------------------------------- the final sinc kernel ------------------------------------- # 175 | if np.random.uniform() < self.opt['final_sinc_prob']: 176 | kernel_size = random.choice(self.kernel_range) 177 | omega_c = np.random.uniform(np.pi / 3, np.pi) 178 | sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) 179 | sinc_kernel = torch.FloatTensor(sinc_kernel) 180 | else: 181 | sinc_kernel = self.pulse_tensor 182 | 183 | # BGR to RGB, HWC to CHW, numpy to tensor 184 | img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0] 185 | kernel = torch.FloatTensor(kernel) 186 | kernel2 = torch.FloatTensor(kernel2) 187 | 188 | return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path} 189 | return return_d 190 | 191 | def __len__(self): 192 | return len(self.paths) 193 | -------------------------------------------------------------------------------- /stablefusion/models/realesrgan/realesrgan/data/realesrgan_paired_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb 3 | from basicsr.data.transforms import augment, paired_random_crop 4 | from basicsr.utils import FileClient, imfrombytes, img2tensor 5 | from basicsr.utils.registry import DATASET_REGISTRY 6 | from torch.utils import data as data 7 | from torchvision.transforms.functional import normalize 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class RealESRGANPairedDataset(data.Dataset): 12 | """Paired image dataset for image restoration. 13 | 14 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. 15 | 16 | There are three modes: 17 | 1. 'lmdb': Use lmdb files. 18 | If opt['io_backend'] == lmdb. 19 | 2. 'meta_info': Use meta information file to generate paths. 20 | If opt['io_backend'] != lmdb and opt['meta_info'] is not None. 21 | 3. 'folder': Scan folders to generate paths. 22 | The rest. 23 | 24 | Args: 25 | opt (dict): Config for train datasets. It contains the following keys: 26 | dataroot_gt (str): Data root path for gt. 27 | dataroot_lq (str): Data root path for lq. 28 | meta_info (str): Path for meta information file. 29 | io_backend (dict): IO backend type and other kwarg. 30 | filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. 31 | Default: '{}'. 32 | gt_size (int): Cropped patched size for gt patches. 33 | use_hflip (bool): Use horizontal flips. 34 | use_rot (bool): Use rotation (use vertical flip and transposing h 35 | and w for implementation). 36 | 37 | scale (bool): Scale, which will be added automatically. 38 | phase (str): 'train' or 'val'. 39 | """ 40 | 41 | def __init__(self, opt): 42 | super(RealESRGANPairedDataset, self).__init__() 43 | self.opt = opt 44 | self.file_client = None 45 | self.io_backend_opt = opt['io_backend'] 46 | # mean and std for normalizing the input images 47 | self.mean = opt['mean'] if 'mean' in opt else None 48 | self.std = opt['std'] if 'std' in opt else None 49 | 50 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] 51 | self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}' 52 | 53 | # file client (lmdb io backend) 54 | if self.io_backend_opt['type'] == 'lmdb': 55 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] 56 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 57 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) 58 | elif 'meta_info' in self.opt and self.opt['meta_info'] is not None: 59 | # disk backend with meta_info 60 | # Each line in the meta_info describes the relative path to an image 61 | with open(self.opt['meta_info']) as fin: 62 | paths = [line.strip() for line in fin] 63 | self.paths = [] 64 | for path in paths: 65 | gt_path, lq_path = path.split(', ') 66 | gt_path = os.path.join(self.gt_folder, gt_path) 67 | lq_path = os.path.join(self.lq_folder, lq_path) 68 | self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)])) 69 | else: 70 | # disk backend 71 | # it will scan the whole folder to get meta info 72 | # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file 73 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) 74 | 75 | def __getitem__(self, index): 76 | if self.file_client is None: 77 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 78 | 79 | scale = self.opt['scale'] 80 | 81 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 82 | # image range: [0, 1], float32. 83 | gt_path = self.paths[index]['gt_path'] 84 | img_bytes = self.file_client.get(gt_path, 'gt') 85 | img_gt = imfrombytes(img_bytes, float32=True) 86 | lq_path = self.paths[index]['lq_path'] 87 | img_bytes = self.file_client.get(lq_path, 'lq') 88 | img_lq = imfrombytes(img_bytes, float32=True) 89 | 90 | # augmentation for training 91 | if self.opt['phase'] == 'train': 92 | gt_size = self.opt['gt_size'] 93 | # random crop 94 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 95 | # flip, rotation 96 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) 97 | 98 | # BGR to RGB, HWC to CHW, numpy to tensor 99 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 100 | # normalize 101 | if self.mean is not None or self.std is not None: 102 | normalize(img_lq, self.mean, self.std, inplace=True) 103 | normalize(img_gt, self.mean, self.std, inplace=True) 104 | 105 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} 106 | 107 | def __len__(self): 108 | return len(self.paths) 109 | -------------------------------------------------------------------------------- /stablefusion/models/realesrgan/realesrgan/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from basicsr.utils import scandir 3 | from os import path as osp 4 | 5 | # automatically scan and import model modules for registry 6 | # scan all the files that end with '_model.py' under the model folder 7 | model_folder = osp.dirname(osp.abspath(__file__)) 8 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 9 | # import all the model modules 10 | _model_modules = [importlib.import_module(f'realesrgan.models.{file_name}') for file_name in model_filenames] 11 | -------------------------------------------------------------------------------- /stablefusion/models/realesrgan/realesrgan/models/realesrnet_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt 5 | from basicsr.data.transforms import paired_random_crop 6 | from basicsr.models.sr_model import SRModel 7 | from basicsr.utils import DiffJPEG, USMSharp 8 | from basicsr.utils.img_process_util import filter2D 9 | from basicsr.utils.registry import MODEL_REGISTRY 10 | from torch.nn import functional as F 11 | 12 | 13 | @MODEL_REGISTRY.register() 14 | class RealESRNetModel(SRModel): 15 | """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. 16 | 17 | It is trained without GAN losses. 18 | It mainly performs: 19 | 1. randomly synthesize LQ images in GPU tensors 20 | 2. optimize the networks with GAN training. 21 | """ 22 | 23 | def __init__(self, opt): 24 | super(RealESRNetModel, self).__init__(opt) 25 | self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts 26 | self.usm_sharpener = USMSharp().cuda() # do usm sharpening 27 | self.queue_size = opt.get('queue_size', 180) 28 | 29 | @torch.no_grad() 30 | def _dequeue_and_enqueue(self): 31 | """It is the training pair pool for increasing the diversity in a batch. 32 | 33 | Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a 34 | batch could not have different resize scaling factors. Therefore, we employ this training pair pool 35 | to increase the degradation diversity in a batch. 36 | """ 37 | # initialize 38 | b, c, h, w = self.lq.size() 39 | if not hasattr(self, 'queue_lr'): 40 | assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' 41 | self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() 42 | _, c, h, w = self.gt.size() 43 | self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() 44 | self.queue_ptr = 0 45 | if self.queue_ptr == self.queue_size: # the pool is full 46 | # do dequeue and enqueue 47 | # shuffle 48 | idx = torch.randperm(self.queue_size) 49 | self.queue_lr = self.queue_lr[idx] 50 | self.queue_gt = self.queue_gt[idx] 51 | # get first b samples 52 | lq_dequeue = self.queue_lr[0:b, :, :, :].clone() 53 | gt_dequeue = self.queue_gt[0:b, :, :, :].clone() 54 | # update the queue 55 | self.queue_lr[0:b, :, :, :] = self.lq.clone() 56 | self.queue_gt[0:b, :, :, :] = self.gt.clone() 57 | 58 | self.lq = lq_dequeue 59 | self.gt = gt_dequeue 60 | else: 61 | # only do enqueue 62 | self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() 63 | self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() 64 | self.queue_ptr = self.queue_ptr + b 65 | 66 | @torch.no_grad() 67 | def feed_data(self, data): 68 | """Accept data from dataloader, and then add two-order degradations to obtain LQ images. 69 | """ 70 | if self.is_train and self.opt.get('high_order_degradation', True): 71 | # training data synthesis 72 | self.gt = data['gt'].to(self.device) 73 | # USM sharpen the GT images 74 | if self.opt['gt_usm'] is True: 75 | self.gt = self.usm_sharpener(self.gt) 76 | 77 | self.kernel1 = data['kernel1'].to(self.device) 78 | self.kernel2 = data['kernel2'].to(self.device) 79 | self.sinc_kernel = data['sinc_kernel'].to(self.device) 80 | 81 | ori_h, ori_w = self.gt.size()[2:4] 82 | 83 | # ----------------------- The first degradation process ----------------------- # 84 | # blur 85 | out = filter2D(self.gt, self.kernel1) 86 | # random resize 87 | updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] 88 | if updown_type == 'up': 89 | scale = np.random.uniform(1, self.opt['resize_range'][1]) 90 | elif updown_type == 'down': 91 | scale = np.random.uniform(self.opt['resize_range'][0], 1) 92 | else: 93 | scale = 1 94 | mode = random.choice(['area', 'bilinear', 'bicubic']) 95 | out = F.interpolate(out, scale_factor=scale, mode=mode) 96 | # add noise 97 | gray_noise_prob = self.opt['gray_noise_prob'] 98 | if np.random.uniform() < self.opt['gaussian_noise_prob']: 99 | out = random_add_gaussian_noise_pt( 100 | out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) 101 | else: 102 | out = random_add_poisson_noise_pt( 103 | out, 104 | scale_range=self.opt['poisson_scale_range'], 105 | gray_prob=gray_noise_prob, 106 | clip=True, 107 | rounds=False) 108 | # JPEG compression 109 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) 110 | out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts 111 | out = self.jpeger(out, quality=jpeg_p) 112 | 113 | # ----------------------- The second degradation process ----------------------- # 114 | # blur 115 | if np.random.uniform() < self.opt['second_blur_prob']: 116 | out = filter2D(out, self.kernel2) 117 | # random resize 118 | updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] 119 | if updown_type == 'up': 120 | scale = np.random.uniform(1, self.opt['resize_range2'][1]) 121 | elif updown_type == 'down': 122 | scale = np.random.uniform(self.opt['resize_range2'][0], 1) 123 | else: 124 | scale = 1 125 | mode = random.choice(['area', 'bilinear', 'bicubic']) 126 | out = F.interpolate( 127 | out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) 128 | # add noise 129 | gray_noise_prob = self.opt['gray_noise_prob2'] 130 | if np.random.uniform() < self.opt['gaussian_noise_prob2']: 131 | out = random_add_gaussian_noise_pt( 132 | out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) 133 | else: 134 | out = random_add_poisson_noise_pt( 135 | out, 136 | scale_range=self.opt['poisson_scale_range2'], 137 | gray_prob=gray_noise_prob, 138 | clip=True, 139 | rounds=False) 140 | 141 | # JPEG compression + the final sinc filter 142 | # We also need to resize images to desired sizes. We group [resize back + sinc filter] together 143 | # as one operation. 144 | # We consider two orders: 145 | # 1. [resize back + sinc filter] + JPEG compression 146 | # 2. JPEG compression + [resize back + sinc filter] 147 | # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. 148 | if np.random.uniform() < 0.5: 149 | # resize back + the final sinc filter 150 | mode = random.choice(['area', 'bilinear', 'bicubic']) 151 | out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) 152 | out = filter2D(out, self.sinc_kernel) 153 | # JPEG compression 154 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) 155 | out = torch.clamp(out, 0, 1) 156 | out = self.jpeger(out, quality=jpeg_p) 157 | else: 158 | # JPEG compression 159 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) 160 | out = torch.clamp(out, 0, 1) 161 | out = self.jpeger(out, quality=jpeg_p) 162 | # resize back + the final sinc filter 163 | mode = random.choice(['area', 'bilinear', 'bicubic']) 164 | out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) 165 | out = filter2D(out, self.sinc_kernel) 166 | 167 | # clamp and round 168 | self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. 169 | 170 | # random crop 171 | gt_size = self.opt['gt_size'] 172 | self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale']) 173 | 174 | # training pair pool 175 | self._dequeue_and_enqueue() 176 | self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract 177 | else: 178 | # for paired training or validation 179 | self.lq = data['lq'].to(self.device) 180 | if 'gt' in data: 181 | self.gt = data['gt'].to(self.device) 182 | self.gt_usm = self.usm_sharpener(self.gt) 183 | 184 | def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): 185 | # do not use the synthetic process during validation 186 | self.is_train = False 187 | super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img) 188 | self.is_train = True 189 | -------------------------------------------------------------------------------- /stablefusion/models/realesrgan/realesrgan/train.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import os.path as osp 3 | from basicsr.train import train_pipeline 4 | 5 | import realesrgan.archs 6 | import realesrgan.data 7 | import realesrgan.models 8 | 9 | if __name__ == '__main__': 10 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 11 | train_pipeline(root_path) 12 | -------------------------------------------------------------------------------- /stablefusion/models/realesrgan/weights/README.md: -------------------------------------------------------------------------------- 1 | # Weights 2 | 3 | Put the downloaded weights to this folder. 4 | -------------------------------------------------------------------------------- /stablefusion/models/safetensors_models/t: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralRealm/StableFusion/9268d686f49b45194a2251c6728e74191f298654/stablefusion/models/safetensors_models/t -------------------------------------------------------------------------------- /stablefusion/pages/10_Utilities.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | from stablefusion import utils 4 | from stablefusion.scripts.gfp_gan import GFPGAN 5 | from stablefusion.scripts.image_info import ImageInfo 6 | from stablefusion.scripts.interrogator import ImageInterrogator 7 | from stablefusion.scripts.upscaler import Upscaler 8 | from stablefusion.scripts.model_adding import ModelAdding 9 | from stablefusion.scripts.model_removing import ModelRemoving 10 | 11 | 12 | def app(): 13 | utils.create_base_page() 14 | task = st.selectbox( 15 | "Choose a utility", 16 | [ 17 | "ImageInfo", 18 | "Model Adding", 19 | "Model Removing", 20 | "SD Upscaler", 21 | "GFPGAN", 22 | "CLIP Interrogator", 23 | ], 24 | ) 25 | if task == "ImageInfo": 26 | ImageInfo().app() 27 | elif task == "Model Adding": 28 | ModelAdding().app() 29 | elif task == "Model Removing": 30 | ModelRemoving().app() 31 | elif task == "SD Upscaler": 32 | with st.form("upscaler_model"): 33 | upscaler_model = st.text_input("Model", "stabilityai/stable-diffusion-x4-upscaler") 34 | submit = st.form_submit_button("Load model") 35 | if submit: 36 | with st.spinner("Loading model..."): 37 | ups = Upscaler( 38 | model=upscaler_model, 39 | device=st.session_state.device, 40 | output_path=st.session_state.output_path, 41 | ) 42 | st.session_state.ups = ups 43 | if "ups" in st.session_state: 44 | st.write(f"Current model: {st.session_state.ups}") 45 | st.session_state.ups.app() 46 | 47 | elif task == "GFPGAN": 48 | with st.spinner("Loading model..."): 49 | gfpgan = GFPGAN( 50 | device=st.session_state.device, 51 | output_path=st.session_state.output_path, 52 | ) 53 | gfpgan.app() 54 | elif task == "CLIP Interrogator": 55 | interrogator = ImageInterrogator( 56 | device=st.session_state.device, 57 | output_path=st.session_state.output_path, 58 | ) 59 | interrogator.app() 60 | 61 | 62 | if __name__ == "__main__": 63 | app() 64 | -------------------------------------------------------------------------------- /stablefusion/pages/1_Text2Image.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | from stablefusion.scripts.img2img import Img2Img 4 | from stablefusion.scripts.text2img import Text2Image 5 | from stablefusion.Home import read_model_list 6 | from stablefusion import utils 7 | 8 | def app(): 9 | utils.create_base_page() 10 | if "img2img" in st.session_state and "text2img" not in st.session_state: 11 | text2img = Text2Image( 12 | model=st.session_state.img2img.model, 13 | device=st.session_state.device, 14 | output_path=st.session_state.output_path, 15 | ) 16 | st.session_state.text2img = text2img 17 | with st.form("text2img_model"): 18 | model = st.selectbox( 19 | "Which model do you want to use?", 20 | options=read_model_list(), 21 | ) 22 | # submit_col, _, clear_col = st.columns(3) 23 | # with submit_col: 24 | submit = st.form_submit_button("Load model") 25 | # with clear_col: 26 | # clear = st.form_submit_button("Clear memory") 27 | # if clear: 28 | # clear_memory(preserve="text2img") 29 | if submit: 30 | with st.spinner("Loading model..."): 31 | text2img = Text2Image( 32 | model=model, 33 | device=st.session_state.device, 34 | output_path=st.session_state.output_path, 35 | ) 36 | st.session_state.text2img = text2img 37 | img2img = Img2Img( 38 | model=None, 39 | device=st.session_state.device, 40 | output_path=st.session_state.output_path, 41 | text2img_model=text2img.pipeline, 42 | ) 43 | st.session_state.img2img = img2img 44 | if "text2img" in st.session_state: 45 | st.write(f"Current model: {st.session_state.text2img}") 46 | st.session_state.text2img.app() 47 | 48 | 49 | if __name__ == "__main__": 50 | app() 51 | -------------------------------------------------------------------------------- /stablefusion/pages/2_Image2Image.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | from stablefusion.scripts.img2img import Img2Img 4 | from stablefusion.scripts.text2img import Text2Image 5 | from stablefusion.Home import read_model_list 6 | from stablefusion import utils 7 | 8 | def app(): 9 | utils.create_base_page() 10 | if "text2img" in st.session_state and "img2img" not in st.session_state: 11 | img2img = Img2Img( 12 | model=None, 13 | device=st.session_state.device, 14 | output_path=st.session_state.output_path, 15 | text2img_model=st.session_state.text2img.pipeline, 16 | ) 17 | st.session_state.img2img = img2img 18 | with st.form("img2img_model"): 19 | model = st.selectbox( 20 | "Which model do you want to use?", 21 | options=read_model_list(), 22 | ) 23 | submit = st.form_submit_button("Load model") 24 | if submit: 25 | with st.spinner("Loading model..."): 26 | img2img = Img2Img( 27 | model=model, 28 | device=st.session_state.device, 29 | output_path=st.session_state.output_path, 30 | ) 31 | st.session_state.img2img = img2img 32 | text2img = Text2Image( 33 | model=model, 34 | device=st.session_state.device, 35 | output_path=st.session_state.output_path, 36 | ) 37 | st.session_state.text2img = text2img 38 | if "img2img" in st.session_state: 39 | st.write(f"Current model: {st.session_state.img2img}") 40 | st.session_state.img2img.app() 41 | 42 | 43 | if __name__ == "__main__": 44 | app() 45 | -------------------------------------------------------------------------------- /stablefusion/pages/3_Inpainting.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | from stablefusion import utils 4 | from stablefusion.scripts.inpainting import Inpainting 5 | from stablefusion.Home import read_model_list 6 | 7 | def app(): 8 | utils.create_base_page() 9 | with st.form("inpainting_model_form"): 10 | model = st.selectbox( 11 | "Which model do you want to use for inpainting?", 12 | options=read_model_list() 13 | ) 14 | pipeline = st.selectbox(label="Select Your Pipeline: ", options=["StableDiffusionInpaintPipelineLegacy" ,"StableDiffusionInpaintPipeline"]) 15 | submit = st.form_submit_button("Load model") 16 | if submit: 17 | st.session_state.inpainting_model = model 18 | with st.spinner("Loading model..."): 19 | inpainting = Inpainting( 20 | model=model, 21 | device=st.session_state.device, 22 | output_path=st.session_state.output_path, 23 | pipeline_select=pipeline 24 | ) 25 | st.session_state.inpainting = inpainting 26 | if "inpainting" in st.session_state: 27 | st.write(f"Current model: {st.session_state.inpainting}") 28 | st.session_state.inpainting.app() 29 | 30 | 31 | if __name__ == "__main__": 32 | app() -------------------------------------------------------------------------------- /stablefusion/pages/4_ControlNet.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | from stablefusion import utils 4 | from stablefusion.scripts.controlnet import Controlnet 5 | from stablefusion.Home import read_model_list 6 | 7 | control_net_model_list = ["lllyasviel/sd-controlnet-canny", 8 | "lllyasviel/sd-controlnet-hed", 9 | "lllyasviel/sd-controlnet-normal", 10 | "lllyasviel/sd-controlnet-scribble", 11 | "lllyasviel/sd-controlnet-depth", 12 | "lllyasviel/sd-controlnet-mlsd", 13 | "lllyasviel/sd-controlnet-openpose", 14 | ] 15 | 16 | processer_list = [ "Canny", 17 | "Hed", 18 | "Normal", 19 | "Scribble", 20 | "Depth", 21 | "Mlsd", 22 | "OpenPose", 23 | ] 24 | 25 | 26 | import streamlit as st 27 | 28 | from stablefusion import utils 29 | from stablefusion.scripts.gfp_gan import GFPGAN 30 | from stablefusion.scripts.image_info import ImageInfo 31 | from stablefusion.scripts.interrogator import ImageInterrogator 32 | from stablefusion.scripts.upscaler import Upscaler 33 | from stablefusion.scripts.model_adding import ModelAdding 34 | from stablefusion.scripts.model_removing import ModelRemoving 35 | 36 | 37 | def app(): 38 | utils.create_base_page() 39 | with st.form("inpainting_model_form"): 40 | base_model = st.selectbox( 41 | "Base model For your ControlNet: ", 42 | options=read_model_list() 43 | ) 44 | controlnet_model = st.selectbox(label="Choose Your ControlNet: ", options=control_net_model_list) 45 | processer = st.selectbox(label="Choose Your Processer: ", options=processer_list) 46 | submit = st.form_submit_button("Load model") 47 | if submit: 48 | st.session_state.controlnet_models = base_model 49 | with st.spinner("Loading model..."): 50 | controlnet = Controlnet( 51 | model=base_model, 52 | device=st.session_state.device, 53 | output_path=st.session_state.output_path, 54 | controlnet_model=controlnet_model, 55 | processer = processer 56 | ) 57 | st.session_state.controlnet = controlnet 58 | if "controlnet" in st.session_state: 59 | st.write(f"Current model: {st.session_state.controlnet}") 60 | st.session_state.controlnet.app() 61 | 62 | 63 | if __name__ == "__main__": 64 | app() -------------------------------------------------------------------------------- /stablefusion/pages/5_OpenPose_Editor.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | from stablefusion import utils 4 | from stablefusion.scripts.pose_editor import OpenPoseEditor 5 | from stablefusion.Home import read_model_list 6 | 7 | def app(): 8 | utils.create_base_page() 9 | with st.form("openpose_editor_form"): 10 | model = st.selectbox( 11 | "Which model do you want to use for OpenPose?", 12 | options=read_model_list() 13 | ) 14 | submit = st.form_submit_button("Load model") 15 | if submit: 16 | st.session_state.openpose_editor = model 17 | with st.spinner("Loading model..."): 18 | openpose = OpenPoseEditor( 19 | model=model, 20 | device=st.session_state.device, 21 | output_path=st.session_state.output_path, 22 | ) 23 | st.session_state.openpose_editor = openpose 24 | if "openpose_editor" in st.session_state: 25 | st.write(f"Current model: {st.session_state.openpose_editor}") 26 | st.session_state.openpose_editor.app() 27 | 28 | 29 | if __name__ == "__main__": 30 | app() -------------------------------------------------------------------------------- /stablefusion/pages/6_Textual Inversion.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | from stablefusion.scripts.textual_inversion import TextualInversion 4 | from stablefusion import utils 5 | from stablefusion.Home import read_model_list 6 | 7 | def app(): 8 | utils.create_base_page() 9 | with st.form("textual_inversion_form"): 10 | model = st.selectbox( 11 | "Which base model do you want to use?", 12 | options=read_model_list(), 13 | ) 14 | token_identifier = st.text_input( 15 | "Token identifier", 16 | value="" 17 | if st.session_state.get("textual_inversion_token_identifier") is None 18 | else st.session_state.textual_inversion_token_identifier, 19 | ) 20 | embeddings = st.text_input( 21 | "Embeddings", 22 | value="https://huggingface.co/sd-concepts-library/axe-tattoo/resolve/main/learned_embeds.bin" 23 | if st.session_state.get("textual_inversion_embeddings") is None 24 | else st.session_state.textual_inversion_embeddings, 25 | ) 26 | # st.file_uploader("Embeddings", type=["pt", "bin"]) 27 | submit = st.form_submit_button("Load model") 28 | if submit: 29 | st.session_state.textual_inversion_model = model 30 | st.session_state.textual_inversion_token_identifier = token_identifier 31 | st.session_state.textual_inversion_embeddings = embeddings 32 | with st.spinner("Loading model..."): 33 | textual_inversion = TextualInversion( 34 | model=model, 35 | token_identifier=token_identifier, 36 | embeddings_url=embeddings, 37 | device=st.session_state.device, 38 | output_path=st.session_state.output_path, 39 | ) 40 | st.session_state.textual_inversion = textual_inversion 41 | if "textual_inversion" in st.session_state: 42 | st.write(f"Current model: {st.session_state.textual_inversion}") 43 | st.session_state.textual_inversion.app() 44 | 45 | 46 | if __name__ == "__main__": 47 | app() 48 | -------------------------------------------------------------------------------- /stablefusion/pages/7_Upscaler.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | from stablefusion import utils 4 | from stablefusion.models.realesrgan import inference_realesrgan 5 | from stablefusion.scripts.upscaler import Upscaler 6 | from stablefusion.scripts.gfp_gan import GFPGAN 7 | 8 | 9 | def app(): 10 | utils.create_base_page() 11 | task = st.selectbox( 12 | label="Choose Your Upsacler", 13 | options=[ 14 | "RealESRGAN", 15 | "SD Upscaler", 16 | "GFPGAN", 17 | ], 18 | ) 19 | 20 | if task == "RealESRGAN": 21 | 22 | model_name = st.selectbox( 23 | label="Choose Your Model", 24 | options=[ 25 | "RealESRGAN_x4plus", 26 | "RealESRNet_x4plus", 27 | "RealESRGAN_x4plus_anime_6B", 28 | "RealESRGAN_x2plus", 29 | "realesr-animevideov3", 30 | "realesr-general-x4v3" 31 | ], 32 | ) 33 | 34 | input_image = st.file_uploader(label="Upload The Picture", type=["png", "jpg", "jpeg"]) 35 | 36 | col1, col2 = st.columns(2) 37 | with col1: 38 | denoise_strength = st.slider(label="Select your Denoise Strength", min_value=0.0 , max_value=1.0, value=0.5, step=0.1, help="Denoise strength. 0 for weak denoise (keep noise), 1 for strong denoise ability.") 39 | outscale = st.slider(label="Select your final Upsampling scale", min_value=0, max_value=4, value=4, step=1, help="The final upsampling scale of the image") 40 | 41 | with col2: 42 | face_enhance = st.selectbox(label="Do you want to Inhance Face", options=[True, False], help="Use GFPGAN to enhance face") 43 | alpha_upsampler = st.selectbox(label="The upsampler for the alpha channels", options=["realesrgan", "bicubic"]) 44 | 45 | if st.button("Start Upscaling"): 46 | with st.spinner("Upscaling The Image..."): 47 | inference_realesrgan.main(model_name=model_name, outscale=outscale, denoise_strength=denoise_strength, face_enhance=face_enhance, tile=0, tile_pad=10, pre_pad=0, fp32="fp32", gpu_id=None, input_image=input_image, model_path=None) 48 | 49 | 50 | elif task == "SD Upscaler": 51 | with st.form("upscaler_model"): 52 | upscaler_model = st.text_input("Model", "stabilityai/stable-diffusion-x4-upscaler") 53 | submit = st.form_submit_button("Load model") 54 | if submit: 55 | with st.spinner("Loading model..."): 56 | ups = Upscaler( 57 | model=upscaler_model, 58 | device=st.session_state.device, 59 | output_path=st.session_state.output_path, 60 | ) 61 | st.session_state.ups = ups 62 | if "ups" in st.session_state: 63 | st.write(f"Current model: {st.session_state.ups}") 64 | st.session_state.ups.app() 65 | 66 | elif task == "GFPGAN": 67 | with st.spinner("Loading model..."): 68 | gfpgan = GFPGAN( 69 | device=st.session_state.device, 70 | output_path=st.session_state.output_path, 71 | ) 72 | gfpgan.app() 73 | 74 | if __name__ == "__main__": 75 | app() 76 | -------------------------------------------------------------------------------- /stablefusion/pages/8_Convertor.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | from stablefusion import utils 4 | from stablefusion.scripts.gfp_gan import GFPGAN 5 | from stablefusion.scripts.image_info import ImageInfo 6 | from stablefusion.scripts.interrogator import ImageInterrogator 7 | from stablefusion.scripts.upscaler import Upscaler 8 | from stablefusion.scripts.model_adding import ModelAdding 9 | from stablefusion.scripts.model_removing import ModelRemoving 10 | from stablefusion.scripts.ckpt_to_diffusion import convert_ckpt_to_diffusion 11 | from stablefusion.scripts.safetensors_to_diffusion import convert_safetensor_to_diffusers 12 | 13 | 14 | def app(): 15 | utils.create_base_page() 16 | task = st.selectbox( 17 | "Choose a Convertor", 18 | [ 19 | "CKPT to Diffusers", 20 | "Safetensors to Diffusers", 21 | ], 22 | ) 23 | if task == "CKPT to Diffusers": 24 | 25 | with st.form("Convert Your CKPT to Diffusers"): 26 | ckpt_model_name = st.text_input(label="Name of the CKPT Model", value="NameOfYourModel.ckpt") 27 | ckpt_model = st.text_input(label="Download Link of CKPT Model path: ", value="https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.0-pruned-fp16.ckpt", help="Path to the checkpoint to convert.") 28 | overwrite_mode = st.selectbox("Do You want to overwrite file: ", options=[False, True], help="If the ckpt files with that same name is present do you want to overwrite the file or want to use that file instead of downloading it again?") 29 | st.subheader("Advance Settings") 30 | st.text("Don't Change anything if you are not sure about it.") 31 | 32 | config_file = st.text_input(label="Enter The Config File: ", value=None, help="The YAML config file corresponding to the original architecture.") 33 | 34 | col1, col2 = st.columns(2) 35 | 36 | with col1: 37 | num_in_channels = st.text_input(label="Enter The Number Of Channels", value=None, help="The number of input channels. If `None` number of input channels will be automatically inferred.") 38 | scheduler_type = st.selectbox(label="Select the Scheduler: ", options=['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm'], help="Type of scheduler to use.") 39 | pipeline_type = st.text_input(label="Enter the pipeline type: ", value=None, help="The pipeline type. If `None` pipeline will be automatically inferred.") 40 | 41 | with col2: 42 | image_size = st.selectbox(label="Image Size", options=[None, "512", "768"], help="The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2 Base. Use 768 for Stable Diffusion v2.") 43 | prediction_type = st.selectbox(label="Select your prediction type", options=[None, "epsilon", "v-prediction"]) 44 | extract_ema = st.selectbox(label="Extract the EMA", options=[True, False], help="Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning.") 45 | 46 | device = None 47 | 48 | submit = st.form_submit_button("Start Converting") 49 | 50 | if submit: 51 | with st.spinner("Converting..."): 52 | convert_ckpt_to_diffusion(device=device, checkpoint_link=ckpt_model, checkpoint_name=ckpt_model_name, num_in_channels=num_in_channels, scheduler_type=scheduler_type, pipeline_type=pipeline_type, extract_ema=extract_ema, dump_path=" ", image_size=image_size, original_config_file=config_file, prediction_type=prediction_type, overwrite_file=overwrite_mode) 53 | 54 | 55 | elif task == "Safetensors to Diffusers": 56 | 57 | with st.form("Convert Your Safetensors to Diffusers"): 58 | stabletensor_model_name = st.text_input(label="Name of the Safetensors Model", value="NameOfYourModel.safetensors") 59 | stabletensor_model = st.text_input(label="Download Link of Safetensors Model path: ", value="https://civitai.com/api/download/models/4007?type=Model&format=SafeTensor", help="Path to the checkpoint to convert.") 60 | overwrite_mode = st.selectbox("Do You want to overwrite file: ", options=[False, True], help="If the Stabletensor files with that same name is present do you want to overwrite the file or want to use that file instead of downloading it again?") 61 | st.subheader("Advance Settings") 62 | st.text("Don't Change anything if you are not sure about it.") 63 | 64 | config_file = st.text_input(label="Enter The Config File: ", value=None, help="The YAML config file corresponding to the original architecture.") 65 | 66 | col1, col2 = st.columns(2) 67 | 68 | with col1: 69 | num_in_channels = st.text_input(label="Enter The Number Of Channels", value=None, help="The number of input channels. If `None` number of input channels will be automatically inferred.") 70 | scheduler_type = st.selectbox(label="Select the Scheduler: ", options=['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm'], help="Type of scheduler to use.") 71 | pipeline_type = st.text_input(label="Enter the pipeline type: ", value=None, help="The pipeline type. If `None` pipeline will be automatically inferred.") 72 | 73 | with col2: 74 | image_size = st.selectbox(label="Image Size", options=[None, "512", "768"], help="The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2 Base. Use 768 for Stable Diffusion v2.") 75 | prediction_type = st.selectbox(label="Select your prediction type", options=[None, "epsilon", "v-prediction"]) 76 | extract_ema = st.selectbox(label="Extract the EMA", options=[True, False], help="Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning.") 77 | 78 | device = None 79 | 80 | submit = st.form_submit_button("Start Converting") 81 | 82 | if submit: 83 | with st.spinner("Converting..."): 84 | if config_file == "None": 85 | config_file = None 86 | 87 | if num_in_channels == "None": 88 | num_in_channels = None 89 | 90 | if pipeline_type == "None": 91 | pipeline_type = None 92 | convert_safetensor_to_diffusers(original_config_file=config_file, image_size=image_size, prediction_type=prediction_type, pipeline_type=pipeline_type, extract_ema=extract_ema, scheduler_type=scheduler_type, num_in_channels=num_in_channels, upcast_attention=False, from_safetensors=True, device=device, stable_unclip=None, stable_unclip_prior=None, clip_stats_path=None, controlnet=None, to_safetensors=None, checkpoint_name=stabletensor_model_name, checkpoint_link=stabletensor_model, overwrite=overwrite_mode) 93 | 94 | 95 | 96 | if __name__ == "__main__": 97 | app() -------------------------------------------------------------------------------- /stablefusion/pages/9_Train.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from stablefusion import utils 3 | from stablefusion.scripts.dreambooth import train_dreambooth 4 | from stablefusion.Home import read_model_list 5 | import json 6 | import os 7 | 8 | 9 | def dump_concept_file(instance_prompt, class_prompt, instance_data_dir, class_data_dir): 10 | 11 | concepts_list = [ 12 | { 13 | "instance_prompt": instance_prompt, 14 | "class_prompt": class_prompt, 15 | "instance_data_dir": "{}/trainner_assets/instance_data/{}".format(utils.base_path(), instance_data_dir), 16 | "class_data_dir": "{}/trainner_assets/class_data/{}".format(utils.base_path(), class_data_dir) 17 | }, 18 | ] 19 | for c in concepts_list: 20 | os.makedirs(c["instance_data_dir"], exist_ok=True) 21 | 22 | with open("{}/trainner_assets/concepts_list.json".format(utils.base_path()), "w") as f: 23 | json.dump(concepts_list, f, indent=4) 24 | 25 | 26 | def app(): 27 | utils.create_base_page() 28 | task = st.selectbox( 29 | "Choose a Convertor", 30 | [ 31 | "Dreambooth", 32 | ], 33 | ) 34 | if task == "Dreambooth": 35 | 36 | with st.form("Train With Dreambooth"): 37 | 38 | col1, col2 = st.columns(2) 39 | with col1: 40 | base_model = st.selectbox(label="Name of the Base Model", options=read_model_list(), help="Path to pretrained model or model identifier from huggingface.co/models.") 41 | instance_prompt = st.text_input(label="The prompt with identifier specifying the instance", value="photo of elon musk person", help="replace elon musk name with your object or person name") 42 | class_prompt = st.text_input(label="The prompt to specify images in the same class as provided instance images.", value="photo of a person", help="replace person with object if you are taining for any kind of object") 43 | instance_data_dir_name = st.text_input(label="The prompt with identifier specifying the instance(folder): ", value="elon musk", help="A folder containing the training data of instance images") 44 | class_data_dir_name = st.text_input(label="A Name of the training data of class images(folder): ", value="person", help="A folder containing the training data of class images.") 45 | output_dir_name = st.text_input(label="Name of your Trainned model(folder): ", value="elon musk", help="The output directory name where the model predictions and checkpoints will be written.") 46 | seed = st.number_input(label="Enter The Seed", value=1337, step=1, help="A seed for reproducible training.") 47 | 48 | with col2: 49 | resolution = st.number_input(label="Enter The Image Resolution: ", value=512, step=1, help="The resolution for input images, all the images in the train/validation dataset will be resized to this resolution") 50 | train_batch_size = st.number_input(label="Train Batch Size", value=4, step=1, help="Batch size (per device) for sampling images.") 51 | learning_rate = st.number_input(label="Learning Rate", value=6, help="Initial learning rate (after the potential warmup period) to use. float(1) / float(10**8) for 1e-8") 52 | num_class_images = st.number_input(label="Number of Class Images", value=100, step=1, help="Minimal class images for prior preservation loss. If there are not enough images already present in class_data_dir, additional images will be sampled with class_prompt.") 53 | sample_batch_size = st.number_input(label="Sample Batch Size: ", value=4, help="Batch size (per device) for sampling images.") 54 | max_train_steps = st.number_input(label="Max Train Steps: ", value=20, help="Total number of training steps to perform.") 55 | save_interval = st.number_input(label="Save Interal Steps: ", value=10000, help="Save weights every N steps.") 56 | 57 | st.subheader("Advance Settings") 58 | st.text("Don't Change anything if you are not sure about it.") 59 | col3, col4 = st.columns(2) 60 | with col3: 61 | revision = st.text_input(label="Revision of pretrained model identifier: ", value=None) 62 | prior_loss_weight_condition = st.selectbox(label="Want to use prior_loss_weight: ", options=[True, False]) 63 | if prior_loss_weight_condition: 64 | prior_loss_weight = st.slider(label="Prior Loss Weight: ",value=1.0, max_value=10.0, help="The weight of prior preservation loss.") 65 | else: 66 | prior_loss_weight = 0 67 | 68 | train_text_encoder = st.selectbox("Whether to train the text encoder", options=[True, False]) 69 | log_interval = st.number_input(label="Save Log Interval: ", value=10, help="Log every N steps.") 70 | tokenizer_name = st.text_input("Enter the tokenizer name", value=None, help="Pretrained tokenizer name or path if not the same as model_name") 71 | 72 | with col4: 73 | use_8bit_adam = st.selectbox(label="Want to use 8bit adam: ", options=[True, False], help="Whether or not to use 8-bit Adam from bitsandbytes.") 74 | gradient_accumulation_steps = st.slider(label="Gradient Accumulation steps", value=1, help="Number of updates steps to accumulate before performing a backward/update pass.") 75 | lr_scheduler = st.selectbox(label="Select Your Learning Schedular: ", options=["constant", "cosine", "cosine_with_restarts", "polynomial","linear", "constant_with_warmup"]) 76 | lr_warmup_steps = st.number_input(label="Learning Rate, Warmup Steps: ", value=50, step=1, help="Number of steps for the warmup in the lr scheduler.") 77 | mixed_precision = st.selectbox(label="Whether to use mixed precision", options=["no", "fp16", "bf16"], help="Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.") 78 | 79 | 80 | submit = st.form_submit_button("Start Trainning") 81 | 82 | if submit: 83 | with st.spinner("Trainning..."): 84 | 85 | output_dir = "{}/trainner_assets/stable_diffusion_weights/{}".format(utils.base_path(), output_dir_name) 86 | 87 | learning_rate = 1 * 10**(-learning_rate) 88 | 89 | dump_concept_file(instance_prompt=instance_prompt, class_prompt=class_prompt, instance_data_dir=instance_data_dir_name, class_data_dir=class_data_dir_name) 90 | train_dreambooth.main(pretrained_model_name_or_path=base_model, revision=revision, output_dir=output_dir, with_prior_preservation=prior_loss_weight, prior_loss_weight=prior_loss_weight, seed=seed, resolution=resolution, train_batch_size=train_batch_size, train_text_encoder=train_text_encoder, mixed_precision=mixed_precision, use_8bit_adam=use_8bit_adam, gradient_accumulation_steps=gradient_accumulation_steps, learning_rate=learning_rate, lr_scheduler=lr_scheduler, lr_warmup_steps=lr_warmup_steps, num_class_images=num_class_images, sample_batch_size=sample_batch_size, max_train_steps=max_train_steps, save_interval=save_interval, save_sample_prompt=None, log_interval=log_interval, tokenizer_name=tokenizer_name) 91 | 92 | 93 | 94 | if __name__ == "__main__": 95 | app() 96 | -------------------------------------------------------------------------------- /stablefusion/scripts/dreambooth/requirements_flax.txt: -------------------------------------------------------------------------------- 1 | transformers>=4.25.1 2 | flax 3 | optax 4 | torch 5 | torchvision 6 | ftfy 7 | tensorboard 8 | Jinja2 9 | -------------------------------------------------------------------------------- /stablefusion/scripts/gfp_gan.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | import shutil 4 | from dataclasses import dataclass 5 | from io import BytesIO 6 | from typing import Optional 7 | import datetime 8 | import cv2 9 | import numpy as np 10 | import streamlit as st 11 | import torch 12 | from basicsr.archs.srvgg_arch import SRVGGNetCompact 13 | from gfpgan.utils import GFPGANer 14 | from loguru import logger 15 | from PIL import Image 16 | from realesrgan.utils import RealESRGANer 17 | 18 | from stablefusion import utils 19 | 20 | 21 | @dataclass 22 | class GFPGAN: 23 | device: Optional[str] = None 24 | output_path: Optional[str] = None 25 | 26 | def __str__(self) -> str: 27 | return f"GFPGAN(device={self.device}, output_path={self.output_path})" 28 | 29 | def __post_init__(self): 30 | files = { 31 | "realesr-general-x4v3.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", 32 | "v1.2": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth", 33 | "v1.3": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth", 34 | "v1.4": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth", 35 | "RestoreFormer": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth", 36 | "CodeFormer": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth", 37 | } 38 | _ = utils.cache_folder() 39 | self.model_paths = {} 40 | for file_key, file in files.items(): 41 | logger.info(f"Downloading {file_key} from {file}") 42 | basename = os.path.basename(file) 43 | output_path = os.path.join(utils.cache_folder(), basename) 44 | if os.path.exists(output_path): 45 | self.model_paths[file_key] = output_path 46 | continue 47 | temp_file = utils.download_file(file) 48 | shutil.move(temp_file, output_path) 49 | self.model_paths[file_key] = output_path 50 | 51 | self.model = SRVGGNetCompact( 52 | num_in_ch=3, 53 | num_out_ch=3, 54 | num_feat=64, 55 | num_conv=32, 56 | upscale=4, 57 | act_type="prelu", 58 | ) 59 | model_path = os.path.join(utils.cache_folder(), self.model_paths["realesr-general-x4v3.pth"]) 60 | half = True if torch.cuda.is_available() else False 61 | self.upsampler = RealESRGANer( 62 | scale=4, 63 | model_path=model_path, 64 | model=self.model, 65 | tile=0, 66 | tile_pad=10, 67 | pre_pad=0, 68 | half=half, 69 | ) 70 | 71 | def inference(self, img, version, scale): 72 | # taken from: https://huggingface.co/spaces/Xintao/GFPGAN/blob/main/app.py 73 | # weight /= 100 74 | if scale > 4: 75 | scale = 4 # avoid too large scale value 76 | 77 | file_bytes = np.asarray(bytearray(img.read()), dtype=np.uint8) 78 | img = cv2.imdecode(file_bytes, 1) 79 | # img = cv2.imread(img, cv2.IMREAD_UNCHANGED) 80 | if len(img.shape) == 2: # for gray inputs 81 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 82 | 83 | h, w = img.shape[0:2] 84 | if h < 300: 85 | img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4) 86 | 87 | if version == "v1.2": 88 | face_enhancer = GFPGANer( 89 | model_path=self.model_paths["v1.2"], 90 | upscale=2, 91 | arch="clean", 92 | channel_multiplier=2, 93 | bg_upsampler=self.upsampler, 94 | ) 95 | elif version == "v1.3": 96 | face_enhancer = GFPGANer( 97 | model_path=self.model_paths["v1.3"], 98 | upscale=2, 99 | arch="clean", 100 | channel_multiplier=2, 101 | bg_upsampler=self.upsampler, 102 | ) 103 | elif version == "v1.4": 104 | face_enhancer = GFPGANer( 105 | model_path=self.model_paths["v1.4"], 106 | upscale=2, 107 | arch="clean", 108 | channel_multiplier=2, 109 | bg_upsampler=self.upsampler, 110 | ) 111 | elif version == "RestoreFormer": 112 | face_enhancer = GFPGANer( 113 | model_path=self.model_paths["RestoreFormer"], 114 | upscale=2, 115 | arch="RestoreFormer", 116 | channel_multiplier=2, 117 | bg_upsampler=self.upsampler, 118 | ) 119 | # elif version == 'CodeFormer': 120 | # face_enhancer = GFPGANer( 121 | # model_path='CodeFormer.pth', upscale=2, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler) 122 | try: 123 | # _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True, weight=weight) 124 | _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) 125 | except RuntimeError as error: 126 | logger.error("Error", error) 127 | 128 | try: 129 | if scale != 2: 130 | interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4 131 | h, w = img.shape[0:2] 132 | output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation) 133 | except Exception as error: 134 | logger.error("wrong scale input.", error) 135 | 136 | output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB) 137 | return output 138 | 139 | def app(self): 140 | input_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) 141 | if input_image is not None: 142 | st.image(input_image, width=512) 143 | with st.form(key="gfpgan"): 144 | version = st.selectbox("GFPGAN version", ["v1.2", "v1.3", "v1.4", "RestoreFormer"]) 145 | scale = st.slider("Scale", 2, 4, 4, 1) 146 | submit = st.form_submit_button("Upscale") 147 | if submit: 148 | if input_image is not None: 149 | with st.spinner("Upscaling image..."): 150 | output_img = self.inference(input_image, version, scale) 151 | st.image(output_img, width=512) 152 | # add image download button 153 | output_img = Image.fromarray(output_img) 154 | buffered = BytesIO() 155 | output_img.save(buffered, format="PNG") 156 | img_str = base64.b64encode(buffered.getvalue()).decode() 157 | now = datetime.datetime.now() 158 | formatted_date_time = now.strftime("%Y-%m-%d_%H_%M_%S") 159 | href = f'

Download Image

' 160 | st.markdown(href, unsafe_allow_html=True) 161 | -------------------------------------------------------------------------------- /stablefusion/scripts/image_info.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import streamlit as st 4 | from PIL import Image 5 | 6 | 7 | @dataclass 8 | class ImageInfo: 9 | def app(self): 10 | # upload image 11 | uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"]) 12 | if uploaded_file is not None: 13 | # read image using pil 14 | pil_image = Image.open(uploaded_file) 15 | st.image(uploaded_file, use_column_width=True) 16 | image_info = pil_image.info 17 | # display image info 18 | st.write(image_info) 19 | -------------------------------------------------------------------------------- /stablefusion/scripts/img2img.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | from dataclasses import dataclass 4 | from io import BytesIO 5 | from typing import Optional, Union 6 | import random 7 | import requests 8 | import streamlit as st 9 | import torch 10 | from diffusers import ( 11 | AltDiffusionImg2ImgPipeline, 12 | AltDiffusionPipeline, 13 | DiffusionPipeline, 14 | StableDiffusionImg2ImgPipeline, 15 | StableDiffusionPipeline, 16 | ) 17 | from loguru import logger 18 | from PIL import Image 19 | from PIL.PngImagePlugin import PngInfo 20 | 21 | from stablefusion import utils 22 | 23 | 24 | @dataclass 25 | class Img2Img: 26 | model: Optional[str] = None 27 | device: Optional[str] = None 28 | output_path: Optional[str] = None 29 | text2img_model: Optional[Union[StableDiffusionPipeline, AltDiffusionPipeline]] = None 30 | 31 | def __str__(self) -> str: 32 | return f"Img2Img(model={self.model}, device={self.device}, output_path={self.output_path})" 33 | 34 | def __post_init__(self): 35 | if self.model is not None: 36 | self.text2img_model = DiffusionPipeline.from_pretrained( 37 | self.model, 38 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, 39 | ) 40 | components = self.text2img_model.components 41 | print(components) 42 | if isinstance(self.text2img_model, StableDiffusionPipeline): 43 | self.pipeline = StableDiffusionImg2ImgPipeline(**components) 44 | elif isinstance(self.text2img_model, AltDiffusionPipeline): 45 | self.pipeline = AltDiffusionImg2ImgPipeline(**components) 46 | else: 47 | raise ValueError("Model type not supported") 48 | 49 | self.pipeline.to(self.device) 50 | self.pipeline.safety_checker = utils.no_safety_checker 51 | self._compatible_schedulers = self.pipeline.scheduler.compatibles 52 | self.scheduler_config = self.pipeline.scheduler.config 53 | self.compatible_schedulers = {scheduler.__name__: scheduler for scheduler in self._compatible_schedulers} 54 | 55 | if self.device == "mps": 56 | self.pipeline.enable_attention_slicing() 57 | # warmup 58 | url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" 59 | response = requests.get(url) 60 | init_image = Image.open(BytesIO(response.content)).convert("RGB") 61 | init_image.thumbnail((768, 768)) 62 | prompt = "A fantasy landscape, trending on artstation" 63 | _ = self.pipeline( 64 | prompt=prompt, 65 | image=init_image, 66 | strength=0.75, 67 | guidance_scale=7.5, 68 | num_inference_steps=2, 69 | ) 70 | 71 | def _set_scheduler(self, scheduler_name): 72 | scheduler = self.compatible_schedulers[scheduler_name].from_config(self.scheduler_config) 73 | self.pipeline.scheduler = scheduler 74 | 75 | def generate_image( 76 | self, prompt, image, strength, negative_prompt, scheduler, num_images, guidance_scale, steps, seed 77 | ): 78 | self._set_scheduler(scheduler) 79 | logger.info(self.pipeline.scheduler) 80 | if self.device == "mps": 81 | generator = torch.manual_seed(seed) 82 | num_images = 1 83 | else: 84 | generator = torch.Generator(device=self.device).manual_seed(seed) 85 | num_images = int(num_images) 86 | output_images = self.pipeline( 87 | prompt=prompt, 88 | image=image, 89 | strength=strength, 90 | negative_prompt=negative_prompt, 91 | num_inference_steps=steps, 92 | guidance_scale=guidance_scale, 93 | num_images_per_prompt=num_images, 94 | generator=generator, 95 | ).images 96 | torch.cuda.empty_cache() 97 | gc.collect() 98 | metadata = { 99 | "prompt": prompt, 100 | "negative_prompt": negative_prompt, 101 | "scheduler": scheduler, 102 | "num_images": num_images, 103 | "guidance_scale": guidance_scale, 104 | "steps": steps, 105 | "seed": seed, 106 | } 107 | metadata = json.dumps(metadata) 108 | _metadata = PngInfo() 109 | _metadata.add_text("img2img", metadata) 110 | 111 | utils.save_images( 112 | images=output_images, 113 | module="img2img", 114 | metadata=metadata, 115 | output_path=self.output_path, 116 | ) 117 | return output_images, _metadata 118 | 119 | def app(self): 120 | available_schedulers = list(self.compatible_schedulers.keys()) 121 | # if EulerAncestralDiscreteScheduler is available in available_schedulers, move it to the first position 122 | if "EulerAncestralDiscreteScheduler" in available_schedulers: 123 | available_schedulers.insert( 124 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler")) 125 | ) 126 | 127 | input_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) 128 | if input_image is not None: 129 | input_image = Image.open(input_image) 130 | st.image(input_image, use_column_width=True) 131 | 132 | # with st.form(key="img2img"): 133 | col1, col2 = st.columns(2) 134 | with col1: 135 | prompt = st.text_area("Prompt", "", help="Prompt to guide image generation") 136 | with col2: 137 | negative_prompt = st.text_area("Negative Prompt", "", help="The prompt not to guide image generation. Write things that you dont want to see in the image.") 138 | 139 | scheduler = st.sidebar.selectbox("Scheduler", available_schedulers, index=0, help="Scheduler(Sampler) to use for generation") 140 | guidance_scale = st.sidebar.slider("Guidance scale", 1.0, 40.0, 7.5, 0.5, help="Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.") 141 | strength = st.sidebar.slider("Denoise Strength", 0.0, 1.0, 0.5, 0.05) 142 | num_images = st.sidebar.slider("Number of images per prompt", 1, 30, 1, 1, help="Number of images you want to generate. More images requires more time and uses more GPU memory.") 143 | steps = st.sidebar.slider("Steps", 1, 150, 50, 1, help="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.") 144 | seed_choice = st.sidebar.selectbox("Do you want a random seed", options=["Yes", "No"]) 145 | if seed_choice == "Yes": 146 | seed = random.randint(0, 9999999) 147 | else: 148 | seed = st.sidebar.number_input( 149 | "Seed", 150 | value=42, 151 | step=1, 152 | help="Random seed. Change for different results using same parameters.", 153 | ) 154 | # seed = st.sidebar.number_input("Seed", 1, 999999, 1, 1) 155 | sub_col, download_col = st.columns(2) 156 | with sub_col: 157 | submit = st.button("Generate") 158 | 159 | if submit: 160 | with st.spinner("Generating images..."): 161 | output_images, metadata = self.generate_image( 162 | prompt=prompt, 163 | image=input_image, 164 | negative_prompt=negative_prompt, 165 | scheduler=scheduler, 166 | num_images=num_images, 167 | guidance_scale=guidance_scale, 168 | steps=steps, 169 | seed=seed, 170 | strength=strength, 171 | ) 172 | 173 | utils.display_and_download_images(output_images, metadata, download_col) 174 | -------------------------------------------------------------------------------- /stablefusion/scripts/interrogator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | import streamlit as st 6 | from stablefusion.scripts.clip_interrogator import Config, Interrogator 7 | from huggingface_hub import hf_hub_download 8 | from loguru import logger 9 | from PIL import Image 10 | 11 | from stablefusion import utils 12 | 13 | 14 | @dataclass 15 | class ImageInterrogator: 16 | device: Optional[str] = None 17 | output_path: Optional[str] = None 18 | 19 | def inference(self, model, image, mode): 20 | preprocess_files = [ 21 | "ViT-H-14_laion2b_s32b_b79k_artists.pkl", 22 | "ViT-H-14_laion2b_s32b_b79k_flavors.pkl", 23 | "ViT-H-14_laion2b_s32b_b79k_mediums.pkl", 24 | "ViT-H-14_laion2b_s32b_b79k_movements.pkl", 25 | "ViT-H-14_laion2b_s32b_b79k_trendings.pkl", 26 | "ViT-L-14_openai_artists.pkl", 27 | "ViT-L-14_openai_flavors.pkl", 28 | "ViT-L-14_openai_mediums.pkl", 29 | "ViT-L-14_openai_movements.pkl", 30 | "ViT-L-14_openai_trendings.pkl", 31 | ] 32 | 33 | logger.info("Downloading preprocessed cache files...") 34 | for file in preprocess_files: 35 | path = hf_hub_download(repo_id="pharma/ci-preprocess", filename=file, cache_dir=utils.cache_folder()) 36 | cache_path = os.path.dirname(path) 37 | 38 | config = Config(cache_path=cache_path, clip_model_path=utils.cache_folder(), clip_model_name=model) 39 | pipeline = Interrogator(config) 40 | 41 | pipeline.config.blip_num_beams = 64 42 | pipeline.config.chunk_size = 2048 43 | pipeline.config.flavor_intermediate_count = 2048 if model == "ViT-L-14/openai" else 1024 44 | 45 | if mode == "best": 46 | prompt = pipeline.interrogate(image) 47 | elif mode == "classic": 48 | prompt = pipeline.interrogate_classic(image) 49 | else: 50 | prompt = pipeline.interrogate_fast(image) 51 | return prompt 52 | 53 | def app(self): 54 | # upload image 55 | input_image = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"]) 56 | with st.form(key="image_interrogator"): 57 | clip_model = st.selectbox("CLIP Model", ["ViT-L (Best for SD1.X)", "ViT-H (Best for SD2.X)"]) 58 | mode = st.selectbox("Mode", ["Best", "Classic"]) 59 | submit = st.form_submit_button("Interrogate") 60 | if input_image is not None: 61 | # read image using pil 62 | pil_image = Image.open(input_image).convert("RGB") 63 | if submit: 64 | with st.spinner("Interrogating image..."): 65 | if clip_model == "ViT-L (Best for SD1.X)": 66 | model = "ViT-L-14/openai" 67 | else: 68 | model = "ViT-H-14/laion2b_s32b_b79k" 69 | prompt = self.inference(model, pil_image, mode.lower()) 70 | col1, col2 = st.columns(2) 71 | with col1: 72 | st.image(input_image, use_column_width=True) 73 | with col2: 74 | st.write(prompt) 75 | -------------------------------------------------------------------------------- /stablefusion/scripts/model_adding.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import streamlit as st 3 | import ast 4 | import os 5 | from stablefusion.utils import base_path 6 | 7 | base_path = base_path() 8 | 9 | @dataclass 10 | class ModelAdding: 11 | 12 | def read_model_list(self): 13 | 14 | try: 15 | with open('{}/model_list.txt'.format(base_path), 'r') as f: 16 | contents = f.read() 17 | except: 18 | with open('stablefusion/model_list.txt', 'r') as f: 19 | contents = f.read() 20 | model_list = ast.literal_eval(contents) 21 | 22 | return model_list 23 | 24 | def write_model_list(self, model_list): 25 | 26 | try: 27 | with open('{}/model_list.txt'.format(base_path), 'w') as f: 28 | f.write(model_list) 29 | except: 30 | with open('stablefusion/model_list.txt', 'w') as f: 31 | f.write(model_list) 32 | 33 | def check_models(self, model_name): 34 | 35 | model_list = self.read_model_list() 36 | 37 | if model_name in model_list: 38 | st.warning("{} already present in the list".format(model_name)) 39 | 40 | else: 41 | model_list.append(model_name) 42 | self.write_model_list(model_list=str(model_list)) 43 | st.success("Succefully added {} into your list".format(model_name)) 44 | 45 | 46 | def app(self): 47 | # upload image 48 | model_name = st.text_input(label="Enter The Model Name", value="runwayml/stable-diffusion-v1-5") 49 | 50 | if st.button("Apply Changes"): 51 | self.check_models(model_name=model_name) 52 | 53 | -------------------------------------------------------------------------------- /stablefusion/scripts/model_removing.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import streamlit as st 3 | import ast 4 | import os 5 | from stablefusion.utils import base_path 6 | 7 | 8 | base_path = base_path() 9 | 10 | @dataclass 11 | class ModelRemoving: 12 | 13 | def read_model_list(self): 14 | 15 | try: 16 | with open('{}/model_list.txt'.format(base_path), 'r') as f: 17 | contents = f.read() 18 | except: 19 | with open('stablefusion/model_list.txt', 'r') as f: 20 | contents = f.read() 21 | model_list = ast.literal_eval(contents) 22 | 23 | return model_list 24 | 25 | def write_model_list(self, model_list): 26 | 27 | try: 28 | with open('{}/model_list.txt'.format(base_path), 'w') as f: 29 | f.write(model_list) 30 | except: 31 | with open('stablefusion/model_list.txt', 'w') as f: 32 | f.write(model_list) 33 | 34 | def check_models(self, model_name): 35 | 36 | model_list = self.read_model_list() 37 | 38 | if model_name not in model_list: 39 | st.warning("{} not present in the list".format(model_name)) 40 | 41 | else: 42 | model_list.remove(model_name) 43 | self.write_model_list(model_list=str(model_list)) 44 | st.success("Succefully Removed {} into your list".format(model_name)) 45 | 46 | 47 | def app(self): 48 | # upload image 49 | model_name = st.selectbox(label="Enter The Model Name", options=self.read_model_list()) 50 | 51 | if st.button("Apply Changes"): 52 | self.check_models(model_name=model_name) 53 | 54 | -------------------------------------------------------------------------------- /stablefusion/scripts/pose_editor.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | from dataclasses import dataclass 4 | from io import BytesIO 5 | from typing import Optional, Union 6 | import random 7 | import requests 8 | import streamlit as st 9 | import torch 10 | from diffusers import StableDiffusionControlNetPipeline, ControlNetModel 11 | from loguru import logger 12 | from PIL import Image 13 | from PIL.PngImagePlugin import PngInfo 14 | from controlnet_aux import OpenposeDetector, HEDdetector, MLSDdetector 15 | from stablefusion import utils 16 | import cv2 17 | from PIL import Image 18 | import numpy as np 19 | from streamlit_drawable_canvas import st_canvas 20 | from transformers import pipeline 21 | from stablefusion.scripts.pose_html import html_component 22 | 23 | 24 | @dataclass 25 | class OpenPoseEditor: 26 | model: Optional[str] = None 27 | device: Optional[str] = None 28 | output_path: Optional[str] = None 29 | 30 | def __str__(self) -> str: 31 | return f"BaseModel(model={self.model}, device={self.device}, output_path={self.output_path})" 32 | 33 | 34 | def __post_init__(self): 35 | self.controlnet = ControlNetModel.from_pretrained( 36 | "lllyasviel/sd-controlnet-openpose", 37 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, 38 | use_auth_token=utils.use_auth_token(), 39 | ) 40 | self.pipeline = StableDiffusionControlNetPipeline.from_pretrained( 41 | self.model, 42 | controlnet=self.controlnet, 43 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, 44 | use_auth_token=utils.use_auth_token(), 45 | ) 46 | 47 | self.pipeline.to(self.device) 48 | self.pipeline.safety_checker = utils.no_safety_checker 49 | self._compatible_schedulers = self.pipeline.scheduler.compatibles 50 | self.scheduler_config = self.pipeline.scheduler.config 51 | self.compatible_schedulers = {scheduler.__name__: scheduler for scheduler in self._compatible_schedulers} 52 | 53 | if self.device == "mps": 54 | self.pipeline.enable_attention_slicing() 55 | # warmup 56 | url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" 57 | response = requests.get(url) 58 | init_image = Image.open(BytesIO(response.content)).convert("RGB") 59 | init_image.thumbnail((768, 768)) 60 | prompt = "A fantasy landscape, trending on artstation" 61 | _ = self.pipeline( 62 | prompt=prompt, 63 | image=init_image, 64 | strength=0.75, 65 | guidance_scale=7.5, 66 | num_inference_steps=2, 67 | ) 68 | 69 | def _set_scheduler(self, scheduler_name): 70 | scheduler = self.compatible_schedulers[scheduler_name].from_config(self.scheduler_config) 71 | self.pipeline.scheduler = scheduler 72 | 73 | def generate_image( 74 | self, prompt, image, negative_prompt, scheduler, num_images, guidance_scale, steps, seed, height, width 75 | ): 76 | self._set_scheduler(scheduler) 77 | logger.info(self.pipeline.scheduler) 78 | if self.device == "mps": 79 | generator = torch.manual_seed(seed) 80 | num_images = 1 81 | else: 82 | generator = torch.Generator(device=self.device).manual_seed(seed) 83 | num_images = int(num_images) 84 | output_images = self.pipeline( 85 | prompt=prompt, 86 | image=image, 87 | negative_prompt=negative_prompt, 88 | num_inference_steps=steps, 89 | guidance_scale=guidance_scale, 90 | num_images_per_prompt=num_images, 91 | generator=generator, 92 | height=height, 93 | width=width 94 | ).images 95 | torch.cuda.empty_cache() 96 | gc.collect() 97 | metadata = { 98 | "prompt": prompt, 99 | "negative_prompt": negative_prompt, 100 | "scheduler": scheduler, 101 | "num_images": num_images, 102 | "guidance_scale": guidance_scale, 103 | "steps": steps, 104 | "seed": seed, 105 | } 106 | metadata = json.dumps(metadata) 107 | _metadata = PngInfo() 108 | _metadata.add_text("img2img", metadata) 109 | 110 | utils.save_images( 111 | images=output_images, 112 | module="controlnet", 113 | metadata=metadata, 114 | output_path=self.output_path, 115 | ) 116 | return output_images, _metadata 117 | 118 | def app(self): 119 | available_schedulers = list(self.compatible_schedulers.keys()) 120 | # if EulerAncestralDiscreteScheduler is available in available_schedulers, move it to the first position 121 | if "EulerAncestralDiscreteScheduler" in available_schedulers: 122 | available_schedulers.insert( 123 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler")) 124 | ) 125 | 126 | html_component() 127 | 128 | input_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) 129 | 130 | if input_image is not None: 131 | 132 | input_image = Image.open(input_image).convert("RGB") 133 | 134 | st.image(input_image, use_column_width=True) 135 | 136 | 137 | # with st.form(key="img2img"): 138 | col1, col2 = st.columns(2) 139 | with col1: 140 | prompt = st.text_area("Prompt", "", help="Prompt to guide image generation") 141 | with col2: 142 | negative_prompt = st.text_area("Negative Prompt", "", help="The prompt not to guide image generation. Write things that you dont want to see in the image.") 143 | 144 | scheduler = st.sidebar.selectbox("Scheduler", available_schedulers, index=0, help="Scheduler(Sampler) to use for generation") 145 | image_height = st.sidebar.slider("Image height", 128, 1024, 512, 128, help="The height in pixels of the generated image.") 146 | image_width = st.sidebar.slider("Image width", 128, 1024, 512, 128, help="The width in pixels of the generated image.") 147 | guidance_scale = st.sidebar.slider("Guidance scale", 1.0, 40.0, 7.5, 0.5, help="Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.") 148 | num_images = st.sidebar.slider("Number of images per prompt", 1, 30, 1, 1, help="Number of images you want to generate. More images requires more time and uses more GPU memory.") 149 | steps = st.sidebar.slider("Steps", 1, 150, 50, 1, help="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.") 150 | seed_choice = st.sidebar.selectbox("Do you want a random seed", options=["Yes", "No"]) 151 | if seed_choice == "Yes": 152 | seed = random.randint(0, 9999999) 153 | else: 154 | seed = st.sidebar.number_input( 155 | "Seed", 156 | value=42, 157 | step=1, 158 | help="Random seed. Change for different results using same parameters.", 159 | ) 160 | sub_col, download_col = st.columns(2) 161 | with sub_col: 162 | submit = st.button("Generate") 163 | 164 | if submit: 165 | with st.spinner("Generating images..."): 166 | output_images, metadata = self.generate_image( 167 | prompt=prompt, 168 | image=input_image, 169 | negative_prompt=negative_prompt, 170 | scheduler=scheduler, 171 | num_images=num_images, 172 | guidance_scale=guidance_scale, 173 | steps=steps, 174 | seed=seed, 175 | height=image_height, 176 | width=image_width 177 | ) 178 | 179 | utils.display_and_download_images(output_images, metadata, download_col) 180 | -------------------------------------------------------------------------------- /stablefusion/scripts/safetensors_to_diffusion.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Conversion script for the LDM checkpoints. """ 16 | import ast 17 | import argparse 18 | import os 19 | import streamlit as st 20 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import load_pipeline_from_original_stable_diffusion_ckpt 21 | import requests 22 | from stablefusion.utils import base_path 23 | 24 | current_path = base_path() 25 | 26 | def convert_to_diffusers_model(checkpoint_path, original_config_file, image_size, prediction_type, pipeline_type, extract_ema, scheduler_type, num_in_channels, upcast_attention, from_safetensors, device, stable_unclip, stable_unclip_prior, clip_stats_path, controlnet, checkpoint_name, to_safetensors): 27 | 28 | pipe = load_pipeline_from_original_stable_diffusion_ckpt( 29 | checkpoint_path=checkpoint_path, 30 | original_config_file=original_config_file, 31 | image_size=image_size, 32 | prediction_type=prediction_type, 33 | model_type=pipeline_type, 34 | extract_ema=extract_ema, 35 | scheduler_type=scheduler_type, 36 | num_in_channels=num_in_channels, 37 | upcast_attention=upcast_attention, 38 | from_safetensors=from_safetensors, 39 | device=device, 40 | stable_unclip=stable_unclip, 41 | stable_unclip_prior=stable_unclip_prior, 42 | clip_stats_path=clip_stats_path, 43 | controlnet=controlnet, 44 | ) 45 | 46 | dump_path = "{}/models/diffusion_models/{}".format(current_path, checkpoint_name.split(".")[0]) 47 | 48 | if controlnet: 49 | # only save the controlnet model 50 | pipe.controlnet.save_pretrained(dump_path, safe_serialization=to_safetensors) 51 | else: 52 | pipe.save_pretrained(dump_path, safe_serialization=to_safetensors) 53 | 54 | st.success("Your Model Has Been Created!") 55 | append_created_model(model_name=str(checkpoint_name).split(".")[0]) 56 | 57 | 58 | def download_ckpt_model(checkpoint_link, checkpoint_name): 59 | 60 | try: 61 | print("Started Downloading the model...") 62 | response = requests.get(checkpoint_link, stream=True) 63 | 64 | with open("{}/models/safetensors_models/{}".format(current_path, checkpoint_name), "wb") as f: 65 | f.write(response.content) 66 | 67 | print("Model Downloading Completed!") 68 | 69 | except: 70 | print("Started Downloading the model...") 71 | response = requests.get(checkpoint_link, stream=True) 72 | 73 | with open("stablefusion/models/safetensors_models/{}".format(checkpoint_name), "wb") as f: 74 | f.write(response.content) 75 | 76 | print("Model Downloading Completed!") 77 | 78 | def read_model_file(): 79 | 80 | try: 81 | with open('{}/model_list.txt'.format(current_path), 'r') as f: 82 | contents = f.read() 83 | except: 84 | with open('stablefusion/model_list.txt', 'r') as f: 85 | contents = f.read() 86 | 87 | model_list = ast.literal_eval(contents) 88 | 89 | return model_list 90 | 91 | 92 | def write_model_list(model_list): 93 | 94 | try: 95 | with open('{}/model_list.txt'.format(current_path), 'w') as f: 96 | f.write(model_list) 97 | except: 98 | with open('stablefusion/model_list.txt', 'w') as f: 99 | f.write(model_list) 100 | 101 | 102 | def append_created_model(model_name): 103 | model_list = read_model_file() 104 | try: 105 | apending_list = os.listdir("{}/models/diffusion_models".format(current_path)) 106 | except: 107 | apending_list = os.listdir("stablefusion/models/diffusion_models") 108 | 109 | for working_item in apending_list: 110 | if str(working_item).split(".")[-1] == "txt": 111 | pass 112 | else: 113 | if model_name not in model_list: 114 | try: 115 | model_list.append("{}/models/diffusion_models/{}".format(current_path, model_name)) 116 | except: 117 | model_list.append("stablefusion/models/diffusion_models/{}".format(model_name)) 118 | 119 | write_model_list(model_list=str(model_list)) 120 | st.success("Model Added to your List Now you can use this model at your Home Page.") 121 | 122 | 123 | def convert_safetensor_to_diffusers(original_config_file, image_size, prediction_type, pipeline_type, extract_ema, scheduler_type, num_in_channels, upcast_attention, from_safetensors, device, stable_unclip, stable_unclip_prior, clip_stats_path, controlnet, to_safetensors, checkpoint_name, checkpoint_link, overwrite): 124 | 125 | checkpoint_path = "{}/models/safetensors_models/{}".format(current_path, checkpoint_name) 126 | 127 | try: 128 | custom_diffusion_model = "{}/models/diffusion_models/{}".format(current_path, str(checkpoint_name).split(".")[0]) 129 | except: 130 | custom_diffusion_model = "stablefusion/models/diffusion_models/{}".format(str(checkpoint_name).split(".")[0]) 131 | 132 | if overwrite is False: 133 | if not os.path.isfile(checkpoint_path): 134 | 135 | download_ckpt_model(checkpoint_link=checkpoint_link, checkpoint_name=checkpoint_name) 136 | 137 | else: 138 | st.warning("Using {}".format(checkpoint_path)) 139 | else: 140 | 141 | download_ckpt_model(checkpoint_link=checkpoint_link, checkpoint_name=checkpoint_name) 142 | 143 | 144 | if overwrite is False: 145 | 146 | if not os.path.exists(custom_diffusion_model): 147 | 148 | convert_to_diffusers_model(checkpoint_name=checkpoint_name, original_config_file=original_config_file, image_size=image_size, prediction_type=prediction_type, extract_ema=extract_ema, scheduler_type=scheduler_type, num_in_channels=num_in_channels, upcast_attention=upcast_attention, from_safetensors=from_safetensors, device=device, stable_unclip=stable_unclip, stable_unclip_prior=stable_unclip_prior, clip_stats_path=clip_stats_path, controlnet=controlnet, to_safetensors=to_safetensors, checkpoint_path=checkpoint_path, pipeline_type=pipeline_type) 149 | 150 | else: 151 | st.warning("Model {} is already present".format(custom_diffusion_model)) 152 | 153 | else: 154 | convert_to_diffusers_model(checkpoint_name=checkpoint_name, original_config_file=original_config_file, image_size=image_size, prediction_type=prediction_type, extract_ema=extract_ema, scheduler_type=scheduler_type, num_in_channels=num_in_channels, upcast_attention=upcast_attention, from_safetensors=from_safetensors, device=device, stable_unclip=stable_unclip, stable_unclip_prior=stable_unclip_prior, clip_stats_path=clip_stats_path, controlnet=controlnet, to_safetensors=to_safetensors, checkpoint_path=checkpoint_path, pipeline_type=pipeline_type) -------------------------------------------------------------------------------- /stablefusion/scripts/text2img.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | import random 6 | import streamlit as st 7 | import torch 8 | from diffusers import DiffusionPipeline 9 | from loguru import logger 10 | from PIL.PngImagePlugin import PngInfo 11 | 12 | from stablefusion import utils 13 | 14 | 15 | @dataclass 16 | class Text2Image: 17 | device: Optional[str] = None 18 | model: Optional[str] = None 19 | output_path: Optional[str] = None 20 | 21 | def __str__(self) -> str: 22 | return f"Text2Image(model={self.model}, device={self.device}, output_path={self.output_path})" 23 | 24 | def __post_init__(self): 25 | self.pipeline = DiffusionPipeline.from_pretrained( 26 | self.model, 27 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, 28 | ) 29 | self.pipeline.to(self.device) 30 | self.pipeline.safety_checker = utils.no_safety_checker 31 | self._compatible_schedulers = self.pipeline.scheduler.compatibles 32 | self.scheduler_config = self.pipeline.scheduler.config 33 | self.compatible_schedulers = {scheduler.__name__: scheduler for scheduler in self._compatible_schedulers} 34 | 35 | if self.device == "mps": 36 | self.pipeline.enable_attention_slicing() 37 | # warmup 38 | prompt = "a photo of an astronaut riding a horse on mars" 39 | _ = self.pipeline(prompt, num_inference_steps=2) 40 | 41 | def _set_scheduler(self, scheduler_name): 42 | scheduler = self.compatible_schedulers[scheduler_name].from_config(self.scheduler_config) 43 | self.pipeline.scheduler = scheduler 44 | 45 | def generate_image(self, prompt, negative_prompt, scheduler, image_size, num_images, guidance_scale, steps, seed): 46 | self._set_scheduler(scheduler) 47 | logger.info(self.pipeline.scheduler) 48 | if self.device == "mps": 49 | generator = torch.manual_seed(seed) 50 | num_images = 1 51 | else: 52 | generator = torch.Generator(device=self.device).manual_seed(seed) 53 | num_images = int(num_images) 54 | output_images = self.pipeline( 55 | prompt, 56 | negative_prompt=negative_prompt, 57 | width=image_size[1], 58 | height=image_size[0], 59 | num_inference_steps=steps, 60 | guidance_scale=guidance_scale, 61 | num_images_per_prompt=num_images, 62 | generator=generator, 63 | ).images 64 | torch.cuda.empty_cache() 65 | gc.collect() 66 | metadata = { 67 | "prompt": prompt, 68 | "negative_prompt": negative_prompt, 69 | "scheduler": scheduler, 70 | "image_size": image_size, 71 | "num_images": num_images, 72 | "guidance_scale": guidance_scale, 73 | "steps": steps, 74 | "seed": seed, 75 | } 76 | metadata = json.dumps(metadata) 77 | _metadata = PngInfo() 78 | _metadata.add_text("text2img", metadata) 79 | 80 | utils.save_images( 81 | images=output_images, 82 | module="text2img", 83 | metadata=metadata, 84 | output_path=self.output_path, 85 | ) 86 | 87 | return output_images, _metadata 88 | 89 | def app(self): 90 | available_schedulers = list(self.compatible_schedulers.keys()) 91 | if "EulerAncestralDiscreteScheduler" in available_schedulers: 92 | available_schedulers.insert( 93 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler")) 94 | ) 95 | # with st.form(key="text2img"): 96 | col1, col2 = st.columns(2) 97 | with col1: 98 | prompt = st.text_area("Prompt", "Blue elephant", help="Prompt to guide image generation") 99 | with col2: 100 | negative_prompt = st.text_area("Negative Prompt", "", help="The prompt not to guide image generation. Write things that you dont want to see in the image.") 101 | # sidebar options 102 | scheduler = st.sidebar.selectbox("Scheduler", available_schedulers, index=0, help="Scheduler(Sampler) to use for generation") 103 | image_height = st.sidebar.slider("Image height", 128, 1024, 512, 128, help="The height in pixels of the generated image.") 104 | image_width = st.sidebar.slider("Image width", 128, 1024, 512, 128, help="The width in pixels of the generated image.") 105 | guidance_scale = st.sidebar.slider("Guidance scale", 1.0, 40.0, 7.5, 0.5, help="Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.") 106 | num_images = st.sidebar.slider("Number of images per prompt", 1, 30, 1, 1, help="Number of images you want to generate. More images requires more time and uses more GPU memory.") 107 | steps = st.sidebar.slider("Steps", 1, 150, 50, 1, help="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.") 108 | 109 | seed_choice = st.sidebar.selectbox("Do you want a random seed", options=["Yes", "No"]) 110 | if seed_choice == "Yes": 111 | seed = random.randint(0, 9999999) 112 | else: 113 | seed = st.sidebar.number_input( 114 | "Seed", 115 | value=42, 116 | step=1, 117 | help="Random seed. Change for different results using same parameters.", 118 | ) 119 | 120 | sub_col, download_col = st.columns(2) 121 | with sub_col: 122 | submit = st.button("Generate") 123 | 124 | if submit: 125 | with st.spinner("Generating images..."): 126 | output_images, metadata = self.generate_image( 127 | prompt=prompt, 128 | negative_prompt=negative_prompt, 129 | scheduler=scheduler, 130 | image_size=(image_height, image_width), 131 | num_images=num_images, 132 | guidance_scale=guidance_scale, 133 | steps=steps, 134 | seed=seed, 135 | ) 136 | 137 | utils.display_and_download_images(output_images, metadata, download_col) 138 | -------------------------------------------------------------------------------- /stablefusion/scripts/textual_inversion.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | import random 6 | import streamlit as st 7 | import torch 8 | from diffusers import DiffusionPipeline 9 | from loguru import logger 10 | from PIL.PngImagePlugin import PngInfo 11 | 12 | from stablefusion import utils 13 | 14 | 15 | def load_embed(learned_embeds_path, text_encoder, tokenizer, token=None): 16 | loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") 17 | if len(loaded_learned_embeds) > 2: 18 | embeds = loaded_learned_embeds["string_to_param"]["*"][-1, :] 19 | else: 20 | # separate token and the embeds 21 | trained_token = list(loaded_learned_embeds.keys())[0] 22 | embeds = loaded_learned_embeds[trained_token] 23 | 24 | # add the token in tokenizer 25 | token = token if token is not None else trained_token 26 | num_added_tokens = tokenizer.add_tokens(token) 27 | i = 1 28 | while num_added_tokens == 0: 29 | logger.warning(f"The tokenizer already contains the token {token}.") 30 | token = f"{token[:-1]}-{i}>" 31 | logger.info(f"Attempting to add the token {token}.") 32 | num_added_tokens = tokenizer.add_tokens(token) 33 | i += 1 34 | 35 | # resize the token embeddings 36 | text_encoder.resize_token_embeddings(len(tokenizer)) 37 | 38 | # get the id for the token and assign the embeds 39 | token_id = tokenizer.convert_tokens_to_ids(token) 40 | text_encoder.get_input_embeddings().weight.data[token_id] = embeds 41 | return token 42 | 43 | 44 | @dataclass 45 | class TextualInversion: 46 | model: str 47 | embeddings_url: str 48 | token_identifier: str 49 | device: Optional[str] = None 50 | output_path: Optional[str] = None 51 | 52 | def __str__(self) -> str: 53 | return f"TextualInversion(model={self.model}, embeddings={self.embeddings_url}, token_identifier={self.token_identifier}, device={self.device}, output_path={self.output_path})" 54 | 55 | def __post_init__(self): 56 | self.pipeline = DiffusionPipeline.from_pretrained( 57 | self.model, 58 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, 59 | ) 60 | self.pipeline.to(self.device) 61 | self.pipeline.safety_checker = utils.no_safety_checker 62 | self._compatible_schedulers = self.pipeline.scheduler.compatibles 63 | self.scheduler_config = self.pipeline.scheduler.config 64 | self.compatible_schedulers = {scheduler.__name__: scheduler for scheduler in self._compatible_schedulers} 65 | 66 | # download the embeddings 67 | self.embeddings_path = utils.download_file(self.embeddings_url) 68 | load_embed( 69 | learned_embeds_path=self.embeddings_path, 70 | text_encoder=self.pipeline.text_encoder, 71 | tokenizer=self.pipeline.tokenizer, 72 | token=self.token_identifier, 73 | ) 74 | 75 | if self.device == "mps": 76 | self.pipeline.enable_attention_slicing() 77 | # warmup 78 | prompt = "a photo of an astronaut riding a horse on mars" 79 | _ = self.pipeline(prompt, num_inference_steps=1) 80 | 81 | def _set_scheduler(self, scheduler_name): 82 | scheduler = self.compatible_schedulers[scheduler_name].from_config(self.scheduler_config) 83 | self.pipeline.scheduler = scheduler 84 | 85 | def generate_image(self, prompt, negative_prompt, scheduler, image_size, num_images, guidance_scale, steps, seed): 86 | self._set_scheduler(scheduler) 87 | logger.info(self.pipeline.scheduler) 88 | if self.device == "mps": 89 | generator = torch.manual_seed(seed) 90 | num_images = 1 91 | else: 92 | generator = torch.Generator(device=self.device).manual_seed(seed) 93 | num_images = int(num_images) 94 | output_images = self.pipeline( 95 | prompt, 96 | negative_prompt=negative_prompt, 97 | width=image_size[1], 98 | height=image_size[0], 99 | num_inference_steps=steps, 100 | guidance_scale=guidance_scale, 101 | num_images_per_prompt=num_images, 102 | generator=generator, 103 | ).images 104 | torch.cuda.empty_cache() 105 | gc.collect() 106 | metadata = { 107 | "prompt": prompt, 108 | "negative_prompt": negative_prompt, 109 | "scheduler": scheduler, 110 | "image_size": image_size, 111 | "num_images": num_images, 112 | "guidance_scale": guidance_scale, 113 | "steps": steps, 114 | "seed": seed, 115 | } 116 | metadata = json.dumps(metadata) 117 | _metadata = PngInfo() 118 | _metadata.add_text("textual_inversion", metadata) 119 | 120 | utils.save_images( 121 | images=output_images, 122 | module="textual_inversion", 123 | metadata=metadata, 124 | output_path=self.output_path, 125 | ) 126 | 127 | return output_images, _metadata 128 | 129 | def app(self): 130 | available_schedulers = list(self.compatible_schedulers.keys()) 131 | if "EulerAncestralDiscreteScheduler" in available_schedulers: 132 | available_schedulers.insert( 133 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler")) 134 | ) 135 | col1, col2 = st.columns(2) 136 | with col1: 137 | prompt = st.text_area("Prompt", "Blue elephant") 138 | with col2: 139 | negative_prompt = st.text_area("Negative Prompt", "") 140 | # sidebar options 141 | scheduler = st.sidebar.selectbox("Scheduler", available_schedulers, index=0) 142 | image_height = st.sidebar.slider("Image height", 128, 1024, 512, 128) 143 | image_width = st.sidebar.slider("Image width", 128, 1024, 512, 128) 144 | guidance_scale = st.sidebar.slider("Guidance scale", 1.0, 40.0, 7.5, 0.5) 145 | num_images = st.sidebar.slider("Number of images per prompt", 1, 30, 1, 1) 146 | steps = st.sidebar.slider("Steps", 1, 150, 50, 1) 147 | seed_choice = st.sidebar.selectbox("Do you want a random seed", options=["Yes", "No"]) 148 | if seed_choice == "Yes": 149 | seed = random.randint(0, 9999999) 150 | else: 151 | seed = st.sidebar.number_input( 152 | "Seed", 153 | value=42, 154 | step=1, 155 | help="Random seed. Change for different results using same parameters.", 156 | ) 157 | sub_col, download_col = st.columns(2) 158 | with sub_col: 159 | submit = st.button("Generate") 160 | 161 | if submit: 162 | with st.spinner("Generating images..."): 163 | output_images, metadata = self.generate_image( 164 | prompt=prompt, 165 | negative_prompt=negative_prompt, 166 | scheduler=scheduler, 167 | image_size=(image_height, image_width), 168 | num_images=num_images, 169 | guidance_scale=guidance_scale, 170 | steps=steps, 171 | seed=seed, 172 | ) 173 | 174 | utils.display_and_download_images(output_images, metadata, download_col) 175 | -------------------------------------------------------------------------------- /stablefusion/scripts/upscaler.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import streamlit as st 7 | import torch 8 | from diffusers import StableDiffusionUpscalePipeline 9 | from loguru import logger 10 | from PIL import Image 11 | from PIL.PngImagePlugin import PngInfo 12 | 13 | from stablefusion import utils 14 | 15 | @dataclass 16 | class Upscaler: 17 | model: str = "stabilityai/stable-diffusion-x4-upscaler" 18 | device: Optional[str] = None 19 | output_path: Optional[str] = None 20 | 21 | def __str__(self) -> str: 22 | return f"Upscaler(model={self.model}, device={self.device}, output_path={self.output_path})" 23 | 24 | def __post_init__(self): 25 | self.pipeline = StableDiffusionUpscalePipeline.from_pretrained( 26 | self.model, 27 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, 28 | ) 29 | self.pipeline.to(self.device) 30 | self.pipeline.safety_checker = utils.no_safety_checker 31 | self._compatible_schedulers = self.pipeline.scheduler.compatibles 32 | self.scheduler_config = self.pipeline.scheduler.config 33 | self.compatible_schedulers = {scheduler.__name__: scheduler for scheduler in self._compatible_schedulers} 34 | 35 | def _set_scheduler(self, scheduler_name): 36 | scheduler = self.compatible_schedulers[scheduler_name].from_config(self.scheduler_config) 37 | self.pipeline.scheduler = scheduler 38 | 39 | def generate_image( 40 | self, image, prompt, negative_prompt, guidance_scale, noise_level, num_images, eta, scheduler, steps, seed 41 | ): 42 | self._set_scheduler(scheduler) 43 | logger.info(self.pipeline.scheduler) 44 | if self.device == "mps": 45 | generator = torch.manual_seed(seed) 46 | num_images = 1 47 | else: 48 | generator = torch.Generator(device=self.device).manual_seed(seed) 49 | num_images = int(num_images) 50 | output_images = self.pipeline( 51 | image=image, 52 | prompt=prompt, 53 | negative_prompt=negative_prompt, 54 | noise_level=noise_level, 55 | num_inference_steps=steps, 56 | eta=eta, 57 | num_images_per_prompt=num_images, 58 | generator=generator, 59 | guidance_scale=guidance_scale, 60 | ).images 61 | 62 | torch.cuda.empty_cache() 63 | gc.collect() 64 | metadata = { 65 | "prompt": prompt, 66 | "negative_prompt": negative_prompt, 67 | "noise_level": noise_level, 68 | "num_images": num_images, 69 | "eta": eta, 70 | "scheduler": scheduler, 71 | "steps": steps, 72 | "seed": seed, 73 | } 74 | 75 | metadata = json.dumps(metadata) 76 | _metadata = PngInfo() 77 | _metadata.add_text("upscaler", metadata) 78 | 79 | utils.save_images( 80 | images=output_images, 81 | module="upscaler", 82 | metadata=metadata, 83 | output_path=self.output_path, 84 | ) 85 | 86 | return output_images, _metadata 87 | 88 | def app(self): 89 | available_schedulers = list(self.compatible_schedulers.keys()) 90 | # if EulerAncestralDiscreteScheduler is available in available_schedulers, move it to the first position 91 | if "EulerAncestralDiscreteScheduler" in available_schedulers: 92 | available_schedulers.insert( 93 | 0, available_schedulers.pop(available_schedulers.index("EulerAncestralDiscreteScheduler")) 94 | ) 95 | 96 | input_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) 97 | if input_image is not None: 98 | input_image = Image.open(input_image) 99 | input_image = input_image.convert("RGB").resize((128, 128), resample=Image.LANCZOS) 100 | st.image(input_image, use_column_width=True) 101 | 102 | col1, col2 = st.columns(2) 103 | with col1: 104 | prompt = st.text_area("Prompt (Optional)", "") 105 | with col2: 106 | negative_prompt = st.text_area("Negative Prompt (Optional)", "") 107 | 108 | scheduler = st.sidebar.selectbox("Scheduler", available_schedulers, index=0) 109 | guidance_scale = st.sidebar.slider("Guidance scale", 1.0, 40.0, 7.5, 0.5) 110 | noise_level = st.sidebar.slider("Noise level", 0, 100, 20, 1) 111 | eta = st.sidebar.slider("Eta", 0.0, 1.0, 0.0, 0.1) 112 | num_images = st.sidebar.slider("Number of images per prompt", 1, 30, 1, 1) 113 | steps = st.sidebar.slider("Steps", 1, 150, 50, 1) 114 | seed_placeholder = st.sidebar.empty() 115 | seed = seed_placeholder.number_input("Seed", value=42, min_value=1, max_value=999999, step=1) 116 | random_seed = st.sidebar.button("Random seed") 117 | _seed = torch.randint(1, 999999, (1,)).item() 118 | if random_seed: 119 | seed = seed_placeholder.number_input("Seed", value=_seed, min_value=1, max_value=999999, step=1) 120 | sub_col, download_col = st.columns(2) 121 | with sub_col: 122 | submit = st.button("Generate") 123 | 124 | if submit: 125 | with st.spinner("Generating images..."): 126 | output_images, metadata = self.generate_image( 127 | prompt=prompt, 128 | image=input_image, 129 | negative_prompt=negative_prompt, 130 | scheduler=scheduler, 131 | num_images=num_images, 132 | guidance_scale=guidance_scale, 133 | steps=steps, 134 | seed=seed, 135 | noise_level=noise_level, 136 | eta=eta, 137 | ) 138 | 139 | utils.display_and_download_images(output_images, metadata, download_col) 140 | -------------------------------------------------------------------------------- /stablefusion/utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from PIL import Image 3 | import gc 4 | import io 5 | import os 6 | import tempfile 7 | import zipfile 8 | from datetime import datetime 9 | from threading import Thread 10 | import cv2 11 | import requests 12 | import streamlit as st 13 | import torch 14 | from huggingface_hub import HfApi 15 | from huggingface_hub.utils._errors import RepositoryNotFoundError 16 | from huggingface_hub.utils._validators import HFValidationError 17 | from loguru import logger 18 | from PIL.PngImagePlugin import PngInfo 19 | from st_clickable_images import clickable_images 20 | import os 21 | 22 | 23 | no_safety_checker = None 24 | 25 | 26 | CODE_OF_CONDUCT = """ 27 | ## Code of conduct 28 | The app should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes. 29 | 30 | Using the app to generate content that is cruel to individuals is a misuse of this app. One shall not use this app to generate content that is intended to be cruel to individuals, or to generate content that is intended to be cruel to individuals in a way that is not obvious to the viewer. 31 | This includes, but is not limited to: 32 | - Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc. 33 | - Intentionally promoting or propagating discriminatory content or harmful stereotypes. 34 | - Impersonating individuals without their consent. 35 | - Sexual content without consent of the people who might see it. 36 | - Mis- and disinformation 37 | - Representations of egregious violence and gore 38 | - Sharing of copyrighted or licensed material in violation of its terms of use. 39 | - Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use. 40 | 41 | By using this app, you agree to the above code of conduct. 42 | 43 | """ 44 | 45 | def use_auth_token(): 46 | token_path = os.path.join(os.path.expanduser("~"), ".huggingface", "token") 47 | if os.path.exists(token_path): 48 | return True 49 | return False 50 | 51 | 52 | def create_base_page(): 53 | st.set_page_config(layout="wide", 54 | menu_items={ 55 | "Get Help": "https://github.com/NeuralRealm/StableFusion/discussions", 56 | "Report a bug": "https://github.com/NeuralRealm/StableFusion/issues", 57 | 'About': "# Stable Fusion\nWelcome to StableFusion! Our AI-powered web app offers a user-friendly interface for transforming text into images and generating new images with customizable styles and formats. Built with Diffusion, Python, and Streamlit, our app makes it easy to create stunning visuals with just a few clicks.\n### Getting Started\nTo get started with StableFusion, simply visit our website and follow the on-screen instructions. You can input your text or upload an image, select your preferred style and output format, and let our AI do the rest.\n### About NeuralRealm\nStableFusion was developed by NeuralRealm, an organization dedicated to advancing the field of artificial intelligence. We are grateful for their contributions and proud to offer this user-friendly web app to our users\n### Contributions\nWe welcome contributions from the community to help improve StableFusion. If you have any feedback, suggestions, or would like to contribute code, please visit our [GitHub repository](https://github.com/NeuralRealm/StableFusion).\n### License\nStable Fusion is licensed under the [GPL-3.0 license](https://github.com/NeuralRealm/StableFusion/blob/master/LICENSE). See the LICENSE file for more information." 58 | } 59 | ) 60 | st.title("StableFusion") 61 | st.markdown("Welcome to **StableFusion**! A web app for **Stable Diffusion Models**") 62 | 63 | def download_file(file_url): 64 | r = requests.get(file_url, stream=True) 65 | with tempfile.NamedTemporaryFile(delete=False) as tmp: 66 | for chunk in r.iter_content(chunk_size=1024): 67 | if chunk: # filter out keep-alive new chunks 68 | tmp.write(chunk) 69 | return tmp.name 70 | 71 | 72 | def base_path(): 73 | 74 | return os.path.dirname(__file__) 75 | 76 | def cache_folder(): 77 | _cache_folder = os.path.join(os.path.expanduser("~"), ".stablefusion") 78 | os.makedirs(_cache_folder, exist_ok=True) 79 | return _cache_folder 80 | 81 | 82 | def clear_memory(preserve): 83 | torch.cuda.empty_cache() 84 | gc.collect() 85 | to_clear = ["inpainting", "text2img", "img2text"] 86 | for key in to_clear: 87 | if key not in preserve and key in st.session_state: 88 | del st.session_state[key] 89 | 90 | 91 | def save_to_hub(api, images, module, current_datetime, metadata, output_path): 92 | logger.info(f"Saving images to hub: {output_path}") 93 | _metadata = PngInfo() 94 | _metadata.add_text("text2img", metadata) 95 | for i, img in enumerate(images): 96 | img_byte_arr = io.BytesIO() 97 | img.save(img_byte_arr, format="PNG", pnginfo=_metadata) 98 | img_byte_arr = img_byte_arr.getvalue() 99 | api.upload_file( 100 | path_or_fileobj=img_byte_arr, 101 | path_in_repo=f"{module}/{current_datetime}/{i}.png", 102 | repo_id=output_path, 103 | repo_type="dataset", 104 | ) 105 | 106 | api.upload_file( 107 | path_or_fileobj=str.encode(metadata), 108 | path_in_repo=f"{module}/{current_datetime}/metadata.json", 109 | repo_id=output_path, 110 | repo_type="dataset", 111 | ) 112 | logger.info(f"Saved images to hub: {output_path}") 113 | 114 | 115 | def save_to_local(images, module, current_datetime, metadata, output_path): 116 | _metadata = PngInfo() 117 | _metadata.add_text("text2img", metadata) 118 | os.makedirs(output_path, exist_ok=True) 119 | os.makedirs(f"{output_path}/{module}", exist_ok=True) 120 | os.makedirs(f"{output_path}/{module}/{current_datetime}", exist_ok=True) 121 | 122 | for i, img in enumerate(images): 123 | img.save( 124 | f"{output_path}/{module}/{current_datetime}/{i}.png", 125 | pnginfo=_metadata, 126 | ) 127 | 128 | # save metadata as text file 129 | with open(f"{output_path}/{module}/{current_datetime}/metadata.txt", "w") as f: 130 | f.write(metadata) 131 | logger.info(f"Saved images to {output_path}/{module}/{current_datetime}") 132 | 133 | 134 | def save_images(images, module, metadata, output_path): 135 | if output_path is None: 136 | logger.warning("No output path specified, skipping saving images") 137 | return 138 | 139 | api = HfApi() 140 | dset_info = None 141 | try: 142 | dset_info = api.dataset_info(output_path) 143 | except (HFValidationError, RepositoryNotFoundError): 144 | logger.warning("No valid hugging face repo. Saving locally...") 145 | 146 | current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 147 | 148 | if not dset_info: 149 | save_to_local(images, module, current_datetime, metadata, output_path) 150 | else: 151 | Thread(target=save_to_hub, args=(api, images, module, current_datetime, metadata, output_path)).start() 152 | 153 | 154 | def display_and_download_images(output_images, metadata, download_col=None): 155 | # st.image(output_images, width=128, output_format="PNG") 156 | 157 | with st.spinner("Preparing images for download..."): 158 | # save images to a temporary directory 159 | with tempfile.TemporaryDirectory() as tmpdir: 160 | gallery_images = [] 161 | for i, image in enumerate(output_images): 162 | 163 | image.save(os.path.join(tmpdir, f"{i + 1}.png"), pnginfo=metadata) 164 | with open(os.path.join(tmpdir, f"{i + 1}.png"), "rb") as img: 165 | encoded = base64.b64encode(img.read()).decode() 166 | gallery_images.append(f"data:image/jpeg;base64,{encoded}") 167 | 168 | # zip the images 169 | zip_path = os.path.join(tmpdir, "images.zip") 170 | with zipfile.ZipFile(zip_path, "w") as zip: 171 | for filename in os.listdir(tmpdir): 172 | if filename.endswith(".png"): 173 | zip.write(os.path.join(tmpdir, filename), filename) 174 | 175 | # convert zip to base64 176 | with open(zip_path, "rb") as f: 177 | encoded = base64.b64encode(f.read()).decode() 178 | 179 | _ = clickable_images( 180 | gallery_images, 181 | titles=[f"Image #{str(i)}" for i in range(len(gallery_images))], 182 | div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, 183 | img_style={"margin": "5px", "height": "512px", "width": "512px"}, 184 | ) 185 | 186 | now = datetime.now() 187 | formatted_date_time = now.strftime("%Y-%m-%d_%H_%M_%S") 188 | 189 | # add download link 190 | st.markdown( 191 | f""" 192 | 193 |

Download Images

194 |
195 | """, 196 | unsafe_allow_html=True, 197 | ) 198 | 199 | -------------------------------------------------------------------------------- /static/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralRealm/StableFusion/9268d686f49b45194a2251c6728e74191f298654/static/.keep -------------------------------------------------------------------------------- /static/Screenshot1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralRealm/StableFusion/9268d686f49b45194a2251c6728e74191f298654/static/Screenshot1.png -------------------------------------------------------------------------------- /static/Screenshot2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralRealm/StableFusion/9268d686f49b45194a2251c6728e74191f298654/static/Screenshot2.png -------------------------------------------------------------------------------- /static/Screenshot3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralRealm/StableFusion/9268d686f49b45194a2251c6728e74191f298654/static/Screenshot3.png -------------------------------------------------------------------------------- /static/screenshot4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralRealm/StableFusion/9268d686f49b45194a2251c6728e74191f298654/static/screenshot4.png --------------------------------------------------------------------------------