├── .gitattributes ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── .style.yapf ├── LICENSE ├── LICENSE.ControlNet ├── README.md ├── app.py ├── gradio_canny2image.py ├── gradio_depth2image.py ├── gradio_fake_scribble2image.py ├── gradio_hed2image.py ├── gradio_hough2image.py ├── gradio_normal2image.py ├── gradio_pose2image.py ├── gradio_scribble2image.py ├── gradio_scribble2image_interactive.py ├── gradio_seg2image.py ├── model.py ├── patch ├── requirements.txt └── style.css /.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tflite filter=lfs diff=lfs merge=lfs -text 29 | *.tgz filter=lfs diff=lfs merge=lfs -text 30 | *.wasm filter=lfs diff=lfs merge=lfs -text 31 | *.xz filter=lfs diff=lfs merge=lfs -text 32 | *.zip filter=lfs diff=lfs merge=lfs -text 33 | *.zst filter=lfs diff=lfs merge=lfs -text 34 | *tfevents* filter=lfs diff=lfs merge=lfs -text 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | models/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "ControlNet"] 2 | path = ControlNet 3 | url = https://github.com/lllyasviel/ControlNet 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: patch 2 | repos: 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: v4.2.0 5 | hooks: 6 | - id: check-executables-have-shebangs 7 | - id: check-json 8 | - id: check-merge-conflict 9 | - id: check-shebang-scripts-are-executable 10 | - id: check-toml 11 | - id: check-yaml 12 | - id: double-quote-string-fixer 13 | - id: end-of-file-fixer 14 | - id: mixed-line-ending 15 | args: ['--fix=lf'] 16 | - id: requirements-txt-fixer 17 | - id: trailing-whitespace 18 | - repo: https://github.com/myint/docformatter 19 | rev: v1.4 20 | hooks: 21 | - id: docformatter 22 | args: ['--in-place'] 23 | - repo: https://github.com/pycqa/isort 24 | rev: 5.12.0 25 | hooks: 26 | - id: isort 27 | - repo: https://github.com/pre-commit/mirrors-mypy 28 | rev: v0.991 29 | hooks: 30 | - id: mypy 31 | args: ['--ignore-missing-imports'] 32 | additional_dependencies: ['types-python-slugify'] 33 | - repo: https://github.com/google/yapf 34 | rev: v0.32.0 35 | hooks: 36 | - id: yapf 37 | args: ['--parallel', '--in-place'] 38 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = pep8 3 | blank_line_before_nested_class_or_def = false 4 | spaces_before_comment = 2 5 | split_before_logical_operator = true 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 hysts 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE.ControlNet: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: ControlNet with other models 3 | emoji: 😻 4 | colorFrom: pink 5 | colorTo: blue 6 | sdk: gradio 7 | sdk_version: 3.18.0 8 | python_version: 3.10.9 9 | app_file: app.py 10 | pinned: false 11 | license: mit 12 | duplicated_from: hysts/ControlNet 13 | --- 14 | 15 | Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference 16 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import annotations 4 | 5 | import os 6 | import pathlib 7 | import shlex 8 | import subprocess 9 | 10 | import gradio as gr 11 | 12 | if os.getenv('SYSTEM') == 'spaces': 13 | with open('patch') as f: 14 | subprocess.run(shlex.split('patch -p1'), stdin=f, cwd='ControlNet') 15 | 16 | base_url = 'https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/' 17 | names = [ 18 | 'body_pose_model.pth', 19 | 'dpt_hybrid-midas-501f0c75.pt', 20 | 'hand_pose_model.pth', 21 | 'mlsd_large_512_fp32.pth', 22 | 'mlsd_tiny_512_fp32.pth', 23 | 'network-bsds500.pth', 24 | 'upernet_global_small.pth', 25 | ] 26 | for name in names: 27 | command = f'wget https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/{name} -O {name}' 28 | out_path = pathlib.Path(f'ControlNet/annotator/ckpts/{name}') 29 | if out_path.exists(): 30 | continue 31 | subprocess.run(shlex.split(command), cwd='ControlNet/annotator/ckpts/') 32 | 33 | from gradio_canny2image import create_demo as create_demo_canny 34 | from gradio_depth2image import create_demo as create_demo_depth 35 | from gradio_fake_scribble2image import create_demo as create_demo_fake_scribble 36 | from gradio_hed2image import create_demo as create_demo_hed 37 | from gradio_hough2image import create_demo as create_demo_hough 38 | from gradio_normal2image import create_demo as create_demo_normal 39 | from gradio_pose2image import create_demo as create_demo_pose 40 | from gradio_scribble2image import create_demo as create_demo_scribble 41 | from gradio_scribble2image_interactive import \ 42 | create_demo as create_demo_scribble_interactive 43 | from gradio_seg2image import create_demo as create_demo_seg 44 | from model import (DEFAULT_BASE_MODEL_FILENAME, DEFAULT_BASE_MODEL_REPO, 45 | DEFAULT_BASE_MODEL_URL, Model) 46 | 47 | MAX_IMAGES = 4 48 | ALLOW_CHANGING_BASE_MODEL = 'hysts/ControlNet-with-other-models' 49 | 50 | model = Model() 51 | 52 | with gr.Blocks(css='style.css') as demo: 53 | with gr.Tabs(): 54 | with gr.TabItem('Canny'): 55 | create_demo_canny(model.process_canny, max_images=MAX_IMAGES) 56 | with gr.TabItem('Hough'): 57 | create_demo_hough(model.process_hough, max_images=MAX_IMAGES) 58 | with gr.TabItem('HED'): 59 | create_demo_hed(model.process_hed, max_images=MAX_IMAGES) 60 | with gr.TabItem('Scribble'): 61 | create_demo_scribble(model.process_scribble, max_images=MAX_IMAGES) 62 | with gr.TabItem('Scribble Interactive'): 63 | create_demo_scribble_interactive( 64 | model.process_scribble_interactive, max_images=MAX_IMAGES) 65 | with gr.TabItem('Fake Scribble'): 66 | create_demo_fake_scribble(model.process_fake_scribble, 67 | max_images=MAX_IMAGES) 68 | with gr.TabItem('Pose'): 69 | create_demo_pose(model.process_pose, max_images=MAX_IMAGES) 70 | with gr.TabItem('Segmentation'): 71 | create_demo_seg(model.process_seg, max_images=MAX_IMAGES) 72 | with gr.TabItem('Depth'): 73 | create_demo_depth(model.process_depth, max_images=MAX_IMAGES) 74 | with gr.TabItem('Normal map'): 75 | create_demo_normal(model.process_normal, max_images=MAX_IMAGES) 76 | 77 | with gr.Accordion(label='Base model', open=False): 78 | current_base_model = gr.Text(label='Current base model', 79 | value=DEFAULT_BASE_MODEL_URL) 80 | with gr.Row(): 81 | base_model_repo = gr.Text(label='Base model repo', 82 | max_lines=1, 83 | placeholder=DEFAULT_BASE_MODEL_REPO, 84 | interactive=ALLOW_CHANGING_BASE_MODEL) 85 | base_model_filename = gr.Text( 86 | label='Base model file', 87 | max_lines=1, 88 | placeholder=DEFAULT_BASE_MODEL_FILENAME, 89 | interactive=ALLOW_CHANGING_BASE_MODEL) 90 | change_base_model_button = gr.Button('Change base model') 91 | gr.Markdown( 92 | '''- You can use other base models by specifying the repository name and filename. 93 | The base model must be compatible with Stable Diffusion v1.5.''') 94 | 95 | change_base_model_button.click(fn=model.set_base_model, 96 | inputs=[ 97 | base_model_repo, 98 | base_model_filename, 99 | ], 100 | outputs=current_base_model) 101 | 102 | demo.queue(api_open=False).launch() 103 | -------------------------------------------------------------------------------- /gradio_canny2image.py: -------------------------------------------------------------------------------- 1 | # This file is adapted from https://github.com/lllyasviel/ControlNet/blob/f4748e3630d8141d7765e2bd9b1e348f47847707/gradio_canny2image.py 2 | # The original license file is LICENSE.ControlNet in this repo. 3 | import gradio as gr 4 | 5 | 6 | def create_demo(process, max_images=12): 7 | with gr.Blocks() as demo: 8 | with gr.Row(): 9 | gr.Markdown('## Control Stable Diffusion with Canny Edge Maps') 10 | with gr.Row(): 11 | with gr.Column(): 12 | input_image = gr.Image(source='upload', type='numpy') 13 | prompt = gr.Textbox(label='Prompt') 14 | run_button = gr.Button(label='Run') 15 | with gr.Accordion('Advanced options', open=False): 16 | num_samples = gr.Slider(label='Images', 17 | minimum=1, 18 | maximum=max_images, 19 | value=1, 20 | step=1) 21 | image_resolution = gr.Slider(label='Image Resolution', 22 | minimum=256, 23 | maximum=768, 24 | value=512, 25 | step=256) 26 | low_threshold = gr.Slider(label='Canny low threshold', 27 | minimum=1, 28 | maximum=255, 29 | value=100, 30 | step=1) 31 | high_threshold = gr.Slider(label='Canny high threshold', 32 | minimum=1, 33 | maximum=255, 34 | value=200, 35 | step=1) 36 | ddim_steps = gr.Slider(label='Steps', 37 | minimum=1, 38 | maximum=100, 39 | value=20, 40 | step=1) 41 | scale = gr.Slider(label='Guidance Scale', 42 | minimum=0.1, 43 | maximum=30.0, 44 | value=9.0, 45 | step=0.1) 46 | seed = gr.Slider(label='Seed', 47 | minimum=-1, 48 | maximum=2147483647, 49 | step=1, 50 | randomize=True) 51 | eta = gr.Number(label='eta (DDIM)', value=0.0) 52 | a_prompt = gr.Textbox( 53 | label='Added Prompt', 54 | value='best quality, extremely detailed') 55 | n_prompt = gr.Textbox( 56 | label='Negative Prompt', 57 | value= 58 | 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' 59 | ) 60 | with gr.Column(): 61 | result_gallery = gr.Gallery(label='Output', 62 | show_label=False, 63 | elem_id='gallery').style( 64 | grid=2, height='auto') 65 | ips = [ 66 | input_image, prompt, a_prompt, n_prompt, num_samples, 67 | image_resolution, ddim_steps, scale, seed, eta, low_threshold, 68 | high_threshold 69 | ] 70 | run_button.click(fn=process, 71 | inputs=ips, 72 | outputs=[result_gallery], 73 | api_name='canny') 74 | return demo 75 | -------------------------------------------------------------------------------- /gradio_depth2image.py: -------------------------------------------------------------------------------- 1 | # This file is adapted from https://github.com/lllyasviel/ControlNet/blob/f4748e3630d8141d7765e2bd9b1e348f47847707/gradio_depth2image.py 2 | # The original license file is LICENSE.ControlNet in this repo. 3 | import gradio as gr 4 | 5 | 6 | def create_demo(process, max_images=12): 7 | with gr.Blocks() as demo: 8 | with gr.Row(): 9 | gr.Markdown('## Control Stable Diffusion with Depth Maps') 10 | with gr.Row(): 11 | with gr.Column(): 12 | input_image = gr.Image(source='upload', type='numpy') 13 | prompt = gr.Textbox(label='Prompt') 14 | run_button = gr.Button(label='Run') 15 | with gr.Accordion('Advanced options', open=False): 16 | num_samples = gr.Slider(label='Images', 17 | minimum=1, 18 | maximum=max_images, 19 | value=1, 20 | step=1) 21 | image_resolution = gr.Slider(label='Image Resolution', 22 | minimum=256, 23 | maximum=768, 24 | value=512, 25 | step=256) 26 | detect_resolution = gr.Slider(label='Depth Resolution', 27 | minimum=128, 28 | maximum=1024, 29 | value=384, 30 | step=1) 31 | ddim_steps = gr.Slider(label='Steps', 32 | minimum=1, 33 | maximum=100, 34 | value=20, 35 | step=1) 36 | scale = gr.Slider(label='Guidance Scale', 37 | minimum=0.1, 38 | maximum=30.0, 39 | value=9.0, 40 | step=0.1) 41 | seed = gr.Slider(label='Seed', 42 | minimum=-1, 43 | maximum=2147483647, 44 | step=1, 45 | randomize=True) 46 | eta = gr.Number(label='eta (DDIM)', value=0.0) 47 | a_prompt = gr.Textbox( 48 | label='Added Prompt', 49 | value='best quality, extremely detailed') 50 | n_prompt = gr.Textbox( 51 | label='Negative Prompt', 52 | value= 53 | 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' 54 | ) 55 | with gr.Column(): 56 | result_gallery = gr.Gallery(label='Output', 57 | show_label=False, 58 | elem_id='gallery').style( 59 | grid=2, height='auto') 60 | ips = [ 61 | input_image, prompt, a_prompt, n_prompt, num_samples, 62 | image_resolution, detect_resolution, ddim_steps, scale, seed, eta 63 | ] 64 | run_button.click(fn=process, 65 | inputs=ips, 66 | outputs=[result_gallery], 67 | api_name='depth') 68 | return demo 69 | -------------------------------------------------------------------------------- /gradio_fake_scribble2image.py: -------------------------------------------------------------------------------- 1 | # This file is adapted from https://github.com/lllyasviel/ControlNet/blob/f4748e3630d8141d7765e2bd9b1e348f47847707/gradio_fake_scribble2image.py 2 | # The original license file is LICENSE.ControlNet in this repo. 3 | import gradio as gr 4 | 5 | 6 | def create_demo(process, max_images=12): 7 | with gr.Blocks() as demo: 8 | with gr.Row(): 9 | gr.Markdown('## Control Stable Diffusion with Fake Scribble Maps') 10 | with gr.Row(): 11 | with gr.Column(): 12 | input_image = gr.Image(source='upload', type='numpy') 13 | prompt = gr.Textbox(label='Prompt') 14 | run_button = gr.Button(label='Run') 15 | with gr.Accordion('Advanced options', open=False): 16 | num_samples = gr.Slider(label='Images', 17 | minimum=1, 18 | maximum=max_images, 19 | value=1, 20 | step=1) 21 | image_resolution = gr.Slider(label='Image Resolution', 22 | minimum=256, 23 | maximum=768, 24 | value=512, 25 | step=256) 26 | detect_resolution = gr.Slider(label='HED Resolution', 27 | minimum=128, 28 | maximum=1024, 29 | value=512, 30 | step=1) 31 | ddim_steps = gr.Slider(label='Steps', 32 | minimum=1, 33 | maximum=100, 34 | value=20, 35 | step=1) 36 | scale = gr.Slider(label='Guidance Scale', 37 | minimum=0.1, 38 | maximum=30.0, 39 | value=9.0, 40 | step=0.1) 41 | seed = gr.Slider(label='Seed', 42 | minimum=-1, 43 | maximum=2147483647, 44 | step=1, 45 | randomize=True) 46 | eta = gr.Number(label='eta (DDIM)', value=0.0) 47 | a_prompt = gr.Textbox( 48 | label='Added Prompt', 49 | value='best quality, extremely detailed') 50 | n_prompt = gr.Textbox( 51 | label='Negative Prompt', 52 | value= 53 | 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' 54 | ) 55 | with gr.Column(): 56 | result_gallery = gr.Gallery(label='Output', 57 | show_label=False, 58 | elem_id='gallery').style( 59 | grid=2, height='auto') 60 | ips = [ 61 | input_image, prompt, a_prompt, n_prompt, num_samples, 62 | image_resolution, detect_resolution, ddim_steps, scale, seed, eta 63 | ] 64 | run_button.click(fn=process, 65 | inputs=ips, 66 | outputs=[result_gallery], 67 | api_name='fake_scribble') 68 | return demo 69 | -------------------------------------------------------------------------------- /gradio_hed2image.py: -------------------------------------------------------------------------------- 1 | # This file is adapted from https://github.com/lllyasviel/ControlNet/blob/f4748e3630d8141d7765e2bd9b1e348f47847707/gradio_hed2image.py 2 | # The original license file is LICENSE.ControlNet in this repo. 3 | import gradio as gr 4 | 5 | 6 | def create_demo(process, max_images=12): 7 | with gr.Blocks() as demo: 8 | with gr.Row(): 9 | gr.Markdown('## Control Stable Diffusion with HED Maps') 10 | with gr.Row(): 11 | with gr.Column(): 12 | input_image = gr.Image(source='upload', type='numpy') 13 | prompt = gr.Textbox(label='Prompt') 14 | run_button = gr.Button(label='Run') 15 | with gr.Accordion('Advanced options', open=False): 16 | num_samples = gr.Slider(label='Images', 17 | minimum=1, 18 | maximum=max_images, 19 | value=1, 20 | step=1) 21 | image_resolution = gr.Slider(label='Image Resolution', 22 | minimum=256, 23 | maximum=768, 24 | value=512, 25 | step=256) 26 | detect_resolution = gr.Slider(label='HED Resolution', 27 | minimum=128, 28 | maximum=1024, 29 | value=512, 30 | step=1) 31 | ddim_steps = gr.Slider(label='Steps', 32 | minimum=1, 33 | maximum=100, 34 | value=20, 35 | step=1) 36 | scale = gr.Slider(label='Guidance Scale', 37 | minimum=0.1, 38 | maximum=30.0, 39 | value=9.0, 40 | step=0.1) 41 | seed = gr.Slider(label='Seed', 42 | minimum=-1, 43 | maximum=2147483647, 44 | step=1, 45 | randomize=True) 46 | eta = gr.Number(label='eta (DDIM)', value=0.0) 47 | a_prompt = gr.Textbox( 48 | label='Added Prompt', 49 | value='best quality, extremely detailed') 50 | n_prompt = gr.Textbox( 51 | label='Negative Prompt', 52 | value= 53 | 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' 54 | ) 55 | with gr.Column(): 56 | result_gallery = gr.Gallery(label='Output', 57 | show_label=False, 58 | elem_id='gallery').style( 59 | grid=2, height='auto') 60 | ips = [ 61 | input_image, prompt, a_prompt, n_prompt, num_samples, 62 | image_resolution, detect_resolution, ddim_steps, scale, seed, eta 63 | ] 64 | run_button.click(fn=process, 65 | inputs=ips, 66 | outputs=[result_gallery], 67 | api_name='hed') 68 | return demo 69 | -------------------------------------------------------------------------------- /gradio_hough2image.py: -------------------------------------------------------------------------------- 1 | # This file is adapted from https://github.com/lllyasviel/ControlNet/blob/f4748e3630d8141d7765e2bd9b1e348f47847707/gradio_hough2image.py 2 | # The original license file is LICENSE.ControlNet in this repo. 3 | import gradio as gr 4 | 5 | 6 | def create_demo(process, max_images=12): 7 | with gr.Blocks() as demo: 8 | with gr.Row(): 9 | gr.Markdown('## Control Stable Diffusion with Hough Line Maps') 10 | with gr.Row(): 11 | with gr.Column(): 12 | input_image = gr.Image(source='upload', type='numpy') 13 | prompt = gr.Textbox(label='Prompt') 14 | run_button = gr.Button(label='Run') 15 | with gr.Accordion('Advanced options', open=False): 16 | num_samples = gr.Slider(label='Images', 17 | minimum=1, 18 | maximum=max_images, 19 | value=1, 20 | step=1) 21 | image_resolution = gr.Slider(label='Image Resolution', 22 | minimum=256, 23 | maximum=768, 24 | value=512, 25 | step=256) 26 | detect_resolution = gr.Slider(label='Hough Resolution', 27 | minimum=128, 28 | maximum=1024, 29 | value=512, 30 | step=1) 31 | value_threshold = gr.Slider( 32 | label='Hough value threshold (MLSD)', 33 | minimum=0.01, 34 | maximum=2.0, 35 | value=0.1, 36 | step=0.01) 37 | distance_threshold = gr.Slider( 38 | label='Hough distance threshold (MLSD)', 39 | minimum=0.01, 40 | maximum=20.0, 41 | value=0.1, 42 | step=0.01) 43 | ddim_steps = gr.Slider(label='Steps', 44 | minimum=1, 45 | maximum=100, 46 | value=20, 47 | step=1) 48 | scale = gr.Slider(label='Guidance Scale', 49 | minimum=0.1, 50 | maximum=30.0, 51 | value=9.0, 52 | step=0.1) 53 | seed = gr.Slider(label='Seed', 54 | minimum=-1, 55 | maximum=2147483647, 56 | step=1, 57 | randomize=True) 58 | eta = gr.Number(label='eta (DDIM)', value=0.0) 59 | a_prompt = gr.Textbox( 60 | label='Added Prompt', 61 | value='best quality, extremely detailed') 62 | n_prompt = gr.Textbox( 63 | label='Negative Prompt', 64 | value= 65 | 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' 66 | ) 67 | with gr.Column(): 68 | result_gallery = gr.Gallery(label='Output', 69 | show_label=False, 70 | elem_id='gallery').style( 71 | grid=2, height='auto') 72 | ips = [ 73 | input_image, prompt, a_prompt, n_prompt, num_samples, 74 | image_resolution, detect_resolution, ddim_steps, scale, seed, eta, 75 | value_threshold, distance_threshold 76 | ] 77 | run_button.click(fn=process, 78 | inputs=ips, 79 | outputs=[result_gallery], 80 | api_name='hough') 81 | return demo 82 | -------------------------------------------------------------------------------- /gradio_normal2image.py: -------------------------------------------------------------------------------- 1 | # This file is adapted from https://github.com/lllyasviel/ControlNet/blob/f4748e3630d8141d7765e2bd9b1e348f47847707/gradio_normal2image.py 2 | # The original license file is LICENSE.ControlNet in this repo. 3 | import gradio as gr 4 | 5 | 6 | def create_demo(process, max_images=12): 7 | with gr.Blocks() as demo: 8 | with gr.Row(): 9 | gr.Markdown('## Control Stable Diffusion with Normal Maps') 10 | with gr.Row(): 11 | with gr.Column(): 12 | input_image = gr.Image(source='upload', type='numpy') 13 | prompt = gr.Textbox(label='Prompt') 14 | run_button = gr.Button(label='Run') 15 | with gr.Accordion('Advanced options', open=False): 16 | num_samples = gr.Slider(label='Images', 17 | minimum=1, 18 | maximum=max_images, 19 | value=1, 20 | step=1) 21 | image_resolution = gr.Slider(label='Image Resolution', 22 | minimum=256, 23 | maximum=768, 24 | value=512, 25 | step=256) 26 | detect_resolution = gr.Slider(label='Normal Resolution', 27 | minimum=128, 28 | maximum=1024, 29 | value=384, 30 | step=1) 31 | bg_threshold = gr.Slider( 32 | label='Normal background threshold', 33 | minimum=0.0, 34 | maximum=1.0, 35 | value=0.4, 36 | step=0.01) 37 | ddim_steps = gr.Slider(label='Steps', 38 | minimum=1, 39 | maximum=100, 40 | value=20, 41 | step=1) 42 | scale = gr.Slider(label='Guidance Scale', 43 | minimum=0.1, 44 | maximum=30.0, 45 | value=9.0, 46 | step=0.1) 47 | seed = gr.Slider(label='Seed', 48 | minimum=-1, 49 | maximum=2147483647, 50 | step=1, 51 | randomize=True) 52 | eta = gr.Number(label='eta (DDIM)', value=0.0) 53 | a_prompt = gr.Textbox( 54 | label='Added Prompt', 55 | value='best quality, extremely detailed') 56 | n_prompt = gr.Textbox( 57 | label='Negative Prompt', 58 | value= 59 | 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' 60 | ) 61 | with gr.Column(): 62 | result_gallery = gr.Gallery(label='Output', 63 | show_label=False, 64 | elem_id='gallery').style( 65 | grid=2, height='auto') 66 | ips = [ 67 | input_image, prompt, a_prompt, n_prompt, num_samples, 68 | image_resolution, detect_resolution, ddim_steps, scale, seed, eta, 69 | bg_threshold 70 | ] 71 | run_button.click(fn=process, 72 | inputs=ips, 73 | outputs=[result_gallery], 74 | api_name='normal') 75 | return demo 76 | -------------------------------------------------------------------------------- /gradio_pose2image.py: -------------------------------------------------------------------------------- 1 | # This file is adapted from https://github.com/lllyasviel/ControlNet/blob/f4748e3630d8141d7765e2bd9b1e348f47847707/gradio_pose2image.py 2 | # The original license file is LICENSE.ControlNet in this repo. 3 | import gradio as gr 4 | 5 | 6 | def create_demo(process, max_images=12): 7 | with gr.Blocks() as demo: 8 | with gr.Row(): 9 | gr.Markdown('## Control Stable Diffusion with Human Pose') 10 | with gr.Row(): 11 | with gr.Column(): 12 | input_image = gr.Image(source='upload', type='numpy') 13 | prompt = gr.Textbox(label='Prompt') 14 | run_button = gr.Button(label='Run') 15 | with gr.Accordion('Advanced options', open=False): 16 | num_samples = gr.Slider(label='Images', 17 | minimum=1, 18 | maximum=max_images, 19 | value=1, 20 | step=1) 21 | image_resolution = gr.Slider(label='Image Resolution', 22 | minimum=256, 23 | maximum=768, 24 | value=512, 25 | step=256) 26 | detect_resolution = gr.Slider(label='OpenPose Resolution', 27 | minimum=128, 28 | maximum=1024, 29 | value=512, 30 | step=1) 31 | ddim_steps = gr.Slider(label='Steps', 32 | minimum=1, 33 | maximum=100, 34 | value=20, 35 | step=1) 36 | scale = gr.Slider(label='Guidance Scale', 37 | minimum=0.1, 38 | maximum=30.0, 39 | value=9.0, 40 | step=0.1) 41 | seed = gr.Slider(label='Seed', 42 | minimum=-1, 43 | maximum=2147483647, 44 | step=1, 45 | randomize=True) 46 | eta = gr.Number(label='eta (DDIM)', value=0.0) 47 | a_prompt = gr.Textbox( 48 | label='Added Prompt', 49 | value='best quality, extremely detailed') 50 | n_prompt = gr.Textbox( 51 | label='Negative Prompt', 52 | value= 53 | 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' 54 | ) 55 | with gr.Column(): 56 | result_gallery = gr.Gallery(label='Output', 57 | show_label=False, 58 | elem_id='gallery').style( 59 | grid=2, height='auto') 60 | ips = [ 61 | input_image, prompt, a_prompt, n_prompt, num_samples, 62 | image_resolution, detect_resolution, ddim_steps, scale, seed, eta 63 | ] 64 | run_button.click(fn=process, 65 | inputs=ips, 66 | outputs=[result_gallery], 67 | api_name='pose') 68 | return demo 69 | -------------------------------------------------------------------------------- /gradio_scribble2image.py: -------------------------------------------------------------------------------- 1 | # This file is adapted from https://github.com/lllyasviel/ControlNet/blob/f4748e3630d8141d7765e2bd9b1e348f47847707/gradio_scribble2image.py 2 | # The original license file is LICENSE.ControlNet in this repo. 3 | import gradio as gr 4 | 5 | 6 | def create_demo(process, max_images=12): 7 | with gr.Blocks() as demo: 8 | with gr.Row(): 9 | gr.Markdown('## Control Stable Diffusion with Scribble Maps') 10 | with gr.Row(): 11 | with gr.Column(): 12 | input_image = gr.Image(source='upload', type='numpy') 13 | prompt = gr.Textbox(label='Prompt') 14 | run_button = gr.Button(label='Run') 15 | with gr.Accordion('Advanced options', open=False): 16 | num_samples = gr.Slider(label='Images', 17 | minimum=1, 18 | maximum=max_images, 19 | value=1, 20 | step=1) 21 | image_resolution = gr.Slider(label='Image Resolution', 22 | minimum=256, 23 | maximum=768, 24 | value=512, 25 | step=256) 26 | ddim_steps = gr.Slider(label='Steps', 27 | minimum=1, 28 | maximum=100, 29 | value=20, 30 | step=1) 31 | scale = gr.Slider(label='Guidance Scale', 32 | minimum=0.1, 33 | maximum=30.0, 34 | value=9.0, 35 | step=0.1) 36 | seed = gr.Slider(label='Seed', 37 | minimum=-1, 38 | maximum=2147483647, 39 | step=1, 40 | randomize=True) 41 | eta = gr.Number(label='eta (DDIM)', value=0.0) 42 | a_prompt = gr.Textbox( 43 | label='Added Prompt', 44 | value='best quality, extremely detailed') 45 | n_prompt = gr.Textbox( 46 | label='Negative Prompt', 47 | value= 48 | 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' 49 | ) 50 | with gr.Column(): 51 | result_gallery = gr.Gallery(label='Output', 52 | show_label=False, 53 | elem_id='gallery').style( 54 | grid=2, height='auto') 55 | ips = [ 56 | input_image, prompt, a_prompt, n_prompt, num_samples, 57 | image_resolution, ddim_steps, scale, seed, eta 58 | ] 59 | run_button.click(fn=process, 60 | inputs=ips, 61 | outputs=[result_gallery], 62 | api_name='scribble') 63 | return demo 64 | -------------------------------------------------------------------------------- /gradio_scribble2image_interactive.py: -------------------------------------------------------------------------------- 1 | # This file is adapted from https://github.com/lllyasviel/ControlNet/blob/f4748e3630d8141d7765e2bd9b1e348f47847707/gradio_scribble2image_interactive.py 2 | # The original license file is LICENSE.ControlNet in this repo. 3 | import gradio as gr 4 | import numpy as np 5 | 6 | 7 | def create_canvas(w, h): 8 | return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255 9 | 10 | 11 | def create_demo(process, max_images=12): 12 | with gr.Blocks() as demo: 13 | with gr.Row(): 14 | gr.Markdown( 15 | '## Control Stable Diffusion with Interactive Scribbles') 16 | with gr.Row(): 17 | with gr.Column(): 18 | canvas_width = gr.Slider(label='Canvas Width', 19 | minimum=256, 20 | maximum=1024, 21 | value=512, 22 | step=1) 23 | canvas_height = gr.Slider(label='Canvas Height', 24 | minimum=256, 25 | maximum=1024, 26 | value=512, 27 | step=1) 28 | create_button = gr.Button(label='Start', 29 | value='Open drawing canvas!') 30 | input_image = gr.Image(source='upload', 31 | type='numpy', 32 | tool='sketch') 33 | gr.Markdown( 34 | value= 35 | 'Do not forget to change your brush width to make it thinner. (Gradio do not allow developers to set brush width so you need to do it manually.) ' 36 | 'Just click on the small pencil icon in the upper right corner of the above block.' 37 | ) 38 | create_button.click(fn=create_canvas, 39 | inputs=[canvas_width, canvas_height], 40 | outputs=[input_image], 41 | queue=False) 42 | prompt = gr.Textbox(label='Prompt') 43 | run_button = gr.Button(label='Run') 44 | with gr.Accordion('Advanced options', open=False): 45 | num_samples = gr.Slider(label='Images', 46 | minimum=1, 47 | maximum=max_images, 48 | value=1, 49 | step=1) 50 | image_resolution = gr.Slider(label='Image Resolution', 51 | minimum=256, 52 | maximum=768, 53 | value=512, 54 | step=256) 55 | ddim_steps = gr.Slider(label='Steps', 56 | minimum=1, 57 | maximum=100, 58 | value=20, 59 | step=1) 60 | scale = gr.Slider(label='Guidance Scale', 61 | minimum=0.1, 62 | maximum=30.0, 63 | value=9.0, 64 | step=0.1) 65 | seed = gr.Slider(label='Seed', 66 | minimum=-1, 67 | maximum=2147483647, 68 | step=1, 69 | randomize=True) 70 | eta = gr.Number(label='eta (DDIM)', value=0.0) 71 | a_prompt = gr.Textbox( 72 | label='Added Prompt', 73 | value='best quality, extremely detailed') 74 | n_prompt = gr.Textbox( 75 | label='Negative Prompt', 76 | value= 77 | 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' 78 | ) 79 | with gr.Column(): 80 | result_gallery = gr.Gallery(label='Output', 81 | show_label=False, 82 | elem_id='gallery').style( 83 | grid=2, height='auto') 84 | ips = [ 85 | input_image, prompt, a_prompt, n_prompt, num_samples, 86 | image_resolution, ddim_steps, scale, seed, eta 87 | ] 88 | run_button.click(fn=process, inputs=ips, outputs=[result_gallery]) 89 | return demo 90 | -------------------------------------------------------------------------------- /gradio_seg2image.py: -------------------------------------------------------------------------------- 1 | # This file is adapted from https://github.com/lllyasviel/ControlNet/blob/f4748e3630d8141d7765e2bd9b1e348f47847707/gradio_seg2image.py 2 | # The original license file is LICENSE.ControlNet in this repo. 3 | import gradio as gr 4 | 5 | 6 | def create_demo(process, max_images=12): 7 | with gr.Blocks() as demo: 8 | with gr.Row(): 9 | gr.Markdown('## Control Stable Diffusion with Segmentation Maps') 10 | with gr.Row(): 11 | with gr.Column(): 12 | input_image = gr.Image(source='upload', type='numpy') 13 | prompt = gr.Textbox(label='Prompt') 14 | run_button = gr.Button(label='Run') 15 | with gr.Accordion('Advanced options', open=False): 16 | num_samples = gr.Slider(label='Images', 17 | minimum=1, 18 | maximum=max_images, 19 | value=1, 20 | step=1) 21 | image_resolution = gr.Slider(label='Image Resolution', 22 | minimum=256, 23 | maximum=768, 24 | value=512, 25 | step=256) 26 | detect_resolution = gr.Slider( 27 | label='Segmentation Resolution', 28 | minimum=128, 29 | maximum=1024, 30 | value=512, 31 | step=1) 32 | ddim_steps = gr.Slider(label='Steps', 33 | minimum=1, 34 | maximum=100, 35 | value=20, 36 | step=1) 37 | scale = gr.Slider(label='Guidance Scale', 38 | minimum=0.1, 39 | maximum=30.0, 40 | value=9.0, 41 | step=0.1) 42 | seed = gr.Slider(label='Seed', 43 | minimum=-1, 44 | maximum=2147483647, 45 | step=1, 46 | randomize=True) 47 | eta = gr.Number(label='eta (DDIM)', value=0.0) 48 | a_prompt = gr.Textbox( 49 | label='Added Prompt', 50 | value='best quality, extremely detailed') 51 | n_prompt = gr.Textbox( 52 | label='Negative Prompt', 53 | value= 54 | 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' 55 | ) 56 | with gr.Column(): 57 | result_gallery = gr.Gallery(label='Output', 58 | show_label=False, 59 | elem_id='gallery').style( 60 | grid=2, height='auto') 61 | ips = [ 62 | input_image, prompt, a_prompt, n_prompt, num_samples, 63 | image_resolution, detect_resolution, ddim_steps, scale, seed, eta 64 | ] 65 | run_button.click(fn=process, 66 | inputs=ips, 67 | outputs=[result_gallery], 68 | api_name='seg') 69 | return demo 70 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # This file is adapted from gradio_*.py in https://github.com/lllyasviel/ControlNet/tree/f4748e3630d8141d7765e2bd9b1e348f47847707 2 | # The original license file is LICENSE.ControlNet in this repo. 3 | from __future__ import annotations 4 | 5 | import pathlib 6 | import random 7 | import shlex 8 | import subprocess 9 | import sys 10 | 11 | import cv2 12 | import einops 13 | import numpy as np 14 | import torch 15 | from huggingface_hub import hf_hub_url 16 | from pytorch_lightning import seed_everything 17 | 18 | sys.path.append('ControlNet') 19 | 20 | import config 21 | from annotator.canny import apply_canny 22 | from annotator.hed import apply_hed, nms 23 | from annotator.midas import apply_midas 24 | from annotator.mlsd import apply_mlsd 25 | from annotator.openpose import apply_openpose 26 | from annotator.uniformer import apply_uniformer 27 | from annotator.util import HWC3, resize_image 28 | from cldm.model import create_model, load_state_dict 29 | from ldm.models.diffusion.ddim import DDIMSampler 30 | from share import * 31 | 32 | MODEL_NAMES = { 33 | 'canny': 'control_canny-fp16.safetensors', 34 | 'hough': 'control_mlsd-fp16.safetensors', 35 | 'hed': 'control_hed-fp16.safetensors', 36 | 'scribble': 'control_scribble-fp16.safetensors', 37 | 'pose': 'control_openpose-fp16.safetensors', 38 | 'seg': 'control_seg-fp16.safetensors', 39 | 'depth': 'control_depth-fp16.safetensors', 40 | 'normal': 'control_normal-fp16.safetensors', 41 | } 42 | MODEL_REPO = 'webui/ControlNet-modules-safetensors' 43 | 44 | DEFAULT_BASE_MODEL_REPO = 'DEFAULT_BASE_MODEL_REPO_PLACEHOLDER' 45 | DEFAULT_BASE_MODEL_FILENAME = 'DEFAULT_BASE_MODEL_FILENAME_PLACEHOLDER' 46 | DEFAULT_BASE_MODEL_URL = 'DEFAULT_BASE_MODEL_URL_PLACEHOLDER' 47 | 48 | # DEFAULT_BASE_MODEL_REPO = 'andite/anything-v4.0' 49 | # DEFAULT_BASE_MODEL_FILENAME = 'anything-v4.0-pruned.safetensors' 50 | # DEFAULT_BASE_MODEL_URL = 'https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.0-pruned.safetensors' 51 | 52 | class Model: 53 | def __init__(self, 54 | model_config_path: str = 'ControlNet/models/cldm_v15.yaml', 55 | model_dir: str = 'models'): 56 | self.device = torch.device( 57 | 'cuda:0' if torch.cuda.is_available() else 'cpu') 58 | self.model = create_model(model_config_path).to(self.device) 59 | self.ddim_sampler = DDIMSampler(self.model) 60 | self.task_name = '' 61 | 62 | self.base_model_url = '' 63 | self.model_dir = pathlib.Path(model_dir) 64 | self.model_dir.mkdir(exist_ok=True, parents=True) 65 | 66 | self.download_models() 67 | self.set_base_model(DEFAULT_BASE_MODEL_REPO, 68 | DEFAULT_BASE_MODEL_FILENAME) 69 | 70 | def set_base_model(self, model_id: str, filename: str) -> str: 71 | if not model_id or not filename: 72 | return self.base_model_url 73 | base_model_url = hf_hub_url(model_id, filename) 74 | if base_model_url != self.base_model_url: 75 | self.load_base_model(base_model_url) 76 | self.base_model_url = base_model_url 77 | return self.base_model_url 78 | 79 | def download_base_model(self, model_url: str) -> pathlib.Path: 80 | self.model_dir.mkdir(exist_ok=True, parents=True) 81 | model_name = model_url.split('/')[-1] 82 | out_path = self.model_dir / model_name 83 | if not out_path.exists(): 84 | subprocess.run(shlex.split(f'wget {model_url} -O {out_path}')) 85 | return out_path 86 | 87 | def load_base_model(self, model_url: str) -> None: 88 | model_path = self.download_base_model(model_url) 89 | self.model.load_state_dict(load_state_dict(model_path, 90 | location=self.device.type), 91 | strict=False) 92 | 93 | def load_weight(self, task_name: str) -> None: 94 | if task_name == self.task_name: 95 | return 96 | weight_path = self.get_weight_path(task_name) 97 | self.model.control_model.load_state_dict( 98 | load_state_dict(weight_path, location=self.device.type)) 99 | self.task_name = task_name 100 | 101 | def get_weight_path(self, task_name: str) -> str: 102 | if 'scribble' in task_name: 103 | task_name = 'scribble' 104 | return f'{self.model_dir}/{MODEL_NAMES[task_name]}' 105 | 106 | def download_models(self) -> None: 107 | self.model_dir.mkdir(exist_ok=True, parents=True) 108 | for name in MODEL_NAMES.values(): 109 | out_path = self.model_dir / name 110 | if out_path.exists(): 111 | continue 112 | model_url = hf_hub_url(MODEL_REPO, name) 113 | subprocess.run(shlex.split(f'wget {model_url} -O {out_path}')) 114 | 115 | @torch.inference_mode() 116 | def process_canny(self, input_image, prompt, a_prompt, n_prompt, 117 | num_samples, image_resolution, ddim_steps, scale, seed, 118 | eta, low_threshold, high_threshold): 119 | self.load_weight('canny') 120 | 121 | img = resize_image(HWC3(input_image), image_resolution) 122 | H, W, C = img.shape 123 | 124 | detected_map = apply_canny(img, low_threshold, high_threshold) 125 | detected_map = HWC3(detected_map) 126 | 127 | control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 128 | control = torch.stack([control for _ in range(num_samples)], dim=0) 129 | control = einops.rearrange(control, 'b h w c -> b c h w').clone() 130 | 131 | if seed == -1: 132 | seed = random.randint(0, 65535) 133 | seed_everything(seed) 134 | 135 | if config.save_memory: 136 | self.model.low_vram_shift(is_diffusing=False) 137 | 138 | cond = { 139 | 'c_concat': [control], 140 | 'c_crossattn': [ 141 | self.model.get_learned_conditioning( 142 | [prompt + ', ' + a_prompt] * num_samples) 143 | ] 144 | } 145 | un_cond = { 146 | 'c_concat': [control], 147 | 'c_crossattn': 148 | [self.model.get_learned_conditioning([n_prompt] * num_samples)] 149 | } 150 | shape = (4, H // 8, W // 8) 151 | 152 | if config.save_memory: 153 | self.model.low_vram_shift(is_diffusing=True) 154 | 155 | samples, intermediates = self.ddim_sampler.sample( 156 | ddim_steps, 157 | num_samples, 158 | shape, 159 | cond, 160 | verbose=False, 161 | eta=eta, 162 | unconditional_guidance_scale=scale, 163 | unconditional_conditioning=un_cond) 164 | 165 | if config.save_memory: 166 | self.model.low_vram_shift(is_diffusing=False) 167 | 168 | x_samples = self.model.decode_first_stage(samples) 169 | x_samples = ( 170 | einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 171 | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 172 | 173 | results = [x_samples[i] for i in range(num_samples)] 174 | return [255 - detected_map] + results 175 | 176 | @torch.inference_mode() 177 | def process_hough(self, input_image, prompt, a_prompt, n_prompt, 178 | num_samples, image_resolution, detect_resolution, 179 | ddim_steps, scale, seed, eta, value_threshold, 180 | distance_threshold): 181 | self.load_weight('hough') 182 | 183 | input_image = HWC3(input_image) 184 | detected_map = apply_mlsd(resize_image(input_image, detect_resolution), 185 | value_threshold, distance_threshold) 186 | detected_map = HWC3(detected_map) 187 | img = resize_image(input_image, image_resolution) 188 | H, W, C = img.shape 189 | 190 | detected_map = cv2.resize(detected_map, (W, H), 191 | interpolation=cv2.INTER_NEAREST) 192 | 193 | control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 194 | control = torch.stack([control for _ in range(num_samples)], dim=0) 195 | control = einops.rearrange(control, 'b h w c -> b c h w').clone() 196 | 197 | if seed == -1: 198 | seed = random.randint(0, 65535) 199 | seed_everything(seed) 200 | 201 | if config.save_memory: 202 | self.model.low_vram_shift(is_diffusing=False) 203 | 204 | cond = { 205 | 'c_concat': [control], 206 | 'c_crossattn': [ 207 | self.model.get_learned_conditioning( 208 | [prompt + ', ' + a_prompt] * num_samples) 209 | ] 210 | } 211 | un_cond = { 212 | 'c_concat': [control], 213 | 'c_crossattn': 214 | [self.model.get_learned_conditioning([n_prompt] * num_samples)] 215 | } 216 | shape = (4, H // 8, W // 8) 217 | 218 | if config.save_memory: 219 | self.model.low_vram_shift(is_diffusing=True) 220 | 221 | samples, intermediates = self.ddim_sampler.sample( 222 | ddim_steps, 223 | num_samples, 224 | shape, 225 | cond, 226 | verbose=False, 227 | eta=eta, 228 | unconditional_guidance_scale=scale, 229 | unconditional_conditioning=un_cond) 230 | 231 | if config.save_memory: 232 | self.model.low_vram_shift(is_diffusing=False) 233 | 234 | x_samples = self.model.decode_first_stage(samples) 235 | x_samples = ( 236 | einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 237 | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 238 | 239 | results = [x_samples[i] for i in range(num_samples)] 240 | return [ 241 | 255 - cv2.dilate(detected_map, 242 | np.ones(shape=(3, 3), dtype=np.uint8), 243 | iterations=1) 244 | ] + results 245 | 246 | @torch.inference_mode() 247 | def process_hed(self, input_image, prompt, a_prompt, n_prompt, num_samples, 248 | image_resolution, detect_resolution, ddim_steps, scale, 249 | seed, eta): 250 | self.load_weight('hed') 251 | 252 | input_image = HWC3(input_image) 253 | detected_map = apply_hed(resize_image(input_image, detect_resolution)) 254 | detected_map = HWC3(detected_map) 255 | img = resize_image(input_image, image_resolution) 256 | H, W, C = img.shape 257 | 258 | detected_map = cv2.resize(detected_map, (W, H), 259 | interpolation=cv2.INTER_LINEAR) 260 | 261 | control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 262 | control = torch.stack([control for _ in range(num_samples)], dim=0) 263 | control = einops.rearrange(control, 'b h w c -> b c h w').clone() 264 | 265 | if seed == -1: 266 | seed = random.randint(0, 65535) 267 | seed_everything(seed) 268 | 269 | if config.save_memory: 270 | self.model.low_vram_shift(is_diffusing=False) 271 | 272 | cond = { 273 | 'c_concat': [control], 274 | 'c_crossattn': [ 275 | self.model.get_learned_conditioning( 276 | [prompt + ', ' + a_prompt] * num_samples) 277 | ] 278 | } 279 | un_cond = { 280 | 'c_concat': [control], 281 | 'c_crossattn': 282 | [self.model.get_learned_conditioning([n_prompt] * num_samples)] 283 | } 284 | shape = (4, H // 8, W // 8) 285 | 286 | if config.save_memory: 287 | self.model.low_vram_shift(is_diffusing=True) 288 | 289 | samples, intermediates = self.ddim_sampler.sample( 290 | ddim_steps, 291 | num_samples, 292 | shape, 293 | cond, 294 | verbose=False, 295 | eta=eta, 296 | unconditional_guidance_scale=scale, 297 | unconditional_conditioning=un_cond) 298 | 299 | if config.save_memory: 300 | self.model.low_vram_shift(is_diffusing=False) 301 | 302 | x_samples = self.model.decode_first_stage(samples) 303 | x_samples = ( 304 | einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 305 | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 306 | 307 | results = [x_samples[i] for i in range(num_samples)] 308 | return [detected_map] + results 309 | 310 | @torch.inference_mode() 311 | def process_scribble(self, input_image, prompt, a_prompt, n_prompt, 312 | num_samples, image_resolution, ddim_steps, scale, 313 | seed, eta): 314 | self.load_weight('scribble') 315 | 316 | img = resize_image(HWC3(input_image), image_resolution) 317 | H, W, C = img.shape 318 | 319 | detected_map = np.zeros_like(img, dtype=np.uint8) 320 | detected_map[np.min(img, axis=2) < 127] = 255 321 | 322 | control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 323 | control = torch.stack([control for _ in range(num_samples)], dim=0) 324 | control = einops.rearrange(control, 'b h w c -> b c h w').clone() 325 | 326 | if seed == -1: 327 | seed = random.randint(0, 65535) 328 | seed_everything(seed) 329 | 330 | if config.save_memory: 331 | self.model.low_vram_shift(is_diffusing=False) 332 | 333 | cond = { 334 | 'c_concat': [control], 335 | 'c_crossattn': [ 336 | self.model.get_learned_conditioning( 337 | [prompt + ', ' + a_prompt] * num_samples) 338 | ] 339 | } 340 | un_cond = { 341 | 'c_concat': [control], 342 | 'c_crossattn': 343 | [self.model.get_learned_conditioning([n_prompt] * num_samples)] 344 | } 345 | shape = (4, H // 8, W // 8) 346 | 347 | if config.save_memory: 348 | self.model.low_vram_shift(is_diffusing=True) 349 | 350 | samples, intermediates = self.ddim_sampler.sample( 351 | ddim_steps, 352 | num_samples, 353 | shape, 354 | cond, 355 | verbose=False, 356 | eta=eta, 357 | unconditional_guidance_scale=scale, 358 | unconditional_conditioning=un_cond) 359 | 360 | if config.save_memory: 361 | self.model.low_vram_shift(is_diffusing=False) 362 | 363 | x_samples = self.model.decode_first_stage(samples) 364 | x_samples = ( 365 | einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 366 | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 367 | 368 | results = [x_samples[i] for i in range(num_samples)] 369 | return [255 - detected_map] + results 370 | 371 | @torch.inference_mode() 372 | def process_scribble_interactive(self, input_image, prompt, a_prompt, 373 | n_prompt, num_samples, image_resolution, 374 | ddim_steps, scale, seed, eta): 375 | self.load_weight('scribble') 376 | 377 | img = resize_image(HWC3(input_image['mask'][:, :, 0]), 378 | image_resolution) 379 | H, W, C = img.shape 380 | 381 | detected_map = np.zeros_like(img, dtype=np.uint8) 382 | detected_map[np.min(img, axis=2) > 127] = 255 383 | 384 | control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 385 | control = torch.stack([control for _ in range(num_samples)], dim=0) 386 | control = einops.rearrange(control, 'b h w c -> b c h w').clone() 387 | 388 | if seed == -1: 389 | seed = random.randint(0, 65535) 390 | seed_everything(seed) 391 | 392 | if config.save_memory: 393 | self.model.low_vram_shift(is_diffusing=False) 394 | 395 | cond = { 396 | 'c_concat': [control], 397 | 'c_crossattn': [ 398 | self.model.get_learned_conditioning( 399 | [prompt + ', ' + a_prompt] * num_samples) 400 | ] 401 | } 402 | un_cond = { 403 | 'c_concat': [control], 404 | 'c_crossattn': 405 | [self.model.get_learned_conditioning([n_prompt] * num_samples)] 406 | } 407 | shape = (4, H // 8, W // 8) 408 | 409 | if config.save_memory: 410 | self.model.low_vram_shift(is_diffusing=True) 411 | 412 | samples, intermediates = self.ddim_sampler.sample( 413 | ddim_steps, 414 | num_samples, 415 | shape, 416 | cond, 417 | verbose=False, 418 | eta=eta, 419 | unconditional_guidance_scale=scale, 420 | unconditional_conditioning=un_cond) 421 | 422 | if config.save_memory: 423 | self.model.low_vram_shift(is_diffusing=False) 424 | 425 | x_samples = self.model.decode_first_stage(samples) 426 | x_samples = ( 427 | einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 428 | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 429 | 430 | results = [x_samples[i] for i in range(num_samples)] 431 | return [255 - detected_map] + results 432 | 433 | @torch.inference_mode() 434 | def process_fake_scribble(self, input_image, prompt, a_prompt, n_prompt, 435 | num_samples, image_resolution, detect_resolution, 436 | ddim_steps, scale, seed, eta): 437 | self.load_weight('scribble') 438 | 439 | input_image = HWC3(input_image) 440 | detected_map = apply_hed(resize_image(input_image, detect_resolution)) 441 | detected_map = HWC3(detected_map) 442 | img = resize_image(input_image, image_resolution) 443 | H, W, C = img.shape 444 | 445 | detected_map = cv2.resize(detected_map, (W, H), 446 | interpolation=cv2.INTER_LINEAR) 447 | detected_map = nms(detected_map, 127, 3.0) 448 | detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0) 449 | detected_map[detected_map > 4] = 255 450 | detected_map[detected_map < 255] = 0 451 | 452 | control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 453 | control = torch.stack([control for _ in range(num_samples)], dim=0) 454 | control = einops.rearrange(control, 'b h w c -> b c h w').clone() 455 | 456 | if seed == -1: 457 | seed = random.randint(0, 65535) 458 | seed_everything(seed) 459 | 460 | if config.save_memory: 461 | self.model.low_vram_shift(is_diffusing=False) 462 | 463 | cond = { 464 | 'c_concat': [control], 465 | 'c_crossattn': [ 466 | self.model.get_learned_conditioning( 467 | [prompt + ', ' + a_prompt] * num_samples) 468 | ] 469 | } 470 | un_cond = { 471 | 'c_concat': [control], 472 | 'c_crossattn': 473 | [self.model.get_learned_conditioning([n_prompt] * num_samples)] 474 | } 475 | shape = (4, H // 8, W // 8) 476 | 477 | if config.save_memory: 478 | self.model.low_vram_shift(is_diffusing=True) 479 | 480 | samples, intermediates = self.ddim_sampler.sample( 481 | ddim_steps, 482 | num_samples, 483 | shape, 484 | cond, 485 | verbose=False, 486 | eta=eta, 487 | unconditional_guidance_scale=scale, 488 | unconditional_conditioning=un_cond) 489 | 490 | if config.save_memory: 491 | self.model.low_vram_shift(is_diffusing=False) 492 | 493 | x_samples = self.model.decode_first_stage(samples) 494 | x_samples = ( 495 | einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 496 | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 497 | 498 | results = [x_samples[i] for i in range(num_samples)] 499 | return [255 - detected_map] + results 500 | 501 | @torch.inference_mode() 502 | def process_pose(self, input_image, prompt, a_prompt, n_prompt, 503 | num_samples, image_resolution, detect_resolution, 504 | ddim_steps, scale, seed, eta): 505 | self.load_weight('pose') 506 | 507 | input_image = HWC3(input_image) 508 | detected_map, _ = apply_openpose( 509 | resize_image(input_image, detect_resolution)) 510 | detected_map = HWC3(detected_map) 511 | img = resize_image(input_image, image_resolution) 512 | H, W, C = img.shape 513 | 514 | detected_map = cv2.resize(detected_map, (W, H), 515 | interpolation=cv2.INTER_NEAREST) 516 | 517 | control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 518 | control = torch.stack([control for _ in range(num_samples)], dim=0) 519 | control = einops.rearrange(control, 'b h w c -> b c h w').clone() 520 | 521 | if seed == -1: 522 | seed = random.randint(0, 65535) 523 | seed_everything(seed) 524 | 525 | if config.save_memory: 526 | self.model.low_vram_shift(is_diffusing=False) 527 | 528 | cond = { 529 | 'c_concat': [control], 530 | 'c_crossattn': [ 531 | self.model.get_learned_conditioning( 532 | [prompt + ', ' + a_prompt] * num_samples) 533 | ] 534 | } 535 | un_cond = { 536 | 'c_concat': [control], 537 | 'c_crossattn': 538 | [self.model.get_learned_conditioning([n_prompt] * num_samples)] 539 | } 540 | shape = (4, H // 8, W // 8) 541 | 542 | if config.save_memory: 543 | self.model.low_vram_shift(is_diffusing=True) 544 | 545 | samples, intermediates = self.ddim_sampler.sample( 546 | ddim_steps, 547 | num_samples, 548 | shape, 549 | cond, 550 | verbose=False, 551 | eta=eta, 552 | unconditional_guidance_scale=scale, 553 | unconditional_conditioning=un_cond) 554 | 555 | if config.save_memory: 556 | self.model.low_vram_shift(is_diffusing=False) 557 | 558 | x_samples = self.model.decode_first_stage(samples) 559 | x_samples = ( 560 | einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 561 | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 562 | 563 | results = [x_samples[i] for i in range(num_samples)] 564 | return [detected_map] + results 565 | 566 | @torch.inference_mode() 567 | def process_seg(self, input_image, prompt, a_prompt, n_prompt, num_samples, 568 | image_resolution, detect_resolution, ddim_steps, scale, 569 | seed, eta): 570 | self.load_weight('seg') 571 | 572 | input_image = HWC3(input_image) 573 | detected_map = apply_uniformer( 574 | resize_image(input_image, detect_resolution)) 575 | img = resize_image(input_image, image_resolution) 576 | H, W, C = img.shape 577 | 578 | detected_map = cv2.resize(detected_map, (W, H), 579 | interpolation=cv2.INTER_NEAREST) 580 | 581 | control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 582 | control = torch.stack([control for _ in range(num_samples)], dim=0) 583 | control = einops.rearrange(control, 'b h w c -> b c h w').clone() 584 | 585 | if seed == -1: 586 | seed = random.randint(0, 65535) 587 | seed_everything(seed) 588 | 589 | if config.save_memory: 590 | self.model.low_vram_shift(is_diffusing=False) 591 | 592 | cond = { 593 | 'c_concat': [control], 594 | 'c_crossattn': [ 595 | self.model.get_learned_conditioning( 596 | [prompt + ', ' + a_prompt] * num_samples) 597 | ] 598 | } 599 | un_cond = { 600 | 'c_concat': [control], 601 | 'c_crossattn': 602 | [self.model.get_learned_conditioning([n_prompt] * num_samples)] 603 | } 604 | shape = (4, H // 8, W // 8) 605 | 606 | if config.save_memory: 607 | self.model.low_vram_shift(is_diffusing=True) 608 | 609 | samples, intermediates = self.ddim_sampler.sample( 610 | ddim_steps, 611 | num_samples, 612 | shape, 613 | cond, 614 | verbose=False, 615 | eta=eta, 616 | unconditional_guidance_scale=scale, 617 | unconditional_conditioning=un_cond) 618 | 619 | if config.save_memory: 620 | self.model.low_vram_shift(is_diffusing=False) 621 | 622 | x_samples = self.model.decode_first_stage(samples) 623 | x_samples = ( 624 | einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 625 | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 626 | 627 | results = [x_samples[i] for i in range(num_samples)] 628 | return [detected_map] + results 629 | 630 | @torch.inference_mode() 631 | def process_depth(self, input_image, prompt, a_prompt, n_prompt, 632 | num_samples, image_resolution, detect_resolution, 633 | ddim_steps, scale, seed, eta): 634 | self.load_weight('depth') 635 | 636 | input_image = HWC3(input_image) 637 | detected_map, _ = apply_midas( 638 | resize_image(input_image, detect_resolution)) 639 | detected_map = HWC3(detected_map) 640 | img = resize_image(input_image, image_resolution) 641 | H, W, C = img.shape 642 | 643 | detected_map = cv2.resize(detected_map, (W, H), 644 | interpolation=cv2.INTER_LINEAR) 645 | 646 | control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 647 | control = torch.stack([control for _ in range(num_samples)], dim=0) 648 | control = einops.rearrange(control, 'b h w c -> b c h w').clone() 649 | 650 | if seed == -1: 651 | seed = random.randint(0, 65535) 652 | seed_everything(seed) 653 | 654 | if config.save_memory: 655 | self.model.low_vram_shift(is_diffusing=False) 656 | 657 | cond = { 658 | 'c_concat': [control], 659 | 'c_crossattn': [ 660 | self.model.get_learned_conditioning( 661 | [prompt + ', ' + a_prompt] * num_samples) 662 | ] 663 | } 664 | un_cond = { 665 | 'c_concat': [control], 666 | 'c_crossattn': 667 | [self.model.get_learned_conditioning([n_prompt] * num_samples)] 668 | } 669 | shape = (4, H // 8, W // 8) 670 | 671 | if config.save_memory: 672 | self.model.low_vram_shift(is_diffusing=True) 673 | 674 | samples, intermediates = self.ddim_sampler.sample( 675 | ddim_steps, 676 | num_samples, 677 | shape, 678 | cond, 679 | verbose=False, 680 | eta=eta, 681 | unconditional_guidance_scale=scale, 682 | unconditional_conditioning=un_cond) 683 | 684 | if config.save_memory: 685 | self.model.low_vram_shift(is_diffusing=False) 686 | 687 | x_samples = self.model.decode_first_stage(samples) 688 | x_samples = ( 689 | einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 690 | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 691 | 692 | results = [x_samples[i] for i in range(num_samples)] 693 | return [detected_map] + results 694 | 695 | @torch.inference_mode() 696 | def process_normal(self, input_image, prompt, a_prompt, n_prompt, 697 | num_samples, image_resolution, detect_resolution, 698 | ddim_steps, scale, seed, eta, bg_threshold): 699 | self.load_weight('normal') 700 | 701 | input_image = HWC3(input_image) 702 | _, detected_map = apply_midas(resize_image(input_image, 703 | detect_resolution), 704 | bg_th=bg_threshold) 705 | detected_map = HWC3(detected_map) 706 | img = resize_image(input_image, image_resolution) 707 | H, W, C = img.shape 708 | 709 | detected_map = cv2.resize(detected_map, (W, H), 710 | interpolation=cv2.INTER_LINEAR) 711 | 712 | control = torch.from_numpy( 713 | detected_map[:, :, ::-1].copy()).float().cuda() / 255.0 714 | control = torch.stack([control for _ in range(num_samples)], dim=0) 715 | control = einops.rearrange(control, 'b h w c -> b c h w').clone() 716 | 717 | if seed == -1: 718 | seed = random.randint(0, 65535) 719 | seed_everything(seed) 720 | 721 | if config.save_memory: 722 | self.model.low_vram_shift(is_diffusing=False) 723 | 724 | cond = { 725 | 'c_concat': [control], 726 | 'c_crossattn': [ 727 | self.model.get_learned_conditioning( 728 | [prompt + ', ' + a_prompt] * num_samples) 729 | ] 730 | } 731 | un_cond = { 732 | 'c_concat': [control], 733 | 'c_crossattn': 734 | [self.model.get_learned_conditioning([n_prompt] * num_samples)] 735 | } 736 | shape = (4, H // 8, W // 8) 737 | 738 | if config.save_memory: 739 | self.model.low_vram_shift(is_diffusing=True) 740 | 741 | samples, intermediates = self.ddim_sampler.sample( 742 | ddim_steps, 743 | num_samples, 744 | shape, 745 | cond, 746 | verbose=False, 747 | eta=eta, 748 | unconditional_guidance_scale=scale, 749 | unconditional_conditioning=un_cond) 750 | 751 | if config.save_memory: 752 | self.model.low_vram_shift(is_diffusing=False) 753 | 754 | x_samples = self.model.decode_first_stage(samples) 755 | x_samples = ( 756 | einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 757 | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 758 | 759 | results = [x_samples[i] for i in range(num_samples)] 760 | return [detected_map] + results 761 | -------------------------------------------------------------------------------- /patch: -------------------------------------------------------------------------------- 1 | diff --git a/annotator/hed/__init__.py b/annotator/hed/__init__.py 2 | index 42d8dc6..1587035 100644 3 | --- a/annotator/hed/__init__.py 4 | +++ b/annotator/hed/__init__.py 5 | @@ -1,8 +1,12 @@ 6 | +import pathlib 7 | + 8 | import numpy as np 9 | import cv2 10 | import torch 11 | from einops import rearrange 12 | 13 | +root_dir = pathlib.Path(__file__).parents[2] 14 | + 15 | 16 | class Network(torch.nn.Module): 17 | def __init__(self): 18 | @@ -64,7 +68,7 @@ class Network(torch.nn.Module): 19 | torch.nn.Sigmoid() 20 | ) 21 | 22 | - self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load('./annotator/ckpts/network-bsds500.pth').items()}) 23 | + self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(f'{root_dir}/annotator/ckpts/network-bsds500.pth').items()}) 24 | # end 25 | 26 | def forward(self, tenInput): 27 | diff --git a/annotator/midas/api.py b/annotator/midas/api.py 28 | index 9fa305e..d8594ea 100644 29 | --- a/annotator/midas/api.py 30 | +++ b/annotator/midas/api.py 31 | @@ -1,5 +1,7 @@ 32 | # based on https://github.com/isl-org/MiDaS 33 | 34 | +import pathlib 35 | + 36 | import cv2 37 | import torch 38 | import torch.nn as nn 39 | @@ -10,10 +12,11 @@ from .midas.midas_net import MidasNet 40 | from .midas.midas_net_custom import MidasNet_small 41 | from .midas.transforms import Resize, NormalizeImage, PrepareForNet 42 | 43 | +root_dir = pathlib.Path(__file__).parents[2] 44 | 45 | ISL_PATHS = { 46 | - "dpt_large": "annotator/ckpts/dpt_large-midas-2f21e586.pt", 47 | - "dpt_hybrid": "annotator/ckpts/dpt_hybrid-midas-501f0c75.pt", 48 | + "dpt_large": f"{root_dir}/annotator/ckpts/dpt_large-midas-2f21e586.pt", 49 | + "dpt_hybrid": f"{root_dir}/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt", 50 | "midas_v21": "", 51 | "midas_v21_small": "", 52 | } 53 | diff --git a/annotator/mlsd/__init__.py b/annotator/mlsd/__init__.py 54 | index 75db717..f310fe6 100644 55 | --- a/annotator/mlsd/__init__.py 56 | +++ b/annotator/mlsd/__init__.py 57 | @@ -1,3 +1,5 @@ 58 | +import pathlib 59 | + 60 | import cv2 61 | import numpy as np 62 | import torch 63 | @@ -8,8 +10,9 @@ from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny 64 | from .models.mbv2_mlsd_large import MobileV2_MLSD_Large 65 | from .utils import pred_lines 66 | 67 | +root_dir = pathlib.Path(__file__).parents[2] 68 | 69 | -model_path = './annotator/ckpts/mlsd_large_512_fp32.pth' 70 | +model_path = f'{root_dir}/annotator/ckpts/mlsd_large_512_fp32.pth' 71 | model = MobileV2_MLSD_Large() 72 | model.load_state_dict(torch.load(model_path), strict=True) 73 | model = model.cuda().eval() 74 | diff --git a/annotator/openpose/__init__.py b/annotator/openpose/__init__.py 75 | index 47d50a5..2369eed 100644 76 | --- a/annotator/openpose/__init__.py 77 | +++ b/annotator/openpose/__init__.py 78 | @@ -1,4 +1,5 @@ 79 | import os 80 | +import pathlib 81 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 82 | 83 | import torch 84 | @@ -7,8 +8,10 @@ from . import util 85 | from .body import Body 86 | from .hand import Hand 87 | 88 | -body_estimation = Body('./annotator/ckpts/body_pose_model.pth') 89 | -hand_estimation = Hand('./annotator/ckpts/hand_pose_model.pth') 90 | +root_dir = pathlib.Path(__file__).parents[2] 91 | + 92 | +body_estimation = Body(f'{root_dir}/annotator/ckpts/body_pose_model.pth') 93 | +hand_estimation = Hand(f'{root_dir}/annotator/ckpts/hand_pose_model.pth') 94 | 95 | 96 | def apply_openpose(oriImg, hand=False): 97 | diff --git a/annotator/uniformer/__init__.py b/annotator/uniformer/__init__.py 98 | index 500e53c..4061dbe 100644 99 | --- a/annotator/uniformer/__init__.py 100 | +++ b/annotator/uniformer/__init__.py 101 | @@ -1,9 +1,12 @@ 102 | +import pathlib 103 | + 104 | from annotator.uniformer.mmseg.apis import init_segmentor, inference_segmentor, show_result_pyplot 105 | from annotator.uniformer.mmseg.core.evaluation import get_palette 106 | 107 | +root_dir = pathlib.Path(__file__).parents[2] 108 | 109 | -checkpoint_file = "annotator/ckpts/upernet_global_small.pth" 110 | -config_file = 'annotator/uniformer/exp/upernet_global_small/config.py' 111 | +checkpoint_file = f"{root_dir}/annotator/ckpts/upernet_global_small.pth" 112 | +config_file = f'{root_dir}/annotator/uniformer/exp/upernet_global_small/config.py' 113 | model = init_segmentor(config_file, checkpoint_file).cuda() 114 | 115 | 116 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | addict==2.4.0 2 | albumentations==1.3.0 3 | einops==0.6.0 4 | gradio==3.18.0 5 | huggingface-hub==0.12.0 6 | imageio==2.25.0 7 | imageio-ffmpeg==0.4.8 8 | kornia==0.6.9 9 | omegaconf==2.3.0 10 | open-clip-torch==2.13.0 11 | opencv-contrib-python==4.7.0.68 12 | opencv-python-headless==4.7.0.68 13 | prettytable==3.6.0 14 | pytorch-lightning==1.9.0 15 | safetensors==0.2.8 16 | timm==0.6.12 17 | torch==1.13.1 18 | torchvision==0.14.1 19 | transformers==4.26.1 20 | xformers==0.0.16 21 | yapf==0.32.0 22 | -------------------------------------------------------------------------------- /style.css: -------------------------------------------------------------------------------- 1 | h1 { 2 | text-align: center; 3 | } 4 | --------------------------------------------------------------------------------