├── .python-version ├── examples ├── fixtures │ └── qrcode.png ├── reimagine.py ├── remove-text.py ├── replace-background.py ├── remove-background.py ├── img2video.py ├── merge-face.py ├── cleanup.py ├── txt2img-with-hiresfix.py ├── txt2video.py ├── model-search.py ├── inpainting.py ├── controlnet.py ├── img2img.py ├── instantid.py ├── txt2img-with-lora.py └── txt2img-with-refiner.py ├── src └── novita_client │ ├── version.py │ ├── __init__.py │ ├── settings.py │ ├── exceptions.py │ ├── serializer.py │ ├── utils.py │ ├── novita.py │ └── proto.py ├── Makefile ├── requirements.lock ├── tests ├── test_model.py ├── test_basics.py └── test_enhance.py ├── generate-readme.sh ├── LICENSE ├── pyproject.toml ├── .gitignore ├── requirements-dev.lock └── README.md /.python-version: -------------------------------------------------------------------------------- 1 | 3.11.6 2 | -------------------------------------------------------------------------------- /examples/fixtures/qrcode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/novitalabs/python-sdk/HEAD/examples/fixtures/qrcode.png -------------------------------------------------------------------------------- /src/novita_client/version.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | __version__ = "0.5.0" 5 | -------------------------------------------------------------------------------- /src/novita_client/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | from .novita import * 5 | from .proto import * 6 | from .utils import * 7 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | build: 2 | @rm -rf dist 3 | @.venv/bin/python3 -m build 4 | 5 | upload: build 6 | @.venv/bin/python3 -m twine upload dist/* 7 | 8 | docs: 9 | @cd docs && make html -------------------------------------------------------------------------------- /src/novita_client/settings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | DEFAULT_REQUEST_TIMEOUT = 120 5 | DEFAULT_DOWNLOAD_IMAGE_ATTEMPTS = 5 6 | DEFAULT_DOWNLOAD_ONE_IMAGE_TIMEOUT = 30 7 | DEFAULT_POLL_INTERVAL = 0.5 -------------------------------------------------------------------------------- /src/novita_client/exceptions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | class NovitaError(Exception): 5 | pass 6 | 7 | 8 | class NovitaResponseError(NovitaError): 9 | pass 10 | 11 | 12 | class NovitaTimeoutError(NovitaError): 13 | pass 14 | -------------------------------------------------------------------------------- /src/novita_client/serializer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | from dataclass_wizard import DumpMeta, JSONWizard 5 | 6 | 7 | class JSONe(JSONWizard): 8 | def __init_subclass__(cls, **kwargs): 9 | super().__init_subclass__(**kwargs) 10 | DumpMeta(key_transform='SNAKE').bind_to(cls) 11 | -------------------------------------------------------------------------------- /examples/reimagine.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from novita_client import NovitaClient 4 | from novita_client.utils import base64_to_image 5 | 6 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 7 | res = client.reimagine( 8 | image="/home/anyisalin/develop/novita-client-python/examples/doodle-generated.png" 9 | ) 10 | 11 | base64_to_image(res.image_file).save("./reimagine.png") 12 | -------------------------------------------------------------------------------- /examples/remove-text.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from novita_client import NovitaClient 4 | from novita_client.utils import base64_to_image 5 | 6 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 7 | res = client.remove_text( 8 | image="https://images.uiiiuiii.com/wp-content/uploads/2023/07/i-banner-20230714-1.jpg" 9 | ) 10 | 11 | base64_to_image(res.image_file).save("./remove_text.png") 12 | -------------------------------------------------------------------------------- /examples/replace-background.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from novita_client import NovitaClient 4 | from novita_client.utils import base64_to_image 5 | 6 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 7 | res = client.replace_background( 8 | image="./telegram-cloud-photo-size-2-5408823814353177899-y.jpg", 9 | prompt="in living room, Christmas tree", 10 | ) 11 | base64_to_image(res.image_file).save("./replace_background.png") 12 | -------------------------------------------------------------------------------- /examples/remove-background.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from novita_client import NovitaClient 4 | from novita_client.utils import base64_to_image 5 | 6 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 7 | res = client.remove_background( 8 | image="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png", 9 | ) 10 | base64_to_image(res.image_file).save("./remove_background.png") 11 | -------------------------------------------------------------------------------- /examples/img2video.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from novita_client import NovitaClient 4 | from novita_client.utils import base64_to_image 5 | 6 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URNOVITA_API_URII', None)) 7 | res = client.img2video( 8 | model_name="SVD-XT", 9 | steps=30, 10 | frames_num=25, 11 | image="https://replicate.delivery/pbxt/JvLi9smWKKDfQpylBYosqQRfPKZPntuAziesp0VuPjidq61n/rocket.png", 12 | enable_frame_interpolation=True 13 | ) 14 | 15 | 16 | with open("test.mp4", "wb") as f: 17 | f.write(res.video_bytes[0]) 18 | -------------------------------------------------------------------------------- /examples/merge-face.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from novita_client import NovitaClient 4 | from novita_client.utils import base64_to_image 5 | 6 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 7 | res = client.merge_face( 8 | image="https://toppng.com/uploads/preview/cut-out-people-png-personas-en-formato-11563277290kozkuzsos5.png", 9 | face_image="https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQDy7sXtuvCNUQoQZvTbLRbX6qK9_kP3PlQfg&s", 10 | enterprise_plan=False, 11 | ) 12 | 13 | base64_to_image(res.image_file).save("./merge_face.png") 14 | -------------------------------------------------------------------------------- /examples/cleanup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from novita_client import NovitaClient 4 | from novita_client.utils import base64_to_image 5 | 6 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 7 | res = client.cleanup( 8 | image="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png", 9 | mask="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" 10 | ) 11 | 12 | base64_to_image(res.image_file).save("./cleanup.png") 13 | -------------------------------------------------------------------------------- /requirements.lock: -------------------------------------------------------------------------------- 1 | # generated by rye 2 | # use `rye lock` or `rye sync` to update this lockfile 3 | # 4 | # last locked with the following flags: 5 | # pre: false 6 | # features: [] 7 | # all-features: false 8 | # with-sources: false 9 | 10 | -e file:. 11 | certifi==2023.7.22 12 | # via requests 13 | charset-normalizer==3.2.0 14 | # via requests 15 | dataclass-wizard==0.22.2 16 | # via novita-client 17 | idna==3.4 18 | # via requests 19 | pillow==10.2.0 20 | # via novita-client 21 | requests==2.31.0 22 | # via novita-client 23 | urllib3==2.0.4 24 | # via requests 25 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | from novita_client import * 2 | import os 3 | 4 | 5 | def test_model_api(): 6 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 7 | models = client.models_v3() 8 | assert all([m.is_nsfw is True for m in models.filter_by_nsfw(True)]) 9 | assert all([m.is_nsfw is False for m in models.filter_by_nsfw(False)]) 10 | 11 | assert len(models. \ 12 | filter_by_type(ModelType.LORA). \ 13 | filter_by_nsfw(False)) > 0 14 | 15 | assert len(models.filter_by_type(ModelType.CHECKPOINT)) > 0 16 | assert len(models.filter_by_type(ModelType.LORA)) > 0 17 | assert len(models.filter_by_type(ModelType.TEXT_INVERSION)) > 0 18 | assert len(models.filter_by_type(ModelType.CONTROLNET)) > 0 19 | -------------------------------------------------------------------------------- /examples/txt2img-with-hiresfix.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from novita_client import NovitaClient, Samplers, Txt2ImgV3HiresFix 4 | from novita_client.utils import base64_to_image 5 | 6 | from PIL import Image 7 | 8 | 9 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 10 | res = client.txt2img_v3( 11 | model_name='dreamshaper_8_93211.safetensors', 12 | prompt="a cute girl", 13 | width=384, 14 | height=512, 15 | image_num=1, 16 | guidance_scale=7.5, 17 | seed=12345, 18 | sampler_name=Samplers.EULER_A, 19 | hires_fix=Txt2ImgV3HiresFix( 20 | # upscaler='Latent' 21 | target_width=768, 22 | target_height=1024, 23 | strength=0.5 24 | ) 25 | ) 26 | 27 | 28 | base64_to_image(res.images_encoded[0]).save("./txt2img_with_hiresfix.png") 29 | -------------------------------------------------------------------------------- /examples/txt2video.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from novita_client import NovitaClient 4 | from novita_client.utils import save_image 5 | 6 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 7 | res = client.txt2video( 8 | model_name = "dreamshaper_8_93211.safetensors", 9 | prompts = [{ 10 | "prompt": "A girl, baby, portrait, 5 years old", 11 | "frames": 16,}, 12 | { 13 | "prompt": "A girl, child, portrait, 10 years old", 14 | "frames": 16, 15 | } 16 | ], 17 | steps = 20, 18 | guidance_scale = 10, 19 | height = 512, 20 | width = 768, 21 | clip_skip = 4, 22 | negative_prompt = "a rainy day", 23 | response_video_type = "mp4", 24 | ) 25 | save_image(res.video_bytes[0], 'output.mp4') 26 | -------------------------------------------------------------------------------- /generate-readme.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cat > README.md <<'EOF' 4 | # Novita AI Python SDK 5 | 6 | This SDK is based on the official [API documentation](https://docs.novita.ai/). 7 | 8 | **Join our discord server for help:** 9 | 10 | [![](https://dcbadge.vercel.app/api/server/Mqx7nWYzDF)](https://discord.com/invite/Mqx7nWYzDF) 11 | 12 | ## Installation 13 | 14 | ```bash 15 | pip install novita-client 16 | ``` 17 | 18 | ## Examples 19 | 20 | - [fine tune example](https://colab.research.google.com/drive/1j_ii9TN67nuauvc3PiauwZnC2lT62tGF?usp=sharing) 21 | EOF 22 | 23 | 24 | for FILE in $(ls examples/ | grep py | sort -V); do 25 | NAME=$(echo "$FILE" | sed 's/.py//') 26 | echo "- [$NAME](./examples/$FILE)" >> README.md 27 | done 28 | 29 | echo "## Code Examples" >> README.md 30 | 31 | for FILE in $(ls examples/ | grep py | sort -V); do 32 | NAME=$(echo "$FILE" | sed 's/.py//') 33 | echo "" >> README.md 34 | echo "### $NAME" >> README.md 35 | echo "\`\`\`python" >> README.md 36 | cat examples/$FILE >> README.md 37 | echo "\`\`\`" >> README.md 38 | done -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 novita.ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /examples/model-search.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | from novita_client import NovitaClient, ModelType 5 | # get your api key refer to https://docs.novita.ai/get-started/ 6 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 7 | 8 | # filter by model type 9 | print("lora count", len(client.models().filter_by_type(ModelType.LORA))) 10 | print("checkpoint count", len(client.models().filter_by_type(ModelType.CHECKPOINT))) 11 | print("textinversion count", len( 12 | client.models().filter_by_type(ModelType.TEXT_INVERSION))) 13 | print("vae count", len(client.models().filter_by_type(ModelType.VAE))) 14 | print("controlnet count", len(client.models().filter_by_type(ModelType.CONTROLNET))) 15 | 16 | 17 | # filter by civitai tags 18 | client.models().filter_by_civi_tags('anime') 19 | 20 | # filter by nsfw 21 | client.models().filter_by_nsfw(False) # or True 22 | 23 | # sort by civitai download 24 | client.models().sort_by_civitai_download() 25 | 26 | # chain filters 27 | client.models().\ 28 | filter_by_type(ModelType.CHECKPOINT).\ 29 | filter_by_nsfw(False).\ 30 | filter_by_civitai_tags('anime') 31 | -------------------------------------------------------------------------------- /examples/inpainting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import base64 3 | from novita_client import NovitaClient 4 | from novita_client.utils import base64_to_image 5 | 6 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 7 | res = client.inpainting( 8 | model_name = "realisticVisionV40_v40VAE-inpainting_81543.safetensors", 9 | image="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png", 10 | mask="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png", 11 | seed=1, 12 | guidance_scale=15, 13 | steps = 20, 14 | image_num = 4, 15 | prompt = "black rabbit", 16 | negative_prompt = "white rabbit", 17 | sampler_name = "Euler a", 18 | inpainting_full_res = 1, 19 | inpainting_full_res_padding = 32, 20 | inpainting_mask_invert = 0, 21 | initial_noise_multiplier = 1, 22 | mask_blur = 1, 23 | clip_skip = 1, 24 | strength = 0.85, 25 | ) 26 | with open("result/result_image/inpaintingsdk.jpeg", "wb") as image_file: 27 | image_file.write(base64.b64decode(res.images_encoded[0])) -------------------------------------------------------------------------------- /examples/controlnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | import os 5 | 6 | from novita_client import NovitaClient, Img2ImgV3Request, Img2ImgV3ControlNetUnit, ControlnetUnit, Samplers, Img2ImgV3Embedding 7 | from novita_client.utils import base64_to_image 8 | 9 | 10 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 11 | res = client.img2img_v3( 12 | input_image="https://img.freepik.com/premium-photo/close-up-dogs-face-with-big-smile-generative-ai_900101-62851.jpg", 13 | model_name="dreamshaper_8_93211.safetensors", 14 | prompt="a cute dog", 15 | sampler_name=Samplers.DPMPP_M_KARRAS, 16 | width=512, 17 | height=512, 18 | steps=30, 19 | controlnet_units=[ 20 | Img2ImgV3ControlNetUnit( 21 | image_base64="https://img.freepik.com/premium-photo/close-up-dogs-face-with-big-smile-generative-ai_900101-62851.jpg", 22 | model_name="control_v11f1p_sd15_depth", 23 | strength=1.0 24 | ) 25 | ], 26 | embeddings=[Img2ImgV3Embedding(model_name=_) for _ in [ 27 | "BadDream_53202", 28 | ]], 29 | seed=-1, 30 | ) 31 | 32 | 33 | base64_to_image(res.images_encoded[0]).save("./img2img-controlnet.png") 34 | -------------------------------------------------------------------------------- /examples/img2img.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import os 3 | 4 | from novita_client import NovitaClient, Img2ImgV3ControlNetUnit, ControlNetPreprocessor, Img2ImgV3Embedding 5 | from novita_client.utils import base64_to_image, input_image_to_pil 6 | 7 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 8 | res = client.img2img_v3( 9 | model_name="MeinaHentai_V5.safetensors", 10 | steps=30, 11 | height=512, 12 | width=512, 13 | input_image="https://img.freepik.com/premium-photo/close-up-dogs-face-with-big-smile-generative-ai_900101-62851.jpg", 14 | prompt="1 cute dog", 15 | strength=0.5, 16 | guidance_scale=7, 17 | embeddings=[Img2ImgV3Embedding(model_name=_) for _ in [ 18 | "bad-image-v2-39000", 19 | "verybadimagenegative_v1.3_21434", 20 | "BadDream_53202", 21 | "badhandv4_16755", 22 | "easynegative_8955.safetensors"]], 23 | seed=-1, 24 | sampler_name="DPM++ 2M Karras", 25 | clip_skip=2, 26 | # controlnet_units=[Img2ImgV3ControlNetUnit( 27 | # model_name="control_v11f1p_sd15_depth", 28 | # preprocessor="depth", 29 | # image_base64="./20240309-003206.jpeg", 30 | # strength=1.0 31 | # )] 32 | ) 33 | 34 | base64_to_image(res.images_encoded[0]).save("./img2img.png") 35 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "novita_client" 3 | version = "0.7.1" 4 | description = "Novita AI Python SDK" 5 | authors = [ 6 | { name = "Novita AI", email = "novitalabs@gmail.com" } 7 | ] 8 | 9 | dependencies = [ 10 | "dataclass_wizard>=0.22.2", 11 | "requests>=2.27.1", 12 | "pillow>=10.2.0", 13 | ] 14 | 15 | license = { file = "LICENSE" } 16 | 17 | classifiers = [ 18 | "License :: OSI Approved :: MIT License", 19 | "Programming Language :: Python", 20 | "Programming Language :: Python :: 3", 21 | ] 22 | 23 | readme = "README.md" 24 | requires-python = ">= 3.6" 25 | 26 | [project.urls] 27 | "Homepage" = "https://github.com/novita/python-sdk" 28 | "Bug Tracker" = "https://discord.gg/nzqq8UScpx" 29 | 30 | [build-system] 31 | requires = ["hatchling"] 32 | build-backend = "hatchling.build" 33 | 34 | [tool.rye] 35 | managed = true 36 | dev-dependencies = [ 37 | "pytest>=7.0.1", 38 | "pytest-dependency>=0.5.1", 39 | "build>=1.1.1", 40 | "twine>=5.0.0", 41 | "pillow>=8.4.0", 42 | "omegaconf>=2.3.0", 43 | "jupyterlab>=3.2.9", 44 | "pkginfo>=1.10.0", 45 | ] 46 | 47 | [tool.hatch.metadata] 48 | allow-direct-references = true 49 | 50 | [tool.pytest.ini_options] 51 | markers = [ 52 | "dependency: mark a test as dependency for other tests", 53 | ] 54 | -------------------------------------------------------------------------------- /tests/test_basics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from novita_client import * 3 | 4 | 5 | @pytest.mark.dependency() 6 | def test_txt2img_v3(): 7 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 8 | res = client.txt2img_v3( 9 | model_name='dreamshaper_8_93211.safetensors', 10 | prompt="a cute girl", 11 | width=384, 12 | height=512, 13 | image_num=1, 14 | guidance_scale=7.5, 15 | seed=12345, 16 | sampler_name=Samplers.EULER_A 17 | ) 18 | assert (len(res.images) == 1) 19 | test_path = os.path.join(os.path.abspath( 20 | os.path.dirname(__name__)), "tests/data") 21 | if not os.path.exists(test_path): 22 | os.makedirs(test_path) 23 | base64_to_image(res.images_encoded[0]).save(os.path.join( 24 | test_path, 'test_txt2img_v3.png')) 25 | 26 | 27 | @pytest.mark.dependency(depends=["test_txt2img_v3"]) 28 | def test_img2img_v3(): 29 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 30 | init_image = os.path.join(os.path.abspath( 31 | os.path.dirname(__name__)), "tests/data/test_txt2img_v3.png") 32 | res = client.img2img_v3( 33 | model_name='dreamshaper_8_93211.safetensors', 34 | prompt="a cute girl", 35 | width=384, 36 | height=512, 37 | image_num=1, 38 | guidance_scale=7.5, 39 | seed=12345, 40 | steps=20, 41 | sampler_name=Samplers.EULER_A, 42 | input_image=init_image 43 | ) 44 | assert (len(res.images) == 1) 45 | -------------------------------------------------------------------------------- /examples/instantid.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from novita_client import NovitaClient, InstantIDControlnetUnit 4 | import base64 5 | 6 | 7 | 8 | if __name__ == '__main__': 9 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 10 | 11 | res = client.instant_id( 12 | model_name="sdxlUnstableDiffusers_v8HEAVENSWRATH_133813.safetensors", 13 | face_images=[ 14 | "https://raw.githubusercontent.com/InstantID/InstantID/main/examples/yann-lecun_resize.jpg", 15 | ], 16 | prompt="Flat illustration, a Chinese a man, ancient style, wearing a red cloth, smile face, white skin, clean background, fireworks blooming, red lanterns", 17 | negative_prompt="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", 18 | id_strength=0.8, 19 | adapter_strength=0.8, 20 | steps=20, 21 | seed=42, 22 | width=1024, 23 | height=1024, 24 | controlnets=[ 25 | InstantIDControlnetUnit( 26 | model_name='controlnet-openpose-sdxl-1.0', 27 | strength=0.4, 28 | preprocessor='openpose', 29 | ), 30 | InstantIDControlnetUnit( 31 | model_name='controlnet-canny-sdxl-1.0', 32 | strength=0.3, 33 | preprocessor='canny', 34 | ), 35 | ], 36 | response_image_type='jpeg', 37 | enterprise_plan=False, 38 | ) 39 | 40 | print('res:', res) 41 | 42 | if hasattr(res, 'images_encoded'): 43 | with open(f"instantid.png", "wb") as f: 44 | f.write(base64.b64decode(res.images_encoded[0])) 45 | -------------------------------------------------------------------------------- /examples/txt2img-with-lora.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | import os 5 | from novita_client import NovitaClient, Txt2ImgV3LoRA, Samplers, ProgressResponseStatusCode, ModelType, add_lora_to_prompt, save_image 6 | from novita_client.utils import base64_to_image, input_image_to_pil 7 | from PIL import Image 8 | 9 | 10 | def make_image_grid(images, rows: int, cols: int, resize: int = None): 11 | """ 12 | Prepares a single grid of images. Useful for visualization purposes. 13 | """ 14 | assert len(images) == rows * cols 15 | 16 | if resize is not None: 17 | images = [img.resize((resize, resize)) for img in images] 18 | 19 | w, h = images[0].size 20 | grid = Image.new("RGB", size=(cols * w, rows * h)) 21 | 22 | for i, img in enumerate(images): 23 | grid.paste(img, box=(i % cols * w, i // cols * h)) 24 | return grid 25 | 26 | 27 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 28 | 29 | res1 = client.txt2img_v3( 30 | prompt="a photo of handsome man, close up", 31 | image_num=1, 32 | guidance_scale=7.0, 33 | sampler_name=Samplers.DPMPP_M_KARRAS, 34 | model_name="dreamshaper_8_93211.safetensors", 35 | height=512, 36 | width=512, 37 | seed=1024, 38 | ) 39 | res2 = client.txt2img_v3( 40 | prompt="a photo of handsome man, close up", 41 | image_num=1, 42 | guidance_scale=7.0, 43 | sampler_name=Samplers.DPMPP_M_KARRAS, 44 | model_name="dreamshaper_8_93211.safetensors", 45 | height=512, 46 | width=512, 47 | seed=1024, 48 | loras=[ 49 | Txt2ImgV3LoRA( 50 | model_name="add_detail_44319", 51 | strength=0.9, 52 | ) 53 | ] 54 | ) 55 | 56 | make_image_grid([base64_to_image(res1.images_encoded[0]), base64_to_image(res2.images_encoded[0])], 1, 2, 512).save("./txt2img-lora-compare.png") 57 | -------------------------------------------------------------------------------- /examples/txt2img-with-refiner.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from novita_client import NovitaClient, Txt2ImgV3Refiner, Samplers 4 | from novita_client.utils import base64_to_image 5 | from PIL import Image 6 | 7 | 8 | def make_image_grid(images, rows: int, cols: int, resize: int = None): 9 | """ 10 | Prepares a single grid of images. Useful for visualization purposes. 11 | """ 12 | assert len(images) == rows * cols 13 | 14 | if resize is not None: 15 | images = [img.resize((resize, resize)) for img in images] 16 | 17 | w, h = images[0].size 18 | grid = Image.new("RGB", size=(cols * w, rows * h)) 19 | 20 | for i, img in enumerate(images): 21 | grid.paste(img, box=(i % cols * w, i // cols * h)) 22 | return grid 23 | 24 | 25 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 26 | 27 | r1 = client.txt2img_v3( 28 | model_name='sd_xl_base_1.0.safetensors', 29 | prompt='a astronaut riding a bike on the moon', 30 | width=1024, 31 | height=1024, 32 | image_num=1, 33 | guidance_scale=7.5, 34 | sampler_name=Samplers.EULER_A, 35 | ) 36 | 37 | r2 = client.txt2img_v3( 38 | model_name='sd_xl_base_1.0.safetensors', 39 | prompt='a astronaut riding a bike on the moon', 40 | width=1024, 41 | height=1024, 42 | image_num=1, 43 | guidance_scale=7.5, 44 | sampler_name=Samplers.EULER_A, 45 | refiner=Txt2ImgV3Refiner( 46 | switch_at=0.7 47 | ) 48 | ) 49 | 50 | r3 = client.txt2img_v3( 51 | model_name='sd_xl_base_1.0.safetensors', 52 | prompt='a astronaut riding a bike on the moon', 53 | width=1024, 54 | height=1024, 55 | image_num=1, 56 | guidance_scale=7.5, 57 | sampler_name=Samplers.EULER_A, 58 | refiner=Txt2ImgV3Refiner( 59 | switch_at=0.5 60 | ) 61 | ) 62 | 63 | 64 | make_image_grid([base64_to_image(r1.images_encoded[0]), base64_to_image(r2.images_encoded[0]), base64_to_image(r3.images_encoded[0])], 1, 3, 1024).save("./txt2img-refiner-compare.png") 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | tests/data 162 | .DS_Store 163 | gradio_examples/ -------------------------------------------------------------------------------- /src/novita_client/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | import base64 5 | import logging 6 | from io import BytesIO 7 | from multiprocessing.pool import ThreadPool 8 | 9 | import requests 10 | from PIL import Image, ImageOps 11 | 12 | from . import settings 13 | from .proto import * 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def batch_download_images(image_links): 19 | def _download(image_link): 20 | attempts = settings.DEFAULT_DOWNLOAD_IMAGE_ATTEMPTS 21 | while attempts > 0: 22 | try: 23 | response = requests.get( 24 | image_link, timeout=settings.DEFAULT_DOWNLOAD_ONE_IMAGE_TIMEOUT) 25 | return response.content 26 | except Exception: 27 | logger.warning("Failed to download image, retrying...") 28 | attempts -= 1 29 | return None 30 | 31 | pool = ThreadPool() 32 | applied = [] 33 | for img_url in image_links: 34 | applied.append(pool.apply_async(_download, (img_url, ))) 35 | ret = [r.get() for r in applied] 36 | return [_ for _ in ret if _ is not None] 37 | 38 | 39 | def save_image(image_bytes, name): 40 | with open(name, "wb") as f: 41 | f.write(image_bytes) 42 | 43 | 44 | def read_image(name): 45 | with open(name, "rb") as f: 46 | return f.read() 47 | 48 | 49 | def read_image_to_base64(name): 50 | with open(name, "rb") as f: 51 | return base64.b64encode(f.read()).decode('utf-8') 52 | 53 | 54 | def image_to_base64(image: Image.Image, format=None) -> str: 55 | buffered = BytesIO() 56 | if format is None: 57 | format = image.format 58 | if not format: 59 | format = "PNG" 60 | image.save(buffered, format) 61 | return base64.b64encode(buffered.getvalue()).decode('ascii') 62 | 63 | 64 | def base64_to_image(base64_image: str) -> Image: 65 | # convert base64 string to image 66 | image = Image.open(BytesIO(base64.b64decode(base64_image))) 67 | image = ImageOps.exif_transpose(image) 68 | return image.convert("RGB") 69 | 70 | 71 | def add_lora_to_prompt(prompt: str, lora_name: str, weight: float = 1.0) -> str: 72 | prompt_split = [s.strip() for s in prompt.split(",")] 73 | ret = [] 74 | replace = False 75 | for prompt_chunk in prompt_split: 76 | if prompt_chunk.startswith("".format(lora_name, weight)) 78 | replace = True 79 | else: 80 | ret.append(prompt_chunk) 81 | if not replace: 82 | ret.append("".format(lora_name, weight)) 83 | return ", ".join(ret) 84 | 85 | 86 | def input_image_to_pil(image) -> Image.Image: 87 | def _convert_to_pil(image): 88 | if isinstance(image, str): 89 | if os.path.exists(image): 90 | return Image.open(BytesIO(read_image(image))) 91 | 92 | if image.startswith("http") or image.startswith("https"): 93 | return Image.open(BytesIO(batch_download_images([image])[0])) 94 | 95 | if isinstance(image, os.PathLike): 96 | return Image.open(BytesIO(read_image(str(image)))) 97 | 98 | if isinstance(image, Image.Image): 99 | return image 100 | raise ValueError("Unknown image type: {}".format(type(image))) 101 | 102 | return ImageOps.exif_transpose(_convert_to_pil(image)) 103 | 104 | 105 | def input_image_to_base64(image) -> str: 106 | if isinstance(image, str): 107 | if os.path.exists(image): 108 | return read_image_to_base64(image) 109 | 110 | if image.startswith("http") or image.startswith("https"): 111 | return base64.b64encode(batch_download_images([image])[0]).decode('ascii') 112 | 113 | # assume it is a base64 string 114 | return image 115 | 116 | if isinstance(image, os.PathLike): 117 | return read_image_to_base64(str(image)) 118 | 119 | if isinstance(image, Image.Image): 120 | return image_to_base64(image) 121 | raise ValueError("Unknown image type: {}".format(type(image))) 122 | -------------------------------------------------------------------------------- /tests/test_enhance.py: -------------------------------------------------------------------------------- 1 | from novita_client import * 2 | from novita_client.utils import save_image, read_image_to_base64, base64_to_image 3 | import os 4 | from PIL import Image 5 | import random 6 | 7 | 8 | def make_image_grid(images, rows: int, cols: int, resize: int = None): 9 | """ 10 | Prepares a single grid of images. Useful for visualization purposes. 11 | """ 12 | assert len(images) == rows * cols 13 | 14 | if resize is not None: 15 | images = [img.resize((resize, resize)) for img in images] 16 | 17 | w, h = images[0].size 18 | grid = Image.new("RGB", size=(cols * w, rows * h)) 19 | 20 | for i, img in enumerate(images): 21 | grid.paste(img, box=(i % cols * w, i // cols * h)) 22 | return grid 23 | 24 | 25 | def test_cleanup(): 26 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 27 | 28 | res = client.cleanup( 29 | image="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png", 30 | mask="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" 31 | ) 32 | 33 | assert (res.image_file is not None) 34 | 35 | test_path = os.path.join(os.path.abspath( 36 | os.path.dirname(__name__)), "tests/data") 37 | if not os.path.exists(test_path): 38 | os.makedirs(test_path) 39 | base64_to_image(res.image_file).save(os.path.join(test_path, "test_cleanup.png")) 40 | 41 | def test_remove_background(): 42 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 43 | 44 | res = client.remove_background( 45 | image="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png", 46 | ) 47 | 48 | assert (res.image_file is not None) 49 | 50 | test_path = os.path.join(os.path.abspath( 51 | os.path.dirname(__name__)), "tests/data") 52 | 53 | if not os.path.exists(test_path): 54 | os.makedirs(test_path) 55 | 56 | base64_to_image(res.image_file).save(os.path.join(test_path, f"test_remove_background.{res.image_type}")) 57 | 58 | 59 | def test_remove_text(): 60 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 61 | 62 | res = client.remove_text( 63 | image="https://cf-images.novitai.com/sdk-cases/remove_text_example_1.jpg/public" 64 | ) 65 | 66 | assert (res.image_file is not None) 67 | 68 | test_path = os.path.join(os.path.abspath( 69 | os.path.dirname(__name__)), "tests/data") 70 | 71 | if not os.path.exists(test_path): 72 | os.makedirs(test_path) 73 | 74 | base64_to_image(res.image_file).save(os.path.join(test_path, f"test_remove_text.{res.image_type}")) 75 | 76 | 77 | def test_reimagine(): 78 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 79 | 80 | res = client.reimagine( 81 | image="https://madera.objects.liquidweb.services/photos/20371-yosemite-may-yosemite-falls-waterfalls-cooks-meadow-spring-2023-Rectangle-600x400.jpg", 82 | ) 83 | 84 | assert (res.image_file is not None) 85 | 86 | test_path = os.path.join(os.path.abspath( 87 | os.path.dirname(__name__)), "tests/data") 88 | 89 | if res.image_type == "gif": 90 | return 91 | 92 | base64_to_image(res.image_file).save(os.path.join(os.path.abspath(test_path), f"test_reimagine.{res.image_type}")) 93 | 94 | def test_replace_background(): 95 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 96 | 97 | res = client.replace_background( 98 | image="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png", 99 | prompt="beautify beach" 100 | ) 101 | 102 | assert (res.image_file is not None) 103 | 104 | test_path = os.path.join(os.path.abspath( 105 | os.path.dirname(__name__)), "tests/data") 106 | 107 | base64_to_image(res.image_file).save(os.path.join(os.path.abspath(test_path), f"test_replace_background.{res.image_type}")) 108 | 109 | def test_merge_face(): 110 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 111 | 112 | res = client.merge_face( 113 | image="https://cf-images.novitai.com/sdk-cases/merge_face_example_1_1.png/public", 114 | face_image="https://cf-images.novitai.com/sdk-cases/merge_face_example_1_2.png/public", 115 | ) 116 | 117 | assert (res.image_file is not None) 118 | 119 | test_path = os.path.join(os.path.abspath( 120 | os.path.dirname(__name__)), "tests/data") 121 | 122 | base64_to_image(res.image_file).save(os.path.join(os.path.abspath(test_path), f"test_merge_face.{res.image_type}")) 123 | -------------------------------------------------------------------------------- /requirements-dev.lock: -------------------------------------------------------------------------------- 1 | # generated by rye 2 | # use `rye lock` or `rye sync` to update this lockfile 3 | # 4 | # last locked with the following flags: 5 | # pre: false 6 | # features: [] 7 | # all-features: false 8 | # with-sources: false 9 | 10 | -e file:. 11 | aiofiles==22.1.0 12 | # via ypy-websocket 13 | aiosqlite==0.19.0 14 | # via ypy-websocket 15 | antlr4-python3-runtime==4.9.3 16 | # via omegaconf 17 | anyio==3.7.1 18 | # via jupyter-server 19 | argon2-cffi==23.1.0 20 | # via jupyter-server 21 | # via nbclassic 22 | # via notebook 23 | argon2-cffi-bindings==21.2.0 24 | # via argon2-cffi 25 | arrow==1.2.3 26 | # via isoduration 27 | attrs==23.1.0 28 | # via jsonschema 29 | babel==2.13.0 30 | # via jupyterlab-server 31 | backcall==0.2.0 32 | # via ipython 33 | beautifulsoup4==4.12.2 34 | # via nbconvert 35 | bleach==6.0.0 36 | # via nbconvert 37 | # via readme-renderer 38 | build==1.1.1 39 | certifi==2023.7.22 40 | # via requests 41 | cffi==1.15.1 42 | # via argon2-cffi-bindings 43 | # via cryptography 44 | charset-normalizer==3.2.0 45 | # via requests 46 | cryptography==41.0.3 47 | # via secretstorage 48 | dataclass-wizard==0.22.2 49 | # via novita-client 50 | debugpy==1.7.0 51 | # via ipykernel 52 | decorator==5.1.1 53 | # via ipython 54 | defusedxml==0.7.1 55 | # via nbconvert 56 | docutils==0.20.1 57 | # via readme-renderer 58 | entrypoints==0.4 59 | # via jupyter-client 60 | fastjsonschema==2.18.1 61 | # via nbformat 62 | fqdn==1.5.1 63 | # via jsonschema 64 | idna==3.4 65 | # via anyio 66 | # via jsonschema 67 | # via requests 68 | importlib-metadata==6.7.0 69 | # via keyring 70 | # via twine 71 | iniconfig==2.0.0 72 | # via pytest 73 | ipykernel==6.16.2 74 | # via nbclassic 75 | # via notebook 76 | ipython==7.34.0 77 | # via ipykernel 78 | # via jupyterlab 79 | ipython-genutils==0.2.0 80 | # via nbclassic 81 | # via notebook 82 | isoduration==20.11.0 83 | # via jsonschema 84 | jaraco-classes==3.2.3 85 | # via keyring 86 | jedi==0.19.1 87 | # via ipython 88 | jeepney==0.8.0 89 | # via keyring 90 | # via secretstorage 91 | jinja2==3.1.2 92 | # via jupyter-server 93 | # via jupyterlab 94 | # via jupyterlab-server 95 | # via nbclassic 96 | # via nbconvert 97 | # via notebook 98 | json5==0.9.14 99 | # via jupyterlab-server 100 | jsonpointer==2.4 101 | # via jsonschema 102 | jsonschema==4.17.3 103 | # via jupyter-events 104 | # via jupyterlab-server 105 | # via nbformat 106 | jupyter-client==7.4.9 107 | # via ipykernel 108 | # via jupyter-server 109 | # via nbclassic 110 | # via nbclient 111 | # via notebook 112 | jupyter-core==4.12.0 113 | # via jupyter-client 114 | # via jupyter-server 115 | # via jupyterlab 116 | # via nbclassic 117 | # via nbclient 118 | # via nbconvert 119 | # via nbformat 120 | # via notebook 121 | jupyter-events==0.6.3 122 | # via jupyter-server-fileid 123 | jupyter-server==1.24.0 124 | # via jupyter-server-fileid 125 | # via jupyterlab 126 | # via jupyterlab-server 127 | # via nbclassic 128 | # via notebook-shim 129 | jupyter-server-fileid==0.9.0 130 | # via jupyter-server-ydoc 131 | jupyter-server-ydoc==0.8.0 132 | # via jupyterlab 133 | jupyter-ydoc==0.2.5 134 | # via jupyter-server-ydoc 135 | # via jupyterlab 136 | jupyterlab==3.6.6 137 | jupyterlab-pygments==0.2.2 138 | # via nbconvert 139 | jupyterlab-server==2.24.0 140 | # via jupyterlab 141 | keyring==24.1.1 142 | # via twine 143 | markdown-it-py==2.2.0 144 | # via rich 145 | markupsafe==2.1.3 146 | # via jinja2 147 | # via nbconvert 148 | matplotlib-inline==0.1.6 149 | # via ipykernel 150 | # via ipython 151 | mdurl==0.1.2 152 | # via markdown-it-py 153 | mistune==3.0.2 154 | # via nbconvert 155 | more-itertools==9.1.0 156 | # via jaraco-classes 157 | nbclassic==1.0.0 158 | # via jupyterlab 159 | # via notebook 160 | nbclient==0.7.4 161 | # via nbconvert 162 | nbconvert==7.6.0 163 | # via jupyter-server 164 | # via nbclassic 165 | # via notebook 166 | nbformat==5.8.0 167 | # via jupyter-server 168 | # via nbclassic 169 | # via nbclient 170 | # via nbconvert 171 | # via notebook 172 | nest-asyncio==1.5.8 173 | # via ipykernel 174 | # via jupyter-client 175 | # via nbclassic 176 | # via notebook 177 | notebook==6.5.6 178 | # via jupyterlab 179 | notebook-shim==0.2.3 180 | # via nbclassic 181 | omegaconf==2.3.0 182 | packaging==23.1 183 | # via build 184 | # via ipykernel 185 | # via jupyter-server 186 | # via jupyterlab 187 | # via jupyterlab-server 188 | # via nbconvert 189 | # via pytest 190 | pandocfilters==1.5.0 191 | # via nbconvert 192 | parso==0.8.3 193 | # via jedi 194 | pexpect==4.8.0 195 | # via ipython 196 | pickleshare==0.7.5 197 | # via ipython 198 | pillow==10.2.0 199 | # via novita-client 200 | pkginfo==1.10.0 201 | # via twine 202 | pluggy==1.2.0 203 | # via pytest 204 | prometheus-client==0.17.1 205 | # via jupyter-server 206 | # via nbclassic 207 | # via notebook 208 | prompt-toolkit==3.0.39 209 | # via ipython 210 | psutil==5.9.6 211 | # via ipykernel 212 | ptyprocess==0.7.0 213 | # via pexpect 214 | # via terminado 215 | pycparser==2.21 216 | # via cffi 217 | pygments==2.16.0 218 | # via ipython 219 | # via nbconvert 220 | # via readme-renderer 221 | # via rich 222 | pyproject-hooks==1.0.0 223 | # via build 224 | pyrsistent==0.19.3 225 | # via jsonschema 226 | pytest==7.4.0 227 | # via pytest-dependency 228 | pytest-dependency==0.5.1 229 | python-dateutil==2.8.2 230 | # via arrow 231 | # via jupyter-client 232 | python-json-logger==2.0.7 233 | # via jupyter-events 234 | pyyaml==6.0.1 235 | # via jupyter-events 236 | # via omegaconf 237 | pyzmq==24.0.1 238 | # via ipykernel 239 | # via jupyter-client 240 | # via jupyter-server 241 | # via nbclassic 242 | # via notebook 243 | readme-renderer==37.3 244 | # via twine 245 | requests==2.31.0 246 | # via jupyterlab-server 247 | # via novita-client 248 | # via requests-toolbelt 249 | # via twine 250 | requests-toolbelt==1.0.0 251 | # via twine 252 | rfc3339-validator==0.1.4 253 | # via jsonschema 254 | # via jupyter-events 255 | rfc3986==2.0.0 256 | # via twine 257 | rfc3986-validator==0.1.1 258 | # via jsonschema 259 | # via jupyter-events 260 | rich==13.5.2 261 | # via twine 262 | secretstorage==3.3.3 263 | # via keyring 264 | send2trash==1.8.2 265 | # via jupyter-server 266 | # via nbclassic 267 | # via notebook 268 | setuptools==68.0.0 269 | # via ipython 270 | six==1.16.0 271 | # via bleach 272 | # via python-dateutil 273 | # via rfc3339-validator 274 | sniffio==1.3.0 275 | # via anyio 276 | soupsieve==2.4.1 277 | # via beautifulsoup4 278 | terminado==0.17.1 279 | # via jupyter-server 280 | # via nbclassic 281 | # via notebook 282 | tinycss2==1.2.1 283 | # via nbconvert 284 | tornado==6.2 285 | # via ipykernel 286 | # via jupyter-client 287 | # via jupyter-server 288 | # via jupyterlab 289 | # via nbclassic 290 | # via notebook 291 | # via terminado 292 | traitlets==5.9.0 293 | # via ipykernel 294 | # via ipython 295 | # via jupyter-client 296 | # via jupyter-core 297 | # via jupyter-events 298 | # via jupyter-server 299 | # via matplotlib-inline 300 | # via nbclassic 301 | # via nbclient 302 | # via nbconvert 303 | # via nbformat 304 | # via notebook 305 | twine==5.0.0 306 | uri-template==1.3.0 307 | # via jsonschema 308 | urllib3==2.0.4 309 | # via requests 310 | # via twine 311 | wcwidth==0.2.8 312 | # via prompt-toolkit 313 | webcolors==1.13 314 | # via jsonschema 315 | webencodings==0.5.1 316 | # via bleach 317 | # via tinycss2 318 | websocket-client==1.6.1 319 | # via jupyter-server 320 | y-py==0.6.2 321 | # via jupyter-ydoc 322 | # via ypy-websocket 323 | ypy-websocket==0.8.4 324 | # via jupyter-server-ydoc 325 | zipp==3.15.0 326 | # via importlib-metadata 327 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Novita AI Python SDK 2 | 3 | This SDK is based on the official [API documentation](https://docs.novita.ai/). 4 | 5 | **Join our discord server for help:** 6 | 7 | [![](https://dcbadge.vercel.app/api/server/Mqx7nWYzDF)](https://discord.com/invite/Mqx7nWYzDF) 8 | 9 | ## Installation 10 | 11 | ```bash 12 | pip install novita-client 13 | ``` 14 | 15 | ## Examples 16 | 17 | - [fine tune example](https://colab.research.google.com/drive/1j_ii9TN67nuauvc3PiauwZnC2lT62tGF?usp=sharing) 18 | - [cleanup](./examples/cleanup.py) 19 | - [controlnet](./examples/controlnet.py) 20 | - [img2img](./examples/img2img.py) 21 | - [img2video](./examples/img2video.py) 22 | - [inpainting](./examples/inpainting.py) 23 | - [instantid](./examples/instantid.py) 24 | - [merge-face](./examples/merge-face.py) 25 | - [model-search](./examples/model-search.py) 26 | - [reimagine](./examples/reimagine.py) 27 | - [remove-background](./examples/remove-background.py) 28 | - [remove-text](./examples/remove-text.py) 29 | - [replace-background](./examples/replace-background.py) 30 | - [txt2img-with-hiresfix](./examples/txt2img-with-hiresfix.py) 31 | - [txt2img-with-lora](./examples/txt2img-with-lora.py) 32 | - [txt2img-with-refiner](./examples/txt2img-with-refiner.py) 33 | - [txt2video](./examples/txt2video.py) 34 | ## Code Examples 35 | 36 | ### cleanup 37 | ```python 38 | import os 39 | 40 | from novita_client import NovitaClient 41 | from novita_client.utils import base64_to_image 42 | 43 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 44 | res = client.cleanup( 45 | image="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png", 46 | mask="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" 47 | ) 48 | 49 | base64_to_image(res.image_file).save("./cleanup.png") 50 | ``` 51 | 52 | ### controlnet 53 | ```python 54 | #!/usr/bin/env python 55 | # -*- coding: UTF-8 -*- 56 | 57 | import os 58 | 59 | from novita_client import NovitaClient, Img2ImgV3Request, Img2ImgV3ControlNetUnit, ControlnetUnit, Samplers, Img2ImgV3Embedding 60 | from novita_client.utils import base64_to_image 61 | 62 | 63 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 64 | res = client.img2img_v3( 65 | input_image="https://img.freepik.com/premium-photo/close-up-dogs-face-with-big-smile-generative-ai_900101-62851.jpg", 66 | model_name="dreamshaper_8_93211.safetensors", 67 | prompt="a cute dog", 68 | sampler_name=Samplers.DPMPP_M_KARRAS, 69 | width=512, 70 | height=512, 71 | steps=30, 72 | controlnet_units=[ 73 | Img2ImgV3ControlNetUnit( 74 | image_base64="https://img.freepik.com/premium-photo/close-up-dogs-face-with-big-smile-generative-ai_900101-62851.jpg", 75 | model_name="control_v11f1p_sd15_depth", 76 | strength=1.0 77 | ) 78 | ], 79 | embeddings=[Img2ImgV3Embedding(model_name=_) for _ in [ 80 | "BadDream_53202", 81 | ]], 82 | seed=-1, 83 | ) 84 | 85 | 86 | base64_to_image(res.images_encoded[0]).save("./img2img-controlnet.png") 87 | ``` 88 | 89 | ### img2img 90 | ```python 91 | import pdb 92 | import os 93 | 94 | from novita_client import NovitaClient, Img2ImgV3ControlNetUnit, ControlNetPreprocessor, Img2ImgV3Embedding 95 | from novita_client.utils import base64_to_image, input_image_to_pil 96 | 97 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 98 | res = client.img2img_v3( 99 | model_name="MeinaHentai_V5.safetensors", 100 | steps=30, 101 | height=512, 102 | width=512, 103 | input_image="https://img.freepik.com/premium-photo/close-up-dogs-face-with-big-smile-generative-ai_900101-62851.jpg", 104 | prompt="1 cute dog", 105 | strength=0.5, 106 | guidance_scale=7, 107 | embeddings=[Img2ImgV3Embedding(model_name=_) for _ in [ 108 | "bad-image-v2-39000", 109 | "verybadimagenegative_v1.3_21434", 110 | "BadDream_53202", 111 | "badhandv4_16755", 112 | "easynegative_8955.safetensors"]], 113 | seed=-1, 114 | sampler_name="DPM++ 2M Karras", 115 | clip_skip=2, 116 | # controlnet_units=[Img2ImgV3ControlNetUnit( 117 | # model_name="control_v11f1p_sd15_depth", 118 | # preprocessor="depth", 119 | # image_base64="./20240309-003206.jpeg", 120 | # strength=1.0 121 | # )] 122 | ) 123 | 124 | base64_to_image(res.images_encoded[0]).save("./img2img.png") 125 | ``` 126 | 127 | ### img2video 128 | ```python 129 | import os 130 | 131 | from novita_client import NovitaClient 132 | from novita_client.utils import base64_to_image 133 | 134 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URNOVITA_API_URII', None)) 135 | res = client.img2video( 136 | model_name="SVD-XT", 137 | steps=30, 138 | frames_num=25, 139 | image="https://replicate.delivery/pbxt/JvLi9smWKKDfQpylBYosqQRfPKZPntuAziesp0VuPjidq61n/rocket.png", 140 | enable_frame_interpolation=True 141 | ) 142 | 143 | 144 | with open("test.mp4", "wb") as f: 145 | f.write(res.video_bytes[0]) 146 | ``` 147 | 148 | ### inpainting 149 | ```python 150 | import os 151 | import base64 152 | from novita_client import NovitaClient 153 | from novita_client.utils import base64_to_image 154 | 155 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 156 | res = client.inpainting( 157 | model_name = "realisticVisionV40_v40VAE-inpainting_81543.safetensors", 158 | image="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png", 159 | mask="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png", 160 | seed=1, 161 | guidance_scale=15, 162 | steps = 20, 163 | image_num = 4, 164 | prompt = "black rabbit", 165 | negative_prompt = "white rabbit", 166 | sampler_name = "Euler a", 167 | inpainting_full_res = 1, 168 | inpainting_full_res_padding = 32, 169 | inpainting_mask_invert = 0, 170 | initial_noise_multiplier = 1, 171 | mask_blur = 1, 172 | clip_skip = 1, 173 | strength = 0.85, 174 | ) 175 | with open("result/result_image/inpaintingsdk.jpeg", "wb") as image_file: 176 | image_file.write(base64.b64decode(res.images_encoded[0]))``` 177 | 178 | ### instantid 179 | ```python 180 | 181 | import os 182 | from novita_client import NovitaClient, InstantIDControlnetUnit 183 | import base64 184 | 185 | 186 | 187 | if __name__ == '__main__': 188 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 189 | 190 | res = client.instant_id( 191 | model_name="sdxlUnstableDiffusers_v8HEAVENSWRATH_133813.safetensors", 192 | face_images=[ 193 | "https://raw.githubusercontent.com/InstantID/InstantID/main/examples/yann-lecun_resize.jpg", 194 | ], 195 | prompt="Flat illustration, a Chinese a man, ancient style, wearing a red cloth, smile face, white skin, clean background, fireworks blooming, red lanterns", 196 | negative_prompt="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", 197 | id_strength=0.8, 198 | adapter_strength=0.8, 199 | steps=20, 200 | seed=42, 201 | width=1024, 202 | height=1024, 203 | controlnets=[ 204 | InstantIDControlnetUnit( 205 | model_name='controlnet-openpose-sdxl-1.0', 206 | strength=0.4, 207 | preprocessor='openpose', 208 | ), 209 | InstantIDControlnetUnit( 210 | model_name='controlnet-canny-sdxl-1.0', 211 | strength=0.3, 212 | preprocessor='canny', 213 | ), 214 | ], 215 | response_image_type='jpeg', 216 | enterprise_plan=False, 217 | ) 218 | 219 | print('res:', res) 220 | 221 | if hasattr(res, 'images_encoded'): 222 | with open(f"instantid.png", "wb") as f: 223 | f.write(base64.b64decode(res.images_encoded[0])) 224 | ``` 225 | 226 | ### merge-face 227 | ```python 228 | import os 229 | 230 | from novita_client import NovitaClient 231 | from novita_client.utils import base64_to_image 232 | 233 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 234 | res = client.merge_face( 235 | image="https://toppng.com/uploads/preview/cut-out-people-png-personas-en-formato-11563277290kozkuzsos5.png", 236 | face_image="https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQDy7sXtuvCNUQoQZvTbLRbX6qK9_kP3PlQfg&s", 237 | enterprise_plan=False, 238 | ) 239 | 240 | base64_to_image(res.image_file).save("./merge_face.png") 241 | ``` 242 | 243 | ### model-search 244 | ```python 245 | #!/usr/bin/env python 246 | # -*- coding: UTF-8 -*- 247 | 248 | from novita_client import NovitaClient, ModelType 249 | # get your api key refer to https://docs.novita.ai/get-started/ 250 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 251 | 252 | # filter by model type 253 | print("lora count", len(client.models().filter_by_type(ModelType.LORA))) 254 | print("checkpoint count", len(client.models().filter_by_type(ModelType.CHECKPOINT))) 255 | print("textinversion count", len( 256 | client.models().filter_by_type(ModelType.TEXT_INVERSION))) 257 | print("vae count", len(client.models().filter_by_type(ModelType.VAE))) 258 | print("controlnet count", len(client.models().filter_by_type(ModelType.CONTROLNET))) 259 | 260 | 261 | # filter by civitai tags 262 | client.models().filter_by_civi_tags('anime') 263 | 264 | # filter by nsfw 265 | client.models().filter_by_nsfw(False) # or True 266 | 267 | # sort by civitai download 268 | client.models().sort_by_civitai_download() 269 | 270 | # chain filters 271 | client.models().\ 272 | filter_by_type(ModelType.CHECKPOINT).\ 273 | filter_by_nsfw(False).\ 274 | filter_by_civitai_tags('anime') 275 | ``` 276 | 277 | ### reimagine 278 | ```python 279 | import os 280 | 281 | from novita_client import NovitaClient 282 | from novita_client.utils import base64_to_image 283 | 284 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 285 | res = client.reimagine( 286 | image="/home/anyisalin/develop/novita-client-python/examples/doodle-generated.png" 287 | ) 288 | 289 | base64_to_image(res.image_file).save("./reimagine.png") 290 | ``` 291 | 292 | ### remove-background 293 | ```python 294 | import os 295 | 296 | from novita_client import NovitaClient 297 | from novita_client.utils import base64_to_image 298 | 299 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 300 | res = client.remove_background( 301 | image="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png", 302 | ) 303 | base64_to_image(res.image_file).save("./remove_background.png") 304 | ``` 305 | 306 | ### remove-text 307 | ```python 308 | import os 309 | 310 | from novita_client import NovitaClient 311 | from novita_client.utils import base64_to_image 312 | 313 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 314 | res = client.remove_text( 315 | image="https://images.uiiiuiii.com/wp-content/uploads/2023/07/i-banner-20230714-1.jpg" 316 | ) 317 | 318 | base64_to_image(res.image_file).save("./remove_text.png") 319 | ``` 320 | 321 | ### replace-background 322 | ```python 323 | import os 324 | 325 | from novita_client import NovitaClient 326 | from novita_client.utils import base64_to_image 327 | 328 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 329 | res = client.replace_background( 330 | image="./telegram-cloud-photo-size-2-5408823814353177899-y.jpg", 331 | prompt="in living room, Christmas tree", 332 | ) 333 | base64_to_image(res.image_file).save("./replace_background.png") 334 | ``` 335 | 336 | ### txt2img-with-hiresfix 337 | ```python 338 | import os 339 | 340 | from novita_client import NovitaClient, Samplers, Txt2ImgV3HiresFix 341 | from novita_client.utils import base64_to_image 342 | 343 | from PIL import Image 344 | 345 | 346 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 347 | res = client.txt2img_v3( 348 | model_name='dreamshaper_8_93211.safetensors', 349 | prompt="a cute girl", 350 | width=384, 351 | height=512, 352 | image_num=1, 353 | guidance_scale=7.5, 354 | seed=12345, 355 | sampler_name=Samplers.EULER_A, 356 | hires_fix=Txt2ImgV3HiresFix( 357 | # upscaler='Latent' 358 | target_width=768, 359 | target_height=1024, 360 | strength=0.5 361 | ) 362 | ) 363 | 364 | 365 | base64_to_image(res.images_encoded[0]).save("./txt2img_with_hiresfix.png") 366 | ``` 367 | 368 | ### txt2img-with-lora 369 | ```python 370 | #!/usr/bin/env python 371 | # -*- coding: UTF-8 -*- 372 | 373 | import os 374 | from novita_client import NovitaClient, Txt2ImgV3LoRA, Samplers, ProgressResponseStatusCode, ModelType, add_lora_to_prompt, save_image 375 | from novita_client.utils import base64_to_image, input_image_to_pil 376 | from PIL import Image 377 | 378 | 379 | def make_image_grid(images, rows: int, cols: int, resize: int = None): 380 | """ 381 | Prepares a single grid of images. Useful for visualization purposes. 382 | """ 383 | assert len(images) == rows * cols 384 | 385 | if resize is not None: 386 | images = [img.resize((resize, resize)) for img in images] 387 | 388 | w, h = images[0].size 389 | grid = Image.new("RGB", size=(cols * w, rows * h)) 390 | 391 | for i, img in enumerate(images): 392 | grid.paste(img, box=(i % cols * w, i // cols * h)) 393 | return grid 394 | 395 | 396 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 397 | 398 | res1 = client.txt2img_v3( 399 | prompt="a photo of handsome man, close up", 400 | image_num=1, 401 | guidance_scale=7.0, 402 | sampler_name=Samplers.DPMPP_M_KARRAS, 403 | model_name="dreamshaper_8_93211.safetensors", 404 | height=512, 405 | width=512, 406 | seed=1024, 407 | ) 408 | res2 = client.txt2img_v3( 409 | prompt="a photo of handsome man, close up", 410 | image_num=1, 411 | guidance_scale=7.0, 412 | sampler_name=Samplers.DPMPP_M_KARRAS, 413 | model_name="dreamshaper_8_93211.safetensors", 414 | height=512, 415 | width=512, 416 | seed=1024, 417 | loras=[ 418 | Txt2ImgV3LoRA( 419 | model_name="add_detail_44319", 420 | strength=0.9, 421 | ) 422 | ] 423 | ) 424 | 425 | make_image_grid([base64_to_image(res1.images_encoded[0]), base64_to_image(res2.images_encoded[0])], 1, 2, 512).save("./txt2img-lora-compare.png") 426 | ``` 427 | 428 | ### txt2img-with-refiner 429 | ```python 430 | import os 431 | 432 | from novita_client import NovitaClient, Txt2ImgV3Refiner, Samplers 433 | from novita_client.utils import base64_to_image 434 | from PIL import Image 435 | 436 | 437 | def make_image_grid(images, rows: int, cols: int, resize: int = None): 438 | """ 439 | Prepares a single grid of images. Useful for visualization purposes. 440 | """ 441 | assert len(images) == rows * cols 442 | 443 | if resize is not None: 444 | images = [img.resize((resize, resize)) for img in images] 445 | 446 | w, h = images[0].size 447 | grid = Image.new("RGB", size=(cols * w, rows * h)) 448 | 449 | for i, img in enumerate(images): 450 | grid.paste(img, box=(i % cols * w, i // cols * h)) 451 | return grid 452 | 453 | 454 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 455 | 456 | r1 = client.txt2img_v3( 457 | model_name='sd_xl_base_1.0.safetensors', 458 | prompt='a astronaut riding a bike on the moon', 459 | width=1024, 460 | height=1024, 461 | image_num=1, 462 | guidance_scale=7.5, 463 | sampler_name=Samplers.EULER_A, 464 | ) 465 | 466 | r2 = client.txt2img_v3( 467 | model_name='sd_xl_base_1.0.safetensors', 468 | prompt='a astronaut riding a bike on the moon', 469 | width=1024, 470 | height=1024, 471 | image_num=1, 472 | guidance_scale=7.5, 473 | sampler_name=Samplers.EULER_A, 474 | refiner=Txt2ImgV3Refiner( 475 | switch_at=0.7 476 | ) 477 | ) 478 | 479 | r3 = client.txt2img_v3( 480 | model_name='sd_xl_base_1.0.safetensors', 481 | prompt='a astronaut riding a bike on the moon', 482 | width=1024, 483 | height=1024, 484 | image_num=1, 485 | guidance_scale=7.5, 486 | sampler_name=Samplers.EULER_A, 487 | refiner=Txt2ImgV3Refiner( 488 | switch_at=0.5 489 | ) 490 | ) 491 | 492 | 493 | make_image_grid([base64_to_image(r1.images_encoded[0]), base64_to_image(r2.images_encoded[0]), base64_to_image(r3.images_encoded[0])], 1, 3, 1024).save("./txt2img-refiner-compare.png") 494 | ``` 495 | 496 | ### txt2video 497 | ```python 498 | import os 499 | 500 | from novita_client import NovitaClient 501 | from novita_client.utils import save_image 502 | 503 | client = NovitaClient(os.getenv('NOVITA_API_KEY'), os.getenv('NOVITA_API_URI', None)) 504 | res = client.txt2video( 505 | model_name = "dreamshaper_8_93211.safetensors", 506 | prompts = [{ 507 | "prompt": "A girl, baby, portrait, 5 years old", 508 | "frames": 16,}, 509 | { 510 | "prompt": "A girl, child, portrait, 10 years old", 511 | "frames": 16, 512 | } 513 | ], 514 | steps = 20, 515 | guidance_scale = 10, 516 | height = 512, 517 | width = 768, 518 | clip_skip = 4, 519 | negative_prompt = "a rainy day", 520 | response_video_type = "mp4", 521 | ) 522 | save_image(res.video_bytes[0], 'output.mp4') 523 | ``` 524 | -------------------------------------------------------------------------------- /src/novita_client/novita.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | import logging 5 | from io import BytesIO 6 | from multiprocessing.pool import ThreadPool 7 | from time import sleep 8 | 9 | import requests 10 | 11 | from . import settings 12 | from .exceptions import * 13 | from .proto import * 14 | from .utils import input_image_to_base64, input_image_to_pil 15 | from .version import __version__ 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class NovitaClient: 21 | """NovitaClient is the main entry point for interacting with the Novita API.""" 22 | 23 | def __init__(self, api_key, base_url=None): 24 | self.base_url = base_url 25 | if self.base_url is None: 26 | self.base_url = "https://api.novita.ai" 27 | self.api_key = api_key 28 | self.session = requests.Session() 29 | 30 | if not self.api_key: 31 | raise ValueError("NOVITA_API_KEY environment variable not set") 32 | 33 | # eg: {"all": [proto.ModelInfo], "checkpoint": [proto.ModelInfo], "lora": [proto.ModelInfo]} 34 | self._model_list_cache = None 35 | self._model_list_cache_v3 = None 36 | self._extra_headers = {} 37 | self._default_response_image_type = "jpeg" 38 | 39 | def set_response_image_type(self, image_type: str): 40 | self._default_response_image_type = image_type 41 | 42 | def set_extra_headers(self, headers: dict): 43 | self._extra_headers = headers 44 | 45 | def _get(self, api_path, params=None) -> dict: 46 | headers = { 47 | 'Accept': 'application/json', 48 | 'Content-Type': 'application/json', 49 | 'Authorization': f'Bearer {self.api_key}', 50 | 'User-Agent': "novita-python-sdk/{}".format(__version__), 51 | 'Accept-Encoding': 'gzip, deflate', 52 | } 53 | headers.update(self._extra_headers) 54 | 55 | logger.debug(f"[GET] params: {params}") 56 | 57 | response = self.session.get( 58 | self.base_url + api_path, 59 | headers=headers, 60 | params=params, 61 | timeout=settings.DEFAULT_REQUEST_TIMEOUT, 62 | ) 63 | 64 | logger.debug(f"[GET] {self.base_url + api_path}, headers: {headers} response: {response.content}") 65 | if response.status_code != 200: 66 | logger.error(f"Request failed: {response}") 67 | raise NovitaResponseError( 68 | f"Request failed with status {response.status_code}") 69 | 70 | return response.json() 71 | 72 | def _post(self, api_path, data) -> dict: 73 | headers = { 74 | 'Accept': 'application/json', 75 | 'Content-Type': 'application/json', 76 | 'Authorization': f'Bearer {self.api_key}', 77 | 'User-Agent': "novita-python-sdk/{}".format(__version__), 78 | 'Accept-Encoding': 'gzip, deflate', 79 | } 80 | headers.update(self._extra_headers) 81 | 82 | logger.debug(f"[POST] {self.base_url + api_path}, headers: {headers} data: {data}") 83 | 84 | response = self.session.post( 85 | self.base_url + api_path, 86 | headers=headers, 87 | json=data, 88 | timeout=settings.DEFAULT_REQUEST_TIMEOUT, 89 | ) 90 | 91 | logger.debug(f"[POST] response: {response.content}") 92 | if response.status_code != 200: 93 | logger.error(f"Request failed: {response}") 94 | raise NovitaResponseError( 95 | f"Request failed with status {response.status_code}, {response.content}") 96 | 97 | return response.json() 98 | 99 | def async_task_result(self, task_id: str) -> V3TaskResponse: 100 | response = self._get('/v3/async/task-result', { 101 | 'task_id': task_id, 102 | }) 103 | return V3TaskResponse.from_dict(response) 104 | 105 | def wait_for_task_v3(self, task_id, wait_for: int = 300, callback: callable = None) -> V3TaskResponse: 106 | i = 0 107 | 108 | while i < wait_for: 109 | logger.info(f"Waiting for task {task_id} to complete") 110 | 111 | progress = self.async_task_result(task_id) 112 | 113 | if callback and callable(callback): 114 | try: 115 | callback(progress) 116 | except Exception as e: 117 | logger.error(f"Task {task_id} progress callback failed: {e}") 118 | 119 | if progress.finished(): 120 | logger.info(f"Task {task_id} completed") 121 | logging.debug(f"Task {progress.task.task_type}/{progress.task.task_id} debug_info: {progress.extra.debug_info}") 122 | return progress 123 | 124 | sleep(settings.DEFAULT_POLL_INTERVAL) 125 | i += 1 126 | raise NovitaTimeoutError( 127 | f"Task {task_id} failed to complete in {wait_for} seconds") 128 | 129 | def cleanup(self, image: InputImage, mask: InputImage, response_image_type=None, enterprise_plan: bool=None) -> CleanupResponse: 130 | image_b64 = input_image_to_base64(image) 131 | mask_b64 = input_image_to_base64(mask) 132 | request = CleanupRequest(image_file=image_b64, mask_file=mask_b64) 133 | if response_image_type is None: 134 | request.set_image_type(self._default_response_image_type) 135 | else: 136 | request.set_image_type(response_image_type) 137 | if enterprise_plan is not None: 138 | request.set_enterprise_plan(enterprise_plan) 139 | else: 140 | request.set_enterprise_plan(False) 141 | 142 | 143 | return CleanupResponse.from_dict(self._post('/v3/cleanup', request.to_dict())) 144 | 145 | def remove_background(self, image: InputImage, response_image_type=None, enterprise_plan: bool=None) -> RemoveBackgroundResponse: 146 | image_b64 = input_image_to_base64(image) 147 | request = RemoveBackgroundRequest(image_file=image_b64) 148 | if response_image_type is None: 149 | request.set_image_type(self._default_response_image_type) 150 | else: 151 | request.set_image_type(response_image_type) 152 | if enterprise_plan is not None: 153 | request.set_enterprise_plan(enterprise_plan) 154 | else: 155 | request.set_enterprise_plan(False) 156 | 157 | return RemoveBackgroundResponse.from_dict(self._post('/v3/remove-background', request.to_dict())) 158 | 159 | def remove_text(self, image: InputImage, response_image_type=None, enterprise_plan: bool=None) -> RemoveTextResponse: 160 | image_b64 = input_image_to_base64(image) 161 | request = RemoveTextRequest(image_file=image_b64) 162 | if response_image_type is None: 163 | request.set_image_type(self._default_response_image_type) 164 | else: 165 | request.set_image_type(response_image_type) 166 | if enterprise_plan is not None: 167 | request.set_enterprise_plan(enterprise_plan) 168 | else: 169 | request.set_enterprise_plan(False) 170 | 171 | return RemoveTextResponse.from_dict(self._post('/v3/remove-text', request.to_dict())) 172 | 173 | def reimagine(self, image: InputImage, response_image_type=None, enterprise_plan: bool=None) -> ReimagineResponse: 174 | image_b64 = input_image_to_base64(image) 175 | request = ReimagineRequest(image_file=image_b64) 176 | if response_image_type is None: 177 | request.set_image_type(self._default_response_image_type) 178 | else: 179 | request.set_image_type(response_image_type) 180 | if enterprise_plan is not None: 181 | request.set_enterprise_plan(enterprise_plan) 182 | else: 183 | request.set_enterprise_plan(False) 184 | 185 | return ReimagineResponse.from_dict(self._post('/v3/reimagine', request.to_dict())) 186 | 187 | def replace_background(self, image: InputImage, prompt: str, response_image_type=None, enterprise_plan: bool=None) -> ReplaceBackgroundResponse: 188 | image_b64 = input_image_to_base64(image) 189 | request = ReplaceBackgroundRequest(image_file=image_b64, prompt=prompt) 190 | if response_image_type is None: 191 | request.set_image_type(self._default_response_image_type) 192 | else: 193 | request.set_image_type(response_image_type) 194 | if enterprise_plan is not None: 195 | request.set_enterprise_plan(enterprise_plan) 196 | else: 197 | request.set_enterprise_plan(False) 198 | return ReplaceBackgroundResponse.from_dict(self._post('/v3/replace-background', request.to_dict())) 199 | 200 | def async_txt2video(self,model_name:str, height:int,width:int,steps:int,prompts:List[Txt2VideoPrompt],guidance_scale:float,seed:int=None,negative_prompt: Optional[str] = None,loras:List[Txt2VideoLoRA]=None,\ 201 | embeddings:List[Txt2VideoEmbedding]=None,clip_skip:int=None,closed_loop:bool=None,response_video_type:str=None,enterprise_plan: bool=None) -> Txt2VideoResponse: 202 | request = Txt2VideoRequest(model_name=model_name,height=height,width=width,steps=steps,prompts=prompts,negative_prompt=negative_prompt,guidance_scale=guidance_scale,loras=loras,embeddings=embeddings,clip_skip=clip_skip) 203 | if seed is not None: 204 | request.seed = seed 205 | if closed_loop is not None: 206 | request.closed_loop = closed_loop 207 | if response_video_type is not None: 208 | request.set_video_type(response_video_type) 209 | if enterprise_plan is not None: 210 | request.set_enterprise_plan(enterprise_plan) 211 | else: 212 | request.set_enterprise_plan(False) 213 | return Txt2VideoResponse.from_dict(self._post('/v3/async/txt2video', request.to_dict())) 214 | 215 | def txt2video(self,model_name:str, height:int,width:int,steps:int,prompts:List[Txt2VideoPrompt],guidance_scale:float,negative_prompt: Optional[str] = None,seed:int=None,loras:List[Txt2VideoLoRA]=None,\ 216 | embeddings:List[Txt2VideoEmbedding]=None,clip_skip:int=None,closed_loop:bool=None,response_video_type:str=None,enterprise_plan: bool=None) -> Txt2VideoResponse: 217 | res: Txt2VideoResponse = self.async_txt2video(model_name, height, width, steps, prompts, guidance_scale, seed,negative_prompt, loras, embeddings, clip_skip, closed_loop,response_video_type,enterprise_plan) 218 | final_res = self.wait_for_task_v3(res.task_id) 219 | if final_res.task.status == V3TaskResponseStatus.TASK_STATUS_SUCCEED: 220 | final_res.download_videos() 221 | else: 222 | raise NovitaResponseError(f"") 223 | return final_res 224 | 225 | 226 | def async_img2video(self, image: InputImage, model_name: str, steps: int, frames_num: int, frames_per_second: int = 6, seed: int = None, image_file_resize_mode: str = Img2VideoResizeMode.CROP_TO_ASPECT_RATIO,\ 227 | motion_bucket_id: int = 127, cond_aug: float = 0.02, enable_frame_interpolation: bool = False, response_video_type:str=None,enterprise_plan: bool=None) -> Img2VideoResponse: 228 | image_b64 = input_image_to_base64(image) 229 | request = Img2VideoRequest(model_name=model_name, image_file=image_b64, steps=steps, frames_num=frames_num, frames_per_second=frames_per_second, seed=seed, 230 | image_file_resize_mode=image_file_resize_mode, motion_bucket_id=motion_bucket_id, cond_aug=cond_aug, enable_frame_interpolation=enable_frame_interpolation) 231 | if response_video_type is not None: 232 | request.set_video_type(response_video_type) 233 | if enterprise_plan is not None: 234 | request.set_enterprise_plan(enterprise_plan) 235 | else: 236 | request.set_enterprise_plan(False) 237 | return Img2VideoResponse.from_dict(self._post('/v3/async/img2video', request.to_dict())) 238 | 239 | def img2video(self, image: InputImage, model_name: str, steps: int, frames_num: int, frames_per_second: int = 6, seed: int = None, image_file_resize_mode: str = Img2VideoResizeMode.CROP_TO_ASPECT_RATIO, motion_bucket_id: int = 127, cond_aug: float = 0.02, enable_frame_interpolation: bool = False, response_video_type: str=None, enterprise_plan: bool=None) -> Img2VideoResponse: 240 | res: Img2VideoResponse = self.async_img2video(image, model_name, steps, frames_num, frames_per_second, seed, image_file_resize_mode, motion_bucket_id, cond_aug, enable_frame_interpolation, response_video_type, enterprise_plan) 241 | final_res = self.wait_for_task_v3(res.task_id) 242 | final_res.download_videos() 243 | return final_res 244 | 245 | def raw_inpainting(self, req: InpaintingRequest, extra: CommonV3Extra = None) -> InpaintingResponse: 246 | _req = CommonV3Request(request=req, extra=extra) 247 | return InpaintingResponse.from_dict(self._post('/v3/async/inpainting', _req.to_dict())) 248 | 249 | 250 | def async_inpainting(self, model_name: str, image: str, mask: str,\ 251 | prompt: str,image_num: int,sampler_name: str, steps:int, guidance_scale: float,\ 252 | seed: int, mask_blur: int=None, negative_prompt: str=None,\ 253 | sd_vae:str=None,loras: List[InpaintingLoRA] = None,\ 254 | embeddings: List[InpaintingEmbedding] = None,\ 255 | clip_skip: int = None, strength: float = None,\ 256 | inpainting_full_res: int=0, inpainting_full_res_padding: int=8,\ 257 | inpainting_mask_invert: int=0, initial_noise_multiplier: float=0.5, **kwargs)\ 258 | -> InpaintingResponse: 259 | request = InpaintingRequest(model_name=model_name, image_base64=image, mask_image_base64=mask, \ 260 | prompt=prompt,sampler_name=sampler_name, image_num=image_num, steps=steps,\ 261 | guidance_scale=guidance_scale, seed=seed, mask_blur=mask_blur,loras= loras,embeddings=embeddings,\ 262 | negative_prompt=negative_prompt, clip_skip=clip_skip, strength=strength,\ 263 | inpainting_full_res=inpainting_full_res, inpainting_full_res_padding=inpainting_full_res_padding,\ 264 | inpainting_mask_invert=inpainting_mask_invert, initial_noise_multiplier=initial_noise_multiplier) 265 | extra = CommonV3Extra(**kwargs) 266 | return self.raw_inpainting(request, extra) 267 | 268 | def inpainting(self, model_name: str, image: InputImage, mask: InputImage,\ 269 | prompt: str,image_num: int,sampler_name: str, steps:int, guidance_scale: float,\ 270 | seed: int, mask_blur: int=None, negative_prompt: str=None,\ 271 | sd_vae:str=None,loras: List[InpaintingLoRA] = None,\ 272 | embeddings: List[InpaintingEmbedding] = None,\ 273 | clip_skip: int = None, strength: float = None,\ 274 | inpainting_full_res: int=0, inpainting_full_res_padding: int=8,\ 275 | inpainting_mask_invert: int=0, initial_noise_multiplier: float=0.5,**kwargs) -> InpaintingResponse: 276 | input_image = input_image_to_base64(image) 277 | mask_image = input_image_to_base64(mask) 278 | res: InpaintingResponse = self.async_inpainting(model_name, input_image, mask_image, prompt, image_num, sampler_name, steps, guidance_scale, seed, mask_blur, negative_prompt, sd_vae, loras, embeddings, clip_skip, strength, inpainting_full_res, inpainting_full_res_padding, inpainting_mask_invert, initial_noise_multiplier, **kwargs) 279 | final_res = self.wait_for_task_v3(res.task_id) 280 | if final_res.task.status == V3TaskResponseStatus.TASK_STATUS_SUCCEED: 281 | final_res.download_images() 282 | else: 283 | logging.error(f"Failed to inpaint image: {final_res.task.status}") 284 | raise NovitaResponseError(f"Task {final_res.task.task_id} failed with status {final_res.task.status}") 285 | return final_res 286 | 287 | def img2prompt(self, image: InputImage) -> Img2PromptResponse: 288 | input_image = input_image_to_base64(image) 289 | resquest = Img2PromptRequest(image_file=input_image) 290 | return Img2PromptResponse.from_dict(self._post('/v3/img2prompt', resquest.to_dict())) 291 | 292 | def merge_face(self, image: InputImage, face_image: InputImage, response_image_type=None, enterprise_plan=None) -> MergeFaceResponse: 293 | input_image = input_image_to_base64(image) 294 | face_image = input_image_to_base64(face_image) 295 | request = MergeFaceRequest(image_file=input_image, face_image_file=face_image) 296 | if response_image_type is None: 297 | request.set_image_type(self._default_response_image_type) 298 | else: 299 | request.set_image_type(response_image_type) 300 | if enterprise_plan is not None: 301 | request.set_enterprise_plan(enterprise_plan) 302 | else: 303 | request.set_enterprise_plan(False) 304 | return MergeFaceResponse.from_dict(self._post('/v3/merge-face', request.to_dict())) 305 | 306 | def upload_training_assets(self, images: List[InputImage], batch_size=10) -> List[str]: 307 | def _upload_assets(image: InputImage) -> str: 308 | pil_image = input_image_to_pil(image) 309 | buff = BytesIO() 310 | if pil_image.format != "JPEG": 311 | pil_image = pil_image.convert("RGB") 312 | pil_image.save(buff, format="JPEG") 313 | else: 314 | pil_image.save(buff, format="JPEG") 315 | 316 | upload_res: UploadAssetResponse = UploadAssetResponse.from_dict(self._post("/v3/assets/training_dataset", UploadAssetRequest(file_extension="jpeg").to_dict())) 317 | res = requests.put(upload_res.upload_url, data=buff.getvalue(), headers={'Content-type': 'image/jpeg'}) 318 | if res.status_code != 200: 319 | raise NovitaResponseError(f"Failed to upload image: {res.content}") 320 | return upload_res 321 | 322 | with ThreadPool(batch_size) as pool: 323 | results = pool.map(_upload_assets, images) 324 | ret = [] 325 | try: 326 | for return_value in results: 327 | ret.append(return_value.assets_id) 328 | except Exception as e: 329 | raise NovitaResponseError(f"Failed to upload image: {e}") 330 | return ret 331 | 332 | # for image in images: 333 | # pil_image = input_image_to_pil(image) 334 | # buff = BytesIO() 335 | # if pil_image.format != "JPEG": 336 | # pil_image = pil_image.convert("RGB") 337 | # pil_image.save(buff, format="JPEG") 338 | # else: 339 | # pil_image.save(buff, format="JPEG") 340 | 341 | # upload_res: UploadAssetResponse = UploadAssetResponse.from_dict(self._post("/v3/assets/training_dataset", UploadAssetRequest(file_extension="jpeg").to_dict())) 342 | # res = requests.put(upload_res.upload_url, data=buff.getvalue(), headers={'Content-type': 'image/jpeg'}) 343 | # if res.status_code != 200: 344 | # raise NovitaResponseError(f"Failed to upload image: {res.content}") 345 | # ret.append(upload_res.assets_id) 346 | # return ret 347 | 348 | def create_training_style(self, 349 | name, 350 | base_model, 351 | images: List[InputImage], 352 | captions: List[str], 353 | width: int = 512, 354 | height: int = 512, 355 | learning_rate: str = None, 356 | seed: int = None, 357 | lr_scheduler: str = None, 358 | with_prior_preservation: bool = None, 359 | prior_loss_weight: float = None, 360 | lora_r: int = None, 361 | lora_alpha: int = None, 362 | max_train_steps: str = None, 363 | lora_text_encoder_r: int = None, 364 | lora_text_encoder_alpha: int = None, 365 | components=None 366 | ): 367 | if len(images) != len(captions): 368 | raise ValueError("images and captions must have the same length") 369 | 370 | assets = self.upload_training_assets(images) 371 | req = CreateTrainingStyleRequest( 372 | name=name, 373 | base_model=base_model, 374 | image_dataset_items=[TrainingStyleImageDatasetItem(assets_id=assets_id, caption=caption) for assets_id, caption in zip(assets, captions)], 375 | expert_setting=TrainingExpertSetting( 376 | max_train_steps=max_train_steps, 377 | learning_rate=learning_rate, 378 | seed=seed, 379 | lr_scheduler=lr_scheduler, 380 | with_prior_preservation=with_prior_preservation, 381 | prior_loss_weight=prior_loss_weight, 382 | lora_r=lora_r, 383 | lora_alpha=lora_alpha, 384 | lora_text_encoder_r=lora_text_encoder_r, 385 | lora_text_encoder_alpha=lora_text_encoder_alpha, 386 | ), 387 | components=[_.to_dict() for _ in components] if components is not None else None, 388 | width=width, 389 | height=height, 390 | ) 391 | res = CreateTrainingStyleResponse.from_dict(self._post("/v3/training/style", req.to_dict())) 392 | return res.task_id 393 | 394 | def create_training_subject(self, name, 395 | base_model, 396 | images: List[InputImage], 397 | instance_prompt: str, 398 | class_prompt: str, 399 | width: int = 512, 400 | height: int = 512, 401 | learning_rate: str = None, 402 | seed: int = None, 403 | lr_scheduler: str = None, 404 | with_prior_preservation: bool = None, 405 | prior_loss_weight: float = None, 406 | lora_r: int = None, 407 | lora_alpha: int = None, 408 | max_train_steps: str = None, 409 | lora_text_encoder_r: int = None, 410 | lora_text_encoder_alpha: int = None, 411 | components=None) -> str: 412 | assets = self.upload_training_assets(images) 413 | req = CreateTrainingSubjectRequest( 414 | name=name, 415 | base_model=base_model, 416 | image_dataset_items=[TrainingImageDatasetItem(assets_id=assets_id) for assets_id in assets], 417 | expert_setting=TrainingExpertSetting( 418 | instance_prompt=instance_prompt, 419 | class_prompt=class_prompt, 420 | max_train_steps=max_train_steps, 421 | learning_rate=learning_rate, 422 | seed=seed, 423 | lr_scheduler=lr_scheduler, 424 | with_prior_preservation=with_prior_preservation, 425 | prior_loss_weight=prior_loss_weight, 426 | lora_r=lora_r, 427 | lora_alpha=lora_alpha, 428 | lora_text_encoder_r=lora_text_encoder_r, 429 | lora_text_encoder_alpha=lora_text_encoder_alpha, 430 | ), 431 | components=[_.to_dict() for _ in components] if components is not None else None, 432 | width=width, 433 | height=height, 434 | ) 435 | res = CreateTrainingSubjectResponse.from_dict(self._post("/v3/training/subject", req.to_dict())) 436 | return res.task_id 437 | 438 | def query_training_subject_status(self, task_id: str) -> QueryTrainingSubjectStatusResponse: 439 | return QueryTrainingSubjectStatusResponse.from_dict(self._get("/v3/training/subject", params={"task_id": task_id})) 440 | 441 | def list_training(self, task_type: str = None) -> TrainingTaskList: 442 | params = {} 443 | if task_type is not None: 444 | params["task_type"] = task_type 445 | 446 | return TrainingTaskList(TrainingTaskListResponse.from_dict(self._get("/v3/training", params=params)).tasks) 447 | 448 | def upload_assets(self, images: List[InputImage], batch_size=10) -> List[str]: 449 | buffs = [] 450 | for image in images: 451 | if os.path.exists(image): 452 | pil_image = input_image_to_pil(image) 453 | buff = BytesIO() 454 | if pil_image.format != "JPEG": 455 | pil_image = pil_image.convert("RGB") 456 | pil_image.save(buff, format="JPEG") 457 | else: 458 | pil_image.save(buff, format="JPEG") 459 | buffs.append(buff) 460 | elif image.startswith("http") or image.startswith("https"): 461 | buff = BytesIO(requests.get(image).content) 462 | buffs.append(buff) 463 | 464 | 465 | def _upload_asset(buff): 466 | attempt = 5 467 | while attempt > 0: 468 | upload_res = requests.put("https://assets.novitai.com/image", data=buff.getvalue(), headers={'Content-type': 'image/jpeg'}) 469 | if upload_res.status_code < 400: 470 | return upload_res.json()["assets_id"] 471 | attempt -= 1 472 | raise NovitaResponseError(f"Failed to upload image: {upload_res.content}") 473 | 474 | with ThreadPool(batch_size) as pool: 475 | results = pool.map(_upload_asset, buffs) 476 | ret = [] 477 | try: 478 | for return_value in results: 479 | ret.append(return_value) 480 | except Exception as e: 481 | raise NovitaResponseError(f"Failed to upload image: {e}") 482 | return ret 483 | 484 | def instant_id(self, 485 | face_images: List[InputImage], 486 | ref_images: List[InputImage] = None, 487 | model_name: str = None, 488 | prompt: str = None, 489 | negative_prompt: str = None, 490 | width: int = None, 491 | height: int = None, # if size arguments (width or height) is None, default size is equal to reference image size 492 | id_strength: float = None, 493 | adapter_strength: float = None, 494 | steps: int = 20, 495 | seed: int = -1, 496 | guidance_scale: float = 5., 497 | sampler_name: str = 'Euler', 498 | controlnets: List[InstantIDControlnetUnit] = None, 499 | loras: List[InstantIDLora] = None, 500 | response_image_type: str = None, 501 | download_images: bool = True, 502 | callback: callable = None, 503 | enterprise_plan=None 504 | ): 505 | #face_images = [input_image_to_pil(img) for img in face_images] 506 | #ref_images = ref_images and [input_image_to_pil(img) for img in ref_images] 507 | 508 | face_image_assets_ids = self.upload_assets(face_images) 509 | if ref_images is not None and len(ref_images) > 0: 510 | ref_image_assets_ids = self.upload_assets(ref_images) 511 | else: 512 | ref_image_assets_ids = face_image_assets_ids[:1] 513 | 514 | if width is None or height is None: 515 | ref_img = ref_images[0] if ref_images and len(ref_images) > 0 else face_images[0] 516 | width, height = ref_img.size 517 | 518 | payload_data = InstantIDRequest( 519 | face_image_assets_ids=face_image_assets_ids, 520 | ref_image_assets_ids=ref_image_assets_ids, 521 | model_name=model_name, 522 | prompt=prompt, 523 | negative_prompt=negative_prompt, 524 | width=width, 525 | height=height, 526 | id_strength=id_strength, 527 | adapter_strength=adapter_strength, 528 | steps=steps, 529 | seed=seed, 530 | guidance_scale=guidance_scale, 531 | sampler_name=sampler_name, 532 | controlnet=InstantIDRequestControlNet(units=controlnets), 533 | loras=loras, 534 | ) 535 | 536 | if response_image_type is not None: 537 | payload_data.set_image_type(response_image_type) 538 | if enterprise_plan is not None: 539 | payload_data.set_enterprise_plan(enterprise_plan) 540 | else: 541 | payload_data.set_enterprise_plan(False) 542 | 543 | res = self._post("/v3/async/instant-id", payload_data.to_dict()) 544 | final_res = self.wait_for_task_v3(res["task_id"], callback=callback) 545 | if final_res.task.status != V3TaskResponseStatus.TASK_STATUS_SUCCEED: 546 | logger.error(f"Task {final_res.task.task_id} failed with status {final_res.task.status}") 547 | else: 548 | if download_images: 549 | final_res.download_images() 550 | 551 | return final_res 552 | 553 | def raw_img2img_v3(self, req: Img2ImgV3Request, extra: CommonV3Extra = None) -> Img2ImgV3Response: 554 | _req = CommonV3Request(request=req, extra=extra) 555 | 556 | return Img2ImgV3Response.from_dict(self._post('/v3/async/img2img', _req.to_dict())) 557 | 558 | def img2img_v3(self, model_name: str, input_image: str, prompt: str, image_num: int, height: int = None, width: int = None, negative_prompt: str = None, sd_vae: str = None, steps: int = None, guidance_scale: float = None, clip_skip: int = None, seed: int = None, strength: float = None, sampler_name: str = None, response_image_type: str = None, loras: List[Img2V3ImgLoRA] = None, embeddings: List[Img2ImgV3Embedding] = None, controlnet_units: List[Img2ImgV3ControlNetUnit] = None, download_images: bool = True, callback: callable = None, **kwargs) -> Img2ImgV3Response: 559 | input_image = input_image_to_base64(input_image) 560 | req = Img2ImgV3Request( 561 | model_name=model_name, 562 | image_base64=input_image, 563 | prompt=prompt, 564 | negative_prompt=negative_prompt, 565 | sd_vae=sd_vae, 566 | steps=steps, 567 | clip_skip=clip_skip, 568 | loras=loras, 569 | embeddings=embeddings, 570 | ) 571 | if height is not None: 572 | req.height = height 573 | if width is not None: 574 | req.width = width 575 | if sampler_name is not None: 576 | req.sampler_name = sampler_name 577 | if image_num is not None: 578 | req.image_num = image_num 579 | if guidance_scale is not None: 580 | req.guidance_scale = guidance_scale 581 | if strength is not None: 582 | req.strength = strength 583 | if seed is not None: 584 | req.seed = seed 585 | if controlnet_units is not None: 586 | for unit in controlnet_units: 587 | unit.image_base64 = input_image_to_base64(unit.image_base64) 588 | req.controlnet = Img2ImgV3ControlNet(units=controlnet_units) 589 | 590 | # req.set_image_type(response_image_type) 591 | extra = CommonV3Extra(**kwargs) 592 | if response_image_type is not None: 593 | extra.response_image_type = response_image_type 594 | 595 | res = self.raw_img2img_v3(req, extra) 596 | final_res = self.wait_for_task_v3(res.task_id, callback=callback) 597 | if final_res.task.status != V3TaskResponseStatus.TASK_STATUS_SUCCEED: 598 | logger.error(f"Task {res.task_id} failed with status {final_res.task.status}") 599 | else: 600 | if download_images: 601 | final_res.download_images() 602 | return final_res 603 | 604 | def raw_txt2img_v3(self, req: Txt2ImgV3Request, extra: CommonV3Extra = None) -> Txt2ImgV3Response: 605 | _req = CommonV3Request(request=req, extra=extra) 606 | return Txt2ImgV3Response.from_dict(self._post('/v3/async/txt2img', _req.to_dict())) 607 | 608 | def txt2img_v3(self, model_name: str, prompt: str, image_num: int, height: int = None, width: int = None, negative_prompt: str = None, sd_vae: str = None, steps: int = None, guidance_scale: float = None, clip_skip: int = None, seed: int = None, strength: float = None, sampler_name: str = None, response_image_type: str = None, loras: List[Txt2ImgV3LoRA] = None, embeddings: List[Txt2ImgV3Embedding] = None, hires_fix: Txt2ImgV3HiresFix = None, refiner: Txt2ImgV3Refiner = None, download_images: bool = True, callback: callable = None, **kwargs) -> Txt2ImgV3Response: 609 | req = Txt2ImgV3Request( 610 | model_name=model_name, 611 | prompt=prompt, 612 | negative_prompt=negative_prompt, 613 | sd_vae=sd_vae, 614 | clip_skip=clip_skip, 615 | loras=loras, 616 | embeddings=embeddings, 617 | hires_fix=hires_fix, 618 | refiner=refiner, 619 | ) 620 | if steps is not None: 621 | req.steps = steps 622 | 623 | if height is not None: 624 | req.height = height 625 | if width is not None: 626 | req.width = width 627 | if sampler_name is not None: 628 | req.sampler_name = sampler_name 629 | if image_num is not None: 630 | req.image_num = image_num 631 | if guidance_scale is not None: 632 | req.guidance_scale = guidance_scale 633 | if strength is not None: 634 | req.strength = strength 635 | if seed is not None: 636 | req.seed = seed 637 | 638 | extra = CommonV3Extra(**kwargs) 639 | if response_image_type is not None: 640 | extra.response_image_type = response_image_type 641 | 642 | res = self.raw_txt2img_v3(req, extra) 643 | final_res = self.wait_for_task_v3(res.task_id, callback=callback) 644 | if final_res.task.status != V3TaskResponseStatus.TASK_STATUS_SUCCEED: 645 | logger.error(f"Task {res.task_id} failed with status {final_res.task.status}") 646 | raise NovitaResponseError(f"Task {res.task_id} failed with status {final_res.task.status}") 647 | else: 648 | if download_images: 649 | final_res.download_images() 650 | return final_res 651 | 652 | def user_info(self) -> UserInfoResponse: 653 | return UserInfoResponse.from_dict(self._get("/v3/user")) 654 | 655 | def query_model_v3(self, visibility:str = None, limit:str = None,\ 656 | query:str = None, cursor:str = None,is_inpainting:bool \ 657 | = False, source:str = None, is_sdxl:bool = None,\ 658 | types:str= None) -> MoodelsResponseV3: 659 | parameters = {} 660 | if visibility is not None: 661 | parameters["filter.visibility"] = visibility 662 | if source is not None: 663 | parameters["filter.source"] = source 664 | if types is not None: 665 | parameters["filter.type"] = types 666 | if is_sdxl is not None: 667 | parameters["filter.is_sdxl"] = is_sdxl 668 | if query is not None: 669 | parameters["filter.query"] = query 670 | if cursor is not None: 671 | parameters["pagination.cursor"] = f"c_{cursor}" 672 | parameters["filter.is_inpainting"] = is_inpainting 673 | if limit is not None: 674 | if float(limit) > 100 or float(limit) <= 0: 675 | limit = 1 676 | parameters["pagination.limit"] = limit 677 | res = self._get('/v3/model', params=parameters) 678 | return MoodelsResponseV3.from_dict(res) 679 | 680 | def models_v3(self, refresh=False) -> ModelListV3: 681 | if self._model_list_cache_v3 is None or len(self._model_list_cache_v3) == 0 or refresh: 682 | visibilities = ["public", "private"] 683 | ret = [] 684 | for visibilitiy in visibilities: #interesting spelling :) 685 | offset = 0 686 | page_size = 100 687 | while True: 688 | res = self._get('/v3/model', params={"pagination.cursor": f"c_{offset}", "pagination.limit": page_size, "filter.visibility": visibilitiy}) 689 | model_response: MoodelsResponseV3 = MoodelsResponseV3.from_dict(res) 690 | for i in range(len(model_response.models)): 691 | model_response.models[i].visibility = visibilitiy 692 | ret.extend(model_response.models) 693 | if model_response.models is None or len(model_response.models) == 0 or len(model_response.models) < page_size: 694 | break 695 | offset += page_size 696 | self._model_list_cache_v3 = ModelListV3(ret) 697 | return self._model_list_cache_v3 698 | -------------------------------------------------------------------------------- /src/novita_client/proto.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | import base64 5 | import os 6 | from dataclasses import dataclass, field 7 | from enum import Enum 8 | from typing import Any, Dict, List, Optional, Union 9 | 10 | from PIL import Image 11 | 12 | from .serializer import JSONe 13 | from .utils import batch_download_images 14 | 15 | # --------------- ControlNet --------------- 16 | 17 | class ControlNetMode(Enum): 18 | BALANCED = 0 19 | PROMPT_IMPORTANCE = 1 20 | CONTROLNET_IMPORTANCE = 2 21 | 22 | def __str__(self): 23 | return self.name 24 | 25 | 26 | class ControlNetResizeMode(Enum): 27 | JUST_RESIZE = 0 28 | RESIZE_OR_CORP = 1 29 | RESIZE_AND_FILL = 2 30 | 31 | def __str__(self): 32 | return self.name 33 | 34 | 35 | class ControlNetPreprocessor(Enum): 36 | NULL = 'none' 37 | CANNY = 'canny' 38 | DEPTH = 'depth' 39 | DEPTH_LERES = 'depth_leres' 40 | DEPTH_LERES_PLUS_PLUS = 'depth_leres++' 41 | HED = 'hed' 42 | HED_SAFE = 'hed_safe' 43 | MEDIAPIPE_FACE = 'mediapipe_face' 44 | MLSD = 'mlsd' 45 | NORMAL_MAP = 'normal_map' 46 | OPENPOSE = 'openpose' 47 | OPENPOSE_HAND = 'openpose_hand' 48 | OPENPOSE_FACE = 'openpose_face' 49 | OPENPOSE_FACEONLY = 'openpose_faceonly' 50 | OPENPOSE_FULL = 'openpose_full' 51 | CLIP_VISION = 'clip_vision' 52 | COLOR = 'color' 53 | PIDINET = 'pidinet' 54 | PIDINET_SAFE = 'pidinet_safe' 55 | PIDINET_SKETCH = 'pidinet_sketch' 56 | PIDINET_SCRIBBLE = 'pidinet_scribble' 57 | SCRIBBLE_XDOG = 'scribble_xdog' 58 | SCRIBBLE_HED = 'scribble_hed' 59 | SEGMENTATION = 'segmentation' 60 | THRESHOLD = 'threshold' 61 | DEPTH_ZOE = 'depth_zoe' 62 | NORMAL_BAE = 'normal_bae' 63 | ONEFORMER_COCO = 'oneformer_coco' 64 | ONEFORMER_ADE20K = 'oneformer_ade20k' 65 | LINEART = 'lineart' 66 | LINEART_COARSE = 'lineart_coarse' 67 | LINEART_ANIME = 'lineart_anime' 68 | LINEART_STANDARD = 'lineart_standard' 69 | SHUFFLE = 'shuffle' 70 | TILE_RESAMPLE = 'tile_resample' 71 | INVERT = 'invert' 72 | LINEART_ANIME_DENOISE = 'lineart_anime_denoise' 73 | REFERENCE_ONLY = 'reference_only' 74 | REFERENCE_ADAIN = 'reference_adain' 75 | REFERENCE_ADAIN_PLUS_ATTN = 'reference_adain+attn' 76 | INPAINT = 'inpaint' 77 | INPAINT_ONLY = 'inpaint_only' 78 | INPAINT_ONLY_PLUS_LAMA = 'inpaint_only+lama' 79 | TILE_COLORFIX = 'tile_colorfix' 80 | TILE_COLORFIX_PLUS_SHARP = 'tile_colorfix+sharp' 81 | 82 | def __str__(self): 83 | return self.name 84 | 85 | 86 | @dataclass 87 | class ControlnetUnit(JSONe): 88 | model: str 89 | weight: Optional[float] = 1 90 | module: Optional[ControlNetPreprocessor] = ControlNetPreprocessor.NULL 91 | input_image: Optional[str] = None 92 | control_mode: Optional[ControlNetMode] = ControlNetMode.BALANCED 93 | resize_mode: Optional[ControlNetResizeMode] = ControlNetResizeMode.RESIZE_OR_CORP 94 | mask: Optional[str] = None 95 | processor_res: Optional[int] = 512 96 | threshold_a: Optional[int] = 64 97 | threshold_b: Optional[int] = 64 98 | guidance_start: Optional[float] = 0.0 99 | guidance_end: Optional[float] = 1.0 100 | pixel_perfect: Optional[bool] = False 101 | 102 | 103 | # --------------- Samplers --------------- 104 | @dataclass 105 | class Samplers: 106 | EULER_A = 'Euler a' 107 | EULER = 'Euler' 108 | LMS = 'LMS' 109 | HEUN = 'Heun' 110 | DPM2 = 'DPM2' 111 | DPM2_A = 'DPM2 a' 112 | DPM2_KARRAS = 'DPM2 Karras' 113 | DPM2_A_KARRAS = 'DPM2 a Karras' 114 | DPMPP_S_A = 'DPM++ 2S a' 115 | DPMPP_M = 'DPM++ 2M' 116 | DPMPP_SDE = 'DPM++ SDE' 117 | DPMPP_KARRAS = 'DPM++ Karras' 118 | DPMPP_S_A_KARRAS = 'DPM++ 2S a Karras' 119 | DPMPP_M_KARRAS = 'DPM++ 2M Karras' 120 | DPMPP_SDE_KARRAS = 'DPM++ SDE Karras' 121 | DDIM = 'DDIM' 122 | PLMS = 'PLMS' 123 | UNIPC = 'UniPC' 124 | 125 | 126 | # --------------- Refiner --------------- 127 | @dataclass 128 | class Refiner: 129 | checkpoint: str 130 | switch_at: float 131 | 132 | 133 | # --------------- ADEtailer --------------- 134 | 135 | 136 | @dataclass 137 | class ADEtailer: 138 | prompt: str 139 | negative_prompt: Optional[str] = None 140 | steps: Optional[int] = 20 141 | strength: Optional[float] = 0.5 142 | seed: Optional[int] = None 143 | 144 | 145 | 146 | 147 | 148 | # --------------- Text2Image --------------- 149 | 150 | @dataclass 151 | class Txt2ImgExtra(JSONe): 152 | enable_nsfw_detection: bool = False 153 | nsfw_detection_level: int = 0 154 | enable_progress_info: bool = True 155 | 156 | 157 | @dataclass 158 | class Txt2ImgRequest(JSONe): 159 | prompt: str 160 | negative_prompt: Optional[str] = None 161 | model_name: str = 'dreamshaper_5BakedVae.safetensors' 162 | sampler_name: str = None 163 | batch_size: int = 1 164 | n_iter: int = 1 165 | steps: int = 20 166 | cfg_scale: float = 7 167 | height: int = 512 168 | width: int = 512 169 | 170 | seed: Optional[int] = -1 171 | restore_faces: Optional[bool] = False 172 | sd_vae: Optional[str] = None 173 | clip_skip: Optional[int] = None 174 | 175 | controlnet_units: Optional[List[ControlnetUnit]] = None 176 | controlnet_no_detectmap: Optional[bool] = False 177 | 178 | enable_hr: Optional[bool] = False 179 | hr_upscaler: Optional[str] = 'R-ESRGAN 4x+' 180 | hr_scale: Optional[float] = 2.0 181 | hr_resize_x: Optional[int] = None 182 | hr_resize_y: Optional[int] = None 183 | 184 | sd_refiner: Optional[Refiner] = None 185 | 186 | adetailer: Optional[ADEtailer] = None 187 | 188 | extra: Optional[Txt2ImgExtra] = None 189 | 190 | 191 | class Txt2ImgResponseCode(Enum): 192 | NORMAL = 0 193 | INTERNAL_ERROR = -1 194 | INVALID_JSON = 1 195 | MODEL_NOT_EXISTS = 2 196 | TASK_ID_NOT_EXISTS = 3 197 | INVALID_AUTH = 4 198 | HOST_UNAVAILABLE = 5 199 | PARAM_RANGE_ERROR = 6 200 | COST_BALANCE_ERROR = 7 201 | SAMPLER_NOT_EXISTS = 8 202 | TIMEOUT = 9 203 | 204 | UNKNOWN = 100 205 | 206 | @classmethod 207 | def _missing_(cls, number): 208 | return cls(cls.UNKNOWN) 209 | 210 | 211 | @dataclass 212 | class Txt2ImgResponseData(JSONe): 213 | task_id: str 214 | warn: Optional[str] = None 215 | 216 | 217 | @dataclass 218 | class Txt2ImgResponse(JSONe): 219 | code: Txt2ImgResponseCode 220 | msg: str 221 | data: Optional[Txt2ImgResponseData] = None 222 | 223 | 224 | # --------------- Image2Image --------------- 225 | 226 | 227 | @dataclass 228 | class Img2ImgExtra(JSONe): 229 | enable_nsfw_detection: bool = False 230 | nsfw_detection_level: int = 0 231 | enable_progress_info: bool = True 232 | 233 | 234 | @dataclass 235 | class Img2ImgRequest(JSONe): 236 | model_name: str = 'dreamshaper_5BakedVae.safetensors' 237 | sampler_name: str = None 238 | init_images: List[str] = None 239 | mask: Optional[str] = None 240 | resize_mode: Optional[int] = 0 241 | denoising_strength: Optional[float] = 0.75 242 | cfg_scale: Optional[float] = None 243 | mask_blur: Optional[int] = 4 244 | inpainting_fill: Optional[int] = 1 245 | inpaint_full_res: Optional[int] = 0 246 | inpaint_full_res_padding: Optional[int] = 32 247 | inpainting_mask_invert: Optional[int] = 0 248 | initial_noise_multiplier: Optional[float] = 1.0 249 | prompt: Optional[str] = None 250 | seed: Optional[int] = None 251 | negative_prompt: Optional[str] = None 252 | batch_size: Optional[int] = 1 253 | n_iter: Optional[int] = 1 254 | steps: Optional[int] = 20 255 | width: Optional[int] = 1024 256 | height: Optional[int] = 1024 257 | restore_faces: Optional[bool] = False 258 | sd_vae: Optional[str] = None 259 | clip_skip: Optional[int] = None 260 | 261 | controlnet_units: Optional[List[ControlnetUnit]] = None 262 | controlnet_no_detectmap: Optional[bool] = False 263 | 264 | sd_refiner: Optional[Refiner] = None 265 | 266 | extra: Optional[Img2ImgExtra] = None 267 | 268 | 269 | class Img2ImgResponseCode(Enum): 270 | NORMAL = 0 271 | INTERNAL_ERROR = -1 272 | INVALID_JSON = 1 273 | MODEL_NOT_EXISTS = 2 274 | TASK_ID_NOT_EXISTS = 3 275 | INVALID_AUTH = 4 276 | HOST_UNAVAILABLE = 5 277 | PARAM_RANGE_ERROR = 6 278 | COST_BALANCE_ERROR = 7 279 | SAMPLER_NOT_EXISTS = 8 280 | TIMEOUT = 9 281 | 282 | UNKNOWN = 100 283 | 284 | @classmethod 285 | def _missing_(cls, number): 286 | return cls(cls.UNKNOWN) 287 | 288 | 289 | @dataclass 290 | class Img2ImgResponseData(JSONe): 291 | task_id: str 292 | warn: Optional[str] = None 293 | 294 | 295 | @dataclass 296 | class Img2ImgResponse(JSONe): 297 | code: Img2ImgResponseCode 298 | msg: str 299 | data: Optional[Img2ImgResponseData] = None 300 | 301 | # --------------- Progress --------------- 302 | 303 | 304 | class ProgressResponseStatusCode(Enum): 305 | INITIALIZING = 0 306 | RUNNING = 1 307 | SUCCESSFUL = 2 308 | FAILED = 3 309 | TIMEOUT = 4 310 | 311 | UNKNOWN = 100 312 | 313 | @classmethod 314 | def _missing_(cls, number): 315 | return cls(cls.UNKNOWN) 316 | 317 | def finished(self): 318 | return self in (ProgressResponseStatusCode.SUCCESSFUL, ProgressResponseStatusCode.FAILED, ProgressResponseStatusCode.TIMEOUT) 319 | 320 | 321 | @dataclass 322 | class ProgressDataDebugInfo(JSONe): 323 | submit_time_ms: int 324 | execution_time_ms: int 325 | txt2img_time_ms: int 326 | finish_time_ms: int 327 | 328 | 329 | @dataclass 330 | class ProgressDataNSFWResult(JSONe): 331 | valid: bool = False 332 | confidence: float = 0.0 333 | 334 | 335 | @dataclass 336 | class ProgressData(JSONe): 337 | status: ProgressResponseStatusCode 338 | progress: int 339 | eta_relative: int 340 | imgs: Optional[List[str]] = None 341 | imgs_bytes: Optional[List[str]] = None 342 | info: Optional[str] = "" 343 | failed_reason: Optional[str] = "" 344 | current_images: Optional[List[str]] = None 345 | submit_time: Optional[str] = "" 346 | execution_time: Optional[str] = "" 347 | txt2img_time: Optional[str] = "" 348 | finish_time: Optional[str] = "" 349 | 350 | enable_nsfw_detection: Optional[bool] = False 351 | nsfw_detection_result: Optional[Union[List[ProgressDataNSFWResult], None]] = None 352 | debug_info: Optional[ProgressDataDebugInfo] = None 353 | 354 | 355 | class ProgressResponseCode(Enum): 356 | NORMAL = 0 357 | INTERNAL_ERROR = -1 358 | INVALID_JSON = 1 359 | MODEL_NOT_EXISTS = 2 360 | TASK_ID_NOT_EXISTS = 3 361 | INVALID_AUTH = 4 362 | HOST_UNAVAILABLE = 5 363 | PARAM_RANGE_ERROR = 6 364 | COST_BALANCE_ERROR = 7 365 | SAMPLER_NOT_EXISTS = 8 366 | TIMEOUT = 9 367 | 368 | UNKNOWN = 100 369 | 370 | @classmethod 371 | def _missing_(cls, number): 372 | return cls(cls.UNKNOWN) 373 | 374 | 375 | @dataclass 376 | class ProgressResponse(JSONe): 377 | code: ProgressResponseCode 378 | data: Optional[ProgressData] = None 379 | msg: Optional[str] = "" 380 | 381 | def download_images(self): 382 | if self.data.imgs is not None and len(self.data.imgs) > 0: 383 | self.data.imgs_bytes = batch_download_images(self.data.imgs) 384 | 385 | # --------------- Upscale --------------- 386 | 387 | 388 | class UpscaleResizeMode(Enum): 389 | SCALE = 0 390 | SIZE = 1 391 | 392 | 393 | @dataclass 394 | class UpscaleRequest(JSONe): 395 | image: str 396 | upscaler_1: Optional[str] = 'R-ESRGAN 4x+' 397 | resize_mode: Optional[UpscaleResizeMode] = UpscaleResizeMode.SCALE 398 | upscaling_resize: Optional[float] = 2.0 399 | upscaling_resize_w: Optional[int] = None 400 | upscaling_resize_h: Optional[int] = None 401 | upscaling_crop: Optional[bool] = False 402 | 403 | upscaler_2: Optional[str] = None 404 | extras_upscaler_2_visibility: Optional[float] = None 405 | gfpgan_visibility: Optional[float] = None 406 | codeformer_visibility: Optional[float] = None 407 | codeformer_weight: Optional[float] = None 408 | 409 | 410 | class UpscaleResponseCode(Enum): 411 | NORMAL = 0 412 | INTERNAL_ERROR = -1 413 | INVALID_JSON = 1 414 | MODEL_NOT_EXISTS = 2 415 | TASK_ID_NOT_EXISTS = 3 416 | INVALID_AUTH = 4 417 | HOST_UNAVAILABLE = 5 418 | PARAM_RANGE_ERROR = 6 419 | COST_BALANCE_ERROR = 7 420 | SAMPLER_NOT_EXISTS = 8 421 | TIMEOUT = 9 422 | 423 | UNKNOWN = 100 424 | 425 | @classmethod 426 | def _missing_(cls, number): 427 | return cls(cls.UNKNOWN) 428 | 429 | 430 | @dataclass 431 | class UpscaleResponseData(JSONe): 432 | task_id: str 433 | warn: Optional[str] = None 434 | 435 | 436 | @dataclass 437 | class UpscaleResponse(JSONe): 438 | code: UpscaleResponseCode 439 | msg: str 440 | data: Optional[UpscaleResponseData] = None 441 | 442 | # --------------- Cleanup --------------- 443 | 444 | 445 | @dataclass 446 | class CleanupRequest(JSONe): 447 | image_file: str 448 | mask_file: str 449 | extra: Dict = field(default_factory=lambda: dict()) 450 | 451 | def set_image_type(self, image_type: str): 452 | self.extra['response_image_type'] = image_type 453 | 454 | def set_enterprise_plan(self, enterprise_plan: bool): 455 | self.extra.setdefault('enterprise_plan', {}) 456 | self.extra['enterprise_plan']['enabled'] = enterprise_plan 457 | 458 | 459 | @dataclass 460 | class CleanupResponse(JSONe): 461 | image_file: str 462 | image_type: str 463 | 464 | 465 | InputImage = Union[str, os.PathLike, Image.Image] 466 | 467 | # --------------- Remove Background --------------- 468 | 469 | 470 | @dataclass 471 | class RemoveBackgroundRequest(JSONe): 472 | image_file: str 473 | extra: Dict = field(default_factory=lambda: dict()) 474 | 475 | def set_image_type(self, image_type: str): 476 | self.extra['response_image_type'] = image_type 477 | 478 | def set_enterprise_plan(self, enterprise_plan: bool): 479 | self.extra.setdefault('enterprise_plan', {}) 480 | self.extra['enterprise_plan']['enabled'] = enterprise_plan 481 | 482 | 483 | @dataclass 484 | class RemoveBackgroundResponse(JSONe): 485 | image_file: str 486 | image_type: str 487 | 488 | # --------------- Remove Text --------------- 489 | 490 | 491 | @dataclass 492 | class RemoveTextRequest(JSONe): 493 | image_file: str 494 | extra: Dict = field(default_factory=lambda: dict()) 495 | 496 | def set_image_type(self, image_type: str): 497 | self.extra['response_image_type'] = image_type 498 | 499 | def set_enterprise_plan(self, enterprise_plan: bool): 500 | self.extra.setdefault('enterprise_plan', {}) 501 | self.extra['enterprise_plan']['enabled'] = enterprise_plan 502 | 503 | @dataclass 504 | class RemoveTextResponse(JSONe): 505 | image_file: str 506 | image_type: str 507 | 508 | # --------------- Reimage --------------- 509 | 510 | 511 | @dataclass 512 | class ReimagineRequest(JSONe): 513 | image_file: str 514 | extra: Dict = field(default_factory=lambda: dict()) 515 | 516 | def set_image_type(self, image_type: str): 517 | self.extra['response_image_type'] = image_type 518 | 519 | def set_enterprise_plan(self, enterprise_plan: bool): 520 | self.extra.setdefault('enterprise_plan', {}) 521 | self.extra['enterprise_plan']['enabled'] = enterprise_plan 522 | 523 | 524 | @dataclass 525 | class ReimagineResponse(JSONe): 526 | image_file: str 527 | image_type: str 528 | 529 | # --------------- Doodle --------------- 530 | 531 | 532 | @dataclass 533 | class DoodleRequest(JSONe): 534 | image_file: str 535 | prompt: str 536 | similarity: float = None 537 | extra: Dict = field(default_factory=lambda: dict()) 538 | 539 | def set_image_type(self, image_type: str): 540 | self.extra['response_image_type'] = image_type 541 | 542 | def set_enterprise_plan(self, enterprise_plan: bool): 543 | self.extra.setdefault('enterprise_plan', {}) 544 | self.extra['enterprise_plan']['enabled'] = enterprise_plan 545 | 546 | @dataclass 547 | class DoodleResponse(JSONe): 548 | image_file: str 549 | image_type: str 550 | 551 | 552 | # --------------- Mix Pose --------------- 553 | 554 | @dataclass 555 | class MixPoseRequest(JSONe): 556 | image_file: str 557 | pose_image_file: str 558 | extra: Dict = field(default_factory=lambda: dict()) 559 | 560 | def set_image_type(self, image_type: str): 561 | self.extra['response_image_type'] = image_type 562 | 563 | def set_enterprise_plan(self, enterprise_plan: bool): 564 | self.extra.setdefault('enterprise_plan', {}) 565 | self.extra['enterprise_plan']['enabled'] = enterprise_plan 566 | 567 | 568 | @dataclass 569 | class MixPoseResponse(JSONe): 570 | image_file: str 571 | image_type: str 572 | 573 | 574 | # --------------- Replace Background --------------- 575 | 576 | @dataclass 577 | class ReplaceBackgroundRequest(JSONe): 578 | image_file: str 579 | prompt: str 580 | extra: Dict = field(default_factory=lambda: dict()) 581 | 582 | def set_image_type(self, image_type: str): 583 | self.extra['response_image_type'] = image_type 584 | 585 | def set_enterprise_plan(self, enterprise_plan: bool): 586 | self.extra.setdefault('enterprise_plan', {}) 587 | self.extra['enterprise_plan']['enabled'] = enterprise_plan 588 | 589 | @dataclass 590 | class ReplaceBackgroundResponse(JSONe): 591 | image_file: str 592 | image_type: str 593 | 594 | # --------------- Relight --------------- 595 | @dataclass 596 | class RelightRequest(JSONe): 597 | image_file: str 598 | prompt: str 599 | model_name: str 600 | lighting_preference: str 601 | steps: int 602 | sampler_name: str 603 | guidance_scale: float 604 | strength: float 605 | seed: int = -1 606 | background_image_file: Optional[str] = None 607 | negative_prompt: Optional[str] = None 608 | extra: Dict = field(default_factory=lambda: dict()) 609 | 610 | 611 | def set_image_type(self, image_type: str): 612 | self.extra['response_image_type'] = image_type 613 | 614 | def set_enterprise_plan(self, enterprise_plan: bool): 615 | self.extra.setdefault('enterprise_plan', {}) 616 | self.extra['enterprise_plan']['enabled'] = enterprise_plan 617 | 618 | @dataclass 619 | class RelightResponse(JSONe): 620 | image_file: str 621 | image_type: str 622 | 623 | 624 | 625 | # --------------- Replace Sky --------------- 626 | 627 | 628 | @dataclass 629 | class ReplaceSkyRequest(JSONe): 630 | image_file: str 631 | sky: str 632 | extra: Dict = field(default_factory=lambda: dict()) 633 | 634 | def set_image_type(self, image_type: str): 635 | self.extra['response_image_type'] = image_type 636 | 637 | def set_enterprise_plan(self, enterprise_plan: bool): 638 | self.extra.setdefault('enterprise_plan', {}) 639 | self.extra['enterprise_plan']['enabled'] = enterprise_plan 640 | 641 | 642 | @dataclass 643 | class ReplaceSkyResponse(JSONe): 644 | image_file: str 645 | image_type: str 646 | 647 | # --------------- Remove Watermark --------------- 648 | 649 | @dataclass 650 | class RemoveWatermarkRequest(JSONe): 651 | image_file: str 652 | extra: Dict = field(default_factory=lambda: dict()) 653 | 654 | def set_image_type(self, image_type: str): 655 | self.extra['response_image_type'] = image_type 656 | 657 | def set_enterprise_plan(self, enterprise_plan: bool): 658 | self.extra.setdefault('enterprise_plan', {}) 659 | self.extra['enterprise_plan']['enabled'] = enterprise_plan 660 | 661 | @dataclass 662 | class RemoveWatermarkResponse(JSONe): 663 | image_file: str 664 | image_type: str 665 | 666 | # --------------- Replace Object --------------- 667 | 668 | 669 | @dataclass 670 | class ReplaceObjectRequest(JSONe): 671 | image_file: str 672 | object_prompt: str 673 | prompt: str 674 | negative_prompt: Optional[str] = None 675 | extra: Dict = field(default_factory=lambda: dict()) 676 | 677 | def set_image_type(self, image_type: str): 678 | self.extra['response_image_type'] = image_type 679 | 680 | def set_enterprise_plan(self, enterprise_plan: bool): 681 | self.extra.setdefault('enterprise_plan', {}) 682 | self.extra['enterprise_plan']['enabled'] = enterprise_plan 683 | 684 | 685 | @dataclass 686 | class ReplaceObjectResponse(JSONe): 687 | image_file: str 688 | image_type: str 689 | 690 | 691 | # --------------- V3 Task Result --------------- 692 | # { 693 | # "task": { 694 | # "task_id": "a910c8f7-76ce-40bd-b805-f00f3ddd7dc1", 695 | # "status": "TASK_STATUS_SUCCEED" 696 | # }, 697 | # "images": [ 698 | # { 699 | # "image_url": "https://faas-output-image.s3.ap-southeast-1.amazonaws.com/dev/replace_object_a910c8f7-76ce-40bd-b805-f00f3ddd7dc1_0.png?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIASVPYCN6LRCW3SOUV%2F20231019%2Fap-southeast-1%2Fs3%2Faws4_request&X-Amz-Date=20231019T084537Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&x-id=GetObject&X-Amz-Signature=b9ad40a5cb3aecf89602c15fe72d28be5d8a33e0bfe3656ce968295fde1aab31", 700 | # "image_type": "png", 701 | # "image_url_ttl": 3600 702 | # } 703 | # ] 704 | # } 705 | 706 | @dataclass 707 | class V3TaskImageNSFWDetectionResult(JSONe): 708 | valid: bool 709 | confidence: float 710 | 711 | 712 | @dataclass 713 | class V3TaskImage(JSONe): 714 | image_url: str 715 | image_type: str 716 | image_url_ttl: int 717 | nsfw_detection_result: Optional[V3TaskImageNSFWDetectionResult] = None 718 | 719 | 720 | @dataclass 721 | class V3TaskVideo(JSONe): 722 | video_url: str 723 | video_type: str 724 | video_url_ttl: int 725 | 726 | 727 | class V3TaskResponseStatus(Enum): 728 | TASK_STATUS_SUCCEED = "TASK_STATUS_SUCCEED" 729 | TASK_STATUS_PROCESSING = "TASK_STATUS_PROCESSING" 730 | TASK_STATUS_QUEUED = "TASK_STATUS_QUEUED" 731 | TASK_STATUS_FAILED = "TASK_STATUS_FAILED" 732 | 733 | 734 | @dataclass 735 | class V3AsyncSubmitResponse(JSONe): 736 | task_id: str 737 | 738 | 739 | @dataclass 740 | class V3TaskResponseTask(JSONe): 741 | task_id: str 742 | status: V3TaskResponseStatus 743 | reason: Optional[str] = None 744 | task_type: Optional[str] = None 745 | eta: Optional[int] = None 746 | progress_percent: Optional[int] = None 747 | 748 | 749 | @dataclass 750 | class V3TaskResponseDebugInfo(JSONe): 751 | submit_time_ms: int 752 | execute_time_ms: int 753 | complete_time_ms: int 754 | request_info: str = None 755 | 756 | 757 | @dataclass 758 | class V3TaskResponseExtra(JSONe): 759 | seed: Optional[int] = None 760 | enable_nsfw_detection: Optional[bool] = False 761 | debug_info: Optional[V3TaskResponseDebugInfo] = None 762 | 763 | 764 | @dataclass 765 | class V3TaskResponse(JSONe): 766 | task: V3TaskResponseTask 767 | images: List[V3TaskImage] = None 768 | videos: List[V3TaskVideo] = None 769 | extra: V3TaskResponseExtra = None 770 | 771 | def finished(self): 772 | return self.task.status == V3TaskResponseStatus.TASK_STATUS_SUCCEED or self.task.status == V3TaskResponseStatus.TASK_STATUS_FAILED 773 | 774 | def get_image_urls(self): 775 | return [image.image_url for image in self.images] 776 | 777 | def get_video_urls(self): 778 | return [video.video_url for video in self.videos] 779 | 780 | def download_images(self): 781 | if self.images is not None and len(self.images) > 0: 782 | self.images_encoded = [base64.b64encode(_).decode('ascii') for _ in batch_download_images(self.get_image_urls())] 783 | 784 | def download_videos(self): 785 | if self.videos is not None and len(self.videos) > 0: 786 | self.video_bytes = batch_download_images(self.get_video_urls()) 787 | 788 | # --------------- Restore Faces --------------- 789 | 790 | 791 | @dataclass 792 | class RestoreFaceRequest(JSONe): 793 | image_file: str 794 | fidelity: Optional[float] = 0.7 795 | extra: Dict = field(default_factory=lambda: dict()) 796 | 797 | def set_image_type(self, image_type: str): 798 | self.extra['response_image_type'] = image_type 799 | 800 | def set_enterprise_plan(self, enterprise_plan: bool): 801 | self.extra.setdefault('enterprise_plan', {}) 802 | self.extra['enterprise_plan']['enabled'] = enterprise_plan 803 | 804 | 805 | @dataclass 806 | class RestoreFaceResponse(JSONe): 807 | image_file: str 808 | image_type: str 809 | 810 | 811 | # --------------- Tile --------------- 812 | @dataclass 813 | class CreateTileRequest(JSONe): 814 | prompt: str 815 | negative_prompt: Optional[str] = None 816 | width: Optional[int] = 1024 817 | height: Optional[int] = 1024 818 | extra: Dict = field(default_factory=lambda: dict()) 819 | 820 | def set_image_type(self, image_type: str): 821 | self.extra['response_image_type'] = image_type 822 | 823 | def set_enterprise_plan(self, enterprise_plan: bool): 824 | self.extra.setdefault('enterprise_plan', {}) 825 | self.extra['enterprise_plan']['enabled'] = enterprise_plan 826 | 827 | 828 | @dataclass 829 | class CreateTileResponse(JSONe): 830 | image_file: str 831 | image_type: str 832 | 833 | # -------------- Image to Mask --------------- 834 | @dataclass 835 | class maskImage(JSONe): 836 | image_file: str 837 | image_type: str 838 | bbox: List[int] 839 | area: int 840 | 841 | 842 | 843 | @dataclass 844 | class Img2MaskRequest(JSONe): 845 | image_file: str 846 | extra: Dict = field(default_factory=lambda: dict()) 847 | 848 | def set_image_type(self, image_type: str): 849 | self.extra['response_image_type'] = image_type 850 | 851 | def set_enterprise_plan(self, enterprise_plan: bool): 852 | self.extra.setdefault('enterprise_plan', {}) 853 | self.extra['enterprise_plan']['enabled'] = enterprise_plan 854 | 855 | @dataclass 856 | class Img2MaskResponse(JSONe): 857 | mask:maskImage 858 | 859 | # --------------- Image to Prompt --------------- 860 | @dataclass 861 | class Img2PromptRequest(JSONe): 862 | image_file: str 863 | 864 | @dataclass 865 | class Img2PromptResponse(JSONe): 866 | prompt: str 867 | 868 | # --------------- Merge Face --------------- 869 | @dataclass 870 | class MergeFaceRequest(JSONe): 871 | image_file: str 872 | face_image_file: str 873 | extra: Dict = field(default_factory=lambda: dict()) 874 | 875 | def set_image_type(self, image_type: str): 876 | self.extra['response_image_type'] = image_type 877 | 878 | def set_enterprise_plan(self, enterprise_plan: bool): 879 | self.extra.setdefault('enterprise_plan', {}) 880 | self.extra['enterprise_plan']['enabled'] = enterprise_plan 881 | 882 | 883 | @dataclass 884 | class MergeFaceResponse(JSONe): 885 | image_file: str 886 | image_type: str 887 | 888 | # --------------- LCM Txt2Img --------------- 889 | 890 | 891 | @dataclass 892 | class LCMTxt2ImgRequest(JSONe): 893 | prompt: str 894 | height: Optional[int] = 512 895 | width: Optional[int] = 512 896 | image_num: Optional[int] = 4 897 | steps: Optional[int] = 4 898 | guidance_scale: Optional[float] = 7.5 899 | 900 | 901 | @dataclass 902 | class LCMTxt2ImgResponseImage(JSONe): 903 | image_file: str 904 | image_type: str 905 | 906 | 907 | @dataclass 908 | class LCMTxt2ImgResponse(JSONe): 909 | images: List[LCMTxt2ImgResponseImage] 910 | 911 | # --------------- ADEtailer --------------- 912 | 913 | 914 | @dataclass 915 | class ADETailerLoRA(JSONe): 916 | model_name: str 917 | strength: Optional[float] = 1.0 918 | 919 | 920 | @dataclass 921 | class ADETailerEmbedding(JSONe): 922 | model_name: str 923 | 924 | 925 | @dataclass 926 | class ADETailerRequest(JSONe): 927 | model_name: str 928 | prompt: str 929 | image_assets_ids: List[str] = None 930 | image_urls: List[str] = None 931 | loras: List[ADETailerLoRA] = None 932 | embeddings: List[ADETailerEmbedding] = None 933 | guidance_scale: Optional[float] = 7.5 934 | sampler_name: Optional[str] = Samplers.DPMPP_KARRAS 935 | steps: Optional[int] = 20 936 | strength: Optional[float] = 0.3 937 | negative_prompt: Optional[str] = None 938 | sd_vae: Optional[str] = None 939 | seed: Optional[int] = None 940 | clip_skip: Optional[int] = None 941 | 942 | 943 | @dataclass 944 | class ADETailerResponse(JSONe): 945 | task_id: str 946 | 947 | # --------------- Training --------------- 948 | 949 | 950 | @dataclass 951 | class UploadAssetRequest(JSONe): 952 | file_extension: str = "png" 953 | 954 | 955 | @dataclass 956 | class UploadAssetResponse(JSONe): 957 | upload_url: str 958 | method: str 959 | assets_id: str 960 | 961 | 962 | @dataclass 963 | class TrainingImageDatasetItem(JSONe): 964 | assets_id: str 965 | 966 | 967 | @dataclass 968 | class TrainingExpertSetting(JSONe): 969 | instance_prompt: str = None 970 | class_prompt: str = None 971 | max_train_steps: int = None 972 | learning_rate: str = None 973 | seed: int = None 974 | lr_scheduler: str = None 975 | with_prior_preservation: bool = None 976 | prior_loss_weight: float = None 977 | lora_r: int = None 978 | lora_alpha: int = None 979 | lora_text_encoder_r: int = None 980 | lora_text_encoder_alpha: int = None 981 | 982 | 983 | @dataclass 984 | class TrainingComponent(JSONe): 985 | name: str 986 | args: List[Dict[str, Any]] 987 | 988 | 989 | FACE_TRAINING_DEFAULT_COMPONENTS = [ 990 | TrainingComponent( 991 | name="face_crop_region", 992 | args=[{ 993 | "name": "ratio", 994 | "value": "1.4" 995 | }] 996 | ), 997 | TrainingComponent( 998 | name="resize", 999 | args=[ 1000 | { 1001 | "name": "height", 1002 | "value": "512", 1003 | }, 1004 | { 1005 | "name": "width", 1006 | "value": "512", 1007 | } 1008 | ] 1009 | ), 1010 | TrainingComponent( 1011 | name="face_restore", 1012 | args=[ 1013 | { 1014 | "name": "method", 1015 | "value": "gfpgan_1.4" 1016 | }, 1017 | { 1018 | "name": "upscale", 1019 | "value": "1.0" 1020 | } 1021 | ] 1022 | ), 1023 | ] 1024 | 1025 | 1026 | @dataclass 1027 | class CreateTrainingSubjectRequest(JSONe): 1028 | name: str 1029 | base_model: str 1030 | image_dataset_items: List[TrainingImageDatasetItem] 1031 | width: int = 512 1032 | height: int = 512 1033 | expert_setting: TrainingExpertSetting = None 1034 | components: List[TrainingComponent] = None 1035 | 1036 | 1037 | @dataclass 1038 | class CreateTrainingSubjectResponse(JSONe): 1039 | task_id: str 1040 | 1041 | 1042 | @dataclass 1043 | class QueryTrainingSubjectModel(JSONe): 1044 | model_name: str 1045 | model_status: str 1046 | 1047 | 1048 | @dataclass 1049 | class QueryTrainingSubjectStatusResponse(JSONe): 1050 | task_id: str 1051 | task_status: str 1052 | model_type: str 1053 | models: List[QueryTrainingSubjectModel] 1054 | 1055 | # --------------- Training Style --------------- 1056 | 1057 | 1058 | @dataclass 1059 | class TrainingStyleImageDatasetItem(JSONe): 1060 | assets_id: str 1061 | caption: str 1062 | 1063 | 1064 | @dataclass 1065 | class CreateTrainingStyleRequest(JSONe): 1066 | name: str 1067 | base_model: str 1068 | image_dataset_items: List[TrainingStyleImageDatasetItem] 1069 | width: int = 512 1070 | height: int = 512 1071 | expert_setting: TrainingExpertSetting = None 1072 | components: List[TrainingComponent] = None 1073 | 1074 | 1075 | @dataclass 1076 | class CreateTrainingStyleResponse(JSONe): 1077 | task_id: str 1078 | 1079 | 1080 | @dataclass 1081 | class TrainingTaskInfoModel(JSONe): 1082 | model_name: str 1083 | model_status: str 1084 | 1085 | 1086 | @dataclass 1087 | class TrainingTaskInfo(JSONe): 1088 | task_name: str 1089 | task_id: str 1090 | task_type: str 1091 | task_status: str 1092 | created_at: int 1093 | models: List[TrainingTaskInfoModel] 1094 | 1095 | 1096 | @dataclass 1097 | class TrainingTaskPagination(JSONe): 1098 | next_cursor: Optional[str] = None 1099 | 1100 | 1101 | @dataclass 1102 | class TrainingTaskListResponse(JSONe): 1103 | tasks: List[TrainingTaskInfo] = field(default_factory=lambda: []) 1104 | pagination: TrainingTaskPagination = None 1105 | 1106 | 1107 | class TrainingTaskList(list): 1108 | def __init__(self, *args, **kwargs): 1109 | super().__init__(*args, **kwargs) 1110 | 1111 | def get_by_task_name(self, task_name: str): 1112 | for task in self: 1113 | if task.task_name == task_name: 1114 | return task 1115 | return None 1116 | 1117 | def filter_by_task_type(self, task_type: str): 1118 | return TrainingTaskList([task for task in self if task.task_type == task_type]) 1119 | 1120 | def filter_by_task_status(self, task_status: str): 1121 | return TrainingTaskList([task for task in self if task.task_status == task_status]) 1122 | 1123 | def filter_by_model_status(self, model_status: str): 1124 | return TrainingTaskList([task for task in self if any(model.model_status == model_status for model in task.models)]) 1125 | 1126 | def sort_by_created_at(self): 1127 | return TrainingTaskList(sorted(self, key=lambda x: x.created_at, reverse=True)) 1128 | 1129 | # --------------- Text to Video --------------- 1130 | @dataclass 1131 | class Txt2VideoLoRA(JSONe): 1132 | model_name: str 1133 | strength: float = 1.0 1134 | 1135 | @dataclass 1136 | class Txt2VideoEmbedding(JSONe): 1137 | model_name: str 1138 | 1139 | @dataclass 1140 | class Txt2VideoPrompt(JSONe): 1141 | prompt: str 1142 | frames: int 1143 | 1144 | @dataclass 1145 | class Txt2VideoRequest(JSONe): 1146 | model_name: str 1147 | prompts: List[Txt2VideoPrompt] 1148 | height: int 1149 | width: int 1150 | steps: int 1151 | guidance_scale: float 1152 | negative_prompt: Optional[str] = None 1153 | loras: List[Txt2VideoLoRA] = None 1154 | embeddings: List[Txt2VideoEmbedding] = None 1155 | clip_skip: int = None 1156 | extra: Dict = field(default_factory=lambda: dict()) 1157 | 1158 | def set_video_type(self, video_type: str): 1159 | self.extra['response_video_type'] = video_type 1160 | 1161 | def set_enterprise_plan(self, enterprise_plan: bool): 1162 | self.extra.setdefault('enterprise_plan', {}) 1163 | self.extra['enterprise_plan']['enabled'] = enterprise_plan 1164 | 1165 | @dataclass 1166 | class Txt2VideoResponse(JSONe): 1167 | task_id: str 1168 | 1169 | 1170 | # --------------- Image to Video --------------- 1171 | 1172 | class Img2VideoResizeMode(Enum): 1173 | ORIGINAL_DIMENSION = "ORIGINAL_DIMENSION" 1174 | CROP_TO_ASPECT_RATIO = "CROP_TO_ASPECT_RATIO" 1175 | 1176 | 1177 | @dataclass 1178 | class Img2VideoRequest(JSONe): 1179 | model_name: str 1180 | image_file: str 1181 | steps: int 1182 | frames_num: int = 14 1183 | frames_per_second: int = 6 1184 | seed: Optional[int] = None 1185 | image_file_resize_mode: Optional[str] = Img2VideoResizeMode.CROP_TO_ASPECT_RATIO 1186 | motion_bucket_id: Optional[int] = 127 1187 | enable_frame_interpolation: Optional[bool] = False 1188 | cond_aug: Optional[float] = 0.02 1189 | extra: Dict = field(default_factory=lambda: dict()) 1190 | 1191 | def set_video_type(self, video_type: str): 1192 | self.extra['response_video_type'] = video_type 1193 | def set_enterprise_plan(self, enterprise_plan: bool): 1194 | self.extra.setdefault('enterprise_plan', {}) 1195 | self.extra['enterprise_plan']['enabled'] = enterprise_plan 1196 | 1197 | 1198 | @dataclass 1199 | class Img2VideoResponse(JSONe): 1200 | task_id: str 1201 | 1202 | # --------------- Image to Video Motion --------------- 1203 | @dataclass 1204 | class Img2VideoMotionRequest(JSONe): 1205 | image_assets_id: str 1206 | motion_video_assets_id: str 1207 | seed: Optional[int] = None 1208 | extra: Dict = field(default_factory=lambda: dict()) 1209 | 1210 | def set_video_type(self, video_type: str): 1211 | self.extra['response_image_type'] = video_type 1212 | 1213 | @dataclass 1214 | class Img2VideoMotionResponse(JSONe): 1215 | task_id: str 1216 | 1217 | # --------------- Animated Anyone --------------- 1218 | @dataclass 1219 | class AnimatedAnyoneRequest(JSONe): 1220 | image_assets_id: str 1221 | pose_video_assets_id: str 1222 | height: int 1223 | width: int 1224 | steps: int 1225 | seed: int=None 1226 | extra: Dict = field(default_factory=lambda: dict()) 1227 | 1228 | def set_video_type(self, video_type: str): 1229 | self.extra['response_image_type'] = video_type 1230 | 1231 | @dataclass 1232 | class AnimatedAnyoneResponse(JSONe): 1233 | task_id: str 1234 | 1235 | 1236 | # --------------- LCM Image to Image --------------- 1237 | 1238 | 1239 | @dataclass 1240 | class LCMLoRA(JSONe): 1241 | model_name: str 1242 | strenth: Optional[float] = 1.0 1243 | 1244 | 1245 | @dataclass 1246 | class LCMEmbedding(JSONe): 1247 | model_name: str 1248 | 1249 | 1250 | @dataclass 1251 | class LCMImg2ImgRequest(JSONe): 1252 | model_name: str 1253 | input_image: str 1254 | prompt: str 1255 | negative_prompt: Optional[str] = None 1256 | sd_vae: Optional[str] = None 1257 | loras: Optional[List[LCMLoRA]] = None 1258 | embeddings: Optional[List[LCMEmbedding]] = None 1259 | seed: Optional[int] = None 1260 | image_num: Optional[int] = 1 1261 | steps: Optional[int] = 8 1262 | clip_skip: Optional[int] = None 1263 | guidance_scale: Optional[float] = 0 1264 | 1265 | 1266 | @dataclass 1267 | class LCMImg2ImgResponseImage(JSONe): 1268 | image_url: str 1269 | image_type: str 1270 | image_url_ttl: int 1271 | 1272 | 1273 | @dataclass 1274 | class LCMImg2ImgResponse(JSONe): 1275 | images: List[LCMImg2ImgResponseImage] 1276 | 1277 | 1278 | # --------------- Make Photo --------------- 1279 | 1280 | @dataclass 1281 | class MakePhotoLoRA(JSONe): 1282 | model_name: str 1283 | strength: Optional[float] = 1.0 1284 | 1285 | 1286 | @dataclass 1287 | class MakePhotoRequest(JSONe): 1288 | image_assets_ids: List[str] 1289 | model_name: str 1290 | prompt: str 1291 | negative_prompt: Optional[str] = None 1292 | loras: List[MakePhotoLoRA] = None 1293 | height: Optional[int] = 1024 1294 | width: Optional[int] = 1024 1295 | image_num: Optional[int] = 1 1296 | steps: Optional[int] = 50 1297 | seed: Optional[int] = None 1298 | guidance_scale: Optional[float] = 7.5 1299 | sampler_name: Optional[str] = Samplers.EULER_A 1300 | strength: Optional[float] = 0.25 1301 | crop_face: Optional[bool] = True 1302 | extra: Dict = field(default_factory=lambda: dict()) 1303 | 1304 | def set_image_type(self, image_type: str): 1305 | self.extra['response_image_type'] = image_type 1306 | 1307 | def set_enterprise_plan(self, enterprise_plan: bool): 1308 | self.extra.setdefault('enterprise_plan', {}) 1309 | self.extra['enterprise_plan']['enabled'] = enterprise_plan 1310 | 1311 | 1312 | @dataclass 1313 | class MakePhotoResponse(JSONe): 1314 | task_id: str 1315 | 1316 | 1317 | @dataclass 1318 | class InstantIDControlnetUnit(JSONe): 1319 | model_name: str 1320 | strength: Optional[float] 1321 | preprocessor: Optional[ControlNetPreprocessor] 1322 | 1323 | 1324 | InstantIDLora = MakePhotoLoRA 1325 | 1326 | 1327 | @dataclass 1328 | class InstantIDRequestControlNet(JSONe): 1329 | units: List[InstantIDControlnetUnit] 1330 | 1331 | 1332 | @dataclass 1333 | class InstantIDRequest(JSONe): 1334 | face_image_assets_ids: List[str] 1335 | ref_image_assets_ids: List[str] 1336 | model_name: str = None 1337 | prompt: str = None 1338 | negative_prompt: str = None 1339 | width: int = None 1340 | height: int = None 1341 | id_strength: float = 1. 1342 | adapter_strength: float = 1. 1343 | steps: int = 20 1344 | seed: int = -1 1345 | image_num: int = 1 1346 | guidance_scale: float = 5. 1347 | sampler_name: str = 'Euler' 1348 | controlnet: InstantIDRequestControlNet = None 1349 | loras: List[InstantIDLora] = None 1350 | extra: Dict = field(default_factory=lambda: dict()) 1351 | 1352 | def set_image_type(self, image_type: str): 1353 | self.extra['response_image_type'] = image_type 1354 | 1355 | def set_enterprise_plan(self, enterprise_plan: bool): 1356 | self.extra.setdefault('enterprise_plan', {}) 1357 | self.extra['enterprise_plan']['enabled'] = enterprise_plan 1358 | 1359 | # --------------- Common V3 --------------- 1360 | 1361 | @dataclass 1362 | class CommonV3Request(JSONe): 1363 | extra: Dict[str, Any] = field(default_factory=lambda: dict()) 1364 | request: Any = field(default_factory=lambda: dict()) 1365 | 1366 | 1367 | @dataclass 1368 | class CommonV3Extra(JSONe): 1369 | response_image_type: str = "jpeg" 1370 | enable_nsfw_detection: bool = False 1371 | nsfw_detection_level: int = 0 1372 | custom_storage: Dict[str, Any] = field(default_factory=lambda: dict()) 1373 | enterprise_plan: Dict[str, Any] = field(default_factory=lambda: dict()) 1374 | 1375 | # --------------- Img2ImgV3 --------------- 1376 | 1377 | 1378 | @dataclass 1379 | class Img2V3ImgLoRA(JSONe): 1380 | model_name: str 1381 | strength: Optional[float] = 1.0 1382 | 1383 | 1384 | @dataclass 1385 | class Img2ImgV3Embedding(JSONe): 1386 | model_name: str 1387 | 1388 | 1389 | @dataclass 1390 | class Img2ImgV3ControlNetUnit(JSONe): 1391 | model_name: str 1392 | image_base64: str 1393 | strength: Optional[float] = 1.0 1394 | preprocessor: Optional[ControlNetPreprocessor] = "canny" 1395 | guidance_start: Optional[float] = 0 1396 | guidance_end: Optional[float] = 1 1397 | 1398 | 1399 | @dataclass 1400 | class Img2ImgV3ControlNet(JSONe): 1401 | units: List[Img2ImgV3ControlNetUnit] 1402 | 1403 | 1404 | @dataclass 1405 | class Img2ImgV3Request(JSONe): 1406 | model_name: str 1407 | image_base64: str 1408 | prompt: str 1409 | width: Optional[int] = 512 1410 | height: Optional[int] = 512 1411 | negative_prompt: Optional[str] = None 1412 | sd_vae: Optional[str] = None 1413 | loras: Optional[List[Img2V3ImgLoRA]] = None 1414 | embeddings: Optional[List[Img2ImgV3Embedding]] = None 1415 | seed: Optional[int] = -1 1416 | image_num: Optional[int] = 1 1417 | steps: Optional[int] = 20 1418 | clip_skip: Optional[int] = None 1419 | guidance_scale: Optional[float] = 7.5 1420 | strength: Optional[float] = 0.5 1421 | sampler_name: Optional[str] = Samplers.EULER_A 1422 | extra: Dict = field(default_factory=lambda: dict()) 1423 | controlnet: Optional[Img2ImgV3ControlNet] = None 1424 | 1425 | def set_image_type(self, image_type: str): 1426 | self.extra['response_image_type'] = image_type 1427 | 1428 | def set_enterprise_plan(self, enterprise_plan: bool): 1429 | self.extra.setdefault('enterprise_plan', {}) 1430 | self.extra['enterprise_plan']['enabled'] = enterprise_plan 1431 | 1432 | 1433 | @dataclass 1434 | class Img2ImgV3Response(JSONe): 1435 | task_id: str 1436 | 1437 | # --------------- Txt2ImgV3 --------------- 1438 | 1439 | 1440 | @dataclass 1441 | class Txt2ImgV3Embedding(JSONe): 1442 | model_name: str 1443 | 1444 | 1445 | @dataclass 1446 | class Txt2ImgV3LoRA(JSONe): 1447 | model_name: str 1448 | strength: Optional[float] = 1.0 1449 | 1450 | 1451 | @dataclass 1452 | class Txt2ImgV3HiresFix(JSONe): 1453 | target_width: int 1454 | target_height: int 1455 | strength: float = 0.5 1456 | upscaler: str = "Latent" 1457 | 1458 | 1459 | @dataclass 1460 | class Txt2ImgV3Refiner(JSONe): 1461 | switch_at: float = 0.5 1462 | 1463 | 1464 | @dataclass 1465 | class Txt2ImgV3Request(JSONe): 1466 | model_name: str 1467 | prompt: str 1468 | height: Optional[int] = 512 1469 | width: Optional[int] = 512 1470 | image_num: Optional[int] = 1 1471 | sd_vae: Optional[str] = None 1472 | steps: Optional[int] = 20 1473 | guidance_scale: Optional[float] = 7.5 1474 | sampler_name: Optional[str] = Samplers.EULER_A 1475 | seed: Optional[int] = None 1476 | negative_prompt: Optional[str] = None 1477 | loras: Optional[List[Txt2ImgV3LoRA]] = None 1478 | embeddings: Optional[List[Txt2ImgV3Embedding]] = None 1479 | refiner: Optional[Txt2ImgV3Refiner] = None 1480 | hires_fix: Optional[Txt2ImgV3HiresFix] = None 1481 | clip_skip: Optional[int] = None 1482 | 1483 | 1484 | @dataclass 1485 | class Txt2ImgV3Response(JSONe): 1486 | task_id: str 1487 | 1488 | 1489 | # --------------- Inpainting --------------- 1490 | @dataclass 1491 | class InpaintingLoRA(JSONe): 1492 | model_name: str 1493 | strength: Optional[float] = 1.0 1494 | 1495 | @dataclass 1496 | class InpaintingEmbedding(JSONe): 1497 | model_name: str 1498 | 1499 | 1500 | @dataclass 1501 | class InpaintingExtra(JSONe): 1502 | response_image_type: str = "png" 1503 | nsfw_detection_level: int = 0 1504 | custom_storage: Dict[str, Any] = field(default_factory=lambda: dict()) 1505 | enterprise_plan: Dict[str, Any] = field(default_factory=lambda: dict()) 1506 | enable_nsfw_detection: bool = False 1507 | 1508 | 1509 | 1510 | 1511 | @dataclass 1512 | class InpaintingRequest(JSONe): 1513 | model_name: str 1514 | image_base64: str 1515 | mask_image_base64: str 1516 | prompt: str 1517 | image_num: int 1518 | steps: int 1519 | guidance_scale: float 1520 | seed: int 1521 | sampler_name: str 1522 | negative_prompt: str = "" 1523 | mask_blur: int = None 1524 | sd_vae: str = "" 1525 | loras: Optional[List[InpaintingLoRA]] = None 1526 | embeddings: Optional[List[InpaintingEmbedding]] = None 1527 | clip_skip: int = 0 1528 | strength: float = 1.0 1529 | inpainting_full_res : bool = False 1530 | inpainting_full_res_padding: int = 8 1531 | inpainting_mask_invert: bool = False 1532 | initial_noise_multiplier: float = 0.5 1533 | 1534 | 1535 | 1536 | 1537 | 1538 | @dataclass 1539 | class InpaintingResponse(JSONe): 1540 | task_id: str 1541 | 1542 | 1543 | 1544 | 1545 | 1546 | # --------------- Model --------------- 1547 | 1548 | 1549 | class ModelType(Enum): 1550 | CHECKPOINT = "checkpoint" 1551 | LORA = "lora" 1552 | VAE = "vae" 1553 | CONTROLNET = "controlnet" 1554 | TEXT_INVERSION = "textualinversion" 1555 | UPSCALER = "upscaler" 1556 | 1557 | UNKNOWN = "unknown" 1558 | 1559 | @classmethod 1560 | def _missing_(cls, number): 1561 | return cls(cls.UNKNOWN) 1562 | 1563 | 1564 | @dataclass 1565 | class CivitaiImageMeta(JSONe): 1566 | prompt: Optional[str] = None 1567 | negative_prompt: Optional[str] = None 1568 | sampler_name: Optional[str] = None 1569 | steps: Optional[int] = None 1570 | cfg_scale: Optional[int] = None 1571 | seed: Optional[int] = None 1572 | height: Optional[int] = None 1573 | width: Optional[int] = None 1574 | model_name: Optional[str] = None 1575 | 1576 | 1577 | @dataclass 1578 | class CivitaiImage(JSONe): 1579 | url: str 1580 | nsfw: str 1581 | meta: Optional[CivitaiImageMeta] = None 1582 | 1583 | 1584 | @dataclass 1585 | class ModelInfo(JSONe): 1586 | name: str 1587 | hash: str 1588 | civitai_version_id: int 1589 | sd_name: str 1590 | third_source: str 1591 | download_status: int 1592 | download_name: str 1593 | dependency_status: int 1594 | type: ModelType 1595 | civitai_nsfw: Optional[bool] = False 1596 | civitai_model_id: Optional[int] = 0 1597 | civitai_link: Optional[str] = None 1598 | civitai_images: Optional[List[CivitaiImage]] = field(default_factory=lambda: []) 1599 | civitai_download_url: Optional[str] = None 1600 | civitai_allow_commercial_use: Optional[bool] = True 1601 | civitai_allow_different_license: Optional[bool] = True 1602 | civitai_create_at: Optional[str] = None 1603 | civitai_update_at: Optional[str] = None 1604 | civitai_tags: Optional[str] = None 1605 | civitai_download_count: Optional[int] = 0 1606 | civitai_favorite_count: Optional[int] = 0 1607 | civitai_comment_count: Optional[int] = 0 1608 | civitai_rating_count: Optional[int] = 0 1609 | civitai_rating: Optional[float] = 0.0 1610 | novita_used_count: Optional[int] = None 1611 | civitai_image_url: Optional[str] = None 1612 | civitai_image_nsfw: Optional[bool] = False 1613 | civitai_origin_image_url: Optional[str] = None 1614 | civitai_image_prompt: Optional[str] = None 1615 | civitai_image_negative_prompt: Optional[str] = None 1616 | civitai_image_sampler_name: Optional[str] = None 1617 | civitai_image_height: Optional[int] = None 1618 | civitai_image_width: Optional[int] = None 1619 | civitai_image_steps: Optional[int] = None 1620 | civitai_image_cfg_scale: Optional[int] = None 1621 | civitai_image_seed: Optional[int] = None 1622 | 1623 | 1624 | @dataclass 1625 | class ModelData(JSONe): 1626 | models: List[ModelInfo] = None 1627 | 1628 | 1629 | @dataclass 1630 | class MoodelsResponse(JSONe): 1631 | code: int 1632 | msg: str 1633 | data: Optional[ModelData] = field(default_factory=lambda: []) 1634 | 1635 | 1636 | class ModelList(list): 1637 | """A list of ModelInfo""" 1638 | 1639 | def __init__(self, *args, **kwargs): 1640 | super().__init__(*args, **kwargs) 1641 | 1642 | def get_by_civitai_version_id(self, civitai_version_id: int): 1643 | for model in self: 1644 | if model.civitai_version_id == civitai_version_id: 1645 | return model 1646 | return None 1647 | 1648 | def get_by_name(self, name): 1649 | for model in self: 1650 | if model.name == name: 1651 | return model 1652 | return None 1653 | 1654 | def get_by_sd_name(self, sd_name): 1655 | for model in self: 1656 | if model.sd_name == sd_name: 1657 | return model 1658 | return None 1659 | 1660 | def list_civitai_tags(self) -> List[str]: 1661 | s = set() 1662 | for model in self: 1663 | if model.civitai_tags: 1664 | s.update(s.strip() 1665 | for s in model.civitai_tags.split(",") if s.strip()) 1666 | return list(s) 1667 | 1668 | def filter_by_civitai_tags(self, *tags): 1669 | ret = [] 1670 | for model in self: 1671 | if model.civitai_tags: 1672 | if set(tags).issubset(set(s.strip() for s in model.civitai_tags.split(","))): 1673 | ret.append(model) 1674 | return ModelList(ret) 1675 | 1676 | def filter_by_nsfw(self, nsfw: bool): 1677 | return ModelList([model for model in self if model.civitai_nsfw == nsfw]) 1678 | 1679 | def filter_by_type(self, type): 1680 | return ModelList([model for model in self if model.type == type]) 1681 | 1682 | def filter_by_civitai_model_id(self, civitai_model_id: int): 1683 | return ModelList([model for model in self if model.civitai_model_id == civitai_model_id]) 1684 | 1685 | def filter_by_civitai_model_name(self, name: str): 1686 | return ModelList([model for model in self if model.name == name]) 1687 | 1688 | def sort_by_civitai_download(self): 1689 | return ModelList(sorted(self, key=lambda x: x.civitai_download_count, reverse=True)) 1690 | 1691 | def sort_by_civitai_rating(self): 1692 | return ModelList(sorted(self, key=lambda x: x.civitai_rating, reverse=True)) 1693 | 1694 | def sort_by_civitai_favorite(self): 1695 | return ModelList(sorted(self, key=lambda x: x.civitai_favorite, reverse=True)) 1696 | 1697 | def sort_by_civitai_comment(self): 1698 | return ModelList(sorted(self, key=lambda x: x.civitai_comment, reverse=True)) 1699 | 1700 | 1701 | # --------------- Model V3 --------------- 1702 | 1703 | @dataclass 1704 | class ModelInfoTypeV3(JSONe): 1705 | name: str 1706 | display_name: str 1707 | 1708 | 1709 | class ModelInfoStatus(Enum): 1710 | UNAVAILABLE = 0 1711 | AVAILABLE = 1 1712 | 1713 | 1714 | @dataclass 1715 | class ModelInfoV3(JSONe): 1716 | id: int 1717 | name: str 1718 | sd_name: str 1719 | type: ModelInfoTypeV3 1720 | status: ModelInfoStatus 1721 | hash_sha256: Optional[str] = None 1722 | categories: Optional[List[str]] = None 1723 | download_url: Optional[str] = None 1724 | base_model: Optional[str] = None 1725 | source: Optional[str] = None 1726 | download_url_ttl: Optional[int] = None 1727 | sd_name_in_api: Optional[str] = None 1728 | is_nsfw: Optional[bool] = None 1729 | visibility: Optional[str] = None 1730 | cover_url: Optional[str] = None 1731 | 1732 | 1733 | class ModelListV3(list): 1734 | def __init__(self, *args, **kwargs): 1735 | super().__init__(*args, **kwargs) 1736 | 1737 | def get_by_name(self, name): 1738 | for model in self: 1739 | if model.name == name: 1740 | return model 1741 | return None 1742 | 1743 | def get_by_sd_name(self, sd_name): 1744 | for model in self: 1745 | if model.sd_name == sd_name: 1746 | return model 1747 | return None 1748 | 1749 | def filter_by_type(self, type): 1750 | return ModelListV3([model for model in self if model.type.name == type]) 1751 | 1752 | def filter_by_nsfw(self, nsfw: bool): 1753 | return ModelListV3([model for model in self if model.is_nsfw == nsfw]) 1754 | 1755 | def filter_by_status(self, status: ModelInfoStatus): 1756 | return ModelListV3([model for model in self if model.status == status]) 1757 | 1758 | def filter_by_source(self, source: str): 1759 | return ModelListV3([model for model in self if model.source == source]) 1760 | 1761 | def filter_by_visibility(self, visibility: str): 1762 | return ModelListV3([model for model in self if model.visibility == visibility]) 1763 | 1764 | def filter_by_available(self, available: bool): 1765 | return ModelListV3([model for model in self if model.status == ModelInfoStatus.AVAILABLE]) 1766 | 1767 | def sort_by_name(self): 1768 | return ModelListV3(sorted(self, key=lambda x: x.name)) 1769 | 1770 | 1771 | @dataclass 1772 | class ModelsPaginationV3(JSONe): 1773 | next_cursor: Optional[str] = None 1774 | 1775 | 1776 | @dataclass 1777 | class MoodelsResponseV3(JSONe): 1778 | models: List[ModelInfoV3] = None 1779 | pagination: ModelsPaginationV3 = None 1780 | 1781 | 1782 | 1783 | 1784 | # --------------- User Info --------------- 1785 | @dataclass 1786 | class UserInfoResponse(JSONe): 1787 | allow_features: List[str] = None 1788 | credit_balance: int = 0 1789 | free_trial: Dict[str, int] = field(default_factory=lambda: {}) 1790 | --------------------------------------------------------------------------------