├── extension ├── __init__.py ├── version.py ├── utils.py └── api.py ├── docs └── how-it-works.png ├── install.py ├── scripts ├── main_tabs.py ├── hijack.py └── main_ui.py ├── .gitignore ├── style.css ├── javascript └── modelBrowser.js └── README.md /extension/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /extension/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.9" -------------------------------------------------------------------------------- /docs/how-it-works.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omniinfer/sd-webui-cloud-inference/HEAD/docs/how-it-works.png -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | import launch 2 | 3 | launch.run_pip("install omniinfer_client==0.3.5", "requirements for sd-webui-cloud-inference") 4 | -------------------------------------------------------------------------------- /extension/utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | import base64 3 | import os 4 | 5 | 6 | def image_to_base64(image, format="PNG"): 7 | buffer = io.BytesIO() 8 | image.save(buffer, format=format) 9 | buffer_bytes = buffer.getvalue() 10 | 11 | base64_str = base64.b64encode(buffer_bytes).decode("utf-8") 12 | 13 | return base64_str 14 | 15 | 16 | def read_image_files(folder_path): 17 | image_extensions = [".png", ".jpg", ".jpeg", ".webp"] 18 | image_files = [] 19 | 20 | for file_name in os.listdir(folder_path): 21 | file_ext_lower = os.path.splitext(file_name)[1].lower() 22 | if file_ext_lower in image_extensions: 23 | file_path = os.path.join(folder_path, file_name) 24 | image_files.append(file_path) 25 | 26 | images_base64 = [] 27 | for file_path in image_files: 28 | with open(file_path, "rb") as image_file: 29 | image_data = image_file.read() 30 | encoded_data = base64.b64encode(image_data) 31 | base64_string = encoded_data.decode("utf-8") 32 | images_base64.append(base64_string) 33 | 34 | return images_base64 35 | -------------------------------------------------------------------------------- /scripts/main_tabs.py: -------------------------------------------------------------------------------- 1 | import modules.scripts as scripts 2 | import gradio as gr 3 | 4 | from modules import script_callbacks 5 | from extension import api 6 | 7 | 8 | def on_ui_tabs(): 9 | with gr.Blocks(analytics_enabled=False) as ui_component: 10 | with gr.Tab(label="Omniinfer"): 11 | with gr.Blocks(): 12 | gr.Markdown(""" 13 | # Omniinfer 14 | Omniinfer is a cloud inference service that allows you to run txt2img and img2img in the cloud. 15 | """) 16 | with gr.Row(): 17 | key_textbox = gr.Textbox( 18 | value=api.get_instance().__dict__.get('_token') 19 | if api.get_instance() is not None else "", 20 | label="Omniinfer Key", 21 | type="password", 22 | placeholder="Enter omniinfer key here", 23 | elem_id="settings_remote_inference_omniinfer_key", 24 | ) 25 | 26 | test_button = gr.Button( 27 | "Test Connection", 28 | label="Test Connection", 29 | variant="primary", 30 | elem_id="settings_remote_inference_omniinfer_test", 31 | ) 32 | 33 | test_message_textbox = gr.Textbox(label="Test Message Results", 34 | interactive=False) 35 | 36 | gr.HTML(value=""" 37 | Register for a free key at Stable Diffusion WebUI Cloud Inference Tutorial. 38 | """) 39 | 40 | def test_callback(key): 41 | try: 42 | ok_msg = api.OmniinferAPI.test_connection(key) 43 | api.OmniinferAPI.update_key_to_config(key) 44 | api.refresh_instance() 45 | return ok_msg 46 | except Exception as exp: 47 | return str(exp) 48 | 49 | test_button.click(fn=test_callback, 50 | inputs=[key_textbox], 51 | outputs=[test_message_textbox]) 52 | with gr.Tab(label="Additional Providers"): 53 | gr.Markdown(""" 54 | # Do you require support for additional api providers? 55 | discuss in [Github](https://github.com/omniinfer/sd-webui-cloud-inference/discussions/new?category=general) 56 | """) 57 | 58 | return [(ui_component, "Cloud Inference", "extension_template_tab")] 59 | 60 | 61 | script_callbacks.on_ui_tabs(on_ui_tabs) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | /extensions/.omniinfer.json 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | extension/.omniinfer.json 163 | -------------------------------------------------------------------------------- /style.css: -------------------------------------------------------------------------------- 1 | .search-container { 2 | text-align: center; 3 | margin: 20px 0; 4 | } 5 | 6 | #search-input { 7 | padding: 10px; 8 | width: 80%; 9 | border: 1px solid #ccc; 10 | border-radius: 5px; 11 | font-size: 16px; 12 | } 13 | 14 | .title-container { 15 | background: rgba(214, 214, 214, 0.6); 16 | padding: 1px; 17 | /* Adjucompact padding for a better look */ 18 | border-radius: 5px; 19 | position: absolute; 20 | bottom: 0; 21 | left: 0; 22 | width: 100%; 23 | display: flex; 24 | align-items: center; 25 | /* Center vertically */ 26 | justify-content: space-between; 27 | /* Spread items horizontally */ 28 | } 29 | 30 | .title { 31 | font-weight: bold; 32 | font-size: 12px; 33 | /* Adjust font size for better visibility */ 34 | color: var(--background-fill-primary); 35 | /* Adjust text color */ 36 | margin: 0; 37 | /* Reset margin for the title */ 38 | } 39 | 40 | .buttons { 41 | position: absolute; 42 | top: 10px; 43 | right: 10px; 44 | display: flex; 45 | gap: 5px; 46 | } 47 | 48 | .btn { 49 | background-color: var(--background-fill-primary); 50 | color: #fff; 51 | border: none; 52 | padding: 5px 10px; 53 | border-radius: 5px; 54 | cursor: pointer; 55 | } 56 | 57 | 58 | 59 | .heading-text { 60 | margin-bottom: 2rem; 61 | font-size: 2rem; 62 | } 63 | 64 | .heading-text span { 65 | font-weight: 100; 66 | } 67 | 68 | /* Responsive image gallery rules begin*/ 69 | 70 | .image-gallery { 71 | /* Mobile first */ 72 | /* max-height: 500px; */ 73 | overflow-y: auto; 74 | display: flex; 75 | flex-direction: column; 76 | gap: 10px; 77 | } 78 | 79 | .image-gallery .column { 80 | display: flex; 81 | flex-direction: column; 82 | gap: 10px; 83 | } 84 | 85 | .image-item img { 86 | width: 100%; 87 | border-radius: 5px; 88 | height: 100%; 89 | object-fit: cover; 90 | } 91 | 92 | .image-item.nsfw img { 93 | filter: blur(10px); 94 | } 95 | 96 | @media only screen and (min-width: 768px) { 97 | .image-gallery { 98 | flex-direction: row; 99 | } 100 | } 101 | 102 | /* overlay styles */ 103 | 104 | .image-item { 105 | position: relative; 106 | cursor: pointer; 107 | min-height: 200px; 108 | } 109 | 110 | .overlay { 111 | position: absolute; 112 | width: 100%; 113 | height: 100%; 114 | background: rgba(57, 57, 57, 0.502); 115 | top: 0; 116 | left: 0; 117 | transform: scale(0); 118 | transition: all 0.2s 0.1s ease-in-out; 119 | color: #fff; 120 | /* center overlay content */ 121 | display: flex; 122 | align-items: center; 123 | justify-content: center; 124 | } 125 | 126 | /* hover */ 127 | .image-item:hover .overlay { 128 | transform: scale(1); 129 | } 130 | 131 | .filter-buttons { 132 | gap: 10px; 133 | margin-top: 10px; 134 | } 135 | 136 | .filter-buttons button { 137 | background-color: var(--background-fill-primary); 138 | color: #fff; 139 | border: none; 140 | padding: 5px 10px; 141 | border-radius: 5px; 142 | cursor: pointer; 143 | } 144 | 145 | /* Style for selected filter button */ 146 | .filter-buttons button.selected { 147 | background-color: var(--background-fill-primary); 148 | /* Change color for selected button */ 149 | } 150 | 151 | /* Style for selected filter button */ 152 | .filter-buttons button.selected { 153 | background-color: var(--block-label-text-color); 154 | color: var(--background-fill-primary); 155 | border-top-left-radius: 5px; 156 | border-top-right-radius: 5px; 157 | border: 1px solid #ccc; 158 | border-bottom: none; 159 | } 160 | 161 | #select-button { 162 | background-color: transparent; 163 | color: white; 164 | } 165 | 166 | 167 | .search-bar { 168 | margin-top: 10px; 169 | padding: 10px; 170 | } 171 | 172 | .filter-search-input { 173 | color: black; 174 | } 175 | 176 | 177 | .global-model-browser-popup { 178 | display: flex; 179 | position: fixed; 180 | z-index: 1001; 181 | left: 0; 182 | top: 0; 183 | width: 100%; 184 | height: 100%; 185 | overflow: auto; 186 | /* background-color: rgba(20, 20, 20, 0.95); */ 187 | } 188 | 189 | .global-model-browser-popup * { 190 | box-sizing: border-box; 191 | } 192 | 193 | .global-model-browser-popup-close:before { 194 | content: "×"; 195 | } 196 | 197 | .global-model-browser-popup-close { 198 | position: fixed; 199 | right: 0.25em; 200 | top: 0; 201 | cursor: pointer; 202 | color: var(--block-label-text-color); 203 | font-size: 32pt; 204 | } 205 | 206 | .global-model-browser-popup-inner { 207 | display: inline-block; 208 | margin: auto; 209 | padding: 2em; 210 | } 211 | 212 | div.block.gradio-box.popup-model-browser { 213 | position: absolute; 214 | left: 50%; 215 | width: 40%; 216 | top: 40%; 217 | height: 60%; 218 | background: var(--body-background-fill); 219 | /* padding: 2em !important; */ 220 | } 221 | 222 | 223 | .gradio-button.model-browser-button { 224 | height: 2.4em; 225 | align-self: end; 226 | line-height: 1em; 227 | border-radius: 0.5em; 228 | } -------------------------------------------------------------------------------- /javascript/modelBrowser.js: -------------------------------------------------------------------------------- 1 | var globalModelBrowserPopup = null; 2 | var globalModelBrowserPopupInner = null; 3 | var globalModelBrowserListeners = []; 4 | function closeModelBrowserPopup() { 5 | if (!globalModelBrowserPopup) return; 6 | 7 | globalModelBrowserPopup.style.display = "none"; 8 | } 9 | 10 | function modelBrowserPopup(tab, contents) { 11 | if (!globalModelBrowserPopup) { 12 | globalModelBrowserPopup = document.createElement('div'); 13 | globalModelBrowserPopup.onclick = closeModelBrowserPopup; 14 | globalModelBrowserPopup.classList.add('global-model-browser-popup'); 15 | 16 | var close = document.createElement('div'); 17 | close.classList.add('global-model-browser-popup-close'); 18 | close.onclick = closeModelBrowserPopup; 19 | close.title = "Close"; 20 | globalModelBrowserPopup.appendChild(close); 21 | 22 | globalModelBrowserPopupInner = document.createElement('div'); 23 | globalModelBrowserPopupInner.onclick = function (event) { 24 | event.stopPropagation(); 25 | return false; 26 | }; 27 | globalModelBrowserPopupInner.classList.add('global-model-browser-popup-inner'); 28 | globalModelBrowserPopup.appendChild(globalModelBrowserPopupInner); 29 | 30 | gradioApp().querySelector('.main').appendChild(globalModelBrowserPopup); 31 | } 32 | 33 | doThingsAfterPopup(tab) 34 | globalModelBrowserPopupInner.innerHTML = ''; 35 | globalModelBrowserPopupInner.appendChild(contents); 36 | 37 | globalModelBrowserPopup.style.display = "flex"; 38 | } 39 | 40 | 41 | 42 | function toggleSelected(button) { 43 | const filterButtons = document.querySelectorAll('.filter-btn'); 44 | 45 | filterButtons.forEach(btn => { 46 | btn.classList.remove('selected'); 47 | }); 48 | button.classList.add('selected'); 49 | } 50 | 51 | function filterImages(tab, kind, selectedTag) { 52 | const imageItems = document.querySelectorAll(`.image-item[data-kind="${kind}"]`); 53 | 54 | imageItems.forEach(item => { 55 | const itemTags = item.getAttribute('data-tags').split(' '); 56 | const shouldDisplay = selectedTag === 'all' || itemTags.includes(selectedTag); 57 | item.style.display = shouldDisplay ? 'block' : 'none'; 58 | }); 59 | } 60 | 61 | function doThingsAfterPopup(tab) { 62 | addFilterButtons(tab, 'checkpoint'); 63 | addFilterButtons(tab, 'lora'); 64 | addFilterButtons(tab, 'embedding'); 65 | 66 | applyTextSearch(tab, 'checkpoint'); 67 | applyTextSearch(tab, 'lora'); 68 | applyTextSearch(tab, 'embedding'); 69 | 70 | // addNsfwToggle() 71 | addImageClickListener(tab); 72 | 73 | 74 | applyNsfwClass(tab); 75 | } 76 | 77 | function addImageClickListener(tab) { 78 | const imageItems = document.querySelectorAll('.image-item'); 79 | 80 | imageItems.forEach(item => { 81 | const selectButton = item.querySelector('#select-button'); 82 | // const favoriteButton = item.querySelector('#favorite-btn'); 83 | const titleElement = item.querySelector('.title').getAttribute('data-alias'); 84 | const browserTabName = item.parentElement.parentElement.parentElement.querySelector('.heading-text').textContent 85 | selectButton.addEventListener('click', (event) => { 86 | if (browserTabName == 'CHECKPOINT Browser') { 87 | desiredCloudInferenceCheckpointName = titleElement; 88 | gradioApp().getElementById(`${tab}_change_cloud_checkpoint`).click() 89 | } else if (browserTabName == 'LORA Browser') { 90 | desiredCloudInferenceLoraName = titleElement; 91 | gradioApp().getElementById(`${tab}_change_cloud_lora`).click() 92 | } else if (browserTabName == 'EMBEDDING Browser') { 93 | desiredCloudInferenceEmbeddingName = titleElement; 94 | gradioApp().getElementById(`${tab}_change_cloud_embedding`).click() 95 | } 96 | }); 97 | 98 | // favoriteButton.addEventListener('click', () => { 99 | // desciredCloudInferenceFavoriteModelName = titleElement; 100 | // gradioApp().getElementById(`${tab}_favorite`).click() 101 | // }) 102 | }) 103 | } 104 | 105 | function addFilterButtons(tab, kind) { 106 | const filterButtons = document.querySelectorAll(`.filter-btn[data-kind="${kind}"][data-tab="${tab}"]`); 107 | // Filter images based on selected filter 108 | filterButtons.forEach(button => { 109 | button.addEventListener('click', () => { 110 | const selectedTag = button.getAttribute('data-tag'); 111 | filterImages(tab, kind, selectedTag); 112 | 113 | toggleSelected(button); // Add selected style to clicked button 114 | }); 115 | }); 116 | } 117 | 118 | 119 | function applyNsfwClass(tab) { 120 | const imageItems = document.querySelectorAll('.image-item'); 121 | 122 | imageItems.forEach(item => { 123 | const itemTags = item.getAttribute('data-tags').split(' '); 124 | if (itemTags.includes('nsfw')) { 125 | item.classList.add('nsfw'); 126 | } else { 127 | item.classList.remove('nsfw'); 128 | } 129 | }); 130 | } 131 | 132 | function applyTextSearch(tab, kind) { 133 | document.getElementById(`${tab}-${kind}-filter-search-input`).addEventListener(`input`, () => { 134 | const searchText = document.getElementById(`${tab}-${kind}-filter-search-input`).value.toLowerCase(); 135 | const selectedTag = getSelectedTag(kind); 136 | 137 | const imageItems = document.querySelectorAll('.image-item'); 138 | document.querySelectorAll(`.image-item[data-kind="${kind}"]`) 139 | imageItems.forEach(item => { 140 | const itemTags = item.getAttribute('data-tags').split(' '); 141 | const searchTerms = item.getAttribute('data-search-terms'); 142 | const matchesSearch = searchTerms.toLowerCase().includes(searchText.toLowerCase()); 143 | const matchesTag = selectedTag === 'all' || itemTags.includes(selectedTag); 144 | console.log(item, matchesSearch, matchesTag) 145 | item.style.display = matchesSearch && matchesTag ? 'block' : 'none'; 146 | }); 147 | }); 148 | } 149 | 150 | // function doThingsAfterClosePopup() { 151 | // for (kind of ['checkpoint', 'lora', 'embedding']) { 152 | // searchInput.removeEventListener(`${kind}-search-input-listener`) 153 | // } 154 | // } 155 | 156 | function getSelectedTag(kind) { 157 | const selectedButton = document.querySelector(`.filter-btn.selected[data-kind="${kind}"]`); 158 | return selectedButton ? selectedButton.getAttribute('data-tag') : 'all'; 159 | } 160 | 161 | function openInNewTab(url) { 162 | var win = window.open(url, '_blank'); 163 | win.focus(); 164 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stable Diffusion Web UI Cloud Inference 2 | 3 | 4 | [](https://discord.gg/nzqq8UScpx) 5 | 6 | 7 | [](https://www.youtube.com/watch?v=B8s2L_o3DrU) 8 | 9 | 10 | 11 | ## What capabilities does this extension offer? 12 | 13 | This extension enables faster image generation without the need for expensive GPUs and seamlessly integrates with the AUTOMAIC1111 UI. 14 | 15 | ## Benefits: 16 | 1. **No expensive GPUs required**, can even use the CPU. 17 | 2. **No need to change your workflow**, compatible with the usage and scripts of sd-webui, such as X/Y/Z Plot, Prompt from file, etc. 18 | 3. **Support for 10000+ Checkpoint models**, don't need download 19 | 20 | ## Compatibility and Limitations 21 | 22 | | Feature | Support | Limitations | 23 | | -------------------------- | ------- | ----------- | 24 | | txt2img | ✅ | | 25 | | txt2img_hires.fix | ✅ | | 26 | | txt2img_sdxl_refiner | ✅ | | 27 | | txt2img_controlnet | ✅ | | 28 | | img2img | ✅ | | 29 | | img2img_inpaint | ✅ | | 30 | | img2img_sdxl_refiner | ✅ | | 31 | | img2img_controlnet | ✅ | | 32 | | extras upscale | ✅ | | 33 | | vae model | ✅ | | 34 | | scripts - X/Y/Z plot | ✅ | | 35 | | scripts - Prompt matrix | ✅ | | 36 | | scripts - Prompt from file | ✅ | | 37 | 38 | 39 | 40 | ## How it works 41 | 42 |  43 | ## Guide 44 | ## 1. Install sd-webui-cloud-inference 45 | 46 |  47 |  48 |  49 | 50 | 51 | ## 2. Get your [omniinfer.io](https://www.omniinfer.io/user/login?utm_source=github_wiki) Key 52 | 53 | Open [omniinfer.io](https://www.omniinfer.io/user/login?utm_source=github_wiki) in browser 54 | 55 | We can choice "Google Login" or "Github Login" 56 | 57 | 58 |  59 | 60 | 61 | 62 | ## 3. Enable Cloud Inference feature 63 | 64 | Let us back to `Cloud Inference` tab of stable-diffusion-webui 65 | 66 |  67 | 68 | ## 4. Test Txt2Img 69 | 70 | Let us back to `Txt2Img` tab of stable-diffusion-webui 71 | 72 | 73 |  74 | From now on, you can give it a try and enjoy your creative journey. 75 | 76 | Furthermore, you are welcome to freely discuss your user experience, share suggestions, and provide feedback on our Discord channel. 77 | [](https://discord.gg/nzqq8UScpx) 78 | 79 | 80 | ## 5. Advanced - Lora 81 | 82 |  83 | 84 | 85 | ## 7. Advanced - Img2img Inpainting 86 |  87 |  88 |  89 | ## 8. Advanced - VAE 90 | 91 |  92 | 93 | or you can use the VAE feature with X/Y/Z 94 | 95 |  96 | 97 | 98 | ## 9. Advanced - ControlNet 99 |  100 | 101 | # 9. Advanced - ControlNet 102 |  103 | 104 | ## 10. Advanced - Upscale and Hires.Fix 105 |  106 | 107 |  108 | 109 | 110 | ## 11. Advanced - Model Browser 111 | 112 |  113 | 114 | 115 | ## 12. Advanced - Tiny Model 116 | 117 | The AUTOMATIC1111 webui loads the model on startup. However, on low-memory computers like the MacBook Air, the performance is suboptimal. To address this, we have developed a stripped-down minimal-size model. You can utilize the following commands to enable it. 118 | 119 | its will reduce memory from 4.8G -> 739MB 120 | 121 | 1. Download tiny model and config to model config. 122 | ``` 123 | wget -O ./models/Stable-diffusion/tiny.yaml https://github.com/omniinfer/sd-webui-cloud-inference/releases/download/tiny-model/tiny.yaml 124 | wget -O ./models/Stable-diffusion/tiny.safetensors https://github.com/omniinfer/sd-webui-cloud-inference/releases/download/tiny-model/tiny.safetensors 125 | ``` 126 | 2. start webui with tiny model 127 | `--ckpt=/stable-diffusion-webui/models/Stable-diffusion/tiny.safetensors` 128 | 129 | ## 13. Advanced - SDXL Refiner 130 | 131 |  132 | -------------------------------------------------------------------------------- /scripts/hijack.py: -------------------------------------------------------------------------------- 1 | import modules.scripts as scripts 2 | import os 3 | import sys 4 | import importlib 5 | 6 | from modules import images, script_callbacks, errors, processing, ui, shared, scripts_postprocessing, ui_common 7 | try: 8 | from modules import generation_parameters_copypaste 9 | except ImportError: 10 | pass 11 | from modules import infotext_utils as generation_parameters_copypaste 12 | 13 | from modules.processing import Processed, StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, StableDiffusionProcessing 14 | from modules.shared import opts, state, prompt_styles 15 | from extension import api 16 | 17 | from inspect import getmembers, isfunction, ismodule 18 | import random 19 | 20 | from PIL import Image 21 | 22 | 23 | class _HijackManager(object): 24 | 25 | def __init__(self): 26 | self._hijacked_onload = False 27 | self._hijacked_on_app_started = False 28 | self._binding = None 29 | 30 | self.hijack_map = {} 31 | 32 | def hijack_one(self, name, new_fn): 33 | tmp = name.rsplit('.', 1) 34 | if len(tmp) < 2: 35 | raise Exception('invalid module.func name: {}'.format(name)) 36 | 37 | module_name, func_name = tmp 38 | old_fn = _hijack_func(module_name, func_name, new_fn) 39 | if old_fn is None: 40 | print('[cloud-inference] hijack failed: {}'.format(name)) 41 | return False 42 | 43 | self.hijack_map[name] = { 44 | 'old': old_fn, 45 | 'new': new_fn, 46 | } 47 | 48 | print('[cloud-inference] hijack {}, old: <{}>, new: <{}>'.format( 49 | name, old_fn.__module__ + '.' + old_fn.__name__, new_fn.__module__ + '.' + new_fn.__name__)) 50 | 51 | def hijack_onload(self): 52 | if self._hijacked_onload: 53 | return 54 | self.hijack_one('modules.processing.process_images', self._hijack_process_images) 55 | self.hijack_one('modules.postprocessing.run_postprocessing', self._hijack_run_postprocessing) 56 | self._apply_xyz() 57 | self._hijacked_onload = True 58 | 59 | def hijack_on_app_started(self, *args, **kwargs): 60 | if self._hijacked_on_app_started: 61 | return 62 | 63 | self.hijack_one('extensions.sd-webui-controlnet.scripts.global_state.update_cn_models', self._hijack_update_cn_models) 64 | self._hijack_update_cn_models() # update once 65 | 66 | self._hijacked_on_app_started = True 67 | 68 | def _apply_xyz(self): 69 | def find_module(module_names): 70 | if isinstance(module_names, str): 71 | module_names = [s.strip() for s in module_names.split(",")] 72 | for data in scripts.scripts_data: 73 | if data.script_class.__module__ in module_names and hasattr( 74 | data, "module"): 75 | return data.module 76 | return None 77 | 78 | xyz_grid = find_module("xyz_grid.py, xy_grid.py") 79 | if xyz_grid: 80 | def xyz_checkpoint_apply(p: StableDiffusionProcessing, opt, v): 81 | if '_cloud_inference_settings' not in p.__dict__: 82 | p._cloud_inference_settings = {} 83 | 84 | m = self._binding.find_model_by_alias(opt) 85 | p._cloud_inference_settings['sd_checkpoint'] = m.name 86 | 87 | def xyz_checkpoint_confirm(p, opt): 88 | return 89 | 90 | def xyz_checkpoint_format(p, opt, v): 91 | return self._binding.find_model_by_alias(v).name.rsplit(".", 1)[0] 92 | 93 | def xyz_vae_apply(p: StableDiffusionProcessing, opt, v): 94 | if '_cloud_inference_settings' not in p.__dict__: 95 | p._cloud_inference_settings = {} 96 | 97 | p._cloud_inference_settings['sd_vae'] = opt 98 | 99 | def xyz_vae_confirm(p, opt): 100 | return 101 | 102 | def xyz_vae_format(p, opt, v): 103 | return v 104 | 105 | print('[cloud-inference] hijack xyz_grid') 106 | xyz_grid.axis_options.append( 107 | xyz_grid.AxisOption('[Cloud Inference] Checkpoint', 108 | str, 109 | apply=xyz_checkpoint_apply, 110 | confirm=xyz_checkpoint_confirm, 111 | format_value=xyz_checkpoint_format, 112 | choices=lambda: [_.alias for _ in self._binding.remote_model_checkpoints])) 113 | xyz_grid.axis_options.append( 114 | xyz_grid.AxisOption('[Cloud Inference] VAE', 115 | str, 116 | apply=xyz_vae_apply, 117 | confirm=xyz_vae_confirm, 118 | format_value=xyz_vae_format, 119 | choices=lambda: ["Automatic", "None"] + [_.name for _ in self._binding.remote_model_vaes])) 120 | self._xyz_hijacked = True 121 | 122 | def _hijack_update_cn_models(self): 123 | from modules.scripts import scripts_data 124 | for script in scripts_data: 125 | if script.module.__name__ == 'controlnet.py': 126 | if self._binding.remote_inference_enabled: 127 | script.module.global_state.cn_models.clear() 128 | cn_models_keys = ["None"] + [_.name for _ in self._binding.remote_model_controlnet] 129 | cn_models_dict = {k: None for k in cn_models_keys} 130 | 131 | script.module.global_state.cn_models.update(cn_models_dict) 132 | script.module.global_state.cn_models_names.clear() 133 | script.module.global_state.cn_models_names.update(cn_models_dict) 134 | break 135 | else: 136 | self.hijack_map['extensions.sd-webui-controlnet.scripts.global_state.update_cn_models']['old']() 137 | 138 | def _hijack_process_images(self, *args, **kwargs) -> Processed: 139 | if len(args) > 0 and isinstance(args[0], 140 | processing.StableDiffusionProcessing): 141 | p = args[0] 142 | else: 143 | raise Exception( 144 | 'process_images: first argument must be a processing object') 145 | 146 | remote_inference_enabled, selected_checkpoint_name, selected_vae_name = get_visible_extension_args(p, 'cloud inference') 147 | 148 | if not remote_inference_enabled: 149 | return self.hijack_map['modules.processing.process_images']['old'](*args, **kwargs) 150 | 151 | # random seed locally if not specified 152 | if p.seed == -1: 153 | p.seed = int(random.randrange(4294967294)) 154 | 155 | state.begin() 156 | state.sampling_steps = p.steps 157 | state.job_count = p.n_iter 158 | 159 | state.textinfo = "remote inferencing ({})".format(api.get_instance().__class__.__name__) 160 | 161 | if '_cloud_inference_settings' not in p.__dict__: 162 | p._cloud_inference_settings = {} 163 | 164 | if 'sd_checkpoint' not in p._cloud_inference_settings: 165 | p._cloud_inference_settings['sd_checkpoint'] = self._binding.find_model_by_alias(selected_checkpoint_name).name 166 | if 'sd_vae' not in p._cloud_inference_settings: 167 | p._cloud_inference_settings['sd_vae'] = selected_vae_name 168 | 169 | if isinstance(p, StableDiffusionProcessingTxt2Img): 170 | generated_images = api.get_instance().txt2img(p) 171 | elif isinstance(p, StableDiffusionProcessingImg2Img): 172 | generated_images = api.get_instance().img2img(p) 173 | else: 174 | return self._fn(*args, **kwargs) 175 | 176 | # compatible with old version 177 | if hasattr(p, 'setup_prompts'): 178 | p.setup_prompts() 179 | else: 180 | if type(p.prompt) == list: 181 | p.all_prompts = p.prompt 182 | else: 183 | p.all_prompts = p.batch_size * p.n_iter * [p.prompt] 184 | if type(p.negative_prompt) == list: 185 | p.all_negative_prompts = p.negative_prompt 186 | else: 187 | p.all_negative_prompts = p.batch_size * p.n_iter * [ 188 | p.negative_prompt 189 | ] 190 | p.all_prompts = [ 191 | prompt_styles.apply_styles_to_prompt(x, p.styles) 192 | for x in p.all_prompts 193 | ] 194 | p.all_negative_prompts = [ 195 | prompt_styles.apply_negative_styles_to_prompt(x, p.styles) 196 | for x in p.all_negative_prompts 197 | ] 198 | 199 | # TODO: img2img hr prompts 200 | 201 | p.all_seeds = [p.seed for _ in range(len(generated_images))] 202 | p.seeds = p.all_seeds 203 | 204 | index_of_first_image = 0 205 | unwanted_grid_because_of_img_count = len( 206 | generated_images) < 2 and opts.grid_only_if_multiple 207 | 208 | comments = {} 209 | infotexts = [] 210 | 211 | def infotext(iteration=0, position_in_batch=0): 212 | return create_infotext(p, p.all_prompts, p.all_seeds, 213 | p.all_subseeds, comments, iteration, 214 | position_in_batch) 215 | 216 | for i, image in enumerate(generated_images): 217 | if opts.enable_pnginfo: 218 | image.info["parameters"] = infotext() 219 | infotexts.append(infotext()) 220 | 221 | seed = None 222 | if len(p.all_seeds) > i: 223 | seed = p.all_seeds[i] 224 | prompt = None 225 | if len(p.all_prompts) > i: 226 | prompt = p.all_prompts[i] 227 | 228 | if opts.samples_save and not p.do_not_save_samples: 229 | images.save_image(image, 230 | p.outpath_samples, 231 | "", 232 | seed, 233 | prompt, 234 | opts.samples_format, 235 | info=infotext(), 236 | p=p) 237 | if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count: 238 | grid = images.image_grid(generated_images, p.batch_size) 239 | 240 | if opts.return_grid: 241 | text = infotext() 242 | infotexts.insert(0, text) 243 | if opts.enable_pnginfo: 244 | grid.info["parameters"] = text 245 | 246 | generated_images.insert(0, grid) 247 | index_of_first_image = 1 248 | 249 | if opts.grid_save: 250 | images.save_image( 251 | grid, 252 | p.outpath_grids, 253 | "grid", 254 | p.all_seeds[0], 255 | p.all_prompts[0], 256 | opts.grid_format, 257 | info=infotext(), 258 | short_filename=not opts.grid_extended_filename, 259 | p=p, 260 | grid=True) 261 | p = Processed( 262 | p, 263 | generated_images, 264 | all_seeds=[p.seed for _ in range(len(generated_images))], 265 | all_prompts=[p.prompt for _ in range(len(generated_images))], 266 | comments="".join(f"{comment}\n" for comment in comments), 267 | index_of_first_image=index_of_first_image, 268 | infotexts=infotexts) 269 | state.end() 270 | return p 271 | 272 | def _hijack_run_postprocessing(self, *args, **kwargs): 273 | if not self._binding.remote_inference_enabled: 274 | return self.hijack_map['modules.postprocessing.run_postprocessing']['old'](*args, **kwargs) 275 | 276 | shared.state.begin() 277 | shared.state.textinfo = "remote inferencing ({})".format(api.get_instance().__class__.__name__) 278 | 279 | image_data = [] 280 | image_names = [] 281 | outputs = [] 282 | 283 | # extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True 284 | extras_mode = args[0] 285 | image = args[1] 286 | image_folder = args[2] 287 | input_dir = args[3] 288 | output_dir = args[4] 289 | show_extras_results = args[5] 290 | if len(args) > 6: 291 | args = args[6:] 292 | save_output = kwargs.get('save_output', True) 293 | 294 | if extras_mode == 1: 295 | for img in image_folder: 296 | if isinstance(img, Image.Image): 297 | image = img 298 | fn = '' 299 | else: 300 | image = Image.open(os.path.abspath(img.name)) 301 | fn = os.path.splitext(img.orig_name)[0] 302 | image_data.append(image) 303 | image_names.append(fn) 304 | elif extras_mode == 2: 305 | assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' 306 | assert input_dir, 'input directory not selected' 307 | 308 | image_list = shared.listfiles(input_dir) 309 | for filename in image_list: 310 | try: 311 | image = Image.open(filename) 312 | except Exception: 313 | continue 314 | image_data.append(image) 315 | image_names.append(filename) 316 | else: 317 | assert image, 'image not selected' 318 | 319 | image_data.append(image) 320 | image_names.append(None) 321 | 322 | if extras_mode == 2 and output_dir != '': 323 | outpath = output_dir 324 | else: 325 | outpath = opts.outdir_samples or opts.outdir_extras_samples 326 | 327 | infotext = '' 328 | 329 | for image, name in zip(image_data, image_names): 330 | shared.state.textinfo = name 331 | 332 | parameters, existing_pnginfo = images.read_info_from_image(image) 333 | if parameters: 334 | existing_pnginfo["parameters"] = parameters 335 | 336 | # api.get_instance().txt2img 337 | # scripts.scripts_postproc.run(pp, args) 338 | resize_mode, \ 339 | upscaling_resize, \ 340 | upscaling_resize_w, \ 341 | upscaling_resize_h, \ 342 | upscaling_crop, \ 343 | extras_upscaler_1, \ 344 | extras_upscaler_2, \ 345 | extras_upscaler_2_visibility, \ 346 | gfpgan_visibility, codeformer_visibility, \ 347 | codeformer_weight, *extra_args = args 348 | 349 | if extras_upscaler_1 and extras_upscaler_1 != 'None': 350 | extras_upscaler_1 = self._binding.find_name_by_alias(extras_upscaler_1) 351 | if extras_upscaler_2 and extras_upscaler_2 != 'None': 352 | extras_upscaler_2 = self._binding.find_name_by_alias(extras_upscaler_2) 353 | 354 | imgs = api.get_instance().upscale(image, 355 | resize_mode, 356 | upscaling_resize, 357 | upscaling_resize_w, 358 | upscaling_resize_h, 359 | upscaling_crop, 360 | extras_upscaler_1, 361 | extras_upscaler_2, 362 | extras_upscaler_2_visibility, 363 | gfpgan_visibility, 364 | codeformer_visibility, 365 | codeformer_weight, 366 | *extra_args) 367 | pp = scripts_postprocessing.PostprocessedImage(imgs[0].convert("RGB")) 368 | 369 | if opts.use_original_name_batch and name is not None: 370 | basename = os.path.splitext(os.path.basename(name))[0] 371 | else: 372 | basename = '' 373 | 374 | infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) 375 | 376 | if opts.enable_pnginfo: 377 | pp.image.info = existing_pnginfo 378 | pp.image.info["postprocessing"] = infotext 379 | 380 | if save_output: 381 | images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, 382 | short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) 383 | 384 | if extras_mode != 2 or show_extras_results: 385 | outputs.append(pp.image) 386 | 387 | return outputs, ui_common.plaintext_to_html(infotext), '' 388 | 389 | 390 | def create_infotext(p, 391 | all_prompts, 392 | all_seeds, 393 | all_subseeds, 394 | comments=None, 395 | iteration=0, 396 | position_in_batch=0): 397 | index = position_in_batch + iteration * p.batch_size 398 | 399 | clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers) 400 | enable_hr = getattr(p, 'enable_hr', False) 401 | 402 | # compatible with old version 403 | token_merging_ratio = None 404 | token_merging_ratio_hr = None 405 | if hasattr(p, 'get_token_merging_ratio'): 406 | token_merging_ratio = p.get_token_merging_ratio() 407 | token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True) 408 | 409 | uses_ensd = opts.eta_noise_seed_delta != 0 410 | if uses_ensd: 411 | uses_ensd = processing.sd_samplers_common.is_sampler_using_eta_noise_seed_delta( 412 | p) 413 | 414 | generation_params = { 415 | "Steps": p.steps, 416 | "Sampler": p.sampler_name, 417 | "CFG scale": p.cfg_scale, 418 | "Image CFG scale": getattr(p, 'image_cfg_scale', None), 419 | "Seed": all_seeds[index], 420 | "Face restoration": (opts.face_restoration_model if p.restore_faces else None), 421 | "Size": f"{p.width}x{p.height}", 422 | "Model": (None if not opts.add_model_name_to_info or not p._cloud_inference_settings['sd_checkpoint'] else p._cloud_inference_settings['sd_checkpoint'].replace(',', '').replace(':', '')), 423 | "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), 424 | "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), 425 | "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), 426 | "Denoising strength": getattr(p, 'denoising_strength', None), 427 | "Conditional mask weight": getattr(p, "inpainting_mask_weight", opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, 428 | "Clip skip": None if clip_skip <= 1 else clip_skip, 429 | "ENSD": getattr(opts, 'eta_noise_seed_delta', None) if uses_ensd else None, 430 | "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio, 431 | "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr, 432 | "Init image hash": getattr(p, 'init_img_hash', None), 433 | "RNG": None, 434 | "NGMS": None, 435 | "Version": None, 436 | **p.extra_generation_params, 437 | } 438 | 439 | # compatible with old version 440 | if getattr(p, 's_min_ucond', None) is not None: 441 | if p.s_min_uncond != 0: 442 | generation_params["NGMS"] = p.s_min_uncond 443 | if getattr(opts, 'randn_source', 444 | None) is not None and opts.randn_source != "GPU": 445 | generation_params["RNG"] = opts.randn_source 446 | 447 | if getattr(opts, 'add_version_to_infotext', None): 448 | if opts.add_version_to_infotext: 449 | generation_params['Version'] = processing.program_version() 450 | 451 | generation_params_text = ", ".join([ 452 | k if k == v else 453 | f'{k}: {generation_parameters_copypaste.quote(v)}' 454 | for k, v in generation_params.items() if v is not None 455 | ]) 456 | 457 | negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[ 458 | index] else "" 459 | 460 | return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip( 461 | ) 462 | 463 | 464 | def get_visible_extension_args(p: processing.StableDiffusionProcessing, name): 465 | for s in p.scripts.alwayson_scripts: 466 | if s.name == name: 467 | return p.script_args[s.args_from:s.args_to] 468 | return [] 469 | 470 | 471 | def _hijack_func(module_name, func_name, new_func): 472 | old_func = None 473 | extension_mode = False 474 | extension_prefix = "" 475 | if module_name.startswith('extensions.'): 476 | extension_mode = True 477 | 478 | # from modules.processing import process_images 479 | search_names = [module_name] 480 | search_names.append(func_name) 481 | tmp = module_name.split(".") 482 | if len(tmp) >= 2: 483 | search_names.append(".".join(tmp[-2:])) # import modules.processing 484 | search_names.append(tmp[-1]) # from modules import processing 485 | # from modules import processing.process_images 486 | search_names.append("{}.{}".format(tmp[-1], func_name)) 487 | 488 | if not extension_mode: 489 | # hajiack for normal module 490 | 491 | # case 1: import module, replace function, return old function 492 | module = importlib.import_module(module_name) 493 | old_func = getattr(module, func_name) 494 | setattr(module, func_name, new_func) 495 | 496 | # case 2: from module import func_name 497 | keys = list(sys.modules.keys()) 498 | for name in keys: 499 | # if (name.startswith('modules') 500 | # or name.startswith('scripts')) and name != 'modules.processing': 501 | if name.startswith('modules') and name != module_name: 502 | members = getmembers(sys.modules[name], isfunction) 503 | if func_name in dict(members): 504 | # func_fullname = '{}.{}'.format(members[func_name].__module__, members[func_name].__name__) 505 | # print(func_fullname, '{}.{}'.format(module_name, func_name)) 506 | # if func_fullname == '{}.{}'.format(module_name, func_name): 507 | print('[cloud-inference] reloading', name) 508 | importlib.reload(sys.modules[name]) 509 | 510 | from modules.scripts import scripts_data 511 | for script in scripts_data: 512 | for name in search_names: 513 | if name in script.module.__dict__: 514 | obj = script.module.__dict__[name] 515 | replace = False 516 | if ismodule(obj) and obj.__file__ == module.__file__: 517 | replace = True 518 | elif isfunction(obj) and obj.__module__ == module_name: # ?? 519 | replace = True 520 | 521 | if replace: 522 | if name == func_name: 523 | print('[cloud-inference] reloading {} - {}'.format(script.module.__name__, func_name)) 524 | setattr(script.module, name, new_func) 525 | else: 526 | print('[cloud-inference] reloading {} - {}'.format(script.module.__name__, name)) 527 | t = getattr(script.module, name) 528 | setattr(t, func_name, new_func) 529 | # setattr(script.module, name, t) # ? 530 | return old_func 531 | else: 532 | # hijack for extension module 533 | 534 | from modules.scripts import scripts_data 535 | tmp1, tmp2, extension_suffix = module_name.split( 536 | '.', 2) # scripts internal module name 537 | extension_prefix = "{}.{}".format(tmp1, tmp2) 538 | module_name = "modules.{}".format(extension_suffix) 539 | 540 | for script in scripts_data: 541 | if extension_mode and os.path.join(*extension_prefix.split('.')) not in script.basedir: 542 | continue 543 | for name in search_names: 544 | if name in script.module.__dict__: 545 | obj = script.module.__dict__[name] 546 | replace = False 547 | if ismodule(obj) and module_name.endswith(obj.__name__): 548 | replace = True 549 | elif isfunction(obj) and obj.__module__ == module_name: # ?? 550 | replace = True 551 | 552 | if replace: 553 | # import pkgutil 554 | # s = [_ for _ in pkgutil.iter_modules([os.path.dirname(script.module.__file__)])] 555 | if name == func_name: 556 | print( 557 | '[cloud-inference] reloading {} - {}'.format(script.module.__name__, func_name)) 558 | old_func = getattr(script.module, name) 559 | setattr(script.module, name, new_func) 560 | else: 561 | print( 562 | '[cloud-inference] reloading {} - {}'.format(script.module.__name__, name)) 563 | t = getattr(script.module, name) 564 | old_func = getattr(t, func_name) 565 | setattr(t, func_name, new_func) 566 | setattr(script.module, name, t) 567 | # print(importlib.reload(importlib.import_module('extensions.sd-webui-controlnet.scripts.xyz_grid_support'))) 568 | return old_func 569 | 570 | 571 | _hijack_manager = _HijackManager() 572 | -------------------------------------------------------------------------------- /extension/api.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import io 3 | import base64 4 | from modules import processing 5 | from modules.shared import opts, state 6 | from PIL import Image, ImageFilter, ImageOps 7 | from multiprocessing.pool import ThreadPool 8 | import importlib 9 | 10 | from omniinfer_client import * 11 | from dataclass_wizard import JSONWizard, DumpMeta 12 | from dataclasses import dataclass, field 13 | 14 | from typing import Dict 15 | 16 | import numpy as np 17 | 18 | import os 19 | import copy 20 | import json 21 | 22 | 23 | from .utils import image_to_base64, read_image_files 24 | from .version import __version__ 25 | 26 | OMNIINFER_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), 27 | '.omniinfer.json') 28 | 29 | OMNIINFER_API_ENDPOINT = "https://api.omniinfer.io" 30 | 31 | 32 | def _user_agent(model_name=None): 33 | if model_name: 34 | return 'sd-webui-cloud-inference/{} (model_name: {})'.format(__version__, model_name) 35 | return 'sd-webui-cloud-inference/{}'.format(__version__) 36 | 37 | 38 | class BaseAPI(object): 39 | def txt2img(self, p: processing.StableDiffusionProcessingTxt2Img): 40 | pass 41 | 42 | def img2img(self, p: processing.StableDiffusionProcessingImg2Img): 43 | pass 44 | 45 | def list_models(): 46 | pass 47 | 48 | def refresh_models(): 49 | pass 50 | 51 | 52 | class UpscaleAPI(object): 53 | def upscale(self, *args, **kwargs): 54 | pass 55 | 56 | 57 | class JSONe(JSONWizard): 58 | def __init_subclass__(cls, **kwargs): 59 | super().__init_subclass__(**kwargs) 60 | DumpMeta(key_transform='SNAKE').bind_to(cls) 61 | 62 | 63 | @dataclass 64 | class StableDiffusionModelExample(JSONe): 65 | prompts: Optional[str] = None 66 | neg_prompt: Optional[str] = None 67 | sampler_name: Optional[str] = None 68 | steps: Optional[int] = None 69 | seed: Optional[int] = None 70 | height: Optional[int] = None 71 | width: Optional[int] = None 72 | preview: Optional[str] = None 73 | cfg_scale: Optional[float] = None 74 | 75 | 76 | @dataclass 77 | class StableDiffusionModel(JSONe): 78 | kind: str 79 | name: str 80 | rating: int = 0 81 | tags: List[str] = None 82 | child: Optional[List[str]] = field(default_factory=lambda: []) 83 | examples: Optional[List[StableDiffusionModelExample]] = field(default_factory=lambda: []) 84 | user_tags: Optional[List[str]] = field(default_factory=lambda: []) 85 | preview_url: Optional[str] = None 86 | search_terms: Optional[List[str]] = field(default_factory=lambda: []) 87 | origin_url: Optional[str] = None 88 | 89 | @property 90 | def alias(self): 91 | # format -> [] [] 92 | if self.kind in ["upscaler", "controlnet"]: 93 | return self.name 94 | # return "cloud://{}".format(self.name) 95 | 96 | n = "" 97 | if len(self.tags) > 0: 98 | n = "[{}] ".format(",".join(self.tags)) 99 | return n + os.path.splitext(self.name)[0] 100 | 101 | def add_user_tag(self, tag): 102 | if tag not in self.user_tags: 103 | self.user_tags.append(tag) 104 | 105 | 106 | # class StableDiffusionModelExample(object): 107 | 108 | # def __init__(self, 109 | # prompts=None, 110 | # neg_prompt=None, 111 | # sampler_name=None, 112 | # steps=None, 113 | # cfg_scale=None, 114 | # seed=None, 115 | # height=None, 116 | # width=None, 117 | # preview=None, 118 | # ): 119 | # self.prompts = prompts 120 | # self.neg_prompt = neg_prompt 121 | # self.sampler_name = sampler_name 122 | # self.steps = steps 123 | # self.cfg_scale = cfg_scale 124 | # self.seed = seed 125 | # self.height = height 126 | # self.width = width 127 | # self.preview = preview 128 | 129 | 130 | class OmniinferAPI(BaseAPI, UpscaleAPI): 131 | 132 | def __init__(self, api_key=None): 133 | self._api_key = api_key 134 | self._client: OmniClient = None 135 | 136 | if self._api_key is not None: 137 | self.update_client() 138 | 139 | self._models: List[StableDiffusionModel] = [] 140 | 141 | def update_client(self): 142 | self._client = OmniClient(self._api_key) 143 | self._client.set_extra_headers({'User-Agent': _user_agent()}) 144 | 145 | @classmethod 146 | def load_from_config(cls): 147 | config = {} 148 | try: 149 | with open(OMNIINFER_CONFIG, 'r') as f: 150 | config = json.load(f) 151 | except Exception as exp: 152 | pass 153 | 154 | o = OmniinferAPI() 155 | if config.get('key') is not None: 156 | o._api_key = config['key'] 157 | o.update_client() 158 | else: 159 | # if no key, we will set it to NONE 160 | o._api_key = 'NONE' 161 | o.update_client() 162 | 163 | if config.get('models') is not None: 164 | try: 165 | o._models = [StableDiffusionModel.from_dict(m) for m in config['models']] 166 | except Exception as exp: 167 | print('[cloud-inference] failed to load models from config file {}, we will create a new one'.format(exp)) 168 | o._models = [] 169 | 170 | return o 171 | 172 | @classmethod 173 | def update_key_to_config(cls, key): 174 | config = {} 175 | if os.path.exists(OMNIINFER_CONFIG): 176 | with open(OMNIINFER_CONFIG, 'r') as f: 177 | try: 178 | config = json.load(f) 179 | except: 180 | print( 181 | '[cloud-inference] failed to load config file, we will create a new one' 182 | ) 183 | pass 184 | 185 | config['key'] = key 186 | with open(OMNIINFER_CONFIG, 'wb+') as f: 187 | f.write( 188 | json.dumps(config, ensure_ascii=False, indent=2, 189 | default=vars).encode('utf-8')) 190 | 191 | @classmethod 192 | def update_models_to_config(cls, models): 193 | config = {} 194 | if os.path.exists(OMNIINFER_CONFIG): 195 | with open(OMNIINFER_CONFIG, 'r') as f: 196 | try: 197 | config = json.load(f) 198 | except: 199 | print( 200 | '[cloud-inference] failed to load config file, we will create a new one' 201 | ) 202 | pass 203 | 204 | config['models'] = models 205 | with open(OMNIINFER_CONFIG, 'wb+') as f: 206 | f.write( 207 | json.dumps(config, ensure_ascii=False, indent=2, 208 | default=vars).encode('utf-8')) 209 | 210 | @classmethod 211 | def test_connection(cls, api_key: str): 212 | client = OmniClient(api_key) 213 | try: 214 | res = client.progress("sd-webui-test") 215 | except Exception as e: 216 | raise Exception("Failed to connect to Omniinfer API: {}".format(e)) 217 | if res.code == ProgressResponseCode.INVALID_AUTH: 218 | raise Exception("Invalid API key") 219 | return "✅ Omniinfer Ready... now you can inference on cloud" 220 | 221 | def _update_state(self, progress: ProgressResponse): 222 | # queue(0-10), generating(10-90), downloading(90-100) 223 | if state.skipped or state.interrupted: 224 | raise Exception("Interrupted") 225 | 226 | progress_data = progress.data 227 | 228 | if progress_data.status == ProgressResponseStatusCode.RUNNING: 229 | global_progress = (0.7 * progress_data.progress) 230 | if global_progress < 0.1: 231 | global_progress = 0.1 232 | 233 | if global_progress >= 0.9: 234 | global_progress = 0.9 # reverse download time 235 | if progress_data.status == ProgressResponseStatusCode.INITIALIZING: 236 | global_progress = 0.1 237 | elif progress_data.status == ProgressResponseStatusCode.SUCCESSFUL: 238 | global_progress = 0.9 239 | elif progress_data.status == ProgressResponseStatusCode.TIMEOUT: 240 | raise Exception("failed to generate image: timeout") 241 | elif progress_data.status == ProgressResponseStatusCode.FAILED: 242 | raise Exception("failed to generate image({}): {}", progress.data.failed_reason) 243 | 244 | state.sampling_step = int(state.sampling_steps * state.job_count * global_progress) 245 | 246 | def img2img( 247 | self, 248 | p: processing.StableDiffusionProcessingImg2Img, 249 | ): 250 | controlnet_batchs = get_controlnet_arg(p) 251 | 252 | live_previews_image_format = "png" 253 | if getattr(opts, 'live_previews_image_format', None): 254 | live_previews_image_format = opts.live_previews_image_format 255 | 256 | images_base64 = [] 257 | for i in p.init_images: 258 | if live_previews_image_format == "png": 259 | # using optimize for large images takes an enormous amount of time 260 | if max(*i.size) <= 256: 261 | save_kwargs = {"optimize": True} 262 | else: 263 | save_kwargs = {"optimize": False, "compress_level": 1} 264 | 265 | else: 266 | save_kwargs = {} 267 | 268 | with io.BytesIO() as buffered: 269 | i.save(buffered, format=live_previews_image_format, **save_kwargs) 270 | base64_image = base64.b64encode(buffered.getvalue()).decode('ascii') 271 | images_base64.append(base64_image) 272 | 273 | def _req(p: processing.StableDiffusionProcessingImg2Img, controlnet_units): 274 | req = Img2ImgRequest( 275 | model_name=p._cloud_inference_settings['sd_checkpoint'], 276 | sampler_name=p.sampler_name, 277 | init_images=images_base64, 278 | mask=image_to_base64(p.image_mask) if p.image_mask else None, 279 | resize_mode=p.resize_mode, 280 | denoising_strength=p.denoising_strength, 281 | cfg_scale=p.cfg_scale, 282 | mask_blur=p.mask_blur_x, 283 | inpainting_fill=p.inpainting_fill, 284 | inpaint_full_res=bool2int(p.inpaint_full_res), 285 | inpaint_full_res_padding=p.inpaint_full_res_padding, 286 | initial_noise_multiplier=p.initial_noise_multiplier, 287 | inpainting_mask_invert=bool2int(p.inpainting_mask_invert), 288 | prompt=p.prompt, 289 | seed=int(p.seed) or -1, 290 | negative_prompt=p.negative_prompt, 291 | batch_size=p.batch_size, 292 | n_iter=p.n_iter, 293 | width=p.width, 294 | height=p.height, 295 | restore_faces=p.restore_faces, 296 | clip_skip=opts.CLIP_stop_at_last_layers, 297 | ) 298 | if 'CLIP_stop_at_last_layers' in p.override_settings: 299 | req.clip_skip = p.override_settings['CLIP_stop_at_last_layers'] 300 | 301 | if 'sd_vae' in p._cloud_inference_settings: 302 | req.sd_vae = p._cloud_inference_settings['sd_vae'] 303 | 304 | 305 | if hasattr(p, 'refiner_checkpoint') and p.refiner_checkpoint is not None and p.refiner_checkpoint != "None": 306 | req.sd_refiner = Refiner( 307 | checkpoint=p.refiner_checkpoint, 308 | switch_at=p.refiner_switch_at, 309 | ) 310 | 311 | if len(controlnet_units) > 0: 312 | req.controlnet_units = controlnet_units 313 | if opts.data.get("control_net_no_detectmap", False): 314 | req.controlnet_no_detectmap = True 315 | 316 | res = self._client.sync_img2img(req, download_images=False, callback=self._update_state) 317 | return res.data.imgs 318 | 319 | controlnet_batchs = get_controlnet_arg(p) 320 | 321 | imgs = [] 322 | if len(controlnet_batchs) > 0: 323 | for c in controlnet_batchs: 324 | imgs.extend(_req(p, c)) 325 | else: 326 | imgs.extend(_req(p, [])) 327 | return retrieve_images(imgs) 328 | 329 | def txt2img(self, p: processing.StableDiffusionProcessingTxt2Img): 330 | controlnet_batchs = get_controlnet_arg(p) 331 | 332 | def _req(p: processing.StableDiffusionProcessingTxt2Img, controlnet_units): 333 | req = Txt2ImgRequest( 334 | model_name=p._cloud_inference_settings['sd_checkpoint'], 335 | sampler_name=p.sampler_name, 336 | prompt=p.prompt, 337 | negative_prompt=p.negative_prompt, 338 | batch_size=p.batch_size, 339 | n_iter=p.n_iter, 340 | steps=p.steps, 341 | cfg_scale=p.cfg_scale, 342 | seed=int(p.seed) or -1, 343 | height=p.height, 344 | width=p.width, 345 | restore_faces=p.restore_faces, 346 | clip_skip=opts.CLIP_stop_at_last_layers, 347 | ) 348 | 349 | if p.enable_hr: 350 | req.enable_hr = True 351 | req.hr_upscaler = p.hr_upscaler 352 | req.hr_scale = p.hr_scale 353 | req.hr_resize_x = p.hr_resize_x 354 | req.hr_resize_y = p.hr_resize_y 355 | 356 | if 'CLIP_stop_at_last_layers' in p.override_settings: 357 | req.clip_skip = p.override_settings['CLIP_stop_at_last_layers'] 358 | if 'sd_vae' in p._cloud_inference_settings: 359 | req.sd_vae = p._cloud_inference_settings['sd_vae'] 360 | 361 | if len(controlnet_units) > 0: 362 | req.controlnet_units = controlnet_units 363 | if opts.data.get("control_net_no_detectmap", False): 364 | req.controlnet_no_detectmap = True 365 | 366 | if hasattr(p, 'refiner_checkpoint') and p.refiner_checkpoint is not None and p.refiner_checkpoint != "None": 367 | req.sd_refiner = Refiner( 368 | checkpoint=p.refiner_checkpoint, 369 | switch_at=p.refiner_switch_at, 370 | ) 371 | 372 | res = self._client.sync_txt2img(req, download_images=False, callback=self._update_state) 373 | if res.data.status != ProgressResponseStatusCode.SUCCESSFUL: 374 | raise Exception(res.data.failed_reason) 375 | 376 | return res.data.imgs 377 | 378 | imgs = [] 379 | if len(controlnet_batchs) > 0: 380 | for c in controlnet_batchs: 381 | imgs.extend(_req(p, c)) 382 | else: 383 | imgs.extend(_req(p, [])) 384 | 385 | state.textinfo = "downloading images..." 386 | 387 | return retrieve_images(imgs) 388 | 389 | def upscale(self, image, 390 | resize_mode: int, 391 | upscaling_resize: float, 392 | upscaling_resize_w: int, 393 | upscaling_resize_h: int, 394 | upscaling_crop: bool, 395 | extras_upscaler_1: str, 396 | extras_upscaler_2: str, 397 | extras_upscaler_2_visibility: float, 398 | gfpgan_visibility: float, 399 | codeformer_visibility: float, 400 | codeformer_weight: float, 401 | *args, 402 | **kwargs 403 | ): 404 | req = UpscaleRequest( 405 | image=image_to_base64(image), 406 | upscaler_1=extras_upscaler_1, 407 | resize_mode=resize_mode, 408 | upscaling_resize=upscaling_resize, 409 | upscaling_resize_w=upscaling_resize_w, 410 | upscaling_resize_h=upscaling_resize_h, 411 | upscaling_crop=upscaling_crop, 412 | upscaler_2=extras_upscaler_2, 413 | extras_upscaler_2_visibility=extras_upscaler_2_visibility, 414 | gfpgan_visibility=gfpgan_visibility, 415 | codeformer_visibility=codeformer_visibility, 416 | codeformer_weight=codeformer_weight 417 | ) 418 | 419 | res = self._client.sync_upscale(req, download_images=False, callback=self._update_state) 420 | if res.data.status != ProgressResponseStatusCode.SUCCESSFUL: 421 | raise Exception(res.data.failed_reason) 422 | return retrieve_images(res.data.imgs) 423 | 424 | def list_models(self): 425 | if self._models is None or len(self._models) == 0: 426 | self._models = self.refresh_models() 427 | return sorted(self._models, key=lambda x: x.rating, reverse=True) 428 | 429 | def refresh_models(self): 430 | 431 | def get_models(type_): 432 | ret = [] 433 | models = self._client.models(refresh=True).filter_by_type(type_) 434 | for item in models: 435 | model = StableDiffusionModel(kind=item.type.value, 436 | name=item.sd_name) 437 | model.search_terms = [ 438 | item.sd_name, 439 | item.name, 440 | str(item.civitai_version_id) 441 | ] 442 | model.rating = item.civitai_download_count 443 | civitai_tags = item.civitai_tags.split(",") if item.civitai_tags is not None else [] 444 | 445 | if model.tags is None: 446 | model.tags = [] 447 | 448 | if len(civitai_tags) > 0: 449 | model.tags.append(civitai_tags[0]) 450 | 451 | if item.civitai_nsfw or item.civitai_image_nsfw: 452 | model.tags.append("nsfw") 453 | 454 | if item.civitai_image_url: 455 | model.preview_url = item.civitai_image_url 456 | 457 | model.examples = [] 458 | if item.civitai_images: 459 | for img in item.civitai_images: 460 | if img.meta.prompt: 461 | model.examples.append(StableDiffusionModelExample( 462 | prompts=img.meta.prompt, 463 | neg_prompt=img.meta.negative_prompt, 464 | width=img.meta.width, 465 | height=img.meta.height, 466 | sampler_name=img.meta.sampler_name, 467 | cfg_scale=img.meta.cfg_scale, 468 | )) 469 | 470 | ret.append(model) 471 | return ret 472 | 473 | sd_models = [] 474 | print("[cloud-inference] refreshing models...") 475 | 476 | sd_models.extend(get_models(ModelType.CHECKPOINT)) 477 | sd_models.extend(get_models(ModelType.LORA)) 478 | sd_models.extend(get_models(ModelType.CONTROLNET)) 479 | sd_models.extend(get_models(ModelType.VAE)) 480 | sd_models.extend(get_models(ModelType.UPSCALER)) 481 | sd_models.extend(get_models(ModelType.TEXT_INVERSION)) 482 | 483 | # build lora and checkpoint relationship 484 | 485 | merged_models = {} 486 | origin_models = {} 487 | for model in self._models: 488 | origin_models[model.name] = model 489 | for model in sd_models: 490 | if model.name in origin_models: 491 | # save user tags 492 | merged_models[model.name] = model 493 | merged_models[model.name].user_tags = origin_models[model.name].user_tags 494 | else: 495 | merged_models[model.name] = model 496 | 497 | self._models = [v for k, v in merged_models.items()] 498 | self.update_models_to_config(self._models) 499 | return self._models 500 | 501 | 502 | _instance = None 503 | 504 | 505 | def get_instance(): 506 | global _instance 507 | if _instance is not None: 508 | return _instance 509 | _instance = OmniinferAPI.load_from_config() 510 | return _instance 511 | 512 | 513 | def refresh_instance(): 514 | global _instance 515 | _instance = OmniinferAPI.load_from_config() 516 | return _instance 517 | 518 | 519 | def get_visible_extension_args(p: processing.StableDiffusionProcessing, name): 520 | for s in p.scripts.alwayson_scripts: 521 | if s.name == name: 522 | return p.script_args[s.args_from:s.args_to] 523 | return [] 524 | 525 | 526 | def get_controlnet_arg(p: processing.StableDiffusionProcessing): 527 | controlnet_batchs = [] 528 | 529 | # controlnet_units = get_visible_extension_args(p, 'controlnet') 530 | try: 531 | external_code = importlib.import_module('extensions.sd-webui-controlnet.scripts.external_code', 'external_code') 532 | except ModuleNotFoundError: 533 | return [] 534 | 535 | controlnet_units = external_code.get_all_units_in_processing(p) 536 | 537 | for c in controlnet_units: 538 | if c.enabled == False: 539 | continue 540 | 541 | controlnet_arg = {} 542 | controlnet_arg['weight'] = c.weight 543 | controlnet_arg['model'] = c.model 544 | controlnet_arg['module'] = c.module 545 | if c.resize_mode == "Just Resize": 546 | controlnet_arg['resize_mode'] = 0 547 | elif c.resize_mode == "Resize and Crop": 548 | controlnet_arg['resize_mode'] = 1 549 | elif c.resize_mode == "Envelope (Outer Fit)": 550 | controlnet_arg['resize_code'] = 2 551 | 552 | if 'pixel_perfect' in c.__dict__: 553 | if c.pixel_perfect: 554 | controlnet_arg['pixel_perfect'] = True 555 | 556 | 557 | if 'processor_res' in c.__dict__: 558 | if c.processor_res > 0: 559 | controlnet_arg['processor_res'] = c.processor_res 560 | 561 | if 'threshold_a' in c.__dict__: 562 | controlnet_arg['threshold_a'] = int(c.threshold_a) 563 | if 'threshold_b' in c.__dict__: 564 | controlnet_arg['threshold_b'] = int(c.threshold_b) 565 | if 'guidance_start' in c.__dict__: 566 | controlnet_arg['guidance_start'] = c.guidance_start 567 | if 'guidance_end' in c.__dict__: 568 | controlnet_arg['guidance_end'] = c.guidance_end 569 | 570 | if c.control_mode == "Balanced": 571 | controlnet_arg['control_mode'] = 0 572 | elif c.control_mode == "My prompt is more important": 573 | controlnet_arg['control_mode'] = 1 574 | elif c.control_mode == "ControlNet is more important": 575 | controlnet_arg['control_mode'] = 2 576 | else: 577 | return 578 | 579 | img2img = isinstance(p, processing.StableDiffusionProcessingImg2Img) 580 | if img2img and not c.image: 581 | c.image = {} 582 | init_image = getattr(p, "init_images", [None])[0] 583 | if init_image is not None: 584 | c.image['image'] = np.asarray(init_image) 585 | 586 | a1111_i2i_resize_mode = getattr(p, "resize_mode", None) 587 | # TODO: mask 588 | 589 | if a1111_i2i_resize_mode is not None: 590 | controlnet_arg['resize_mode'] = a1111_i2i_resize_mode 591 | 592 | if getattr(c.input_mode, 'value', '') == "simple": 593 | if c.image is not None: 594 | c.image = image_dict_from_any(c.image) 595 | if "mask" in c.image: 596 | mask = Image.fromarray(c.image["mask"]) 597 | controlnet_arg['mask'] = image_to_base64(mask) 598 | 599 | controlnet_arg['input_image'] = image_to_base64(Image.fromarray(c.image["image"])) 600 | 601 | if len(controlnet_batchs) == 0: 602 | controlnet_batchs = [[]] 603 | 604 | controlnet_batchs[0].append(controlnet_arg) 605 | 606 | elif getattr(c.input_mode, 'value', '') == "batch": 607 | if c.batch_images != "" and c.batch_images != None: 608 | images = read_image_files(c.batch_images) 609 | for i, img in enumerate(images): 610 | if len(controlnet_batchs) <= i: 611 | controlnet_batchs.append([]) 612 | 613 | controlnet_new_arg = copy.deepcopy(controlnet_arg) 614 | controlnet_new_arg['input_image'] = img 615 | 616 | controlnet_batchs[i].append(controlnet_new_arg) 617 | else: 618 | print("batch_images is empty") 619 | 620 | else: 621 | print("input_mode is empty") 622 | 623 | return controlnet_batchs 624 | 625 | 626 | def image_has_mask(input_image: np.ndarray) -> bool: 627 | return ( 628 | input_image.ndim == 3 and 629 | input_image.shape[2] == 4 and 630 | np.max(input_image[:, :, 3]) > 127 631 | ) 632 | 633 | 634 | def prepare_mask( 635 | mask: Image.Image, p: processing.StableDiffusionProcessing 636 | ) -> Image.Image: 637 | mask = mask.convert("L") 638 | if getattr(p, "inpainting_mask_invert", False): 639 | mask = ImageOps.invert(mask) 640 | if getattr(p, "mask_blur", 0) > 0: 641 | mask = mask.filter(ImageFilter.GaussianBlur(p.mask_blur)) 642 | return mask 643 | 644 | 645 | def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]: 646 | if image is None: 647 | return None 648 | 649 | if isinstance(image, (tuple, list)): 650 | image = {'image': image[0], 'mask': image[1]} 651 | elif not isinstance(image, dict): 652 | image = {'image': image, 'mask': None} 653 | else: # type(image) is dict 654 | # copy to enable modifying the dict and prevent response serialization error 655 | image = dict(image) 656 | 657 | if isinstance(image['image'], str): 658 | if os.path.exists(image['image']): 659 | image['image'] = np.array(Image.open(image['image'])).astype('uint8') 660 | elif image['image']: 661 | image['image'] = external_code.to_base64_nparray(image['image']) 662 | else: 663 | image['image'] = None 664 | 665 | # If there is no image, return image with None image and None mask 666 | if image['image'] is None: 667 | image['mask'] = None 668 | return image 669 | 670 | if 'mask' not in image: 671 | image['mask'] = None 672 | 673 | if isinstance(image['mask'], str): 674 | if os.path.exists(image['mask']): 675 | image['mask'] = np.array(Image.open(image['mask'])).astype('uint8') 676 | elif image['mask']: 677 | image['mask'] = external_code.to_base64_nparray(image['mask']) 678 | else: 679 | image['mask'] = np.zeros_like(image['image'], dtype=np.uint8) 680 | elif image['mask'] is None: 681 | image['mask'] = np.zeros_like(image['image'], dtype=np.uint8) 682 | 683 | return image 684 | 685 | 686 | def retrieve_images(img_urls): 687 | def _download(img_url): 688 | attempts = 5 689 | while attempts > 0: 690 | try: 691 | response = requests.get(img_url, timeout=2) 692 | with io.BytesIO(response.content) as fp: 693 | return Image.open(fp).copy() 694 | except Exception: 695 | print("[cloud-inference] failed to download image, retrying...") 696 | attempts -= 1 697 | return None 698 | 699 | pool = ThreadPool() 700 | applied = [] 701 | for img_url in img_urls: 702 | applied.append(pool.apply_async(_download, (img_url, ))) 703 | ret = [r.get() for r in applied] 704 | return [_ for _ in ret if _ is not None] 705 | 706 | 707 | def bool2int(b): 708 | if isinstance(b, bool): 709 | return 1 if b else 0 710 | return b 711 | -------------------------------------------------------------------------------- /scripts/main_ui.py: -------------------------------------------------------------------------------- 1 | import modules.scripts as scripts 2 | import gradio as gr 3 | import os 4 | 5 | from modules import script_callbacks, shared, paths_internal, ui_common 6 | from extension import api 7 | 8 | from collections import Counter 9 | import random 10 | 11 | 12 | refresh_symbol = '\U0001f504' # 🔄 13 | favorite_symbol = '\U0001f49e' # 💞 14 | model_browser_symbol = '\U0001f50d' # 🔍 15 | 16 | 17 | class FormComponent: 18 | def get_expected_parent(self): 19 | return gr.components.Form 20 | 21 | 22 | class FormButton(FormComponent, gr.Button): 23 | def __init__(self, *args, **kwargs): 24 | super().__init__(*args, **kwargs) 25 | 26 | def get_block_name(self): 27 | return "button" 28 | 29 | 30 | class ToolButton(FormComponent, gr.Button): 31 | """Small button with single emoji as text, fits inside gradio forms""" 32 | 33 | def __init__(self, *args, **kwargs): 34 | classes = kwargs.pop("elem_classes", []) 35 | super().__init__(*args, elem_classes=["tool", *classes], **kwargs) 36 | 37 | def get_block_name(self): 38 | return "button" 39 | 40 | 41 | class DataBinding: 42 | def __init__(self): 43 | self.enable_cloud_inference = None 44 | 45 | # internal component 46 | self.txt2img_prompt = None 47 | self.txt2img_neg_prompt = None 48 | self.img2img_prompt = None 49 | self.img2img_neg_prompt = None 50 | self.txt2img_generate = None 51 | self.img2img_generate = None 52 | 53 | # custom component, need to sync 54 | self.txt2img_cloud_inference_model_dropdown = None 55 | self.img2img_cloud_inference_model_dropdown = None 56 | self.img2img_cloud_inference_checkbox = None 57 | self.txt2img_cloud_inference_checkbox = None 58 | 59 | self.txt2img_cloud_inference_vae_dropdown = None 60 | self.img2img_cloud_inference_vae_dropdown = None 61 | 62 | self.txt2img_cloud_inference_suggest_prompts_checkbox = None 63 | self.img2img_cloud_inference_suggest_prompts_checkbox = None 64 | 65 | self.remote_inference_enabled = False 66 | 67 | self.remote_models = None 68 | self.remote_models_aliases = {} 69 | self.remote_model_checkpoints = None 70 | self.remote_model_embeddings = None 71 | self.remote_model_loras = None 72 | self.remote_model_controlnet = None 73 | self.remote_model_vaes = None 74 | self.remote_model_upscalers = None 75 | 76 | # refiner 77 | self.txt2img_checkpoint = None 78 | self.img2img_checkpoint = None 79 | self.txt2img_checkpoint_refresh = None 80 | self.img2img_checkpoint_refresh = None 81 | 82 | # third component 83 | self.txt2img_controlnet_model_dropdown_units = [] 84 | self.img2img_controlnet_model_dropdown_units = [] 85 | 86 | # upscale 87 | self.extras_upscaler_1 = None 88 | self.extras_upscaler_2 = None 89 | self.txt2img_hr_upscaler = None 90 | 91 | # backup config 92 | self.txt2img_hr_upscaler_original = None 93 | self.extras_upscaler_1_original = None 94 | self.extras_upscaler_2_original = None 95 | self.txt2img_controlnet_model_dropdown_original_units = [] 96 | self.img2img_controlnet_model_dropdown_original_units = [] 97 | self.txt2img_checkpoint_original = None 98 | self.img2img_checkpoint_original = None 99 | 100 | self.default_remote_model = None 101 | self.initialized = False 102 | 103 | self.bultin_refiner_supported = False 104 | self.ext_controlnet_installed = False 105 | 106 | def on_selected_model(self, name_index: int, selected_loras: list[str], selected_embedding: list[str], suggest_prompts_enabled, prompt: str, neg_prompt: str): 107 | selected: api.StableDiffusionModel = self.find_model_by_alias(name_index) 108 | selected_checkpoint = selected 109 | 110 | # name = self.remote_sd_models[name_index].name 111 | prompt = prompt 112 | neg_prompt = neg_prompt 113 | 114 | if len(selected.examples) > 0: 115 | example = random.choice(selected.examples) 116 | if suggest_prompts_enabled and example.prompts: 117 | prompt = example.prompts 118 | prompt = prompt.replace("\n", "") 119 | if len(selected_loras) > 0: 120 | prompt = self._update_lora_in_prompt(prompt, selected_loras) 121 | if suggest_prompts_enabled and example.neg_prompt: 122 | neg_prompt = example.neg_prompt 123 | neg_prompt = neg_prompt.replace("\n", "") 124 | if len(selected_embedding) > 0: 125 | neg_prompt = self._update_embedding_in_neg_prompt(neg_prompt, selected_embedding) 126 | 127 | return gr.Dropdown.update( 128 | choices=[_.alias for _ in self.remote_model_checkpoints], value=selected_checkpoint.alias), gr.update(value=prompt), gr.update(value=neg_prompt) 129 | 130 | def update_models(self): 131 | for model in self.remote_models: 132 | self.remote_models_aliases[model.alias] = model 133 | 134 | _binding.remote_model_loras = _get_kind_from_remote_models(_binding.remote_models, "lora") 135 | _binding.remote_model_embeddings = _get_kind_from_remote_models(_binding.remote_models, "textualinversion") 136 | _binding.remote_model_checkpoints = _get_kind_from_remote_models(_binding.remote_models, "checkpoint") 137 | _binding.remote_model_vaes = _get_kind_from_remote_models(_binding.remote_models, "vae") 138 | _binding.remote_model_controlnet = _get_kind_from_remote_models(_binding.remote_models, "controlnet") 139 | _binding.remote_model_upscalers = _get_kind_from_remote_models(_binding.remote_models, "upscaler") 140 | 141 | @staticmethod 142 | def _update_lora_in_prompt(prompt, _lora_names, weight=1): 143 | lora_names = [] 144 | for lora_name in _lora_names: 145 | lora_names.append(_binding.find_model_by_alias(lora_name).name) 146 | 147 | prompt = prompt 148 | add_lora_prompts = [] 149 | 150 | prompt_split = [_.strip() for _ in prompt.split(',') if _.strip() != ""] 151 | 152 | # add 153 | for lora_name in lora_names: 154 | if '".format( 156 | lora_name, weight)) 157 | # delete 158 | for prompt_item in prompt_split: 159 | if prompt_item.startswith(""): 160 | lora_name = prompt_item.split(":")[1] 161 | if lora_name not in lora_names: 162 | prompt_split.remove(prompt_item) 163 | 164 | prompt_split.extend(add_lora_prompts) 165 | 166 | return ", ".join(prompt_split) 167 | 168 | @staticmethod 169 | def _update_embedding_in_neg_prompt(neg_prompt, _embedding_names): 170 | embedding_names = [] 171 | for embedding_name in _embedding_names: 172 | name = _binding.find_model_by_alias(embedding_name).name.rsplit(".", 1)[0] # remove extension 173 | embedding_names.append(name) 174 | 175 | neg_prompt = neg_prompt 176 | add_embedding_prompts = [] 177 | 178 | neg_prompt_split = [_.strip() for _ in neg_prompt.split(',') if _.strip() != ""] 179 | 180 | # add 181 | for embedding_name in embedding_names: 182 | if embedding_name not in neg_prompt: 183 | add_embedding_prompts.append(embedding_name) 184 | # delete 185 | for prompt_item in neg_prompt_split: 186 | if prompt_item in embedding_names: 187 | neg_prompt_split.remove(prompt_item) 188 | 189 | neg_prompt_split.extend(add_embedding_prompts) 190 | 191 | return ", ".join(neg_prompt_split) 192 | 193 | def update_selected_lora(self, lora_names, prompt): 194 | return gr.update(value=self._update_lora_in_prompt(prompt, lora_names)) 195 | 196 | def update_selected_embedding(self, embedding_names, neg_prompt): 197 | return gr.update(value=self._update_embedding_in_neg_prompt(neg_prompt, embedding_names)) 198 | 199 | def update_cloud_api(self, v): 200 | self.cloud_api = v 201 | 202 | def find_model_by_alias(self, choice): # alias -> sd_name 203 | for model in self.remote_models: 204 | if model.alias == choice: 205 | return model 206 | 207 | def find_name_by_alias(self, choice): 208 | for model in self.remote_models: 209 | if model.alias == choice: 210 | return model.name 211 | 212 | # def update_model_favorite(self, alias): 213 | # model = self.find_model_by_alias(alias) 214 | # if model is not None: 215 | # if "favorite" in model.tags: 216 | # model.tags.remove("favorite") 217 | # else: 218 | # model.tags.append("favorite") 219 | # return gr.update(value=build_model_browser_html_for_checkpoint("txt2img", _binding.remote_model_checkpoints)), \ 220 | # gr.update(value=build_model_browser_html_for_loras("txt2img", _binding.remote_model_loras)), \ 221 | # gr.update(value=build_model_browser_html_for_embeddings("txt2img", _binding.remote_model_embeddings)), \ 222 | 223 | 224 | def _get_kind_from_remote_models(models, kind): 225 | t = [] 226 | for model in models: 227 | if model.kind == kind: 228 | t.append(model) 229 | return t 230 | 231 | 232 | class CloudInferenceScript(scripts.Script): 233 | # Extension title in menu UI 234 | def title(self): 235 | return "Cloud Inference" 236 | 237 | def show(self, is_img2img): 238 | return scripts.AlwaysVisible 239 | 240 | def ui(self, is_img2img): 241 | tabname = "txt2img" 242 | if is_img2img: 243 | tabname = "img2img" 244 | 245 | # data initialize, TODO: move 246 | if _binding.remote_models is None or len(_binding.remote_models) == 0: 247 | _binding.remote_models = api.get_instance().list_models() 248 | _binding.update_models() 249 | 250 | top_n = min(len(_binding.remote_model_checkpoints), 50) 251 | if _binding.default_remote_model is None: 252 | _binding.default_remote_model = random.choice(_binding.remote_model_checkpoints[:top_n]).alias if len(_binding.remote_model_checkpoints) > 0 else None 253 | 254 | default_enabled = shared.opts.data.get("cloud_inference_default_enabled", False) 255 | if default_enabled: 256 | _binding.remote_inference_enabled = True 257 | 258 | default_suggest_prompts_enabled = shared.opts.data.get("cloud_inference_suggest_prompts_default_enabled", True) 259 | 260 | # define ui layouts 261 | with gr.Accordion('Cloud Inference', open=True): 262 | with gr.Row(): 263 | cloud_inference_checkbox = gr.Checkbox( 264 | label="Enable Cloud Inference", 265 | value=lambda: default_enabled, 266 | visible=not shared.opts.data.get( 267 | "cloud_inference_checkbox_hidden", False), 268 | elem_id="{}_cloud_inference_checkbox".format(tabname)) 269 | 270 | cloud_inference_suggest_prompts_checkbox = gr.Checkbox( 271 | value=lambda: default_suggest_prompts_enabled, 272 | label="Suggest Prompts", 273 | elem_id="{}_cloud_inference_suggest_prompts_checkbox".format(tabname)) 274 | 275 | with gr.Row(): 276 | gr.Dropdown( 277 | label="Service Provider", 278 | choices=["Omniinfer"], 279 | value="Omniinfer", 280 | elem_id="{}_cloud_api_dropdown".format(tabname), 281 | scale=1 282 | ) 283 | 284 | cloud_inference_model_dropdown = gr.Dropdown( 285 | label="Checkpoint", 286 | choices=[_.alias for _ in _binding.remote_model_checkpoints], 287 | value=lambda: _binding.default_remote_model, 288 | elem_id="{}_cloud_inference_model_dropdown".format(tabname), scale=2) 289 | 290 | model_browser_button = FormButton(value="{} Browser".format(model_browser_symbol), elem_classes='model-browser-button', 291 | elem_id="{}_cloud_inference_browser_button".format(tabname), scale=0) 292 | refresh_button = ToolButton(value=refresh_symbol, elem_id="{}_cloud_inference_refersh_button".format(tabname)) 293 | 294 | # model_browser_button = ToolButton(model_browser_symbol, elem_id="{}_cloud_inference_browser_button".format(tabname)) 295 | # favorite_button = ToolButton( 296 | # value=favorite_symbol, elem_id="{}_cloud_inference_favorite_button".format(tabname)) 297 | 298 | with gr.Row(): 299 | cloud_inference_lora_dropdown = gr.Dropdown( 300 | choices=[_.alias for _ in _binding.remote_model_loras], 301 | label="Lora", 302 | elem_id="{}_cloud_inference_lora_dropdown", multiselect=True, scale=4) 303 | cloud_inference_embedding_dropdown = gr.Dropdown( 304 | choices=[_.alias for _ in _binding.remote_model_embeddings], 305 | label="Embedding", 306 | elem_id="{}_cloud_inference_embedding_dropdown", multiselect=True, scale=4) 307 | 308 | cloud_inference_extra_checkbox = gr.Checkbox( 309 | label="Extra", 310 | value=False, 311 | elem_id="{}_cloud_inference_extra_subseed_show", 312 | scale=1 313 | ) 314 | 315 | # functionally 316 | hide_button_change_checkpoint = gr.Button('Change Cloud checkpoint', elem_id='{}_change_cloud_checkpoint'.format(tabname), visible=False) 317 | hide_button_change_lora = gr.Button('Change Cloud LORA', elem_id='{}_change_cloud_lora'.format(tabname), visible=False) 318 | hide_button_change_embedding = gr.Button('Change Cloud Embedding', elem_id='{}_change_cloud_embedding'.format(tabname), visible=False) 319 | # hide_button_favorite = gr.Button('Favorite', elem_id='{}_favorite'.format(tabname), visible=False) 320 | 321 | with gr.Box(elem_id='{}_model_browser'.format(tabname), elem_classes="popup-model-browser", visbile=False) as checkpoint_model_browser_dialog: 322 | with gr.Tab(label="Checkpoint", elem_id='{}_model_browser_checkpoint_tab'.format(tabname)): 323 | model_checkpoint_browser_dialog_html = gr.HTML(build_model_browser_html_for_checkpoint(tabname, _binding.remote_model_checkpoints)) 324 | with gr.Tab(label="LORA", elem_id='{}_model_browser_lora_tab'.format(tabname)): 325 | model_lora_browser_dialog_html = gr.HTML(build_model_browser_html_for_loras(tabname, _binding.remote_model_loras)) 326 | with gr.Tab(label="Embedding", elem_id='{}_model_browser_embedding_tab'.format(tabname)): 327 | model_embedding_browser_dialog_html = gr.HTML(build_model_browser_html_for_embeddings(tabname, _binding.remote_model_embeddings)) 328 | 329 | checkpoint_model_browser_dialog.visible = False 330 | model_browser_button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[checkpoint_model_browser_dialog],).\ 331 | then(fn=None, _js="function(){ modelBrowserPopup('" + tabname + "', gradioApp().getElementById('" + checkpoint_model_browser_dialog.elem_id + "')); }", show_progress=True) 332 | 333 | with gr.Row(visible=False) as extra_row: 334 | cloud_inference_vae_dropdown = gr.Dropdown( 335 | choices=["Automatic", "None"] + [_.name for _ in _binding.remote_model_vaes], 336 | value="Automatic", 337 | label="VAE", 338 | elme_id="{}_cloud_inference_vae_dropdown".format(tabname), 339 | ) 340 | 341 | cloud_inference_extra_checkbox.change(lambda x: gr.update(visible=x), inputs=[ 342 | cloud_inference_extra_checkbox], outputs=[extra_row]) 343 | 344 | # lora 345 | # define events of components. 346 | # auto fill prompt after select model 347 | hide_button_change_checkpoint.click( 348 | fn=_binding.on_selected_model, 349 | _js="function(a, b, c, d, e, f){ var res = desiredCloudInferenceCheckpointName; desiredCloudInferenceCheckpointName = ''; return [res, b, c, d, e, f]; }", 350 | inputs=[ 351 | cloud_inference_model_dropdown, 352 | cloud_inference_lora_dropdown, 353 | cloud_inference_embedding_dropdown, 354 | cloud_inference_suggest_prompts_checkbox, 355 | getattr(_binding, "{}_prompt".format(tabname)), 356 | getattr(_binding, "{}_neg_prompt".format(tabname)) 357 | ], 358 | outputs=[ 359 | cloud_inference_model_dropdown, 360 | getattr(_binding, "{}_prompt".format(tabname)), 361 | getattr(_binding, "{}_neg_prompt".format(tabname)) 362 | ] 363 | ) 364 | # dummy_component = gr.Label(visible=False) 365 | # hide_button_favorite.click( 366 | # fn=_binding.update_model_favorite, 367 | # _js='''function(){ name = desciredCloudInferenceFavoriteModelName; desciredCloudInferenceFavoriteModelName = ""; return [name]; }''', 368 | # inputs=[dummy_component], 369 | # outputs=[ 370 | # model_checkpoint_browser_dialog_html, 371 | # model_lora_browser_dialog_html, 372 | # model_embedding_browser_dialog_html, 373 | # ], 374 | # ) 375 | 376 | hide_button_change_lora.click( 377 | fn=lambda x, y: _binding.update_selected_lora(x, y), 378 | _js="function(a, b){ a.includes(desiredCloudInferenceLoraName) || a.push(desiredCloudInferenceLoraName); desiredCloudInferenceLoraName = ''; return [a, b]; }", 379 | inputs=[ 380 | cloud_inference_lora_dropdown, 381 | getattr(_binding, "{}_prompt".format(tabname)) 382 | ], 383 | outputs=getattr(_binding, "{}_prompt".format(tabname)), 384 | ) 385 | # auto fill prompt after select lora 386 | cloud_inference_lora_dropdown.select( 387 | fn=lambda x, y: _binding.update_selected_lora(x, y), 388 | inputs=[ 389 | cloud_inference_lora_dropdown, 390 | getattr(_binding, "{}_prompt".format(tabname)) 391 | ], 392 | outputs=getattr(_binding, "{}_prompt".format(tabname)), 393 | ) 394 | 395 | hide_button_change_embedding.click( 396 | fn=lambda x, y: _binding.update_selected_embedding(x, y), 397 | _js="function(a, b){ a.includes(desiredCloudInferenceEmbeddingName) || a.push(desiredCloudInferenceEmbeddingName); desiredCloudInferenceEmbeddingName = ''; return [a, b]; }", 398 | inputs=[ 399 | cloud_inference_embedding_dropdown, 400 | getattr(_binding, "{}_neg_prompt".format(tabname)) 401 | ], 402 | outputs=getattr(_binding, "{}_neg_prompt".format(tabname)), 403 | ) 404 | # embeddings 405 | cloud_inference_embedding_dropdown.select( 406 | fn=lambda x, y: _binding.update_selected_embedding(x, y), 407 | inputs=[ 408 | cloud_inference_embedding_dropdown, 409 | getattr(_binding, "{}_neg_prompt".format(tabname)) 410 | ], 411 | outputs=[ 412 | getattr(_binding, "{}_neg_prompt".format(tabname)), 413 | ] 414 | ) 415 | 416 | cloud_inference_model_dropdown.select( 417 | fn=_binding.on_selected_model, 418 | inputs=[ 419 | cloud_inference_model_dropdown, 420 | cloud_inference_lora_dropdown, 421 | cloud_inference_embedding_dropdown, 422 | cloud_inference_suggest_prompts_checkbox, 423 | getattr(_binding, "{}_prompt".format(tabname)), 424 | getattr(_binding, "{}_neg_prompt".format(tabname)) 425 | ], 426 | outputs=[ 427 | cloud_inference_model_dropdown, 428 | getattr(_binding, "{}_prompt".format(tabname)), 429 | getattr(_binding, "{}_neg_prompt".format(tabname)) 430 | ]) 431 | 432 | def _model_refresh(tab): 433 | def wrapper(): 434 | api.get_instance().refresh_models() 435 | _binding.remote_models = api.get_instance().list_models() 436 | _binding.update_models() 437 | 438 | return gr.update(choices=[_.alias for _ in _binding.remote_model_checkpoints]), \ 439 | gr.update(choices=[_.alias for _ in _binding.remote_model_loras]), \ 440 | gr.update(choices=["Automatic", "None"] + [_.name for _ in _binding.remote_model_vaes]), \ 441 | gr.update(choices=[_.alias for _ in _binding.remote_model_embeddings]), \ 442 | gr.update(value=build_model_browser_html_for_checkpoint(tab, _binding.remote_model_checkpoints)), \ 443 | gr.update(value=build_model_browser_html_for_loras(tab, _binding.remote_model_loras)), \ 444 | gr.update(value=build_model_browser_html_for_embeddings(tab, _binding.remote_model_embeddings)) 445 | return wrapper 446 | 447 | refresh_button.click( 448 | fn=_model_refresh(tabname), 449 | inputs=[], 450 | outputs=[cloud_inference_model_dropdown, 451 | cloud_inference_lora_dropdown, 452 | cloud_inference_embedding_dropdown, 453 | cloud_inference_vae_dropdown, 454 | 455 | model_checkpoint_browser_dialog_html, 456 | model_lora_browser_dialog_html, 457 | model_embedding_browser_dialog_html, 458 | ]) 459 | 460 | return [cloud_inference_checkbox, cloud_inference_model_dropdown, cloud_inference_vae_dropdown] 461 | 462 | 463 | # TODO: refactor this 464 | _binding = None 465 | if _binding is None: 466 | _binding = DataBinding() 467 | if shared.opts.data.get("cloud_inference_default_enabled", False): 468 | _binding.remote_inference_enabled = True 469 | 470 | if os.path.isdir(os.path.join(paths_internal.extensions_dir, "sd-webui-controlnet")) and 'sd-webui-controlnet' not in shared.opts.data.get('disabled_extensions', []): 471 | _binding.ext_controlnet_installed = True 472 | 473 | try: 474 | import modules.processing_scripts.refiner 475 | _binding.bultin_refiner_supported = True 476 | except: 477 | pass 478 | 479 | from scripts.hijack import _hijack_manager 480 | _hijack_manager._binding = _binding 481 | _hijack_manager.hijack_onload() 482 | 483 | _binding.remote_models = api.get_instance().list_models() 484 | _binding.update_models() 485 | 486 | 487 | print('Loading extension: sd-webui-cloud-inference') 488 | 489 | 490 | def on_after_component_callback(component, **_kwargs): 491 | if type(component) is gr.Button and getattr(component, 'elem_id', None) == 'txt2img_generate': 492 | _binding.txt2img_generate = component 493 | 494 | if type(component) is gr.Button and getattr(component, 'elem_id', None) == 'img2img_generate': 495 | _binding.img2img_generate = component 496 | 497 | if type(component) is gr.Textbox and getattr(component, 'elem_id', None) == 'txt2img_prompt': 498 | _binding.txt2img_prompt = component 499 | if type(component) is gr.Textbox and getattr(component, 'elem_id', None) == 'txt2img_neg_prompt': 500 | _binding.txt2img_neg_prompt = component 501 | if type(component) is gr.Textbox and getattr(component, 'elem_id', None) == 'img2img_prompt': 502 | _binding.img2img_prompt = component 503 | if type(component) is gr.Textbox and getattr(component, 'elem_id', None) == 'img2img_neg_prompt': 504 | _binding.img2img_neg_prompt = component 505 | 506 | if type(component) is gr.Checkbox and getattr(component, 'elem_id', None) == 'txt2img_cloud_inference_checkbox': 507 | _binding.txt2img_cloud_inference_checkbox = component 508 | if type(component) is gr.Checkbox and getattr(component, 'elem_id', None) == 'img2img_cloud_inference_checkbox': 509 | _binding.img2img_cloud_inference_checkbox = component 510 | if type(component) is gr.Checkbox and getattr(component, 'elem_id', None) == 'txt2img_cloud_inference_suggest_prompts_checkbox': 511 | _binding.txt2img_cloud_inference_suggest_prompts_checkbox = component 512 | if type(component) is gr.Checkbox and getattr(component, 'elem_id', None) == 'img2img_cloud_inference_suggest_prompts_checkbox': 513 | _binding.img2img_cloud_inference_suggest_prompts_checkbox = component 514 | 515 | if type(component) is gr.Dropdown and getattr(component, 'elem_id', None) == 'txt2img_cloud_inference_model_dropdown': 516 | _binding.txt2img_cloud_inference_model_dropdown = component 517 | if type(component) is gr.Dropdown and getattr(component, 'elem_id', None) == 'img2img_cloud_inference_model_dropdown': 518 | _binding.img2img_cloud_inference_model_dropdown = component 519 | 520 | # example: txt2img_controlnet_ControlNet_controlnet_model_dropdown and img2img_controlnet_ControlNet-0_controlnet_model_dropdown 521 | if type(component) is gr.Dropdown and getattr(component, 'elem_id', None) != None and component.elem_id.startswith('txt2img_controlnet_ControlNet') and component.elem_id.endswith('_model_dropdown'): 522 | _binding.txt2img_controlnet_model_dropdown_units.append(component) 523 | _binding.txt2img_controlnet_model_dropdown_original_units.append(component.get_config()) 524 | 525 | if _binding.remote_inference_enabled: 526 | component.choices = ['None'] + [_.alias for _ in _binding.remote_model_controlnet] 527 | component.value = component.choices[0] 528 | 529 | if type(component) is gr.Dropdown and getattr(component, 'elem_id', None) != None and component.elem_id.startswith('img2img_controlnet_ControlNet') and component.elem_id.endswith('_model_dropdown'): 530 | _binding.img2img_controlnet_model_dropdown_units.append(component) 531 | _binding.img2img_controlnet_model_dropdown_original_units.append(component.get_config()) 532 | 533 | if _binding.remote_inference_enabled: 534 | component.choices = ['None'] + [_.alias for _ in _binding.remote_model_controlnet] 535 | component.value = component.choices[0] 536 | 537 | if type(component) is gr.Dropdown and getattr(component, 'elem_id', None) == 'extras_upscaler_1': 538 | _binding.extras_upscaler_1 = component 539 | _binding.extras_upscaler_1_original = component.get_config() 540 | 541 | if _binding.remote_inference_enabled: 542 | component.choices = [_.alias for _ in _binding.remote_model_upscalers] 543 | component.value = component.choices[0] 544 | 545 | if type(component) is gr.Dropdown and getattr(component, 'elem_id', None) == 'extras_upscaler_2': 546 | _binding.extras_upscaler_2 = component 547 | _binding.extras_upscaler_2_original = component.get_config() 548 | 549 | if _binding.remote_inference_enabled: 550 | component.choices = ['None'] + [_.alias for _ in _binding.remote_model_upscalers] 551 | component.value = component.choices[0] 552 | 553 | if type(component) is gr.Dropdown and getattr(component, 'elem_id', None) == 'txt2img_hr_upscaler': 554 | _binding.txt2img_hr_upscaler = component 555 | _binding.txt2img_hr_upscaler_original = component.get_config() 556 | 557 | if _binding.remote_inference_enabled: 558 | component.choices = [_.alias for _ in _binding.remote_model_upscalers] 559 | component.value = component.choices[0] 560 | 561 | # txt2img refiner 562 | if type(component) is gr.Dropdown and getattr(component, 'elem_id', None) == 'txt2img_checkpoint': 563 | _binding.txt2img_checkpoint = component 564 | _binding.txt2img_checkpoint_original = component.get_config() 565 | 566 | if _binding.remote_inference_enabled: 567 | component.choices = ["None"] + [_.name for _ in _binding.remote_model_checkpoints if 'refiner' in _.name] # TODO 568 | component.value = component.choices[0] 569 | if gr.Dropdown and getattr(component, 'elem_id', None) == 'txt2img_checkpoint_refresh': 570 | _binding.txt2img_checkpoint_refresh = component 571 | if _binding.remote_inference_enabled: 572 | component.visible = False 573 | 574 | # img2img refiner 575 | if type(component) is gr.Dropdown and getattr(component, 'elem_id', None) == 'img2img_checkpoint': 576 | _binding.img2img_checkpoint = component 577 | _binding.img2img_checkpoint_original = component.get_config() 578 | 579 | if _binding.remote_inference_enabled: 580 | component.choices = ["None"] + [_.name for _ in _binding.remote_model_checkpoints if 'refiner' in _.name] # TODO 581 | component.value = component.choices[0] 582 | 583 | if gr.Dropdown and getattr(component, 'elem_id', None) == 'img2img_checkpoint_refresh': 584 | _binding.img2img_checkpoint_refresh = component 585 | if _binding.remote_inference_enabled: 586 | component.visible = False 587 | 588 | if _binding.txt2img_cloud_inference_checkbox and \ 589 | _binding.img2img_cloud_inference_checkbox and \ 590 | _binding.txt2img_cloud_inference_model_dropdown and \ 591 | _binding.img2img_cloud_inference_model_dropdown and \ 592 | _binding.txt2img_cloud_inference_suggest_prompts_checkbox and \ 593 | _binding.img2img_cloud_inference_suggest_prompts_checkbox and \ 594 | _binding.txt2img_generate and \ 595 | _binding.img2img_generate and \ 596 | _binding.extras_upscaler_1 and \ 597 | _binding.extras_upscaler_2 and \ 598 | _binding.txt2img_hr_upscaler and \ 599 | not _binding.initialized: 600 | 601 | if _binding.ext_controlnet_installed: 602 | expect_unit_amount = shared.opts.data.get("control_net_max_models_num", 1) 603 | if expect_unit_amount != len(_binding.txt2img_controlnet_model_dropdown_units): 604 | return 605 | 606 | if _binding.bultin_refiner_supported: 607 | if _binding.txt2img_checkpoint is None or _binding.img2img_checkpoint is None: 608 | return 609 | 610 | sync_cloud_model(_binding.txt2img_cloud_inference_model_dropdown, 611 | _binding.img2img_cloud_inference_model_dropdown) 612 | sync_two_component(_binding.txt2img_cloud_inference_suggest_prompts_checkbox, 613 | _binding.img2img_cloud_inference_suggest_prompts_checkbox, 'change') 614 | on_cloud_inference_checkbox_change(_binding) 615 | 616 | _binding.initialized = True 617 | 618 | 619 | def build_model_browser_html_for_checkpoint(tab, checkpoints): 620 | column_html = "" 621 | column_size = 5 622 | column_items = [[] for _ in range(column_size)] 623 | tag_counter = Counter() 624 | kind = "checkpoint" 625 | for i, model in enumerate(checkpoints): 626 | trimed_tags = [_.replace(" ", "_") for _ in model.tags] 627 | tag_counter.update(trimed_tags) 628 | if model.preview_url is None or not model.preview_url.startswith("http"): 629 | model.preview_url = "https://via.placeholder.com/512x512.png?text=Preview+Not+Available" 630 | model_html = f""" 631 | 632 | 633 | {model.name.rsplit(".", 1)[0]} 634 | 635 | 636 | 637 | 638 | Select 639 | 640 | """ 641 | column_index = i % column_size 642 | column_items[column_index].append(model_html) 643 | 644 | for i in range(column_size): 645 | column_image_items_html = "" 646 | for item in column_items[i]: 647 | column_image_items_html += item 648 | column_html += """{}""".format(column_image_items_html) 649 | 650 | tag_html = f""" 651 | ALL 652 | """ 653 | tag_html += """{}""" 654 | tag_html = tag_html.format("\n".join([f"""{_[0].upper()}""" for _ in tag_counter.most_common()])) 655 | 656 | return f"""{kind.upper()} Browser{tag_html} 657 | 658 | {column_html}""" 659 | 660 | 661 | def build_model_browser_html_for_loras(tab, loras): 662 | column_html = "" 663 | column_size = 5 664 | column_items = [[] for _ in range(column_size)] 665 | tag_counter = Counter() 666 | kind = "lora" 667 | for i, model in enumerate(loras): 668 | trimed_tags = [_.replace(" ", "_") for _ in model.tags] 669 | tag_counter.update(trimed_tags) 670 | model_html = f""" 671 | 672 | 673 | {model.name.rsplit(".", 1)[0]} 674 | 675 | 676 | 677 | 678 | Select 679 | 680 | """ 681 | column_index = i % column_size 682 | column_items[column_index].append(model_html) 683 | 684 | for i in range(column_size): 685 | column_image_items_html = "" 686 | for item in column_items[i]: 687 | column_image_items_html += item 688 | column_html += """{}""".format(column_image_items_html) 689 | 690 | tag_html = f""" 691 | ALL 692 | """ 693 | tag_html += """{}""" 694 | tag_html = tag_html.format("\n".join([f"""{_[0].upper()}""" for _ in tag_counter.most_common()])) 695 | 696 | return f"""{kind.upper()} Browser{tag_html}{column_html}""" 697 | 698 | 699 | def build_model_browser_html_for_embeddings(tab, embeddings): 700 | column_html = "" 701 | column_size = 5 702 | column_items = [[] for _ in range(column_size)] 703 | tag_counter = Counter() 704 | kind = "embedding" 705 | for i, model in enumerate(embeddings): 706 | trimed_tags = [_.replace(" ", "_") for _ in model.tags] 707 | tag_counter.update(trimed_tags) 708 | model_html = f""" 709 | 710 | 711 | {model.name.rsplit(".", 1)[0]} 712 | 713 | 714 | 715 | 716 | Select 717 | 718 | """ 719 | column_index = i % column_size 720 | column_items[column_index].append(model_html) 721 | 722 | for i in range(column_size): 723 | column_image_items_html = "" 724 | for item in column_items[i]: 725 | column_image_items_html += item 726 | column_html += """{}""".format(column_image_items_html) 727 | 728 | tag_html = f""" 729 | ALL 730 | """ 731 | tag_html += """{}""" 732 | tag_html = tag_html.format("\n".join([f"""{_[0].upper()}""" for _ in tag_counter.most_common()])) 733 | 734 | return f"""{kind.upper()} Browser{tag_html}{column_html}""" 735 | 736 | 737 | def sync_two_component(a, b, event_name): 738 | def mirror(a, b): 739 | if a != b: 740 | b = a 741 | return a, b 742 | getattr(a, event_name)(fn=mirror, inputs=[a, b], outputs=[a, b]) 743 | getattr(b, event_name)(fn=mirror, inputs=[b, a], outputs=[b, a]) 744 | 745 | 746 | def sync_cloud_model(a, b): 747 | def mirror(a, b): 748 | if a != b: 749 | b = a 750 | return a, b 751 | getattr(a, "change")(fn=mirror, inputs=[a, b], outputs=[a, b]) 752 | getattr(b, "change")(fn=mirror, inputs=[b, a], outputs=[b, a]) 753 | 754 | 755 | def on_cloud_inference_checkbox_change(binding: DataBinding): 756 | def mirror(source, target): 757 | enabled = source 758 | 759 | if source != target: 760 | target = source 761 | 762 | button_text = "Generate" 763 | if enabled: 764 | binding.remote_inference_enabled = True 765 | button_text = "Generate (cloud)" 766 | else: 767 | binding.remote_inference_enabled = False 768 | 769 | controlnet_models = ["None"] + [_.name for _ in binding.remote_model_controlnet] 770 | upscale_models_with_none = ["None"] + [_.alias for _ in binding.remote_model_upscalers] 771 | upscale_models = [_.alias for _ in binding.remote_model_upscalers] 772 | refiner_models = ["None"] + [_.name for _ in binding.remote_model_checkpoints if 'refiner' in _.name] # TODO 773 | 774 | update_components = ( 775 | source, 776 | target, 777 | button_text, 778 | button_text, 779 | ) 780 | 781 | def back_to_original(origin_config): 782 | allow_update_fields = ['value', 'choices'] 783 | return {k: v for k, v in origin_config.items() if k in allow_update_fields} 784 | 785 | if not enabled: 786 | update_components += ( 787 | gr.update(**back_to_original(binding.extras_upscaler_1_original)), 788 | gr.update(**back_to_original(binding.extras_upscaler_2_original)), 789 | gr.update(**back_to_original(binding.txt2img_hr_upscaler_original)) 790 | ) 791 | if binding.ext_controlnet_installed: 792 | update_components += ( 793 | *[gr.update(**back_to_original(_)) for _ in binding.txt2img_controlnet_model_dropdown_original_units], 794 | *[gr.update(**back_to_original(_)) for _ in binding.img2img_controlnet_model_dropdown_original_units], 795 | ) 796 | if binding.bultin_refiner_supported: 797 | update_components += ( 798 | gr.update(**back_to_original(binding.txt2img_checkpoint_original)), 799 | gr.update(**back_to_original(binding.img2img_checkpoint_original)), 800 | gr.update(visible=True), 801 | gr.update(visible=True), 802 | ) 803 | 804 | return update_components 805 | 806 | update_components += ( 807 | gr.update(value=upscale_models[0], choices=upscale_models), 808 | gr.update(value=upscale_models_with_none[0], choices=upscale_models_with_none), 809 | gr.update(value=upscale_models[0], choices=upscale_models), 810 | ) 811 | if binding.ext_controlnet_installed: 812 | update_components += ( 813 | *[gr.update(value=controlnet_models[0], choices=controlnet_models) for _ in binding.txt2img_controlnet_model_dropdown_units], 814 | *[gr.update(value=controlnet_models[0], choices=controlnet_models) for _ in binding.img2img_controlnet_model_dropdown_units], 815 | ) 816 | if binding.bultin_refiner_supported: 817 | update_components += ( 818 | gr.update(value=refiner_models[0], choices=refiner_models), 819 | gr.update(value=refiner_models[0], choices=refiner_models), 820 | gr.update(visible=False), 821 | gr.update(visible=False), 822 | ) 823 | 824 | return update_components 825 | 826 | expect_update_components = ( 827 | _binding.txt2img_generate, 828 | _binding.img2img_generate, 829 | _binding.extras_upscaler_1, 830 | _binding.extras_upscaler_2, 831 | _binding.txt2img_hr_upscaler 832 | ) 833 | if _binding.ext_controlnet_installed: 834 | expect_update_components += ( 835 | *_binding.txt2img_controlnet_model_dropdown_units, 836 | *_binding.img2img_controlnet_model_dropdown_units, 837 | ) 838 | if _binding.bultin_refiner_supported: 839 | expect_update_components += ( 840 | _binding.txt2img_checkpoint, 841 | _binding.img2img_checkpoint, 842 | _binding.txt2img_checkpoint_refresh, 843 | _binding.img2img_checkpoint_refresh, 844 | ) 845 | 846 | _binding.txt2img_cloud_inference_checkbox.change(fn=mirror, 847 | inputs=[_binding.txt2img_cloud_inference_checkbox, 848 | _binding.img2img_cloud_inference_checkbox, 849 | ], 850 | outputs=[ 851 | _binding.img2img_cloud_inference_checkbox, 852 | _binding.txt2img_cloud_inference_checkbox, 853 | *expect_update_components]) 854 | _binding.img2img_cloud_inference_checkbox.change(fn=mirror, 855 | inputs=[_binding.img2img_cloud_inference_checkbox, 856 | _binding.txt2img_cloud_inference_checkbox], 857 | outputs=[ 858 | _binding.img2img_cloud_inference_checkbox, 859 | _binding.txt2img_cloud_inference_checkbox, 860 | *expect_update_components 861 | ]) 862 | 863 | 864 | def on_ui_settings(): 865 | section = ('cloud_inference', "Cloud Inference") 866 | shared.opts.add_option("cloud_inference_default_enabled", shared.OptionInfo( 867 | False, "Cloud Inference Default Enabled", component=gr.Checkbox, section=section)) 868 | shared.opts.add_option("cloud_inference_checkbox_hidden", shared.OptionInfo( 869 | False, "Cloud Inference Checkbox Hideen", component=gr.Checkbox, section=section)) 870 | shared.opts.add_option("cloud_inference_suggest_prompts_default_enabled", shared.OptionInfo( 871 | True, "Cloud Inference Suggest Prompts Default Enabled", component=gr.Checkbox, section=section)) 872 | 873 | 874 | script_callbacks.on_after_component(on_after_component_callback) 875 | script_callbacks.on_ui_settings(on_ui_settings) 876 | script_callbacks.on_app_started(_hijack_manager.hijack_on_app_started) 877 | --------------------------------------------------------------------------------