├── .gitignore ├── requirements.txt ├── fiftyone.yml ├── local_t2i_models.py ├── .pre-commit-config.yaml ├── assets └── icon.svg ├── README.md └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.cython 3 | *.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | openai>=1.1.0 2 | replicate 3 | diffusers>=0.24.0 -------------------------------------------------------------------------------- /fiftyone.yml: -------------------------------------------------------------------------------- 1 | fiftyone: 2 | version: ">=0.23.7" 3 | name: "@jacobmarks/text_to_image" 4 | version: "1.2.6" 5 | description: "Run Text-to-image models to add synthetic images directly to your dataset!" 6 | url: "https://github.com/jacobmarks/text-to-image" 7 | operators: 8 | - txt2img 9 | -------------------------------------------------------------------------------- /local_t2i_models.py: -------------------------------------------------------------------------------- 1 | from diffusers import DiffusionPipeline 2 | 3 | 4 | def get_cache(): 5 | g = globals() 6 | if "_local_t2i_models" not in g: 7 | g["_local_t2i_models"] = {} 8 | 9 | return g["_local_t2i_models"] 10 | 11 | 12 | def lcm( 13 | prompt, 14 | width, 15 | height, 16 | num_inference_steps, 17 | guide_scale, 18 | lcm_origin_steps, 19 | ): 20 | if "lcm" not in get_cache(): 21 | get_cache()["lcm"] = DiffusionPipeline.from_pretrained( 22 | "SimianLuo/LCM_Dreamshaper_v7" 23 | ) 24 | 25 | pipe = get_cache()["lcm"] 26 | images = pipe( 27 | prompt=prompt, 28 | num_inference_steps=num_inference_steps, 29 | guidance_scale=guide_scale, 30 | lcm_origin_steps=lcm_origin_steps, 31 | output_type="pil", 32 | width=width, 33 | height=height, 34 | ).images 35 | return images[0] 36 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/asottile/blacken-docs 3 | rev: v1.12.0 4 | hooks: 5 | - id: blacken-docs 6 | additional_dependencies: [black==21.12b0] 7 | args: ["-l 79"] 8 | exclude: index.umd.js 9 | - repo: https://github.com/ambv/black 10 | rev: 22.3.0 11 | hooks: 12 | - id: black 13 | language_version: python3 14 | args: ["-l 79"] 15 | exclude: index.umd.js 16 | - repo: local 17 | hooks: 18 | - id: pylint 19 | name: pylint 20 | language: system 21 | files: \.py$ 22 | entry: pylint 23 | args: ["--errors-only"] 24 | exclude: index.umd.js 25 | - repo: local 26 | hooks: 27 | - id: ipynb-strip 28 | name: ipynb-strip 29 | language: system 30 | files: \.ipynb$ 31 | entry: jupyter nbconvert --clear-output --ClearOutputPreprocessor.enabled=True 32 | args: ["--log-level=ERROR"] 33 | - repo: https://github.com/pre-commit/mirrors-prettier 34 | rev: v2.6.2 35 | hooks: 36 | - id: prettier 37 | exclude: index.umd.js 38 | language_version: system 39 | -------------------------------------------------------------------------------- /assets/icon.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Text-to-Image Plugin 2 | 3 | ![ssd1b](https://github.com/jacobmarks/ai-art-gallery/assets/12500356/f5202d68-c5c1-44c7-b662-98d98e5c05aa) 4 | 5 | ### Updates 6 | 7 | - **2024-04-23**: Added support for Stable Diffusion 3 (Thanks [Dan Gural](https://github.com/danielgural/)) 8 | - **2023-12-19**: Added support for Kandinsky-2.2 and Playground V2 models 9 | - **2023-11-30**: Version 1.2.0 10 | - adds local model running via `diffusers` (>=0.24.0) 11 | - adds [calling from the Python SDK](#python-sdk)! 12 | - :warning: **BREAKING CHANGE**: the plugin and operator URIs have been changed from `ai_art_gallery` to `text_to_image`. If you have any saved pipelines that use the plugin, you will need to update the URIs. 13 | - **2023-11-08**: Version 1.1.0 adds support for DALLE-3 Model — upgrade to `openai>=1.1.0` to use 😄 14 | - **2023-10-30**: Added support for Segmind Stable Diffusion (SSD-1B) Model 15 | - **2023-10-23**: Added support for Latent Consistency Model 16 | - **2023-10-18**: Added support for SDXL, operator icon, and download location selection 17 | 18 | ### Plugin Overview 19 | 20 | This plugin is a Python plugin that allows you to generate images from text 21 | prompts and add them directly into your dataset. 22 | 23 | :warning: This plugin is only verified to work for local datasets. It may not 24 | work for remote datasets. 25 | 26 | ### Supported Models 27 | 28 | This version of the plugin supports the following models: 29 | 30 | - [DALL-E2](https://openai.com/dall-e-2) 31 | - [DALL-E3](https://openai.com/dall-e-3) 32 | - [Kandinsky-2.2](https://replicate.com/ai-forever/kandinsky-2.2) 33 | - [Latent Consistency Model](https://replicate.com/luosiallen/latent-consistency-model/) 34 | - [Playground V2](https://replicate.com/playgroundai/playground-v2-1024px-aesthetic) 35 | - [SDXL](https://replicate.com/stability-ai/sdxl) 36 | - [SDXL-Lighting](https://replicate.com/lucataco/sdxl-lightning-4step) 37 | - [Segmind Stable Diffusion (SSD-1B)](https://replicate.com/lucataco/ssd-1b/) 38 | - [Stable Diffusion](https://replicate.com/stability-ai/stable-diffusion) 39 | - [Stable Diffusion 3](https://stability.ai/news/stable-diffusion-3) 40 | - [VQGAN-CLIP](https://replicate.com/mehdidc/feed_forward_vqgan_clip) 41 | 42 | It is straightforward to add support for other models! 43 | 44 | ## Watch On Youtube 45 | 46 | [![Video Thumbnail](https://img.youtube.com/vi/qJNEyC_FqG0/0.jpg)](https://www.youtube.com/watch?v=qJNEyC_FqG0&list=PLuREAXoPgT0RZrUaT0UpX_HzwKkoB-S9j&index=2) 47 | 48 | ## Installation 49 | 50 | ```shell 51 | fiftyone plugins download https://github.com/jacobmarks/text-to-image 52 | ``` 53 | 54 | If you want to use Replicate models, you will 55 | need to `pip install replicate` and set the environment variable 56 | `REPLICATE_API_TOKEN` with your API token. 57 | 58 | If you want to use DALL-E2 or DALL-E3, you will need to `pip install openai` and set the 59 | environment variable `OPENAI_API_KEY` with your API key. 60 | 61 | To run the Latency Consistency model locally with Hugging Face's diffusers library, 62 | you will need `diffusers>=0.24.0`. If you need to, you can install it with 63 | `pip install diffusers>=0.24.0`. 64 | 65 | To run Stable Diffusion 3, you will need to set up a [Stability.ai](https://platform.stability.ai/) account to get access to key. Then set the environment variable 66 | `STABILITY_API_KEY` with your API token. 67 | 68 | Refer to the [main README](https://github.com/voxel51/fiftyone-plugins) for 69 | more information about managing downloaded plugins and developing plugins 70 | locally. 71 | 72 | ## Operators 73 | 74 | ### `txt2img` 75 | 76 | - Generates an image from a text prompt and adds it to the dataset 77 | 78 | ### Python SDK 79 | 80 | You can also use the `txt2img` operators from the Python SDK! 81 | 82 | ⚠️ If you're using `fiftyone<=0.23.6`, due to the way Jupyter Notebooks interact with asyncio, this will not work in a Jupyter Notebook. You will need to run this code in a Python script or in a Python console. 83 | 84 | ```python 85 | import fiftyone as fo 86 | import fiftyone.operators as foo 87 | import fiftyone.zoo as foz 88 | 89 | dataset = fo.load_dataset("quickstart") 90 | 91 | ## Access the operator via its URI (plugin name + operator name) 92 | t2i = foo.get_operator("@jacobmarks/text_to_image/txt2img") 93 | 94 | ## Run the operator 95 | 96 | prompt = "A dog sitting in a field" 97 | t2i(dataset, prompt=prompt, model_name="latent-consistency", delegate=False) 98 | 99 | ## Pass in model-specific arguments 100 | t2i( 101 | dataset, 102 | prompt=prompt, 103 | model_name="latent-consistency", 104 | delegate=False, 105 | width=768, 106 | height=768, 107 | num_inference_steps=8, 108 | ) 109 | ``` 110 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | """Text to Image plugin. 2 | 3 | | Copyright 2017-2023, Voxel51, Inc. 4 | | `voxel51.com `_ 5 | | 6 | """ 7 | 8 | from datetime import datetime 9 | import os 10 | import uuid 11 | from importlib.util import find_spec 12 | from importlib.metadata import version as mversion 13 | from packaging import version as pversion 14 | 15 | import fiftyone.operators as foo 16 | from fiftyone.operators import types 17 | import fiftyone as fo 18 | import fiftyone.core.utils as fou 19 | from fiftyone.core.utils import add_sys_path 20 | 21 | import requests 22 | 23 | openai = fou.lazy_import("openai") 24 | replicate = fou.lazy_import("replicate") 25 | 26 | SD_MODEL_URL = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478" 27 | SD_SCHEDULER_CHOICES = ( 28 | "DDIM", 29 | "K_EULER", 30 | "K_EULER_ANCESTRAL", 31 | "PNDM", 32 | "K-LMS", 33 | ) 34 | SD_SIZE_CHOICES = ( 35 | "128", 36 | "256", 37 | "384", 38 | "512", 39 | "576", 40 | "640", 41 | "704", 42 | "768", 43 | "832", 44 | "896", 45 | "960", 46 | "1024", 47 | ) 48 | 49 | SDXL_MODEL_URL = "stability-ai/sdxl:c221b2b8ef527988fb59bf24a8b97c4561f1c671f73bd389f866bfb27c061316" 50 | SDXL_SCHEDULER_CHOICES = ( 51 | "DDIM", 52 | "K_EULER", 53 | "K_EULER_ANCESTRAL", 54 | "PNDM", 55 | "DPMSolverMultistep", 56 | "KarrasDPM", 57 | "HeunDiscrete", 58 | ) 59 | 60 | SDXL_REFINE_CHOICES = ("None", "Expert Ensemble", "Base") 61 | 62 | SDXL_REFINE_MAP = { 63 | "None": "no_refiner", 64 | "Expert Ensemble": "expert_ensemble_refiner", 65 | "Base": "base_image_refiner", 66 | } 67 | 68 | SDXL_LIGHTNING_MODEL_URL = "lucataco/sdxl-lightning-4step:727e49a643e999d602a896c774a0658ffefea21465756a6ce24b7ea4165eba6a" 69 | SDXL_LIGHTNING_SIZE_CHOICES = ("1024", "1280") 70 | SDXL_LIGHTNING_NEGATIVE_PROMPT_DEFAULT = "worst quality, low quality" 71 | SDXL_LIGHTNING_SCHEDULER_CHOICES = ( 72 | "DDIM", 73 | "DPMSolverMultistep", 74 | "HeunDiscrete", 75 | "KarrasDPM", 76 | "K_EULER_ANCESTRAL", 77 | "K_EULER", 78 | "PNDM", 79 | "DPM+2MSDE", 80 | ) 81 | SDXL_LIGHTNING_NUM_INFERENCE_STEPS_DEFAULT = 4 82 | SDXL_LIGHTNING_GUIDANCE_SCALE_DEFAULT = 7.5 83 | 84 | 85 | SSD1B_MODEL_URL = "lucataco/ssd-1b:1ee85ef681d5ad3d6870b9da1a4543cb3ad702d036fa5b5210f133b83b05a780" 86 | SSD1B_SCHEDULER_CHOICES = ( 87 | "DDIM", 88 | "DDPMSolverMultistep", 89 | "HeunDiscrete", 90 | "KarrasDPM", 91 | "K_EULER_ANCESTRAL", 92 | "K_EULER", 93 | "PNDM", 94 | ) 95 | 96 | PLAYGROUND_V2_MODEL_URL = "playgroundai/playground-v2-1024px-aesthetic:42fe626e41cc811eaf02c94b892774839268ce1994ea778eba97103fe1ef51b8" 97 | PLAYGROUND_V2_SCHEDULER_CHOICES = ( 98 | "DDIM", 99 | "DPMSolverMultistep", 100 | "HeunDiscrete", 101 | "K_EULER_ANCESTRAL", 102 | "K_EULER", 103 | "PNDM", 104 | ) 105 | 106 | LC_MODEL_URL = "luosiallen/latent-consistency-model:553803fd018b3cf875a8bc774c99da9b33f36647badfd88a6eec90d61c5f62fc" 107 | 108 | VQGAN_MODEL_URL = "mehdidc/feed_forward_vqgan_clip:28b5242dadb5503688e17738aaee48f5f7f5c0b6e56493d7cf55f74d02f144d8" 109 | 110 | KANDINSKY_MODEL_URL = "ai-forever/kandinsky-2.2:ea1addaab376f4dc227f5368bbd8eff901820fd1cc14ed8cad63b29249e9d463" 111 | KANDINSDKY_SIZE_CHOICES = ( 112 | "384", 113 | "512", 114 | "576", 115 | "640", 116 | "704", 117 | "768", 118 | "960", 119 | "1024", 120 | "1152", 121 | "1280", 122 | "1536", 123 | "1792", 124 | "2048", 125 | ) 126 | 127 | 128 | DALLE2_SIZE_CHOICES = ("256x256", "512x512", "1024x1024") 129 | 130 | DALLE3_SIZE_CHOICES = ("1024x1024", "1024x1792", "1792x1024") 131 | 132 | DALLE3_QUALITY_CHOICES = ("standard", "hd") 133 | 134 | 135 | 136 | SD3_ASPECT_RATIO_CHOICES = ("1:1", "16:9", "21:9", "2:3", "3:2", "4:5", "5:4", "9:16", "9:21" ) 137 | 138 | SD3_MODEL_URL = "https://api.stability.ai/v2beta/stable-image/generate/sd3" 139 | 140 | 141 | def allows_replicate_models(): 142 | """Returns whether the current environment allows replicate models.""" 143 | return ( 144 | find_spec("replicate") is not None 145 | and "REPLICATE_API_TOKEN" in os.environ 146 | ) 147 | 148 | 149 | def allows_openai_models(): 150 | """Returns whether the current environment allows openai models.""" 151 | return find_spec("openai") is not None and "OPENAI_API_KEY" in os.environ 152 | 153 | def allows_stabilityai_models(): 154 | """Returns whether the current environment allows stabilityai models.""" 155 | return "STABILITY_API_KEY" in os.environ 156 | 157 | 158 | def allows_diffusers_models(): 159 | """Returns whether the current environment allows diffusers models.""" 160 | if find_spec("diffusers") is None: 161 | return False 162 | version = mversion("diffusers") 163 | return pversion.parse(version) >= pversion.parse("0.24.0") 164 | 165 | 166 | def download_image(image_url, filename): 167 | img_data = requests.get(image_url).content 168 | with open(filename, "wb") as handler: 169 | handler.write(img_data) 170 | 171 | def write_image(response, filename): 172 | if response.status_code == 200: 173 | with open(filename, 'wb') as file: 174 | file.write(response.content) 175 | else: 176 | raise Exception(str(response.json())) 177 | 178 | class Text2Image: 179 | """Wrapper for a Text2Image model.""" 180 | 181 | def __init__(self): 182 | self.name = None 183 | self.model_name = None 184 | 185 | def generate_image(self, ctx): 186 | pass 187 | 188 | 189 | class StableDiffusion(Text2Image): 190 | """Wrapper for a StableDiffusion model.""" 191 | 192 | def __init__(self): 193 | super().__init__() 194 | self.name = "stable-diffusion" 195 | self.model_name = SD_MODEL_URL 196 | 197 | def generate_image(self, ctx): 198 | prompt = ctx.params.get("prompt", "None provided") 199 | width = int(ctx.params.get("width_choices", "None provided")) 200 | height = int(ctx.params.get("height_choices", "None provided")) 201 | inference_steps = ctx.params.get("inference_steps", "None provided") 202 | scheduler = ctx.params.get("scheduler_choices", "None provided") 203 | 204 | response = replicate.run( 205 | self.model_name, 206 | input={ 207 | "prompt": prompt, 208 | "width": width, 209 | "height": height, 210 | "inference_steps": inference_steps, 211 | "scheduler": scheduler, 212 | }, 213 | ) 214 | if type(response) == list: 215 | response = response[0] 216 | return response 217 | 218 | 219 | class SDXL(Text2Image): 220 | """Wrapper for a StableDiffusion XL model.""" 221 | 222 | def __init__(self): 223 | super().__init__() 224 | self.name = "sdxl" 225 | self.model_name = SDXL_MODEL_URL 226 | 227 | def generate_image(self, ctx): 228 | prompt = ctx.params.get("prompt", "None provided") 229 | inference_steps = ctx.params.get("inference_steps", 50.0) 230 | scheduler = ctx.params.get("scheduler_choices", "None provided") 231 | guidance_scale = ctx.params.get("guidance_scale", 7.5) 232 | refiner = ctx.params.get("refine_choices", SDXL_REFINE_CHOICES[0]) 233 | refiner = SDXL_REFINE_MAP[refiner] 234 | refine_steps = ctx.params.get("refine_steps", None) 235 | negative_prompt = ctx.params.get("negative_prompt", None) 236 | high_noise_frac = ctx.params.get("high_noise_frac", None) 237 | 238 | _inputs = { 239 | "prompt": prompt, 240 | "inference_steps": inference_steps, 241 | "scheduler": scheduler, 242 | "refine": refiner, 243 | "guidance_scale": guidance_scale, 244 | } 245 | if negative_prompt is not None: 246 | _inputs["negative_prompt"] = negative_prompt 247 | if refine_steps is not None: 248 | _inputs["refine_steps"] = refine_steps 249 | if high_noise_frac is not None: 250 | _inputs["high_noise_frac"] = high_noise_frac 251 | 252 | response = replicate.run( 253 | self.model_name, 254 | input=_inputs, 255 | ) 256 | if type(response) == list: 257 | response = response[0] 258 | return response 259 | 260 | 261 | class SDXLLightning(Text2Image): 262 | """Wrapper for a StableDiffusion XL Lightning model.""" 263 | 264 | def __init__(self): 265 | super().__init__() 266 | self.name = "sdxl-lightning" 267 | self.model_name = SDXL_LIGHTNING_MODEL_URL 268 | 269 | def generate_image(self, ctx): 270 | prompt = ctx.params.get("prompt", "None provided") 271 | negative_prompt = ctx.params.get( 272 | "negative_prompt", SDXL_LIGHTNING_NEGATIVE_PROMPT_DEFAULT 273 | ) 274 | inference_steps = ctx.params.get( 275 | "inference_steps", SDXL_LIGHTNING_NUM_INFERENCE_STEPS_DEFAULT 276 | ) 277 | scheduler = ctx.params.get( 278 | "scheduler_choices", SDXL_LIGHTNING_SCHEDULER_CHOICES[0] 279 | ) 280 | guidance_scale = ctx.params.get( 281 | "guidance_scale", SDXL_LIGHTNING_GUIDANCE_SCALE_DEFAULT 282 | ) 283 | width = ctx.params.get("width", SDXL_LIGHTNING_SIZE_CHOICES[0]) 284 | height = ctx.params.get("height", SDXL_LIGHTNING_SIZE_CHOICES[0]) 285 | 286 | _inputs = { 287 | "prompt": prompt, 288 | "num_inference_steps": inference_steps, 289 | "scheduler": scheduler, 290 | "guidance_scale": guidance_scale, 291 | "negative_prompt": negative_prompt, 292 | "width": int(width), 293 | "height": int(height), 294 | } 295 | 296 | response = replicate.run( 297 | self.model_name, 298 | input=_inputs, 299 | ) 300 | if type(response) == list: 301 | response = response[0] 302 | return response 303 | 304 | class StableDiffusion3(Text2Image): 305 | """Wrapper for a StableDiffusion 3 model.""" 306 | 307 | def __init__(self): 308 | super().__init__() 309 | self.name = "stable-diffusion-3" 310 | self.model_name = SD3_MODEL_URL 311 | 312 | def generate_image(self, ctx): 313 | prompt = ctx.params.get("prompt", "None provided") 314 | aspect_ratio = ctx.params.get("aspect_ratio", SD3_ASPECT_RATIO_CHOICES[0]) 315 | seed = ctx.params.get("seed", 51) 316 | negative_prompt = ctx.params.get("negative_prompt", "") 317 | 318 | stability_key = os.environ["STABILITY_API_KEY"] 319 | 320 | response = requests.post( 321 | SD3_MODEL_URL, 322 | headers={ 323 | "authorization": f"Bearer {stability_key}", 324 | "accept": "image/*" 325 | }, 326 | files={"none": ''}, 327 | data={ 328 | "prompt": prompt, 329 | "aspect_ratio": aspect_ratio, 330 | "negative_prompt": negative_prompt, 331 | "seed": seed, 332 | "output_format": "jpeg", 333 | }, 334 | ) 335 | if type(response) == list: 336 | response = response[0] 337 | return response 338 | 339 | class SSD1B(Text2Image): 340 | """Wrapper for a SSD-1B model.""" 341 | 342 | def __init__(self): 343 | super().__init__() 344 | self.name = "ssd-1b" 345 | self.model_name = SSD1B_MODEL_URL 346 | 347 | def generate_image(self, ctx): 348 | prompt = ctx.params.get("prompt", "None provided") 349 | inference_steps = ctx.params.get("inference_steps", 25.0) 350 | scheduler = ctx.params.get("scheduler_choices", "None provided") 351 | guidance_scale = ctx.params.get("guidance_scale", 7.5) 352 | negative_prompt = ctx.params.get("negative_prompt", None) 353 | 354 | _inputs = { 355 | "prompt": prompt, 356 | "inference_steps": inference_steps, 357 | "scheduler": scheduler, 358 | "guidance_scale": guidance_scale, 359 | } 360 | if negative_prompt is not None: 361 | _inputs["negative_prompt"] = negative_prompt 362 | 363 | response = replicate.run( 364 | self.model_name, 365 | input=_inputs, 366 | ) 367 | if type(response) == list: 368 | response = response[0] 369 | return response 370 | 371 | 372 | class Kandinsky(Text2Image): 373 | """Wrapper for a Kandinsky model.""" 374 | 375 | def __init__(self): 376 | super().__init__() 377 | self.name = "kandinsky" 378 | self.model_name = KANDINSKY_MODEL_URL 379 | 380 | def generate_image(self, ctx): 381 | prompt = ctx.params.get("prompt", "None provided") 382 | negative_prompt = ctx.params.get("negative_prompt", None) 383 | width = int(ctx.params.get("width", 512)) 384 | height = int(ctx.params.get("height", 512)) 385 | inference_steps = ctx.params.get("inference_steps", 75) 386 | inference_steps_prior = ctx.params.get("inference_steps_prior", 25) 387 | 388 | input = { 389 | "prompt": prompt, 390 | "width": width, 391 | "height": height, 392 | "inference_steps": inference_steps, 393 | "inference_steps_prior": inference_steps_prior, 394 | } 395 | 396 | if negative_prompt is not None: 397 | input["negative_prompt"] = negative_prompt 398 | 399 | response = replicate.run( 400 | self.model_name, 401 | input=input, 402 | ) 403 | if type(response) == list: 404 | response = response[0] 405 | return response 406 | 407 | 408 | class PlaygroundV2(Text2Image): 409 | """Wrapper for Aesthetic Playground V2 model.""" 410 | 411 | def __init__(self): 412 | super().__init__() 413 | self.name = "playground-v2" 414 | self.model_name = PLAYGROUND_V2_MODEL_URL 415 | 416 | def generate_image(self, ctx): 417 | prompt = ctx.params.get("prompt", "None provided") 418 | negative_prompt = ctx.params.get("negative_prompt", None) 419 | width = int(ctx.params.get("width", 1024)) 420 | height = int(ctx.params.get("height", 1024)) 421 | inference_steps = ctx.params.get("inference_steps", 50) 422 | guidance_scale = ctx.params.get("guide_scale", 3.0) 423 | scheduler = ctx.params.get("scheduler_choices", "None provided") 424 | 425 | input = { 426 | "prompt": prompt, 427 | "width": width, 428 | "height": height, 429 | "num_inference_steps": inference_steps, 430 | "guidance_scale": guidance_scale, 431 | "scheduler": scheduler, 432 | } 433 | 434 | if negative_prompt is not None: 435 | input["negative_prompt"] = negative_prompt 436 | 437 | response = replicate.run( 438 | self.model_name, 439 | input=input, 440 | ) 441 | if type(response) == list: 442 | response = response[0] 443 | return response 444 | 445 | 446 | class LatentConsistencyModel(Text2Image): 447 | """Wrapper for a Latent Consistency model.""" 448 | 449 | def __init__(self): 450 | super().__init__() 451 | self.name = "latent-consistency" 452 | self.model_name = LC_MODEL_URL 453 | 454 | def generate_image(self, ctx): 455 | prompt = ctx.params.get("prompt", "None provided") 456 | width = int(ctx.params.get("width", 512)) 457 | height = int(ctx.params.get("height", 512)) 458 | num_inf_steps = int(ctx.params.get("num_inference_steps", 4)) 459 | guide_scale = float(ctx.params.get("guidance_scale", 7.5)) 460 | lcm_origin_steps = int(ctx.params.get("lcm_origin_steps", 50)) 461 | distro = ctx.params.get("model_distribution", "None provided") 462 | 463 | if distro == "replicate": 464 | response = replicate.run( 465 | self.model_name, 466 | input={ 467 | "prompt": prompt, 468 | "width": width, 469 | "height": height, 470 | "num_inference_steps": num_inf_steps, 471 | "guidance_scale": guide_scale, 472 | "lcm_origin_steps": lcm_origin_steps, 473 | }, 474 | ) 475 | 476 | if type(response) == list: 477 | response = response[0] 478 | return response 479 | else: 480 | with add_sys_path(os.path.dirname(os.path.abspath(__file__))): 481 | # pylint: disable=no-name-in-module,import-error 482 | from local_t2i_models import lcm 483 | 484 | response = lcm( 485 | prompt, 486 | width, 487 | height, 488 | num_inf_steps, 489 | guide_scale, 490 | lcm_origin_steps, 491 | ) 492 | return response 493 | 494 | 495 | class DALLE2(Text2Image): 496 | """Wrapper for a DALL-E 2 model.""" 497 | 498 | def __init__(self): 499 | super().__init__() 500 | self.name = "dalle2" 501 | 502 | def generate_image(self, ctx): 503 | prompt = ctx.params.get("prompt", "None provided") 504 | size = ctx.params.get("size_choices", "None provided") 505 | 506 | response = openai.OpenAI().images.generate( 507 | model="dall-e-2", prompt=prompt, n=1, size=size 508 | ) 509 | return response.data[0].url 510 | 511 | 512 | class DALLE3(Text2Image): 513 | """Wrapper for a DALL-E 3 model.""" 514 | 515 | def __init__(self): 516 | super().__init__() 517 | self.name = "dalle3" 518 | 519 | def generate_image(self, ctx): 520 | prompt = ctx.params.get("prompt", "None provided") 521 | size = ctx.params.get("size_choices", "None provided") 522 | quality = ctx.params.get("quality_choices", "None provided") 523 | 524 | response = openai.OpenAI().images.generate( 525 | model="dall-e-3", prompt=prompt, n=1, quality=quality, size=size 526 | ) 527 | 528 | revised_prompt = response.data[0].revised_prompt 529 | ctx.params["revised_prompt"] = revised_prompt 530 | 531 | return response.data[0].url 532 | 533 | 534 | class VQGANCLIP(Text2Image): 535 | """Wrapper for a VQGAN-CLIP model.""" 536 | 537 | def __init__(self): 538 | super().__init__() 539 | self.name = "vqgan-clip" 540 | self.model_name = VQGAN_MODEL_URL 541 | 542 | def generate_image(self, ctx): 543 | prompt = ctx.params.get("prompt", "None provided") 544 | response = replicate.run(self.model_name, input={"prompt": prompt}) 545 | if type(response) == list: 546 | response = response[0] 547 | return response 548 | 549 | 550 | def get_model(model_name): 551 | mapping = { 552 | "sd": StableDiffusion, 553 | "sdxl": SDXL, 554 | "sdxl-lightning": SDXLLightning, 555 | "ssd-1b": SSD1B, 556 | "latent-consistency": LatentConsistencyModel, 557 | "kandinsky-2.2": Kandinsky, 558 | "playground-v2": PlaygroundV2, 559 | "dalle2": DALLE2, 560 | "dalle3": DALLE3, 561 | "vqgan-clip": VQGANCLIP, 562 | "stable-diffusion-3": StableDiffusion3, 563 | } 564 | return mapping[model_name]() 565 | 566 | 567 | def set_stable_diffusion_config(sample, ctx): 568 | sample["stable_diffusion_config"] = fo.DynamicEmbeddedDocument( 569 | inference_steps=ctx.params.get("inference_steps", "None provided"), 570 | scheduler=ctx.params.get("scheduler_choices", "None provided"), 571 | width=ctx.params.get("width_choices", "None provided"), 572 | height=ctx.params.get("height_choices", "None provided"), 573 | ) 574 | 575 | 576 | def set_sdxl_config(sample, ctx): 577 | sample["sdxl_config"] = fo.DynamicEmbeddedDocument( 578 | inference_steps=ctx.params.get("inference_steps", "None provided"), 579 | scheduler=ctx.params.get("scheduler_choices", "None provided"), 580 | guidance_scale=ctx.params.get("guidance_scale", 7.5), 581 | refiner=SDXL_REFINE_MAP[ 582 | ctx.params.get("refine_choices", "None provided") 583 | ], 584 | refine_steps=ctx.params.get("refine_steps", None), 585 | negative_prompt=ctx.params.get("negative_prompt", None), 586 | high_noise_frac=ctx.params.get("high_noise_frac", None), 587 | ) 588 | 589 | 590 | def set_sdxl_lightning_config(sample, ctx): 591 | sample["sdxl_lightning_config"] = fo.DynamicEmbeddedDocument( 592 | inference_steps=ctx.params.get( 593 | "inference_steps", SDXL_LIGHTNING_NUM_INFERENCE_STEPS_DEFAULT 594 | ), 595 | scheduler=ctx.params.get("scheduler_choices", "None provided"), 596 | guidance_scale=ctx.params.get( 597 | "guidance_scale", SDXL_LIGHTNING_GUIDANCE_SCALE_DEFAULT 598 | ), 599 | width=ctx.params.get("width", 1024), 600 | height=ctx.params.get("height", 1024), 601 | negative_prompt=ctx.params.get( 602 | "negative_prompt", SDXL_LIGHTNING_NEGATIVE_PROMPT_DEFAULT 603 | ), 604 | ) 605 | 606 | def set_sd3_config(sample, ctx): 607 | sample["sd3_config"] = fo.DynamicEmbeddedDocument( 608 | aspect_ratio = ctx.params.get("aspect_ratio", SD3_ASPECT_RATIO_CHOICES[0]), 609 | negative_prompt=ctx.params.get( 610 | "negative_prompt", "" 611 | ), 612 | seed=ctx.params.get( 613 | "seed", 51 614 | ), 615 | ) 616 | 617 | 618 | def set_ssd1b_config(sample, ctx): 619 | sample["ssd1b_config"] = fo.DynamicEmbeddedDocument( 620 | inference_steps=ctx.params.get("inference_steps", "None provided"), 621 | scheduler=ctx.params.get("scheduler_choices", "None provided"), 622 | guidance_scale=ctx.params.get("guidance_scale", 7.5), 623 | negative_prompt=ctx.params.get("negative_prompt", None), 624 | ) 625 | 626 | 627 | def set_kandinsky_config(sample, ctx): 628 | sample["kandinsky_config"] = fo.DynamicEmbeddedDocument( 629 | inference_steps=ctx.params.get("inference_steps", 75), 630 | inference_steps_prior=ctx.params.get("inference_steps_prior", 25), 631 | width=ctx.params.get("width", 512), 632 | height=ctx.params.get("height", 512), 633 | negative_prompt=ctx.params.get("negative_prompt", None), 634 | ) 635 | 636 | 637 | def set_playground_v2_config(sample, ctx): 638 | sample["playground_v2_config"] = fo.DynamicEmbeddedDocument( 639 | inference_steps=ctx.params.get("inference_steps", 50), 640 | width=ctx.params.get("width", 1024), 641 | height=ctx.params.get("height", 1024), 642 | guidance_scale=ctx.params.get("guidance_scale", 3.0), 643 | scheduler=ctx.params.get("scheduler_choices", "None provided"), 644 | negative_prompt=ctx.params.get("negative_prompt", None), 645 | ) 646 | 647 | 648 | def set_latent_consistency_config(sample, ctx): 649 | sample["latent_consistency_config"] = fo.DynamicEmbeddedDocument( 650 | inference_steps=ctx.params.get("num_inference_steps", 4), 651 | guidance_scale=ctx.params.get("guidance_scale", 7.5), 652 | lcm_origin_steps=ctx.params.get("lcm_origin_steps", 50), 653 | width=ctx.params.get("width", 512), 654 | height=ctx.params.get("height", 512), 655 | ) 656 | 657 | 658 | def set_vqgan_clip_config(sample, ctx): 659 | return 660 | 661 | 662 | def set_dalle2_config(sample, ctx): 663 | sample["dalle2_config"] = fo.DynamicEmbeddedDocument( 664 | size=ctx.params.get("size_choices", "None provided") 665 | ) 666 | 667 | 668 | def set_dalle3_config(sample, ctx): 669 | sample["dalle3_config"] = fo.DynamicEmbeddedDocument( 670 | size=ctx.params.get("size_choices", "None provided"), 671 | quality=ctx.params.get("quality_choices", "None provided"), 672 | revised_prompt=ctx.params.get("revised_prompt", "None provided"), 673 | ) 674 | 675 | 676 | def set_config(sample, ctx, model_name): 677 | mapping = { 678 | "sd": set_stable_diffusion_config, 679 | "sdxl": set_sdxl_config, 680 | "sdxl-lightning": set_sdxl_lightning_config, 681 | "ssd-1b": set_ssd1b_config, 682 | "latent-consistency": set_latent_consistency_config, 683 | "kandinsky-2.2": set_kandinsky_config, 684 | "playground-v2": set_playground_v2_config, 685 | "dalle2": set_dalle2_config, 686 | "dalle3": set_dalle3_config, 687 | "vqgan-clip": set_vqgan_clip_config, 688 | "stable-diffusion-3": set_sd3_config, 689 | } 690 | 691 | config_setter = mapping[model_name] 692 | config_setter(sample, ctx) 693 | 694 | 695 | def generate_filepath(ctx): 696 | download_dir = ctx.params.get("download_dir", {}) 697 | if type(download_dir) == dict: 698 | download_dir = download_dir.get("absolute_path", "/tmp") 699 | 700 | filename = str(uuid.uuid4())[:13].replace("-", "") + ".png" 701 | return os.path.join(download_dir, filename) 702 | 703 | 704 | #### MODEL CHOICES #### 705 | def _add_replicate_choices(model_choices): 706 | model_choices.add_choice("sd", label="Stable Diffusion") 707 | model_choices.add_choice("sdxl", label="SDXL") 708 | model_choices.add_choice("sdxl-lightning", label="SDXL Lightning") 709 | model_choices.add_choice("ssd-1b", label="SSD-1B") 710 | model_choices.add_choice("kandinsky-2.2", label="Kandinsky 2.2") 711 | model_choices.add_choice("playground-v2", label="Playground V2") 712 | if "latent-consistency" not in model_choices.values(): 713 | model_choices.add_choice( 714 | "latent-consistency", label="Latent Consistency" 715 | ) 716 | model_choices.add_choice("vqgan-clip", label="VQGAN-CLIP") 717 | 718 | 719 | def _add_openai_choices(model_choices): 720 | model_choices.add_choice("dalle2", label="DALL-E2") 721 | model_choices.add_choice("dalle3", label="DALL-E3") 722 | 723 | def _add_stability_choices(model_choices): 724 | model_choices.add_choice("stable-diffusion-3", label="Stable Diffusion 3") 725 | 726 | def _add_diffusers_choices(model_choices): 727 | if "latent-consistency" not in model_choices.values(): 728 | model_choices.add_choice( 729 | "latent-consistency", label="Latent Consistency" 730 | ) 731 | 732 | 733 | #### STABLE DIFFUSION INPUTS #### 734 | def _handle_stable_diffusion_input(ctx, inputs): 735 | size_choices = SD_SIZE_CHOICES 736 | width_choices = types.Dropdown(label="Width") 737 | for size in size_choices: 738 | width_choices.add_choice(size, label=size) 739 | 740 | inputs.enum( 741 | "width_choices", 742 | width_choices.values(), 743 | default="512", 744 | view=width_choices, 745 | ) 746 | 747 | height_choices = types.Dropdown(label="Height") 748 | for size in size_choices: 749 | height_choices.add_choice(size, label=size) 750 | 751 | inputs.enum( 752 | "height_choices", 753 | height_choices.values(), 754 | default="512", 755 | view=height_choices, 756 | ) 757 | 758 | inference_steps_slider = types.SliderView( 759 | label="Num Inference Steps", 760 | componentsProps={"slider": {"min": 1, "max": 500, "step": 1}}, 761 | ) 762 | inputs.int("inference_steps", default=50, view=inference_steps_slider) 763 | 764 | scheduler_choices_dropdown = types.Dropdown(label="Scheduler") 765 | for scheduler in SD_SCHEDULER_CHOICES: 766 | scheduler_choices_dropdown.add_choice(scheduler, label=scheduler) 767 | 768 | inputs.enum( 769 | "scheduler_choices", 770 | scheduler_choices_dropdown.values(), 771 | default="K_EULER", 772 | view=scheduler_choices_dropdown, 773 | ) 774 | 775 | 776 | #### SDXL INPUTS #### 777 | def _handle_sdxl_input(ctx, inputs): 778 | 779 | inputs.str("negative_prompt", label="Negative Prompt", required=False) 780 | 781 | scheduler_choices_dropdown = types.Dropdown(label="Scheduler") 782 | for scheduler in SDXL_SCHEDULER_CHOICES: 783 | scheduler_choices_dropdown.add_choice(scheduler, label=scheduler) 784 | 785 | inputs.enum( 786 | "scheduler_choices", 787 | scheduler_choices_dropdown.values(), 788 | default="K_EULER", 789 | view=scheduler_choices_dropdown, 790 | ) 791 | 792 | inference_steps_slider = types.SliderView( 793 | label="Num Inference Steps", 794 | componentsProps={"slider": {"min": 1, "max": 100, "step": 1}}, 795 | ) 796 | inputs.int("inference_steps", default=50, view=inference_steps_slider) 797 | 798 | guidance_scale_slider = types.SliderView( 799 | label="Guidance Scale", 800 | componentsProps={"slider": {"min": 0.0, "max": 10.0, "step": 0.1}}, 801 | ) 802 | inputs.float("guidance_scale", default=7.5, view=guidance_scale_slider) 803 | 804 | refiner_choices_dropdown = types.Dropdown( 805 | label="Refiner", 806 | description="Which refine style to use", 807 | ) 808 | for refiner in SDXL_REFINE_CHOICES: 809 | refiner_choices_dropdown.add_choice(refiner, label=refiner) 810 | 811 | inputs.enum( 812 | "refine_choices", 813 | refiner_choices_dropdown.values(), 814 | default="None", 815 | view=refiner_choices_dropdown, 816 | ) 817 | 818 | rfc = SDXL_REFINE_MAP[ctx.params.get("refine_choices", "None")] 819 | if rfc == "base_image_refiner": 820 | _default = ctx.params.get("inference_steps", 50) 821 | refine_steps_slider = types.SliderView( 822 | label="Num Refine Steps", 823 | componentsProps={"slider": {"min": 1, "max": _default, "step": 1}}, 824 | ) 825 | inputs.int( 826 | "refine_steps", 827 | label="Refine Steps", 828 | description="The number of steps to refine", 829 | default=_default, 830 | view=refine_steps_slider, 831 | ) 832 | elif rfc == "expert_ensemble_refiner": 833 | inputs.float( 834 | "high_noise_frac", 835 | label="High Noise Fraction", 836 | description="The fraction of noise to use", 837 | default=0.8, 838 | ) 839 | 840 | 841 | #### SDXL LIGHTNING INPUTS #### 842 | def _handle_sdxl_lightning_input(ctx, inputs): 843 | size_choices = SDXL_LIGHTNING_SIZE_CHOICES 844 | width_choices = types.Dropdown(label="Width") 845 | for size in size_choices: 846 | width_choices.add_choice(size, label=size) 847 | 848 | inputs.enum( 849 | "width", 850 | width_choices.values(), 851 | default=SDXL_LIGHTNING_SIZE_CHOICES[0], 852 | view=width_choices, 853 | ) 854 | 855 | height_choices = types.Dropdown(label="Height") 856 | for size in size_choices: 857 | height_choices.add_choice(size, label=size) 858 | 859 | inputs.enum( 860 | "height", 861 | height_choices.values(), 862 | default=SDXL_LIGHTNING_SIZE_CHOICES[0], 863 | view=height_choices, 864 | ) 865 | 866 | inputs.str( 867 | "negative_prompt", 868 | label="Negative Prompt", 869 | required=False, 870 | default=SDXL_LIGHTNING_NEGATIVE_PROMPT_DEFAULT, 871 | ) 872 | 873 | inference_steps_slider = types.SliderView( 874 | label="Num Inference Steps", 875 | componentsProps={"slider": {"min": 1, "max": 10, "step": 1}}, 876 | ) 877 | inputs.int( 878 | "inference_steps", 879 | default=SDXL_LIGHTNING_NUM_INFERENCE_STEPS_DEFAULT, 880 | view=inference_steps_slider, 881 | ) 882 | 883 | guidance_scale_slider = types.SliderView( 884 | label="Guidance Scale", 885 | componentsProps={"slider": {"min": 0.0, "max": 10.0, "step": 0.1}}, 886 | ) 887 | inputs.float( 888 | "guidance_scale", 889 | default=SDXL_LIGHTNING_GUIDANCE_SCALE_DEFAULT, 890 | view=guidance_scale_slider, 891 | ) 892 | 893 | scheduler_choices_dropdown = types.Dropdown(label="Scheduler") 894 | for scheduler in SDXL_LIGHTNING_SCHEDULER_CHOICES: 895 | scheduler_choices_dropdown.add_choice(scheduler, label=scheduler) 896 | 897 | inputs.enum( 898 | "scheduler_choices", 899 | scheduler_choices_dropdown.values(), 900 | default="K_EULER", 901 | view=scheduler_choices_dropdown, 902 | ) 903 | 904 | #### SD3 INPUTS #### 905 | def _handle_sd3_input(ctx, inputs): 906 | aspect_choices = SD3_ASPECT_RATIO_CHOICES 907 | aspect_choices_drop = types.Dropdown(label="Aspect Ratio") 908 | for aspect in aspect_choices: 909 | aspect_choices_drop.add_choice(aspect, label=aspect) 910 | 911 | inputs.enum( 912 | "aspect_ratio", 913 | aspect_choices_drop.values(), 914 | default=SD3_ASPECT_RATIO_CHOICES[0], 915 | view=aspect_choices_drop, 916 | ) 917 | 918 | 919 | inputs.str( 920 | "negative_prompt", 921 | label="Negative Prompt", 922 | required=False, 923 | default="", 924 | ) 925 | 926 | 927 | inputs.int("seed", 928 | label="seed", 929 | description="Enter the seed for the run", 930 | default=51, 931 | view=types.FieldView(componentsProps={'field': {'min': 1, 'max': 100}}) 932 | ) 933 | 934 | #### SSD-1B INPUTS #### 935 | def _handle_ssd1b_input(ctx, inputs): 936 | inputs.int("width", label="Width", default=768) 937 | inputs.int("height", label="Height", default=768) 938 | inputs.str("negative_prompt", label="Negative Prompt", required=False) 939 | 940 | scheduler_choices_dropdown = types.Dropdown(label="Scheduler") 941 | for scheduler in SSD1B_SCHEDULER_CHOICES: 942 | scheduler_choices_dropdown.add_choice(scheduler, label=scheduler) 943 | 944 | inputs.enum( 945 | "scheduler_choices", 946 | scheduler_choices_dropdown.values(), 947 | default="K_EULER", 948 | view=scheduler_choices_dropdown, 949 | ) 950 | 951 | inference_steps_slider = types.SliderView( 952 | label="Num Inference Steps", 953 | componentsProps={"slider": {"min": 1, "max": 100, "step": 1}}, 954 | ) 955 | inputs.int("inference_steps", default=25, view=inference_steps_slider) 956 | 957 | guidance_scale_slider = types.SliderView( 958 | label="Guidance Scale", 959 | componentsProps={"slider": {"min": 0.0, "max": 10.0, "step": 0.1}}, 960 | ) 961 | inputs.float("guidance_scale", default=7.5, view=guidance_scale_slider) 962 | 963 | 964 | #### KANDINSKY INPUTS #### 965 | def _handle_kandinsky_input(ctx, inputs): 966 | 967 | size_choices = KANDINSDKY_SIZE_CHOICES 968 | width_choices = types.Dropdown(label="Width") 969 | for size in size_choices: 970 | width_choices.add_choice(size, label=size) 971 | 972 | inputs.enum( 973 | "width", 974 | width_choices.values(), 975 | default="512", 976 | view=width_choices, 977 | ) 978 | 979 | height_choices = types.Dropdown(label="Height") 980 | for size in size_choices: 981 | height_choices.add_choice(size, label=size) 982 | 983 | inputs.enum( 984 | "height", 985 | height_choices.values(), 986 | default="512", 987 | view=height_choices, 988 | ) 989 | inputs.str("negative_prompt", label="Negative Prompt", required=False) 990 | 991 | inference_steps_slider = types.SliderView( 992 | label="Num Inference Steps", 993 | componentsProps={"slider": {"min": 1, "max": 100, "step": 1}}, 994 | ) 995 | inputs.int("inference_steps", default=75, view=inference_steps_slider) 996 | 997 | inference_steps_prior_slider = types.SliderView( 998 | label="Num Inference Steps Prior", 999 | componentsProps={"slider": {"min": 1, "max": 50, "step": 1}}, 1000 | ) 1001 | inputs.int( 1002 | "inference_steps_prior", 1003 | default=25, 1004 | view=inference_steps_prior_slider, 1005 | ) 1006 | 1007 | 1008 | #### PLAYGROUND V2 INPUTS #### 1009 | def _handle_playground_v2_input(ctx, inputs): 1010 | inputs.int("width", label="Width", default=1024) 1011 | inputs.int("height", label="Height", default=1024) 1012 | 1013 | inputs.str("negative_prompt", label="Negative Prompt", required=False) 1014 | 1015 | inference_steps_slider = types.SliderView( 1016 | label="Num Inference Steps", 1017 | componentsProps={"slider": {"min": 1, "max": 100, "step": 1}}, 1018 | ) 1019 | inputs.int("inference_steps", default=50, view=inference_steps_slider) 1020 | 1021 | guidance_scale_slider = types.SliderView( 1022 | label="Guidance Scale", 1023 | componentsProps={"slider": {"min": 0.0, "max": 10.0, "step": 0.1}}, 1024 | ) 1025 | inputs.float("guidance_scale", default=3.0, view=guidance_scale_slider) 1026 | 1027 | scheduler_choices_dropdown = types.Dropdown(label="Scheduler") 1028 | for scheduler in PLAYGROUND_V2_SCHEDULER_CHOICES: 1029 | scheduler_choices_dropdown.add_choice(scheduler, label=scheduler) 1030 | 1031 | inputs.enum( 1032 | "scheduler_choices", 1033 | scheduler_choices_dropdown.values(), 1034 | default="K_EULER_ANCESTRAL", 1035 | view=scheduler_choices_dropdown, 1036 | ) 1037 | 1038 | 1039 | #### LATENT CONSISTENCY INPUTS #### 1040 | def _handle_latent_consistency_input(ctx, inputs): 1041 | 1042 | replicate_flag = allows_replicate_models() 1043 | diffusers_flag = allows_diffusers_models() 1044 | 1045 | if not replicate_flag: 1046 | ctx.params["model_distribution"] = "diffusers" 1047 | elif not diffusers_flag: 1048 | ctx.params["model_distribution"] = "replicate" 1049 | else: 1050 | model_distribution_choices = types.Dropdown(label="Model Distribution") 1051 | model_distribution_choices.add_choice("diffusers", label="Diffusers") 1052 | model_distribution_choices.add_choice("replicate", label="Replicate") 1053 | inputs.enum( 1054 | "model_distribution", 1055 | model_distribution_choices.values(), 1056 | default="diffusers", 1057 | view=model_distribution_choices, 1058 | ) 1059 | 1060 | inputs.int("width", label="Width", default=512) 1061 | inputs.int("height", label="Height", default=512) 1062 | 1063 | inference_steps_slider = types.SliderView( 1064 | label="Num Inference Steps", 1065 | componentsProps={"slider": {"min": 1, "max": 50, "step": 1}}, 1066 | ) 1067 | inputs.int("num_inference_steps", default=4, view=inference_steps_slider) 1068 | 1069 | lcm_origin_steps_slider = types.SliderView( 1070 | label="LCM Origin Steps", 1071 | componentsProps={"slider": {"min": 1, "max": 100, "step": 1}}, 1072 | ) 1073 | inputs.int("lcm_origin_steps", default=50, view=lcm_origin_steps_slider) 1074 | 1075 | guidance_scale_slider = types.SliderView( 1076 | label="Guidance Scale", 1077 | componentsProps={"slider": {"min": 0.0, "max": 10.0, "step": 0.1}}, 1078 | ) 1079 | inputs.float("guidance_scale", default=7.5, view=guidance_scale_slider) 1080 | 1081 | 1082 | #### DALLE2 INPUTS #### 1083 | def _handle_dalle2_input(ctx, inputs): 1084 | size_choices_dropdown = types.Dropdown(label="Size") 1085 | for size in DALLE2_SIZE_CHOICES: 1086 | size_choices_dropdown.add_choice(size, label=size) 1087 | 1088 | inputs.enum( 1089 | "size_choices", 1090 | size_choices_dropdown.values(), 1091 | default="512x512", 1092 | view=size_choices_dropdown, 1093 | ) 1094 | 1095 | 1096 | #### DALLE3 INPUTS #### 1097 | def _handle_dalle3_input(ctx, inputs): 1098 | size_choices_dropdown = types.Dropdown(label="Size") 1099 | for size in DALLE3_SIZE_CHOICES: 1100 | size_choices_dropdown.add_choice(size, label=size) 1101 | 1102 | inputs.enum( 1103 | "size_choices", 1104 | size_choices_dropdown.values(), 1105 | default="1024x1024", 1106 | view=size_choices_dropdown, 1107 | ) 1108 | 1109 | quality_choices_dropdown = types.Dropdown(label="Quality") 1110 | for quality in DALLE3_QUALITY_CHOICES: 1111 | quality_choices_dropdown.add_choice(quality, label=quality) 1112 | 1113 | inputs.enum( 1114 | "quality_choices", 1115 | quality_choices_dropdown.values(), 1116 | default="standard", 1117 | view=quality_choices_dropdown, 1118 | ) 1119 | 1120 | 1121 | #### VQGAN-CLIP INPUTS #### 1122 | def _handle_vqgan_clip_input(ctx, inputs): 1123 | return 1124 | 1125 | 1126 | INPUT_MAPPER = { 1127 | "sd": _handle_stable_diffusion_input, 1128 | "sdxl": _handle_sdxl_input, 1129 | "sdxl-lightning": _handle_sdxl_lightning_input, 1130 | "ssd-1b": _handle_ssd1b_input, 1131 | "kandinsky-2.2": _handle_kandinsky_input, 1132 | "playground-v2": _handle_playground_v2_input, 1133 | "latent-consistency": _handle_latent_consistency_input, 1134 | "dalle2": _handle_dalle2_input, 1135 | "dalle3": _handle_dalle3_input, 1136 | "vqgan-clip": _handle_vqgan_clip_input, 1137 | "stable-diffusion-3": _handle_sd3_input 1138 | } 1139 | 1140 | 1141 | def _handle_input(ctx, inputs): 1142 | model_name = ctx.params.get("model_choices", "sd") 1143 | model_input_handler = INPUT_MAPPER[model_name] 1144 | model_input_handler(ctx, inputs) 1145 | 1146 | 1147 | def _resolve_download_dir(ctx, inputs): 1148 | if len(ctx.dataset) == 0: 1149 | file_explorer = types.FileExplorerView( 1150 | choose_dir=True, 1151 | button_label="Choose a directory...", 1152 | ) 1153 | inputs.file( 1154 | "download_dir", 1155 | required=True, 1156 | description="Choose a location to store downloaded images", 1157 | view=file_explorer, 1158 | ) 1159 | else: 1160 | base_dir = os.path.dirname(ctx.dataset.first().filepath).split("/") 1161 | ctx.params["download_dir"] = "/".join(base_dir) 1162 | 1163 | 1164 | def _handle_calling(uri, sample_collection, prompt, model_name, **kwargs): 1165 | ctx = dict(view=sample_collection.view()) 1166 | params = dict(kwargs) 1167 | params["prompt"] = prompt 1168 | params["model_choices"] = model_name 1169 | 1170 | return foo.execute_operator(uri, ctx, params=params) 1171 | 1172 | 1173 | class Txt2Image(foo.Operator): 1174 | @property 1175 | def config(self): 1176 | _config = foo.OperatorConfig( 1177 | name="txt2img", 1178 | label="Text to Image: Generate Image from Text", 1179 | dynamic=True, 1180 | ) 1181 | _config.icon = "/assets/icon.svg" 1182 | return _config 1183 | 1184 | def resolve_input(self, ctx): 1185 | inputs = types.Object() 1186 | _resolve_download_dir(ctx, inputs) 1187 | 1188 | replicate_flag = allows_replicate_models() 1189 | openai_flag = allows_openai_models() 1190 | stabilityai_flag = allows_stabilityai_models() 1191 | diffusers_flag = allows_diffusers_models() 1192 | 1193 | 1194 | 1195 | any_flag = replicate_flag or openai_flag or diffusers_flag or stabilityai_flag 1196 | if not any_flag: 1197 | inputs.message( 1198 | "message", 1199 | label="No models available.", 1200 | descriptions=( 1201 | "You must install one of `replicate`, `openai`, or `diffusers` or define a STABILITY_API_KEY", 1202 | " to use this plugin. ", 1203 | ), 1204 | ) 1205 | return types.Property(inputs) 1206 | 1207 | model_choices = types.Dropdown() 1208 | if replicate_flag: 1209 | _add_replicate_choices(model_choices) 1210 | if openai_flag: 1211 | _add_openai_choices(model_choices) 1212 | if diffusers_flag: 1213 | _add_diffusers_choices(model_choices) 1214 | if stabilityai_flag: 1215 | _add_stability_choices(model_choices) 1216 | inputs.enum( 1217 | "model_choices", 1218 | model_choices.values(), 1219 | default=model_choices.choices[0].value, 1220 | label="Model", 1221 | description="Choose a model to generate images", 1222 | view=model_choices, 1223 | ) 1224 | 1225 | inputs.str( 1226 | "prompt", 1227 | label="Prompt", 1228 | description="The prompt to generate an image from", 1229 | required=True, 1230 | ) 1231 | _handle_input(ctx, inputs) 1232 | return types.Property(inputs) 1233 | 1234 | def execute(self, ctx): 1235 | model_name = ctx.params.get("model_choices", "None provided") 1236 | model = get_model(model_name) 1237 | prompt = ctx.params.get("prompt", "None provided") 1238 | 1239 | response = model.generate_image(ctx) 1240 | filepath = generate_filepath(ctx) 1241 | 1242 | if type(response) == str: 1243 | ## served models return a url 1244 | image_url = response 1245 | download_image(image_url, filepath) 1246 | elif type(response) == requests.models.Response: 1247 | ## served model return whole object 1248 | write_image(response,filepath) 1249 | else: 1250 | ## local models return a PIL image 1251 | response.save(filepath) 1252 | 1253 | sample = fo.Sample( 1254 | filepath=filepath, 1255 | tags=["generated"], 1256 | model=model.name, 1257 | prompt=prompt, 1258 | date_created=datetime.now(), 1259 | ) 1260 | set_config(sample, ctx, model_name) 1261 | 1262 | dataset = ctx.dataset 1263 | dataset.add_sample(sample, dynamic=True) 1264 | 1265 | if dataset.get_dynamic_field_schema() is not None: 1266 | dataset.add_dynamic_sample_fields() 1267 | ctx.ops.reload_dataset() 1268 | else: 1269 | ctx.ops.reload_samples() 1270 | 1271 | def list_models(self): 1272 | return list(INPUT_MAPPER.keys()) 1273 | 1274 | def __call__(self, sample_collection, prompt, model_name, **kwargs): 1275 | _handle_calling( 1276 | self.uri, sample_collection, prompt, model_name, **kwargs 1277 | ) 1278 | 1279 | 1280 | def register(plugin): 1281 | plugin.register(Txt2Image) 1282 | --------------------------------------------------------------------------------