├── .gitignore ├── .python-version ├── LICENSE ├── Makefile ├── README.md ├── docs ├── Makefile ├── conf.py ├── index.rst ├── make.bat ├── modules.rst └── omniinfer_client.rst ├── examples ├── controlnet_qrcode.py ├── fixtures │ └── qrcode.png ├── model_search.py ├── txt2img_with_hiresfix.py ├── txt2img_with_lora.py └── txt2img_with_refiner.py ├── pyproject.toml ├── requirements-dev.lock ├── requirements.lock ├── src └── omniinfer_client │ ├── __init__.py │ ├── exceptions.py │ ├── omni.py │ ├── proto.py │ ├── serializer.py │ ├── settings.py │ ├── utils.py │ └── version.py └── tests ├── test_basics.py └── test_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | tests/data -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.7.9 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 omniinfer.io 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | build: 2 | @rm -rf dist 3 | @.venv/bin/python3 -m build 4 | 5 | upload: build 6 | @.venv/bin/python3 -m twine upload dist/* 7 | 8 | docs: 9 | @cd docs && make html -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Omniinfer Python SDK 2 | 3 | Thanks to the initial contribution of [@shanginn](https://github.com/shanginn), we have made the decision to create this SDK. 4 | 5 | this SDK is based on the official [API documentation](https://docs.omniinfer.io/) 6 | 7 | **join our discord server for help** 8 | 9 | [![](https://dcbadge.vercel.app/api/server/nzqq8UScpx)](https://discord.gg/nzqq8UScpx) 10 | 11 | ## Installation 12 | 13 | ```bash 14 | pip install omniinfer-client 15 | ``` 16 | 17 | ## Quick Start 18 | 19 | **Get api key refer to [https://docs.omniinfer.io/get-started](https://docs.omniinfer.io/get-started/)** 20 | 21 | ```python 22 | import os 23 | from omniinfer_client import OmniClient, Txt2ImgRequest, Samplers, ModelType, save_image 24 | 25 | client = OmniClient(os.getenv('OMNI_API_KEY')) 26 | 27 | req = Txt2ImgRequest( 28 | model_name='sd_xl_base_1.0.safetensors', 29 | prompt='a dog flying in the sky', 30 | batch_size=1, 31 | cfg_scale=7.5, 32 | height=1024, 33 | width=1024, 34 | sampler_name=Samplers.EULER_A, 35 | ) 36 | save_image(client.sync_txt2img(req).data.imgs_bytes[0], 'output.png') 37 | ``` 38 | 39 | ## Examples 40 | 41 | [txt2img_with_lora.py](./examples/txt2img_with_lora.py) 42 | 43 | ```python 44 | #!/usr/bin/env python 45 | # -*- coding: UTF-8 -*- 46 | 47 | import os 48 | from omniinfer_client import OmniClient, Txt2ImgRequest, Samplers, ProgressResponseStatusCode, ModelType, add_lora_to_prompt, save_image 49 | 50 | 51 | client = OmniClient(os.getenv('OMNI_API_KEY')) 52 | models = client.models() 53 | 54 | # Anything V5/Ink, https://civitai.com/models/9409/or-anything-v5ink 55 | checkpoint_model = models.filter_by_type(ModelType.CHECKPOINT).get_by_civitai_version_id(90854) 56 | 57 | # Detail Tweaker LoRA, https://civitai.com/models/58390/detail-tweaker-lora-lora 58 | lora_model = models.filter_by_type(ModelType.LORA).get_by_civitai_version_id(62833) 59 | 60 | prompt = add_lora_to_prompt('a dog flying in the sky', lora_model.sd_name, "0.8") 61 | 62 | res = client.sync_txt2img(Txt2ImgRequest( 63 | prompt=prompt, 64 | batch_size=1, 65 | cfg_scale=7.5, 66 | sampler_name=Samplers.EULER_A, 67 | model_name=checkpoint_model.sd_name, 68 | seed=103304, 69 | )) 70 | 71 | if res.data.status != ProgressResponseStatusCode.SUCCESSFUL: 72 | raise Exception('Failed to generate image with error: ' + 73 | res.data.failed_reason) 74 | save_image(res.data.imgs_bytes[0], "test.png") 75 | ``` 76 | 77 | ### Model Search 78 | 79 | [model_search.py](./examples/model_search.py) 80 | 81 | ```python 82 | from omniinfer_client import OmniClient, ModelType 83 | 84 | client = OmniClient(os.getenv('OMNI_API_KEY')) 85 | 86 | # filter by model type 87 | print("lora count", len(client.models().filter_by_type(ModelType.LORA))) 88 | print("checkpoint count", len(client.models().filter_by_type(ModelType.CHECKPOINT))) 89 | print("textinversion count", len( 90 | client.models().filter_by_type(ModelType.TEXT_INVERSION))) 91 | print("vae count", len(client.models().filter_by_type(ModelType.VAE))) 92 | print("controlnet count", len(client.models().filter_by_type(ModelType.CONTROLNET))) 93 | 94 | 95 | # filter by civitai tags 96 | client.models().filter_by_civi_tags('anime') 97 | 98 | # filter by nsfw 99 | client.models().filter_by_nsfw(False) # or True 100 | 101 | # sort by civitai download 102 | client.models().sort_by_civitai_download() 103 | 104 | # chain filters 105 | client.models().\ 106 | filter_by_type(ModelType.CHECKPOINT).\ 107 | filter_by_nsfw(False).\ 108 | filter_by_civitai_tags('anime') 109 | ``` 110 | 111 | ### ControlNet QRCode 112 | 113 | [controlnet_qrcode.py](./examples/controlnet_qrcode.py) 114 | 115 | ```python 116 | import os 117 | 118 | from omniinfer_client import * 119 | 120 | # get your api key refer to https://docs.omniinfer.io/get-started/ 121 | client = OmniClient(os.getenv('OMNI_API_KEY')) 122 | 123 | controlnet_model = client.models().filter_by_type(ModelType.CONTROLNET).get_by_name("control_v1p_sd15_qrcode_monster_v2") 124 | if controlnet_model is None: 125 | raise Exception("controlnet model not found") 126 | 127 | req = Txt2ImgRequest( 128 | prompt="a beautify butterfly in the colorful flowers, best quality, best details, masterpiece", 129 | sampler_name=Samplers.DPMPP_M_KARRAS, 130 | width=512, 131 | height=512, 132 | steps=30, 133 | controlnet_units=[ 134 | ControlnetUnit( 135 | input_image=read_image_to_base64(os.path.join(os.path.abspath(os.path.dirname(__file__)), "fixtures/qrcode.png")), 136 | control_mode=ControlNetMode.BALANCED, 137 | model=controlnet_model.sd_name, 138 | module=ControlNetPreprocessor.NULL, 139 | resize_mode=ControlNetResizeMode.JUST_RESIZE, 140 | weight=2.0, 141 | ) 142 | ] 143 | ) 144 | 145 | res = client.sync_txt2img(req) 146 | if res.data.status != ProgressResponseStatusCode.SUCCESSFUL: 147 | raise Exception('Failed to generate image with error: ' + 148 | res.data.failed_reason) 149 | 150 | save_image(res.data.imgs_bytes[0], "qrcode-art.png") 151 | ``` 152 | 153 | ### Txt2Img with Hires.Fix 154 | 155 | [txt2img_with_hiresfix.py](./examples/txt2img_with_hiresfix.py) 156 | 157 | ```python 158 | import os 159 | 160 | from omniinfer_client import * 161 | 162 | client = OmniClient(os.getenv('OMNI_API_KEY')) 163 | req = Txt2ImgRequest( 164 | model_name='dreamshaper_8_93211.safetensors', 165 | prompt='a dog flying in the sky', 166 | width=512, 167 | height=512, 168 | batch_size=1, 169 | cfg_scale=7.5, 170 | sampler_name=Samplers.EULER_A, 171 | enable_hr=True, 172 | hr_scale=2.0 173 | ) 174 | 175 | res = client.sync_txt2img(req) 176 | if res.data.status != ProgressResponseStatusCode.SUCCESSFUL: 177 | raise Exception('Failed to generate image with error: ' + 178 | res.data.failed_reason) 179 | 180 | save_image(res.data.imgs_bytes[0], "txt2img-hiresfix-1024.png") 181 | ``` 182 | 183 | ### SDXL Refiner 184 | 185 | [sdxl_refiner.py](./txt2img_with_refiner.py) 186 | 187 | ```python 188 | import os 189 | 190 | from omniinfer_client import * 191 | 192 | client = OmniClient(os.getenv('OMNI_API_KEY')) 193 | req = Txt2ImgRequest( 194 | model_name='sd_xl_base_1.0.safetensors', 195 | prompt='a dog flying in the sky', 196 | width=1024, 197 | height=1024, 198 | batch_size=1, 199 | cfg_scale=7.5, 200 | sampler_name=Samplers.EULER_A, 201 | sd_refiner=Refiner( 202 | checkpoint='sd_xl_refiner_1.0.safetensors', 203 | switch_at=0.5, 204 | )) 205 | 206 | res = client.sync_txt2img(req) 207 | if res.data.status != ProgressResponseStatusCode.SUCCESSFUL: 208 | raise Exception('Failed to generate image with error: ' + 209 | res.data.failed_reason) 210 | 211 | save_image(res.data.imgs_bytes[0], "txt2img-refiner.png") 212 | ``` 213 | 214 | 215 | ## Testing 216 | 217 | ``` 218 | export OMNI_API_KEY= 219 | 220 | python -m pytest 221 | ``` -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | import os 10 | import sys 11 | 12 | sys.path.insert(0, os.path.abspath('..')) 13 | 14 | project = 'omniinfer-python-sdk' 15 | copyright = '2023, Omniinfer' 16 | author = 'Omniinfer' 17 | release = '0.0.1' 18 | 19 | # -- General configuration --------------------------------------------------- 20 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 21 | 22 | extensions = [ 23 | 'sphinx.ext.autodoc', 24 | 'sphinx.ext.viewcode', 25 | 'sphinx.ext.napoleon' 26 | ] 27 | 28 | templates_path = ['_templates'] 29 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 30 | 31 | 32 | # -- Options for HTML output ------------------------------------------------- 33 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 34 | 35 | html_theme = 'alabaster' 36 | html_static_path = ['_static'] 37 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. omniinfer-python-sdk documentation master file, created by 2 | sphinx-quickstart on Sun Aug 6 20:52:31 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to omniinfer-python-sdk's documentation! 7 | ================================================ 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | 14 | 15 | Indices and tables 16 | ================== 17 | 18 | * :ref:`genindex` 19 | * :ref:`modindex` 20 | * :ref:`search` 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | src 2 | === 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | omniinfer_client 8 | -------------------------------------------------------------------------------- /docs/omniinfer_client.rst: -------------------------------------------------------------------------------- 1 | omniinfer\_client package 2 | ========================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | omniinfer\_client.exceptions module 8 | ----------------------------------- 9 | 10 | .. automodule:: omniinfer_client.exceptions 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | omniinfer\_client.omni module 16 | ----------------------------- 17 | 18 | .. automodule:: omniinfer_client.omni 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | omniinfer\_client.proto module 24 | ------------------------------ 25 | 26 | .. automodule:: omniinfer_client.proto 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | omniinfer\_client.serializer module 32 | 33 | --------------------------------- 34 | 35 | .. automodule:: omniinfer_client.settings 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | 40 | omniinfer\_client.utils module 41 | ------------------------------ 42 | 43 | .. automodule:: omniinfer_client.utils 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | 48 | Module contents 49 | --------------- 50 | 51 | .. automodule:: omniinfer_client 52 | :members: 53 | :undoc-members: 54 | :show-inheritance: 55 | -------------------------------------------------------------------------------- /examples/controlnet_qrcode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | import os 5 | 6 | from omniinfer_client import * 7 | 8 | # get your api key refer to https://docs.omniinfer.io/get-started/ 9 | client = OmniClient(os.getenv('OMNI_API_KEY')) 10 | 11 | controlnet_model = client.models().filter_by_type(ModelType.CONTROLNET).get_by_name("control_v1p_sd15_qrcode_monster_v2") 12 | if controlnet_model is None: 13 | raise Exception("controlnet model not found") 14 | 15 | req = Txt2ImgRequest( 16 | prompt="a beautify butterfly in the colorful flowers, best quality, best details, masterpiece", 17 | sampler_name=Samplers.DPMPP_M_KARRAS, 18 | width=512, 19 | height=512, 20 | steps=30, 21 | controlnet_units=[ 22 | ControlnetUnit( 23 | input_image=read_image_to_base64(os.path.join(os.path.abspath(os.path.dirname(__file__)), "fixtures/qrcode.png")), 24 | control_mode=ControlNetMode.BALANCED, 25 | model=controlnet_model.sd_name, 26 | module=ControlNetPreprocessor.NULL, 27 | resize_mode=ControlNetResizeMode.JUST_RESIZE, 28 | weight=2.0, 29 | ) 30 | ] 31 | ) 32 | 33 | res = client.sync_txt2img(req) 34 | if res.data.status != ProgressResponseStatusCode.SUCCESSFUL: 35 | raise Exception('Failed to generate image with error: ' + 36 | res.data.failed_reason) 37 | 38 | save_image(res.data.imgs_bytes[0], "qrcode-art.png") 39 | -------------------------------------------------------------------------------- /examples/fixtures/qrcode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omniinfer/python-sdk/61b6d21bcfb4e7d27fe772018a4466216b3f5647/examples/fixtures/qrcode.png -------------------------------------------------------------------------------- /examples/model_search.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | from omniinfer_client import OmniClient, ModelType 5 | # get your api key refer to https://docs.omniinfer.io/get-started/ 6 | client = OmniClient(os.getenv('OMNI_API_KEY')) 7 | 8 | # filter by model type 9 | print("lora count", len(client.models().filter_by_type(ModelType.LORA))) 10 | print("checkpoint count", len(client.models().filter_by_type(ModelType.CHECKPOINT))) 11 | print("textinversion count", len( 12 | client.models().filter_by_type(ModelType.TEXT_INVERSION))) 13 | print("vae count", len(client.models().filter_by_type(ModelType.VAE))) 14 | print("controlnet count", len(client.models().filter_by_type(ModelType.CONTROLNET))) 15 | 16 | 17 | # filter by civitai tags 18 | client.models().filter_by_civi_tags('anime') 19 | 20 | # filter by nsfw 21 | client.models().filter_by_nsfw(False) # or True 22 | 23 | # sort by civitai download 24 | client.models().sort_by_civitai_download() 25 | 26 | # chain filters 27 | client.models().\ 28 | filter_by_type(ModelType.CHECKPOINT).\ 29 | filter_by_nsfw(False).\ 30 | filter_by_civitai_tags('anime') 31 | -------------------------------------------------------------------------------- /examples/txt2img_with_hiresfix.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from omniinfer_client import * 4 | 5 | client = OmniClient(os.getenv('OMNI_API_KEY')) 6 | req = Txt2ImgRequest( 7 | model_name='dreamshaper_8_93211.safetensors', 8 | prompt='a dog flying in the sky', 9 | width=512, 10 | height=512, 11 | batch_size=1, 12 | cfg_scale=7.5, 13 | sampler_name=Samplers.EULER_A, 14 | enable_hr=True, 15 | hr_scale=2.0 16 | ) 17 | 18 | res = client.sync_txt2img(req) 19 | if res.data.status != ProgressResponseStatusCode.SUCCESSFUL: 20 | raise Exception('Failed to generate image with error: ' + 21 | res.data.failed_reason) 22 | 23 | save_image(res.data.imgs_bytes[0], "txt2img-hiresfix-1024.png") 24 | -------------------------------------------------------------------------------- /examples/txt2img_with_lora.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | import os 5 | from omniinfer_client import OmniClient, Txt2ImgRequest, Samplers, ProgressResponseStatusCode, ModelType, add_lora_to_prompt, save_image 6 | 7 | 8 | client = OmniClient(os.getenv('OMNI_API_KEY')) 9 | models = client.models() 10 | 11 | # Anything V5/Ink, https://civitai.com/models/9409/or-anything-v5ink 12 | checkpoint_model = models.filter_by_type(ModelType.CHECKPOINT).get_by_civitai_version_id(90854) 13 | 14 | # Detail Tweaker LoRA, https://civitai.com/models/58390/detail-tweaker-lora-lora 15 | lora_model = models.filter_by_type(ModelType.LORA).get_by_civitai_version_id(62833) 16 | 17 | prompt = add_lora_to_prompt('a dog flying in the sky', lora_model.sd_name, "0.8") 18 | 19 | res = client.sync_txt2img(Txt2ImgRequest( 20 | prompt=prompt, 21 | batch_size=1, 22 | cfg_scale=7.5, 23 | sampler_name=Samplers.EULER_A, 24 | model_name=checkpoint_model.sd_name, 25 | seed=103304, 26 | )) 27 | 28 | if res.data.status != ProgressResponseStatusCode.SUCCESSFUL: 29 | raise Exception('Failed to generate image with error: ' + 30 | res.data.failed_reason) 31 | save_image(res.data.imgs_bytes[0], "test.png") 32 | -------------------------------------------------------------------------------- /examples/txt2img_with_refiner.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from omniinfer_client import * 4 | 5 | client = OmniClient(os.getenv('OMNI_API_KEY')) 6 | req = Txt2ImgRequest( 7 | model_name='sd_xl_base_1.0.safetensors', 8 | prompt='a dog flying in the sky', 9 | width=1024, 10 | height=1024, 11 | batch_size=1, 12 | cfg_scale=7.5, 13 | sampler_name=Samplers.EULER_A, 14 | sd_refiner=Refiner( 15 | checkpoint='sd_xl_refiner_1.0.safetensors', 16 | switch_at=0.5, 17 | )) 18 | 19 | res = client.sync_txt2img(req) 20 | if res.data.status != ProgressResponseStatusCode.SUCCESSFUL: 21 | raise Exception('Failed to generate image with error: ' + 22 | res.data.failed_reason) 23 | 24 | save_image(res.data.imgs_bytes[0], "txt2img-refiner.png") 25 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "omniinfer_client" 3 | version = "0.3.5" 4 | description = "Omniinfer SDK for Python" 5 | authors = [ 6 | { name = "Omniinfer", email = "omniinfer@gmail.com" } 7 | ] 8 | 9 | 10 | dependencies = [ 11 | "dataclass_wizard>=0.22.2", 12 | "requests>=2.27.1", 13 | ] 14 | 15 | license = { file = "LICENSE" } 16 | 17 | classifiers = [ 18 | "License :: OSI Approved :: MIT License", 19 | "Programming Language :: Python", 20 | "Programming Language :: Python :: 3", 21 | ] 22 | 23 | readme = "README.md" 24 | requires-python = ">= 3.6" 25 | 26 | [project.urls] 27 | "Homepage" = "https://github.com/omniinfer/python-sdk" 28 | "Bug Tracker" = "https://discord.gg/nzqq8UScpx" 29 | 30 | [build-system] 31 | requires = ["hatchling"] 32 | build-backend = "hatchling.build" 33 | 34 | [tool.rye] 35 | managed = true 36 | dev-dependencies = [ 37 | "pytest>=7.0.1", 38 | "pytest-dependency>=0.5.1", 39 | "build>=0.9.0", 40 | "twine>=3.8.0", 41 | "pillow>=8.4.0", 42 | ] 43 | 44 | [tool.hatch.metadata] 45 | allow-direct-references = true 46 | -------------------------------------------------------------------------------- /requirements-dev.lock: -------------------------------------------------------------------------------- 1 | # generated by rye 2 | # use `rye lock` or `rye sync` to update this lockfile 3 | # 4 | # last locked with the following flags: 5 | # pre: false 6 | # features: [] 7 | # all-features: false 8 | 9 | -e file:. 10 | bleach==6.0.0 11 | build==0.10.0 12 | certifi==2023.7.22 13 | cffi==1.15.1 14 | charset-normalizer==3.2.0 15 | cryptography==41.0.3 16 | dataclass-wizard==0.22.2 17 | docutils==0.20.1 18 | exceptiongroup==1.1.2 19 | idna==3.4 20 | importlib-metadata==6.7.0 21 | importlib-resources==5.12.0 22 | iniconfig==2.0.0 23 | jaraco-classes==3.2.3 24 | jeepney==0.8.0 25 | keyring==24.1.1 26 | markdown-it-py==2.2.0 27 | mdurl==0.1.2 28 | more-itertools==9.1.0 29 | packaging==23.1 30 | pillow==9.5.0 31 | pkginfo==1.9.6 32 | pluggy==1.2.0 33 | pycparser==2.21 34 | pygments==2.16.0 35 | pyproject-hooks==1.0.0 36 | pytest==7.4.0 37 | pytest-dependency==0.5.1 38 | readme-renderer==37.3 39 | requests==2.31.0 40 | requests-toolbelt==1.0.0 41 | rfc3986==2.0.0 42 | rich==13.5.2 43 | secretstorage==3.3.3 44 | six==1.16.0 45 | tomli==2.0.1 46 | twine==4.0.2 47 | typing-extensions==4.7.1 48 | urllib3==2.0.4 49 | webencodings==0.5.1 50 | zipp==3.15.0 51 | -------------------------------------------------------------------------------- /requirements.lock: -------------------------------------------------------------------------------- 1 | # generated by rye 2 | # use `rye lock` or `rye sync` to update this lockfile 3 | # 4 | # last locked with the following flags: 5 | # pre: false 6 | # features: [] 7 | # all-features: false 8 | 9 | -e file:. 10 | certifi==2023.7.22 11 | charset-normalizer==3.2.0 12 | dataclass-wizard==0.22.2 13 | idna==3.4 14 | requests==2.31.0 15 | typing-extensions==4.7.1 16 | urllib3==2.0.4 17 | -------------------------------------------------------------------------------- /src/omniinfer_client/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | from .omni import * 5 | from .proto import * 6 | from .utils import * 7 | -------------------------------------------------------------------------------- /src/omniinfer_client/exceptions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | class OmniError(Exception): 5 | pass 6 | 7 | 8 | class OmniResponseError(OmniError): 9 | pass 10 | 11 | 12 | class OmniTimeoutError(OmniError): 13 | pass 14 | -------------------------------------------------------------------------------- /src/omniinfer_client/omni.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | import logging 5 | 6 | from time import sleep 7 | 8 | from .version import __version__ 9 | 10 | from .exceptions import * 11 | from .proto import * 12 | 13 | import requests 14 | from . import settings 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class OmniClient: 21 | """OmniClient is the main entry point for interacting with the Omni API.""" 22 | 23 | def __init__(self, api_key): 24 | self.base_url = "https://api.omniinfer.io/v2" 25 | self.api_key = api_key 26 | self.session = requests.Session() 27 | 28 | if not self.api_key: 29 | raise ValueError("OMNI_API_KEY environment variable not set") 30 | 31 | # eg: {"all": [proto.ModelInfo], "checkpoint": [proto.ModelInfo], "lora": [proto.ModelInfo]} 32 | self._model_list_cache = None 33 | self._extra_headers = {} 34 | 35 | def set_extra_headers(self, headers: dict): 36 | self._extra_headers = headers 37 | 38 | def _get(self, api_path, params=None) -> dict: 39 | headers = { 40 | 'Accept': 'application/json', 41 | 'Content-Type': 'application/json', 42 | 'X-Omni-Key': self.api_key, 43 | 'User-Agent': "omniinfer-python-sdk/{}".format(__version__), 44 | 'Accept-Encoding': 'gzip, deflate', 45 | } 46 | headers.update(self._extra_headers) 47 | 48 | logger.debug(f"[GET] params: {params}") 49 | 50 | response = self.session.get( 51 | self.base_url + api_path, 52 | headers=headers, 53 | params=params, 54 | timeout=settings.DEFAULT_REQUEST_TIMEOUT, 55 | ) 56 | 57 | logger.debug(f"[GET] response: {response.content}") 58 | if response.status_code != 200: 59 | logger.error(f"Request failed: {response}") 60 | raise OmniResponseError( 61 | f"Request failed with status {response.status_code}") 62 | 63 | return response.json() 64 | 65 | def _post(self, api_path, data) -> dict: 66 | headers = { 67 | 'Accept': 'application/json', 68 | 'Content-Type': 'application/json', 69 | 'X-Omni-Key': self.api_key, 70 | 'User-Agent': "omniinfer-python-sdk/{}".format(__version__), 71 | 'Accept-Encoding': 'gzip, deflate', 72 | } 73 | headers.update(self._extra_headers) 74 | 75 | logger.debug(f"[POST] data: {data}") 76 | 77 | response = self.session.post( 78 | self.base_url + api_path, 79 | headers=headers, 80 | json=data, 81 | timeout=settings.DEFAULT_REQUEST_TIMEOUT, 82 | ) 83 | 84 | logger.debug(f"[POST] response: {response.content}") 85 | if response.status_code != 200: 86 | logger.error(f"Request failed: {response}") 87 | raise OmniResponseError( 88 | f"Request failed with status {response.status_code}") 89 | 90 | return response.json() 91 | 92 | def txt2img(self, request: Txt2ImgRequest) -> Txt2ImgResponse: 93 | """Asynchronously generate images from request 94 | 95 | Args: 96 | request (Txt2ImgRequest): The request object containing the text and image generation parameters. 97 | 98 | Returns: 99 | Txt2ImgResponse: The response object containing the task ID and status URL. 100 | """ 101 | response = self._post('/txt2img', request.to_dict()) 102 | 103 | return Txt2ImgResponse.from_dict(response) 104 | 105 | def progress(self, task_id: str) -> ProgressResponse: 106 | """Get the progress of a task. 107 | 108 | Args: 109 | task_id (str): The ID of the task to get the progress for. 110 | 111 | Returns: 112 | ProgressResponse: The response object containing the progress information for the task. 113 | """ 114 | response = self._get('/progress', { 115 | 'task_id': task_id, 116 | }) 117 | 118 | return ProgressResponse.from_dict(response) 119 | 120 | def img2img(self, request: Img2ImgRequest) -> Img2ImgResponse: 121 | """Asynchronously generate images from request 122 | 123 | Args: 124 | request (Img2ImgRequest): The request object containing the image and image generation parameters. 125 | 126 | Returns: 127 | Img2ImgResponse: The response object containing the task ID and status URL. 128 | """ 129 | response = self._post('/img2img', request.to_dict()) 130 | 131 | return Img2ImgResponse.from_dict(response) 132 | 133 | def wait_for_task(self, task_id, wait_for: int = 300, callback: callable = None) -> ProgressResponse: 134 | """Wait for a task to complete 135 | 136 | This method waits for a task to complete by periodically checking its progress. If the task is not completed within the specified time, an OmniTimeoutError is raised. 137 | 138 | Args: 139 | task_id (_type_): The ID of the task to wait for. 140 | wait_for (int, optional): The maximum time to wait for the task to complete, in seconds. Defaults to 300. 141 | 142 | Raises: 143 | OmniTimeoutError: If the task fails to complete within the specified time. 144 | 145 | Returns: 146 | ProgressResponse: The response object containing the progress information for the task. 147 | """ 148 | i = 0 149 | 150 | while i < wait_for: 151 | logger.info(f"Waiting for task {task_id} to complete") 152 | 153 | progress = self.progress(task_id) 154 | 155 | if callback and callable(callback): 156 | try: 157 | callback(progress) 158 | except Exception as e: 159 | logger.error(f"Task {task_id} progress callback failed: {e}") 160 | 161 | logger.info( 162 | f"Task {task_id} progress eta_relative: {progress.data.eta_relative}") 163 | 164 | if progress.data.status.finished(): 165 | logger.info(f"Task {task_id} completed") 166 | return progress 167 | 168 | sleep(settings.DEFAULT_POLL_INTERVAL) 169 | i += 1 170 | 171 | raise OmniTimeoutError( 172 | f"Task {task_id} failed to complete in {wait_for} seconds") 173 | 174 | def sync_txt2img(self, request: Txt2ImgRequest, download_images=True, callback: callable = None) -> ProgressResponse: 175 | """Synchronously generate images from request, optionally download images 176 | 177 | This method generates images synchronously from the given request object. If download_images is set to True, the generated images will be downloaded. 178 | 179 | Args: 180 | request (Txt2ImgRequest): The request object containing the input text and other parameters. 181 | download_images (bool, optional): Whether to download the generated images. Defaults to True. 182 | 183 | Raises: 184 | OmniResponseError: If the text to image generation fails. 185 | 186 | Returns: 187 | ProgressResponse: The response object containing the task status and generated images. 188 | """ 189 | response = self.txt2img(request) 190 | 191 | if response.data is None: 192 | raise OmniResponseError(f"Text to Image generation failed with response {response.msg}, code: {response.code}") 193 | 194 | res = self.wait_for_task(response.data.task_id, callback=callback) 195 | if download_images: 196 | res.download_images() 197 | return res 198 | 199 | def sync_img2img(self, request: Img2ImgRequest, download_images=True, callback: callable = None) -> ProgressResponse: 200 | """Synchronously generate images from request, optionally download images 201 | 202 | Args: 203 | request (Img2ImgRequest): The request object containing the input image and other parameters. 204 | download_images (bool, optional): Whether to download the generated images. Defaults to True. 205 | 206 | Returns: 207 | ProgressResponse: The response object containing the task status and generated images. 208 | """ 209 | response = self.img2img(request) 210 | 211 | if response.data is None: 212 | raise OmniResponseError(f"Image to Image generation failed with response {response.msg}, code: {response.code}") 213 | 214 | res = self.wait_for_task(response.data.task_id, callback=callback) 215 | if download_images: 216 | res.download_images() 217 | return res 218 | 219 | def sync_upscale(self, request: UpscaleRequest, download_images=True, callback: callable = None) -> ProgressResponse: 220 | """Syncronously upscale image from request, optionally download images 221 | 222 | Args: 223 | request (UpscaleRequest): _description_ 224 | download_images (bool, optional): _description_. Defaults to True. 225 | 226 | Returns: 227 | ProgressResponse: _description_ 228 | """ 229 | response = self.upscale(request) 230 | 231 | if response.data is None: 232 | raise OmniResponseError(f"Upscale failed with response {response.msg}, code: {response.code}") 233 | 234 | res = self.wait_for_task(response.data.task_id, callback=callback) 235 | if download_images: 236 | res.download_images() 237 | return res 238 | 239 | def upscale(self, request: UpscaleRequest) -> UpscaleResponse: 240 | """Upscale image 241 | 242 | This method sends a request to the Omni API to upscale an image using the specified parameters. 243 | 244 | Args: 245 | request (UpscaleRequest): An object containing the input image and other parameters. 246 | 247 | Returns: 248 | UpscaleResponse: An object containing the task status and the URL of the upscaled image. 249 | """ 250 | response = self._post('/upscale', request.to_dict()) 251 | 252 | return UpscaleResponse.from_dict(response) 253 | 254 | def models(self, refresh=False) -> ModelList: 255 | """Get list of models 256 | 257 | This method retrieves a list of models available in the Omni API. If the list has already been retrieved and 258 | `refresh` is False, the cached list will be returned. Otherwise, a new request will be made to the API to 259 | retrieve the list. 260 | 261 | Args: 262 | refresh (bool, optional): If True, a new request will be made to the API to retrieve the list of models. 263 | If False and the list has already been retrieved, the cached list will be returned. Defaults to False. 264 | 265 | Returns: 266 | ModelList: A list of models available in the Omni API. 267 | """ 268 | 269 | if (self._model_list_cache is None or len(self._model_list_cache) == 0) or refresh: 270 | res = self._get('/models') 271 | 272 | # TODO: fix this 273 | res_controlnet = self._get( 274 | '/models', params={'type': 'controlnet'}) 275 | res_vae = self._get('/models', params={'type': 'vae'}) 276 | 277 | tmp = [] 278 | tmp.extend(MoodelsResponse.from_dict(res).data.models) 279 | tmp.extend(MoodelsResponse.from_dict(res_controlnet).data.models) 280 | tmp.extend(MoodelsResponse.from_dict(res_vae).data.models) 281 | 282 | # In future /models maybe return all models, so we need to filter out duplicates 283 | tmp_set = set() 284 | models = [] 285 | for m in tmp: 286 | if m.sd_name not in tmp_set: 287 | tmp_set.add(m.sd_name) 288 | models.append(m) 289 | 290 | self._model_list_cache = ModelList(models) 291 | 292 | return self._model_list_cache 293 | -------------------------------------------------------------------------------- /src/omniinfer_client/proto.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | from dataclasses import dataclass, field 5 | from typing import Optional, List 6 | from .serializer import JSONe 7 | from enum import Enum 8 | from .utils import batch_download_images 9 | 10 | 11 | # --------------- ControlNet --------------- 12 | 13 | class ControlNetMode(Enum): 14 | BALANCED = 0 15 | PROMPT_IMPORTANCE = 1 16 | CONTROLNET_IMPORTANCE = 2 17 | 18 | def __str__(self): 19 | return self.name 20 | 21 | 22 | class ControlNetResizeMode(Enum): 23 | JUST_RESIZE = 0 24 | RESIZE_OR_CORP = 1 25 | RESIZE_AND_FILL = 2 26 | 27 | def __str__(self): 28 | return self.name 29 | 30 | 31 | class ControlNetPreprocessor(Enum): 32 | NULL = 'none' 33 | CANNY = 'canny' 34 | DEPTH = 'depth' 35 | DEPTH_LERES = 'depth_leres' 36 | DEPTH_LERES_PLUS_PLUS = 'depth_leres++' 37 | HED = 'hed' 38 | HED_SAFE = 'hed_safe' 39 | MEDIAPIPE_FACE = 'mediapipe_face' 40 | MLSD = 'mlsd' 41 | NORMAL_MAP = 'normal_map' 42 | OPENPOSE = 'openpose' 43 | OPENPOSE_HAND = 'openpose_hand' 44 | OPENPOSE_FACE = 'openpose_face' 45 | OPENPOSE_FACEONLY = 'openpose_faceonly' 46 | OPENPOSE_FULL = 'openpose_full' 47 | CLIP_VISION = 'clip_vision' 48 | COLOR = 'color' 49 | PIDINET = 'pidinet' 50 | PIDINET_SAFE = 'pidinet_safe' 51 | PIDINET_SKETCH = 'pidinet_sketch' 52 | PIDINET_SCRIBBLE = 'pidinet_scribble' 53 | SCRIBBLE_XDOG = 'scribble_xdog' 54 | SCRIBBLE_HED = 'scribble_hed' 55 | SEGMENTATION = 'segmentation' 56 | THRESHOLD = 'threshold' 57 | DEPTH_ZOE = 'depth_zoe' 58 | NORMAL_BAE = 'normal_bae' 59 | ONEFORMER_COCO = 'oneformer_coco' 60 | ONEFORMER_ADE20K = 'oneformer_ade20k' 61 | LINEART = 'lineart' 62 | LINEART_COARSE = 'lineart_coarse' 63 | LINEART_ANIME = 'lineart_anime' 64 | LINEART_STANDARD = 'lineart_standard' 65 | SHUFFLE = 'shuffle' 66 | TILE_RESAMPLE = 'tile_resample' 67 | INVERT = 'invert' 68 | LINEART_ANIME_DENOISE = 'lineart_anime_denoise' 69 | REFERENCE_ONLY = 'reference_only' 70 | REFERENCE_ADAIN = 'reference_adain' 71 | REFERENCE_ADAIN_PLUS_ATTN = 'reference_adain+attn' 72 | INPAINT = 'inpaint' 73 | INPAINT_ONLY = 'inpaint_only' 74 | INPAINT_ONLY_PLUS_LAMA = 'inpaint_only+lama' 75 | TILE_COLORFIX = 'tile_colorfix' 76 | TILE_COLORFIX_PLUS_SHARP = 'tile_colorfix+sharp' 77 | 78 | def __str__(self): 79 | return self.name 80 | 81 | 82 | @dataclass 83 | class ControlnetUnit(JSONe): 84 | model: str 85 | weight: Optional[float] = 1 86 | module: Optional[ControlNetPreprocessor] = ControlNetPreprocessor.NULL 87 | input_image: Optional[str] = None 88 | control_mode: Optional[ControlNetMode] = ControlNetMode.BALANCED 89 | resize_mode: Optional[ControlNetResizeMode] = ControlNetResizeMode.RESIZE_OR_CORP 90 | mask: Optional[str] = None 91 | processor_res: Optional[int] = 512 92 | threshold_a: Optional[int] = 64 93 | threshold_b: Optional[int] = 64 94 | guidance_start: Optional[float] = 0.0 95 | guidance_end: Optional[float] = 1.0 96 | pixel_perfect: Optional[bool] = False 97 | 98 | 99 | # --------------- Samplers --------------- 100 | @dataclass 101 | class Samplers: 102 | EULER_A = 'Euler a' 103 | EULER = 'Euler' 104 | LMS = 'LMS' 105 | HEUN = 'Heun' 106 | DPM2 = 'DPM2' 107 | DPM2_A = 'DPM2 a' 108 | DPM2_KARRAS = 'DPM2 Karras' 109 | DPM2_A_KARRAS = 'DPM2 a Karras' 110 | DPMPP_S_A = 'DPM++ 2S a' 111 | DPMPP_M = 'DPM++ 2M' 112 | DPMPP_SDE = 'DPM++ SDE' 113 | DPMPP_KARRAS = 'DPM++ Karras' 114 | DPMPP_S_A_KARRAS = 'DPM++ 2S a Karras' 115 | DPMPP_M_KARRAS = 'DPM++ 2M Karras' 116 | DPMPP_SDE_KARRAS = 'DPM++ SDE Karras' 117 | DDIM = 'DDIM' 118 | PLMS = 'PLMS' 119 | UNIPC = 'UniPC' 120 | 121 | 122 | # --------------- Refiner --------------- 123 | @dataclass 124 | class Refiner: 125 | checkpoint: str 126 | switch_at: float 127 | 128 | # --------------- Text2Image --------------- 129 | 130 | 131 | @dataclass 132 | class Txt2ImgRequest(JSONe): 133 | prompt: str 134 | negative_prompt: Optional[str] = None 135 | model_name: str = 'dreamshaper_5BakedVae.safetensors' 136 | sampler_name: str = None 137 | batch_size: int = 1 138 | n_iter: int = 1 139 | steps: int = 20 140 | cfg_scale: float = 7 141 | height: int = 512 142 | width: int = 512 143 | 144 | seed: Optional[int] = -1 145 | restore_faces: Optional[bool] = False 146 | sd_vae: Optional[str] = None 147 | clip_skip: Optional[int] = 1 148 | 149 | controlnet_units: Optional[List[ControlnetUnit]] = None 150 | controlnet_no_detectmap: Optional[bool] = False 151 | 152 | enable_hr: Optional[bool] = False 153 | hr_upscaler: Optional[str] = 'R-ESRGAN 4x+' 154 | hr_scale: Optional[float] = 2.0 155 | hr_resize_x: Optional[int] = None 156 | hr_resize_y: Optional[int] = None 157 | 158 | sd_refiner: Optional[Refiner] = None 159 | 160 | 161 | class Txt2ImgResponseCode(Enum): 162 | NORMAL = 0 163 | INTERNAL_ERROR = -1 164 | INVALID_JSON = 1 165 | MODEL_NOT_EXISTS = 2 166 | TASK_ID_NOT_EXISTS = 3 167 | INVALID_AUTH = 4 168 | HOST_UNAVAILABLE = 5 169 | PARAM_RANGE_ERROR = 6 170 | COST_BALANCE_ERROR = 7 171 | SAMPLER_NOT_EXISTS = 8 172 | TIMEOUT = 9 173 | 174 | UNKNOWN = 100 175 | 176 | @classmethod 177 | def _missing_(cls, number): 178 | return cls(cls.UNKNOWN) 179 | 180 | 181 | @dataclass 182 | class Txt2ImgResponseData(JSONe): 183 | task_id: str 184 | warn: Optional[str] = None 185 | 186 | 187 | @dataclass 188 | class Txt2ImgResponse(JSONe): 189 | code: Txt2ImgResponseCode 190 | msg: str 191 | data: Optional[Txt2ImgResponseData] = None 192 | 193 | 194 | # --------------- Image2Image --------------- 195 | 196 | 197 | @dataclass 198 | class Img2ImgRequest(JSONe): 199 | model_name: str = 'dreamshaper_5BakedVae.safetensors' 200 | sampler_name: str = None 201 | init_images: List[str] = None 202 | mask: Optional[str] = None 203 | resize_mode: Optional[int] = 0 204 | denoising_strength: Optional[float] = 0.75 205 | cfg_scale: Optional[float] = None 206 | mask_blur: Optional[int] = 4 207 | inpainting_fill: Optional[int] = 1 208 | inpaint_full_res: Optional[int] = 0 209 | inpaint_full_res_padding: Optional[int] = 32 210 | inpainting_mask_invert: Optional[int] = 0 211 | initial_noise_multiplier: Optional[float] = 1.0 212 | prompt: Optional[str] = None 213 | seed: Optional[int] = None 214 | negative_prompt: Optional[str] = None 215 | batch_size: Optional[int] = 1 216 | n_iter: Optional[int] = 1 217 | steps: Optional[int] = 20 218 | width: Optional[int] = 1024 219 | height: Optional[int] = 1024 220 | restore_faces: Optional[bool] = False 221 | sd_vae: Optional[str] = None 222 | clip_skip: Optional[int] = 1 223 | 224 | controlnet_units: Optional[List[ControlnetUnit]] = None 225 | controlnet_no_detectmap: Optional[bool] = False 226 | 227 | sd_refiner: Optional[Refiner] = None 228 | 229 | 230 | class Img2ImgResponseCode(Enum): 231 | NORMAL = 0 232 | INTERNAL_ERROR = -1 233 | INVALID_JSON = 1 234 | MODEL_NOT_EXISTS = 2 235 | TASK_ID_NOT_EXISTS = 3 236 | INVALID_AUTH = 4 237 | HOST_UNAVAILABLE = 5 238 | PARAM_RANGE_ERROR = 6 239 | COST_BALANCE_ERROR = 7 240 | SAMPLER_NOT_EXISTS = 8 241 | TIMEOUT = 9 242 | 243 | UNKNOWN = 100 244 | 245 | @classmethod 246 | def _missing_(cls, number): 247 | return cls(cls.UNKNOWN) 248 | 249 | 250 | @dataclass 251 | class Img2ImgResponseData(JSONe): 252 | task_id: str 253 | warn: Optional[str] = None 254 | 255 | 256 | @dataclass 257 | class Img2ImgResponse(JSONe): 258 | code: Img2ImgResponseCode 259 | msg: str 260 | data: Optional[Img2ImgResponseData] = None 261 | 262 | # --------------- Progress --------------- 263 | 264 | 265 | class ProgressResponseStatusCode(Enum): 266 | INITIALIZING = 0 267 | RUNNING = 1 268 | SUCCESSFUL = 2 269 | FAILED = 3 270 | TIMEOUT = 4 271 | 272 | UNKNOWN = 100 273 | 274 | @classmethod 275 | def _missing_(cls, number): 276 | return cls(cls.UNKNOWN) 277 | 278 | def finished(self): 279 | return self in (ProgressResponseStatusCode.SUCCESSFUL, ProgressResponseStatusCode.FAILED, ProgressResponseStatusCode.TIMEOUT) 280 | 281 | 282 | @dataclass 283 | class ProgressData(JSONe): 284 | status: ProgressResponseStatusCode 285 | progress: int 286 | eta_relative: int 287 | imgs: Optional[List[str]] = None 288 | imgs_bytes: Optional[List[str]] = None 289 | # info: Optional[Dict[str, str]] = None # We'll handle this field separately 290 | failed_reason: Optional[str] = "" 291 | current_images: Optional[List[str]] = None 292 | submit_time: Optional[str] = "" 293 | execution_time: Optional[str] = "" 294 | txt2img_time: Optional[str] = "" 295 | finish_time: Optional[str] = "" 296 | 297 | 298 | class ProgressResponseCode(Enum): 299 | NORMAL = 0 300 | INTERNAL_ERROR = -1 301 | INVALID_JSON = 1 302 | MODEL_NOT_EXISTS = 2 303 | TASK_ID_NOT_EXISTS = 3 304 | INVALID_AUTH = 4 305 | HOST_UNAVAILABLE = 5 306 | PARAM_RANGE_ERROR = 6 307 | COST_BALANCE_ERROR = 7 308 | SAMPLER_NOT_EXISTS = 8 309 | TIMEOUT = 9 310 | 311 | UNKNOWN = 100 312 | 313 | @classmethod 314 | def _missing_(cls, number): 315 | return cls(cls.UNKNOWN) 316 | 317 | 318 | @dataclass 319 | class ProgressResponse(JSONe): 320 | code: ProgressResponseCode 321 | data: Optional[ProgressData] = None 322 | msg: Optional[str] = "" 323 | 324 | def download_images(self): 325 | if self.data.imgs is not None and len(self.data.imgs) > 0: 326 | self.data.imgs_bytes = batch_download_images(self.data.imgs) 327 | 328 | # --------------- Upscale --------------- 329 | 330 | 331 | class UpscaleResizeMode(Enum): 332 | SCALE = 0 333 | SIZE = 1 334 | 335 | 336 | @dataclass 337 | class UpscaleRequest(JSONe): 338 | image: str 339 | upscaler_1: Optional[str] = 'R-ESRGAN 4x+' 340 | resize_mode: Optional[UpscaleResizeMode] = UpscaleResizeMode.SCALE 341 | upscaling_resize: Optional[float] = 2.0 342 | upscaling_resize_w: Optional[int] = None 343 | upscaling_resize_h: Optional[int] = None 344 | upscaling_crop: Optional[bool] = False 345 | 346 | upscaler_2: Optional[str] = None 347 | extras_upscaler_2_visibility: Optional[float] = None 348 | gfpgan_visibility: Optional[float] = None 349 | codeformer_visibility: Optional[float] = None 350 | codeformer_weight: Optional[float] = None 351 | 352 | 353 | class UpscaleResponseCode(Enum): 354 | NORMAL = 0 355 | INTERNAL_ERROR = -1 356 | INVALID_JSON = 1 357 | MODEL_NOT_EXISTS = 2 358 | TASK_ID_NOT_EXISTS = 3 359 | INVALID_AUTH = 4 360 | HOST_UNAVAILABLE = 5 361 | PARAM_RANGE_ERROR = 6 362 | COST_BALANCE_ERROR = 7 363 | SAMPLER_NOT_EXISTS = 8 364 | TIMEOUT = 9 365 | 366 | UNKNOWN = 100 367 | 368 | @classmethod 369 | def _missing_(cls, number): 370 | return cls(cls.UNKNOWN) 371 | 372 | 373 | @dataclass 374 | class UpscaleResponseData(JSONe): 375 | task_id: str 376 | warn: Optional[str] = None 377 | 378 | 379 | @dataclass 380 | class UpscaleResponse(JSONe): 381 | code: UpscaleResponseCode 382 | msg: str 383 | data: Optional[UpscaleResponseData] = None 384 | 385 | 386 | # --------------- Model --------------- 387 | 388 | 389 | class ModelType(Enum): 390 | CHECKPOINT = "checkpoint" 391 | LORA = "lora" 392 | VAE = "vae" 393 | CONTROLNET = "controlnet" 394 | TEXT_INVERSION = "textualinversion" 395 | UPSCALER = "upscaler" 396 | 397 | UNKNOWN = "unknown" 398 | 399 | @classmethod 400 | def _missing_(cls, number): 401 | return cls(cls.UNKNOWN) 402 | 403 | 404 | @dataclass 405 | class CivitaiImageMeta(JSONe): 406 | prompt: Optional[str] = None 407 | negative_prompt: Optional[str] = None 408 | sampler_name: Optional[str] = None 409 | steps: Optional[int] = None 410 | cfg_scale: Optional[int] = None 411 | seed: Optional[int] = None 412 | height: Optional[int] = None 413 | width: Optional[int] = None 414 | model_name: Optional[str] = None 415 | 416 | 417 | @dataclass 418 | class CivitaiImage(JSONe): 419 | url: str 420 | nsfw: str 421 | meta: Optional[CivitaiImageMeta] = None 422 | 423 | 424 | @dataclass 425 | class ModelInfo(JSONe): 426 | name: str 427 | hash: str 428 | civitai_version_id: int 429 | sd_name: str 430 | third_source: str 431 | download_status: int 432 | download_name: str 433 | dependency_status: int 434 | type: ModelType 435 | civitai_nsfw: Optional[bool] = False 436 | civitai_model_id: Optional[int] = 0 437 | civitai_link: Optional[str] = None 438 | civitai_images: Optional[List[CivitaiImage]] = field(default_factory=lambda: []) 439 | civitai_download_url: Optional[str] = None 440 | civitai_allow_commercial_use: Optional[bool] = True 441 | civitai_allow_different_license: Optional[bool] = True 442 | civitai_create_at: Optional[str] = None 443 | civitai_update_at: Optional[str] = None 444 | civitai_tags: Optional[str] = None 445 | civitai_download_count: Optional[int] = 0 446 | civitai_favorite_count: Optional[int] = 0 447 | civitai_comment_count: Optional[int] = 0 448 | civitai_rating_count: Optional[int] = 0 449 | civitai_rating: Optional[float] = 0.0 450 | omni_used_count: Optional[int] = None 451 | civitai_image_url: Optional[str] = None 452 | civitai_image_nsfw: Optional[bool] = False 453 | civitai_origin_image_url: Optional[str] = None 454 | civitai_image_prompt: Optional[str] = None 455 | civitai_image_negative_prompt: Optional[str] = None 456 | civitai_image_sampler_name: Optional[str] = None 457 | civitai_image_height: Optional[int] = None 458 | civitai_image_width: Optional[int] = None 459 | civitai_image_steps: Optional[int] = None 460 | civitai_image_cfg_scale: Optional[int] = None 461 | civitai_image_seed: Optional[int] = None 462 | 463 | 464 | @dataclass 465 | class ModelData(JSONe): 466 | models: List[ModelInfo] = None 467 | 468 | 469 | @dataclass 470 | class MoodelsResponse(JSONe): 471 | code: int 472 | msg: str 473 | data: Optional[ModelData] = field(default_factory=lambda: []) 474 | 475 | 476 | class ModelList(list): 477 | """A list of ModelInfo""" 478 | 479 | def __init__(self, *args, **kwargs): 480 | super().__init__(*args, **kwargs) 481 | 482 | def get_by_civitai_version_id(self, civitai_version_id: int): 483 | for model in self: 484 | if model.civitai_version_id == civitai_version_id: 485 | return model 486 | return None 487 | 488 | def get_by_name(self, name): 489 | for model in self: 490 | if model.name == name: 491 | return model 492 | return None 493 | 494 | def get_by_sd_name(self, sd_name): 495 | for model in self: 496 | if model.sd_name == sd_name: 497 | return model 498 | return None 499 | 500 | def list_civitai_tags(self) -> List[str]: 501 | s = set() 502 | for model in self: 503 | if model.civitai_tags: 504 | s.update(s.strip() 505 | for s in model.civitai_tags.split(",") if s.strip()) 506 | return list(s) 507 | 508 | def filter_by_civitai_tags(self, *tags): 509 | ret = [] 510 | for model in self: 511 | if model.civitai_tags: 512 | if set(tags).issubset(set(s.strip() for s in model.civitai_tags.split(","))): 513 | ret.append(model) 514 | return ModelList(ret) 515 | 516 | def filter_by_nsfw(self, nsfw: bool): 517 | return ModelList([model for model in self if model.civitai_nsfw == nsfw]) 518 | 519 | def filter_by_type(self, type): 520 | return ModelList([model for model in self if model.type == type]) 521 | 522 | def filter_by_civitai_model_id(self, civitai_model_id: int): 523 | return ModelList([model for model in self if model.civitai_model_id == civitai_model_id]) 524 | 525 | def filter_by_civitai_model_name(self, name: str): 526 | return ModelList([model for model in self if model.name == name]) 527 | 528 | def sort_by_civitai_download(self): 529 | return ModelList(sorted(self, key=lambda x: x.civitai_download_count, reverse=True)) 530 | 531 | def sort_by_civitai_rating(self): 532 | return ModelList(sorted(self, key=lambda x: x.civitai_rating, reverse=True)) 533 | 534 | def sort_by_civitai_favorite(self): 535 | return ModelList(sorted(self, key=lambda x: x.civitai_favorite, reverse=True)) 536 | 537 | def sort_by_civitai_comment(self): 538 | return ModelList(sorted(self, key=lambda x: x.civitai_comment, reverse=True)) 539 | -------------------------------------------------------------------------------- /src/omniinfer_client/serializer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | from dataclass_wizard import JSONWizard, DumpMeta 5 | 6 | 7 | class JSONe(JSONWizard): 8 | def __init_subclass__(cls, **kwargs): 9 | super().__init_subclass__(**kwargs) 10 | DumpMeta(key_transform='SNAKE').bind_to(cls) 11 | -------------------------------------------------------------------------------- /src/omniinfer_client/settings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | DEFAULT_REQUEST_TIMEOUT = 120 5 | DEFAULT_DOWNLOAD_IMAGE_ATTEMPTS = 5 6 | DEFAULT_DOWNLOAD_ONE_IMAGE_TIMEOUT = 30 7 | DEFAULT_POLL_INTERVAL = 0.5 -------------------------------------------------------------------------------- /src/omniinfer_client/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | from multiprocessing.pool import ThreadPool 5 | 6 | import base64 7 | import logging 8 | 9 | import requests 10 | 11 | from . import settings 12 | from .proto import * 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def batch_download_images(image_links): 18 | def _download(image_link): 19 | attempts = settings.DEFAULT_DOWNLOAD_IMAGE_ATTEMPTS 20 | while attempts > 0: 21 | try: 22 | response = requests.get( 23 | image_link, timeout=settings.DEFAULT_DOWNLOAD_ONE_IMAGE_TIMEOUT) 24 | return response.content 25 | except Exception: 26 | logger.warning("Failed to download image, retrying...") 27 | attempts -= 1 28 | return None 29 | 30 | pool = ThreadPool() 31 | applied = [] 32 | for img_url in image_links: 33 | applied.append(pool.apply_async(_download, (img_url, ))) 34 | ret = [r.get() for r in applied] 35 | return [_ for _ in ret if _ is not None] 36 | 37 | 38 | def save_image(image_bytes, name): 39 | with open(name, "wb") as f: 40 | f.write(image_bytes) 41 | 42 | 43 | def read_image(name): 44 | with open(name, "rb") as f: 45 | return f.read() 46 | 47 | 48 | def read_image_to_base64(name): 49 | with open(name, "rb") as f: 50 | return base64.b64encode(f.read()).decode('utf-8') 51 | 52 | 53 | def add_lora_to_prompt(prompt: str, lora_name: str, weight: float = 1.0) -> str: 54 | prompt_split = [s.strip() for s in prompt.split(",")] 55 | ret = [] 56 | replace = False 57 | for prompt_chunk in prompt_split: 58 | if prompt_chunk.startswith("".format(lora_name, weight)) 60 | replace = True 61 | else: 62 | ret.append(prompt_chunk) 63 | if not replace: 64 | ret.append("".format(lora_name, weight)) 65 | return ", ".join(ret) 66 | -------------------------------------------------------------------------------- /src/omniinfer_client/version.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | __version__ = "0.3.4" 5 | -------------------------------------------------------------------------------- /tests/test_basics.py: -------------------------------------------------------------------------------- 1 | from omniinfer_client import * 2 | from omniinfer_client.utils import save_image, read_image_to_base64 3 | import os 4 | from PIL import Image 5 | import io 6 | 7 | 8 | import pytest 9 | 10 | 11 | @pytest.mark.dependency() 12 | def test_txt2img_sync(): 13 | client = OmniClient(os.getenv('OMNI_API_KEY')) 14 | res = client.sync_txt2img(Txt2ImgRequest( 15 | prompt='a dog flying in the sky', 16 | batch_size=1, 17 | cfg_scale=7.5, 18 | sampler_name=Samplers.EULER_A, 19 | )) 20 | 21 | assert (res.data.status == ProgressResponseStatusCode.SUCCESSFUL) 22 | assert (len(res.data.imgs_bytes) == 1) 23 | 24 | test_path = os.path.join(os.path.abspath( 25 | os.path.dirname(__name__)), "tests/data") 26 | if not os.path.exists(test_path): 27 | os.makedirs(test_path) 28 | save_image(res.data.imgs_bytes[0], os.path.join( 29 | test_path, 'test_txt2img_sync.png')) 30 | 31 | 32 | @pytest.mark.dependency(depends=['test_txt2img_sync']) 33 | def test_img2img_sync(): 34 | client = OmniClient(os.getenv('OMNI_API_KEY')) 35 | init_image = os.path.join(os.path.abspath( 36 | os.path.dirname(__name__)), "tests/data/test_txt2img_sync.png") 37 | init_image_base64 = read_image_to_base64(init_image) 38 | 39 | res = client.sync_img2img(Img2ImgRequest( 40 | prompt='a dog flying in the sky', 41 | batch_size=1, 42 | cfg_scale=7.5, 43 | sampler_name=Samplers.EULER_A, 44 | init_images=[init_image_base64] 45 | )) 46 | 47 | assert (res.data.status == ProgressResponseStatusCode.SUCCESSFUL) 48 | assert (len(res.data.imgs_bytes) == 1) 49 | 50 | test_path = os.path.join(os.path.abspath( 51 | os.path.dirname(__name__)), "tests/data") 52 | if not os.path.exists(test_path): 53 | os.makedirs(test_path) 54 | save_image(res.data.imgs_bytes[0], os.path.join( 55 | test_path, 'test_img2img_sync.png')) 56 | 57 | 58 | @pytest.mark.dependency(depends=['test_img2img_sync']) 59 | def test_txt2img_controlnet(): 60 | client = OmniClient(os.getenv('OMNI_API_KEY')) 61 | init_image = os.path.join(os.path.abspath(os.path.dirname( 62 | __name__)), "tests/data/test_txt2img_sync.png") 63 | init_image_base64 = read_image_to_base64(init_image) 64 | 65 | client = OmniClient(os.getenv('OMNI_API_KEY')) 66 | request = Txt2ImgRequest( 67 | prompt='a dog flying in the sky', 68 | batch_size=1, 69 | cfg_scale=7.5, 70 | sampler_name=Samplers.EULER_A, 71 | controlnet_units=[ 72 | ControlnetUnit( 73 | input_image=init_image_base64, 74 | model='control_v11p_sd15_canny', 75 | module='canny', 76 | ), 77 | ] 78 | ) 79 | 80 | res = client.sync_txt2img(request) 81 | test_path = os.path.join(os.path.abspath( 82 | os.path.dirname(__name__)), "tests/data") 83 | 84 | if not os.path.exists(test_path): 85 | os.makedirs(test_path) 86 | 87 | assert len(res.data.imgs_bytes) == 2 88 | 89 | save_image(res.data.imgs_bytes[1], os.path.join( 90 | test_path, 'test_txt2img_controlnet_processor.png')) 91 | save_image(res.data.imgs_bytes[0], os.path.join( 92 | test_path, 'test_txt2img_controlnet_result.png')) 93 | 94 | 95 | @pytest.mark.dependency(depends=['test_img2img_sync']) 96 | def test_img2img_controlnet(): 97 | client = OmniClient(os.getenv('OMNI_API_KEY')) 98 | init_image = os.path.join(os.path.abspath( 99 | os.path.dirname(__name__)), "tests/data/test_txt2img_sync.png") 100 | init_image_base64 = read_image_to_base64(init_image) 101 | 102 | res = client.sync_img2img(Img2ImgRequest( 103 | prompt='a dog flying in the sky', 104 | batch_size=1, 105 | cfg_scale=7.5, 106 | sampler_name=Samplers.EULER_A, 107 | init_images=[init_image_base64], 108 | controlnet_units=[ 109 | ControlnetUnit( 110 | input_image=init_image_base64, 111 | model='control_v11p_sd15_canny', 112 | module='canny', 113 | ), 114 | ] 115 | )) 116 | 117 | assert (res.data.status == ProgressResponseStatusCode.SUCCESSFUL) 118 | assert len(res.data.imgs_bytes) == 2 119 | 120 | test_path = os.path.join(os.path.abspath( 121 | os.path.dirname(__name__)), "tests/data") 122 | if not os.path.exists(test_path): 123 | os.makedirs(test_path) 124 | 125 | save_image(res.data.imgs_bytes[1], os.path.join( 126 | test_path, 'test_img2img_controlnet_processor.png')) 127 | save_image(res.data.imgs_bytes[0], os.path.join( 128 | test_path, 'test_img2img_controlnet_result.png')) 129 | 130 | 131 | def test_txt2img_upscale_2x(): 132 | client = OmniClient(os.getenv('OMNI_API_KEY')) 133 | res = client.sync_txt2img(Txt2ImgRequest( 134 | model_name='dreamshaper_8_93211.safetensors', 135 | prompt='a dog flying in the sky', 136 | width=512, 137 | height=512, 138 | batch_size=1, 139 | cfg_scale=7.5, 140 | sampler_name=Samplers.EULER_A, 141 | enable_hr=True, 142 | hr_scale=2.0 143 | )) 144 | 145 | assert (res.data.status == ProgressResponseStatusCode.SUCCESSFUL) 146 | assert (len(res.data.imgs_bytes) == 1) 147 | 148 | test_path = os.path.join(os.path.abspath( 149 | os.path.dirname(__name__)), "tests/data") 150 | if not os.path.exists(test_path): 151 | os.makedirs(test_path) 152 | img = Image.open(io.BytesIO(res.data.imgs_bytes[0])) 153 | assert img.size == (1024, 1024) 154 | 155 | 156 | def test_txt2img_upscale_specify_size(): 157 | client = OmniClient(os.getenv('OMNI_API_KEY')) 158 | res = client.sync_txt2img(Txt2ImgRequest( 159 | model_name='dreamshaper_8_93211.safetensors', 160 | prompt='a dog flying in the sky', 161 | width=512, 162 | height=512, 163 | batch_size=1, 164 | cfg_scale=7.5, 165 | sampler_name=Samplers.EULER_A, 166 | enable_hr=True, 167 | hr_resize_x=768, 168 | hr_resize_y=768 169 | )) 170 | 171 | assert (res.data.status == ProgressResponseStatusCode.SUCCESSFUL) 172 | assert (len(res.data.imgs_bytes) == 1) 173 | 174 | test_path = os.path.join(os.path.abspath( 175 | os.path.dirname(__name__)), "tests/data") 176 | if not os.path.exists(test_path): 177 | os.makedirs(test_path) 178 | img = Image.open(io.BytesIO(res.data.imgs_bytes[0])) 179 | assert img.size == (768, 768) 180 | 181 | 182 | def test_upscale_2x(): 183 | client = OmniClient(os.getenv('OMNI_API_KEY')) 184 | res = client.sync_img2img(Img2ImgRequest( 185 | model_name='dreamshaper_8_93211.safetensors', 186 | prompt='a dog flying in the sky', 187 | width=512, 188 | height=512, 189 | batch_size=1, 190 | cfg_scale=7.5, 191 | sampler_name=Samplers.EULER_A, 192 | )) 193 | 194 | assert (res.data.status == ProgressResponseStatusCode.SUCCESSFUL) 195 | assert (len(res.data.imgs_bytes) == 1) 196 | 197 | image = base64.b64encode(res.data.imgs_bytes[0]).decode('utf-8') 198 | upscale_req = UpscaleRequest( 199 | image=image, 200 | resize_mode=UpscaleResizeMode.SCALE, 201 | upscaling_resize=2 202 | ) 203 | upscale_res = client.sync_upscale(upscale_req) 204 | assert (upscale_res.data.status == ProgressResponseStatusCode.SUCCESSFUL) 205 | assert (len(upscale_res.data.imgs_bytes) == 1) 206 | img = Image.open(io.BytesIO(upscale_res.data.imgs_bytes[0])) 207 | assert img.size == (1024, 1024) 208 | 209 | 210 | def test_upscale_specify_size(): 211 | client = OmniClient(os.getenv('OMNI_API_KEY')) 212 | res = client.sync_img2img(Img2ImgRequest( 213 | model_name='dreamshaper_8_93211.safetensors', 214 | prompt='a dog flying in the sky', 215 | width=512, 216 | height=512, 217 | batch_size=1, 218 | cfg_scale=7.5, 219 | sampler_name=Samplers.EULER_A, 220 | )) 221 | 222 | assert (res.data.status == ProgressResponseStatusCode.SUCCESSFUL) 223 | assert (len(res.data.imgs_bytes) == 1) 224 | 225 | image = base64.b64encode(res.data.imgs_bytes[0]).decode('utf-8') 226 | upscale_req = UpscaleRequest( 227 | image=image, 228 | resize_mode=UpscaleResizeMode.SIZE, 229 | upscaling_resize_h=768, 230 | upscaling_resize_w=768 231 | ) 232 | upscale_res = client.sync_upscale(upscale_req) 233 | assert (upscale_res.data.status == ProgressResponseStatusCode.SUCCESSFUL) 234 | assert (len(upscale_res.data.imgs_bytes) == 1) 235 | img = Image.open(io.BytesIO(upscale_res.data.imgs_bytes[0])) 236 | assert img.size == (768, 768) 237 | 238 | 239 | def test_upscale_multiple_upscaler(): 240 | client = OmniClient(os.getenv('OMNI_API_KEY')) 241 | res = client.sync_img2img(Img2ImgRequest( 242 | model_name='dreamshaper_8_93211.safetensors', 243 | prompt='a dog flying in the sky', 244 | width=512, 245 | height=512, 246 | batch_size=1, 247 | cfg_scale=7.5, 248 | sampler_name=Samplers.EULER_A, 249 | )) 250 | 251 | assert (res.data.status == ProgressResponseStatusCode.SUCCESSFUL) 252 | assert (len(res.data.imgs_bytes) == 1) 253 | 254 | image = base64.b64encode(res.data.imgs_bytes[0]).decode('utf-8') 255 | upscale_req = UpscaleRequest( 256 | image=image, 257 | upscaler_1='R-ESRGAN 4x+', 258 | resize_mode=UpscaleResizeMode.SIZE, 259 | upscaling_resize_h=768, 260 | upscaling_resize_w=768, 261 | upscaler_2='Nearest', 262 | gfpgan_visibility=1, 263 | ) 264 | upscale_res = client.sync_upscale(upscale_req) 265 | assert (upscale_res.data.status == ProgressResponseStatusCode.SUCCESSFUL) 266 | assert (len(upscale_res.data.imgs_bytes) == 1) 267 | img = Image.open(io.BytesIO(upscale_res.data.imgs_bytes[0])) 268 | assert img.size == (768, 768) 269 | 270 | 271 | def test_txt2img_custom_headers(): 272 | client = OmniClient(os.getenv('OMNI_API_KEY')) 273 | client.set_extra_headers({"User-Agent": "test-custom-user-agent"}) 274 | 275 | res = client.sync_img2img(Img2ImgRequest( 276 | model_name='dreamshaper_8_93211.safetensors', 277 | prompt='a dog flying in the sky', 278 | width=512, 279 | height=512, 280 | batch_size=1, 281 | cfg_scale=7.5, 282 | sampler_name=Samplers.EULER_A, 283 | )) 284 | 285 | assert (res.data.status == ProgressResponseStatusCode.SUCCESSFUL) 286 | assert (len(res.data.imgs_bytes) == 1) 287 | 288 | 289 | def test_txt2img_with_callback(): 290 | client = OmniClient(os.getenv('OMNI_API_KEY')) 291 | 292 | def callback(res: ProgressResponse): 293 | assert isinstance(res.data.progress, float) 294 | 295 | res = client.sync_txt2img(Txt2ImgRequest( 296 | model_name='dreamshaper_8_93211.safetensors', 297 | prompt='a dog flying in the sky', 298 | width=512, 299 | height=512, 300 | batch_size=1, 301 | cfg_scale=7.5, 302 | sampler_name=Samplers.EULER_A, 303 | ), callback=callback) 304 | 305 | assert (res.data.status == ProgressResponseStatusCode.SUCCESSFUL) 306 | assert (len(res.data.imgs_bytes) == 1) 307 | 308 | 309 | def test_txt2img_sdxl_with_refiner(): 310 | client = OmniClient(os.getenv('OMNI_API_KEY')) 311 | res = client.sync_txt2img(Txt2ImgRequest( 312 | model_name='sd_xl_base_1.0.safetensors', 313 | prompt='a dog flying in the sky', 314 | width=1024, 315 | height=1024, 316 | batch_size=1, 317 | cfg_scale=7.5, 318 | sampler_name=Samplers.EULER_A, 319 | sd_refiner=Refiner( 320 | checkpoint='sd_xl_refiner_1.0.safetensors', 321 | switch_at=0.5, 322 | ) 323 | )) 324 | 325 | assert (res.data.status == ProgressResponseStatusCode.SUCCESSFUL) 326 | assert (len(res.data.imgs_bytes) == 1) 327 | 328 | test_path = os.path.join(os.path.abspath( 329 | os.path.dirname(__name__)), "tests/data") 330 | if not os.path.exists(test_path): 331 | os.makedirs(test_path) 332 | save_image(res.data.imgs_bytes[0], os.path.join( 333 | test_path, 'test_sdxl_txt2img_refienr.png')) 334 | 335 | 336 | def test_img2img_sdxl_with_refiner(): 337 | client = OmniClient(os.getenv('OMNI_API_KEY')) 338 | res = client.sync_txt2img(Txt2ImgRequest( 339 | model_name='sd_xl_base_1.0.safetensors', 340 | prompt='a dog flying in the sky', 341 | width=1024, 342 | height=1024, 343 | batch_size=1, 344 | cfg_scale=7.5, 345 | sampler_name=Samplers.EULER_A, 346 | sd_refiner=Refiner( 347 | checkpoint='sd_xl_refiner_1.0.safetensors', 348 | switch_at=0.5, 349 | ) 350 | )) 351 | 352 | assert (res.data.status == ProgressResponseStatusCode.SUCCESSFUL) 353 | assert (len(res.data.imgs_bytes) == 1) 354 | 355 | base64_image = base64.b64encode(res.data.imgs_bytes[0]).decode('utf-8') 356 | 357 | res = client.sync_img2img(Img2ImgRequest( 358 | prompt='a dog flying in the sky', 359 | model_name='sd_xl_base_1.0.safetensors', 360 | batch_size=1, 361 | cfg_scale=7.5, 362 | sampler_name=Samplers.EULER_A, 363 | init_images=[base64_image], 364 | sd_refiner=Refiner( 365 | checkpoint='sd_xl_refiner_1.0.safetensors', 366 | switch_at=0.5, 367 | ) 368 | )) 369 | 370 | assert (res.data.status == ProgressResponseStatusCode.SUCCESSFUL) 371 | assert (len(res.data.imgs_bytes) == 1) 372 | 373 | test_path = os.path.join(os.path.abspath( 374 | os.path.dirname(__name__)), "tests/data") 375 | if not os.path.exists(test_path): 376 | os.makedirs(test_path) 377 | save_image(res.data.imgs_bytes[0], os.path.join( 378 | test_path, 'test_sdxl_img2img_refienr.png')) 379 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | from omniinfer_client import * 2 | import os 3 | 4 | 5 | def test_model_api(): 6 | client = OmniClient(os.getenv('OMNI_API_KEY')) 7 | models = client.models() 8 | assert all([m.civitai_nsfw is True for m in models.filter_by_nsfw(True)]) 9 | assert all([m.civitai_nsfw is False for m in models.filter_by_nsfw(False)]) 10 | 11 | assert len(models. \ 12 | filter_by_type(ModelType.LORA). \ 13 | filter_by_nsfw(False). \ 14 | filter_by_civitai_tags('anime'). \ 15 | sort_by_civitai_rating()) > 0 16 | 17 | assert len(models.filter_by_type(ModelType.CHECKPOINT)) > 0 18 | assert len(models.filter_by_type(ModelType.LORA)) > 0 19 | assert len(models.filter_by_type(ModelType.TEXT_INVERSION)) > 0 20 | assert len(models.filter_by_type(ModelType.CONTROLNET)) > 0 21 | --------------------------------------------------------------------------------