├── .gitignore
├── requirements.txt
├── fiftyone.yml
├── local_t2i_models.py
├── .pre-commit-config.yaml
├── assets
└── icon.svg
├── README.md
└── __init__.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | *.cython
3 | *.pyc
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | openai>=1.1.0
2 | replicate
3 | diffusers>=0.24.0
--------------------------------------------------------------------------------
/fiftyone.yml:
--------------------------------------------------------------------------------
1 | fiftyone:
2 | version: ">=0.23.7"
3 | name: "@jacobmarks/text_to_image"
4 | version: "1.2.6"
5 | description: "Run Text-to-image models to add synthetic images directly to your dataset!"
6 | url: "https://github.com/jacobmarks/text-to-image"
7 | operators:
8 | - txt2img
9 |
--------------------------------------------------------------------------------
/local_t2i_models.py:
--------------------------------------------------------------------------------
1 | from diffusers import DiffusionPipeline
2 |
3 |
4 | def get_cache():
5 | g = globals()
6 | if "_local_t2i_models" not in g:
7 | g["_local_t2i_models"] = {}
8 |
9 | return g["_local_t2i_models"]
10 |
11 |
12 | def lcm(
13 | prompt,
14 | width,
15 | height,
16 | num_inference_steps,
17 | guide_scale,
18 | lcm_origin_steps,
19 | ):
20 | if "lcm" not in get_cache():
21 | get_cache()["lcm"] = DiffusionPipeline.from_pretrained(
22 | "SimianLuo/LCM_Dreamshaper_v7"
23 | )
24 |
25 | pipe = get_cache()["lcm"]
26 | images = pipe(
27 | prompt=prompt,
28 | num_inference_steps=num_inference_steps,
29 | guidance_scale=guide_scale,
30 | lcm_origin_steps=lcm_origin_steps,
31 | output_type="pil",
32 | width=width,
33 | height=height,
34 | ).images
35 | return images[0]
36 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/asottile/blacken-docs
3 | rev: v1.12.0
4 | hooks:
5 | - id: blacken-docs
6 | additional_dependencies: [black==21.12b0]
7 | args: ["-l 79"]
8 | exclude: index.umd.js
9 | - repo: https://github.com/ambv/black
10 | rev: 22.3.0
11 | hooks:
12 | - id: black
13 | language_version: python3
14 | args: ["-l 79"]
15 | exclude: index.umd.js
16 | - repo: local
17 | hooks:
18 | - id: pylint
19 | name: pylint
20 | language: system
21 | files: \.py$
22 | entry: pylint
23 | args: ["--errors-only"]
24 | exclude: index.umd.js
25 | - repo: local
26 | hooks:
27 | - id: ipynb-strip
28 | name: ipynb-strip
29 | language: system
30 | files: \.ipynb$
31 | entry: jupyter nbconvert --clear-output --ClearOutputPreprocessor.enabled=True
32 | args: ["--log-level=ERROR"]
33 | - repo: https://github.com/pre-commit/mirrors-prettier
34 | rev: v2.6.2
35 | hooks:
36 | - id: prettier
37 | exclude: index.umd.js
38 | language_version: system
39 |
--------------------------------------------------------------------------------
/assets/icon.svg:
--------------------------------------------------------------------------------
1 |
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Text-to-Image Plugin
2 |
3 | 
4 |
5 | ### Updates
6 |
7 | - **2024-04-23**: Added support for Stable Diffusion 3 (Thanks [Dan Gural](https://github.com/danielgural/))
8 | - **2023-12-19**: Added support for Kandinsky-2.2 and Playground V2 models
9 | - **2023-11-30**: Version 1.2.0
10 | - adds local model running via `diffusers` (>=0.24.0)
11 | - adds [calling from the Python SDK](#python-sdk)!
12 | - :warning: **BREAKING CHANGE**: the plugin and operator URIs have been changed from `ai_art_gallery` to `text_to_image`. If you have any saved pipelines that use the plugin, you will need to update the URIs.
13 | - **2023-11-08**: Version 1.1.0 adds support for DALLE-3 Model — upgrade to `openai>=1.1.0` to use 😄
14 | - **2023-10-30**: Added support for Segmind Stable Diffusion (SSD-1B) Model
15 | - **2023-10-23**: Added support for Latent Consistency Model
16 | - **2023-10-18**: Added support for SDXL, operator icon, and download location selection
17 |
18 | ### Plugin Overview
19 |
20 | This plugin is a Python plugin that allows you to generate images from text
21 | prompts and add them directly into your dataset.
22 |
23 | :warning: This plugin is only verified to work for local datasets. It may not
24 | work for remote datasets.
25 |
26 | ### Supported Models
27 |
28 | This version of the plugin supports the following models:
29 |
30 | - [DALL-E2](https://openai.com/dall-e-2)
31 | - [DALL-E3](https://openai.com/dall-e-3)
32 | - [Kandinsky-2.2](https://replicate.com/ai-forever/kandinsky-2.2)
33 | - [Latent Consistency Model](https://replicate.com/luosiallen/latent-consistency-model/)
34 | - [Playground V2](https://replicate.com/playgroundai/playground-v2-1024px-aesthetic)
35 | - [SDXL](https://replicate.com/stability-ai/sdxl)
36 | - [SDXL-Lighting](https://replicate.com/lucataco/sdxl-lightning-4step)
37 | - [Segmind Stable Diffusion (SSD-1B)](https://replicate.com/lucataco/ssd-1b/)
38 | - [Stable Diffusion](https://replicate.com/stability-ai/stable-diffusion)
39 | - [Stable Diffusion 3](https://stability.ai/news/stable-diffusion-3)
40 | - [VQGAN-CLIP](https://replicate.com/mehdidc/feed_forward_vqgan_clip)
41 |
42 | It is straightforward to add support for other models!
43 |
44 | ## Watch On Youtube
45 |
46 | [](https://www.youtube.com/watch?v=qJNEyC_FqG0&list=PLuREAXoPgT0RZrUaT0UpX_HzwKkoB-S9j&index=2)
47 |
48 | ## Installation
49 |
50 | ```shell
51 | fiftyone plugins download https://github.com/jacobmarks/text-to-image
52 | ```
53 |
54 | If you want to use Replicate models, you will
55 | need to `pip install replicate` and set the environment variable
56 | `REPLICATE_API_TOKEN` with your API token.
57 |
58 | If you want to use DALL-E2 or DALL-E3, you will need to `pip install openai` and set the
59 | environment variable `OPENAI_API_KEY` with your API key.
60 |
61 | To run the Latency Consistency model locally with Hugging Face's diffusers library,
62 | you will need `diffusers>=0.24.0`. If you need to, you can install it with
63 | `pip install diffusers>=0.24.0`.
64 |
65 | To run Stable Diffusion 3, you will need to set up a [Stability.ai](https://platform.stability.ai/) account to get access to key. Then set the environment variable
66 | `STABILITY_API_KEY` with your API token.
67 |
68 | Refer to the [main README](https://github.com/voxel51/fiftyone-plugins) for
69 | more information about managing downloaded plugins and developing plugins
70 | locally.
71 |
72 | ## Operators
73 |
74 | ### `txt2img`
75 |
76 | - Generates an image from a text prompt and adds it to the dataset
77 |
78 | ### Python SDK
79 |
80 | You can also use the `txt2img` operators from the Python SDK!
81 |
82 | ⚠️ If you're using `fiftyone<=0.23.6`, due to the way Jupyter Notebooks interact with asyncio, this will not work in a Jupyter Notebook. You will need to run this code in a Python script or in a Python console.
83 |
84 | ```python
85 | import fiftyone as fo
86 | import fiftyone.operators as foo
87 | import fiftyone.zoo as foz
88 |
89 | dataset = fo.load_dataset("quickstart")
90 |
91 | ## Access the operator via its URI (plugin name + operator name)
92 | t2i = foo.get_operator("@jacobmarks/text_to_image/txt2img")
93 |
94 | ## Run the operator
95 |
96 | prompt = "A dog sitting in a field"
97 | t2i(dataset, prompt=prompt, model_name="latent-consistency", delegate=False)
98 |
99 | ## Pass in model-specific arguments
100 | t2i(
101 | dataset,
102 | prompt=prompt,
103 | model_name="latent-consistency",
104 | delegate=False,
105 | width=768,
106 | height=768,
107 | num_inference_steps=8,
108 | )
109 | ```
110 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | """Text to Image plugin.
2 |
3 | | Copyright 2017-2023, Voxel51, Inc.
4 | | `voxel51.com `_
5 | |
6 | """
7 |
8 | from datetime import datetime
9 | import os
10 | import uuid
11 | from importlib.util import find_spec
12 | from importlib.metadata import version as mversion
13 | from packaging import version as pversion
14 |
15 | import fiftyone.operators as foo
16 | from fiftyone.operators import types
17 | import fiftyone as fo
18 | import fiftyone.core.utils as fou
19 | from fiftyone.core.utils import add_sys_path
20 |
21 | import requests
22 |
23 | openai = fou.lazy_import("openai")
24 | replicate = fou.lazy_import("replicate")
25 |
26 | SD_MODEL_URL = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478"
27 | SD_SCHEDULER_CHOICES = (
28 | "DDIM",
29 | "K_EULER",
30 | "K_EULER_ANCESTRAL",
31 | "PNDM",
32 | "K-LMS",
33 | )
34 | SD_SIZE_CHOICES = (
35 | "128",
36 | "256",
37 | "384",
38 | "512",
39 | "576",
40 | "640",
41 | "704",
42 | "768",
43 | "832",
44 | "896",
45 | "960",
46 | "1024",
47 | )
48 |
49 | SDXL_MODEL_URL = "stability-ai/sdxl:c221b2b8ef527988fb59bf24a8b97c4561f1c671f73bd389f866bfb27c061316"
50 | SDXL_SCHEDULER_CHOICES = (
51 | "DDIM",
52 | "K_EULER",
53 | "K_EULER_ANCESTRAL",
54 | "PNDM",
55 | "DPMSolverMultistep",
56 | "KarrasDPM",
57 | "HeunDiscrete",
58 | )
59 |
60 | SDXL_REFINE_CHOICES = ("None", "Expert Ensemble", "Base")
61 |
62 | SDXL_REFINE_MAP = {
63 | "None": "no_refiner",
64 | "Expert Ensemble": "expert_ensemble_refiner",
65 | "Base": "base_image_refiner",
66 | }
67 |
68 | SDXL_LIGHTNING_MODEL_URL = "lucataco/sdxl-lightning-4step:727e49a643e999d602a896c774a0658ffefea21465756a6ce24b7ea4165eba6a"
69 | SDXL_LIGHTNING_SIZE_CHOICES = ("1024", "1280")
70 | SDXL_LIGHTNING_NEGATIVE_PROMPT_DEFAULT = "worst quality, low quality"
71 | SDXL_LIGHTNING_SCHEDULER_CHOICES = (
72 | "DDIM",
73 | "DPMSolverMultistep",
74 | "HeunDiscrete",
75 | "KarrasDPM",
76 | "K_EULER_ANCESTRAL",
77 | "K_EULER",
78 | "PNDM",
79 | "DPM+2MSDE",
80 | )
81 | SDXL_LIGHTNING_NUM_INFERENCE_STEPS_DEFAULT = 4
82 | SDXL_LIGHTNING_GUIDANCE_SCALE_DEFAULT = 7.5
83 |
84 |
85 | SSD1B_MODEL_URL = "lucataco/ssd-1b:1ee85ef681d5ad3d6870b9da1a4543cb3ad702d036fa5b5210f133b83b05a780"
86 | SSD1B_SCHEDULER_CHOICES = (
87 | "DDIM",
88 | "DDPMSolverMultistep",
89 | "HeunDiscrete",
90 | "KarrasDPM",
91 | "K_EULER_ANCESTRAL",
92 | "K_EULER",
93 | "PNDM",
94 | )
95 |
96 | PLAYGROUND_V2_MODEL_URL = "playgroundai/playground-v2-1024px-aesthetic:42fe626e41cc811eaf02c94b892774839268ce1994ea778eba97103fe1ef51b8"
97 | PLAYGROUND_V2_SCHEDULER_CHOICES = (
98 | "DDIM",
99 | "DPMSolverMultistep",
100 | "HeunDiscrete",
101 | "K_EULER_ANCESTRAL",
102 | "K_EULER",
103 | "PNDM",
104 | )
105 |
106 | LC_MODEL_URL = "luosiallen/latent-consistency-model:553803fd018b3cf875a8bc774c99da9b33f36647badfd88a6eec90d61c5f62fc"
107 |
108 | VQGAN_MODEL_URL = "mehdidc/feed_forward_vqgan_clip:28b5242dadb5503688e17738aaee48f5f7f5c0b6e56493d7cf55f74d02f144d8"
109 |
110 | KANDINSKY_MODEL_URL = "ai-forever/kandinsky-2.2:ea1addaab376f4dc227f5368bbd8eff901820fd1cc14ed8cad63b29249e9d463"
111 | KANDINSDKY_SIZE_CHOICES = (
112 | "384",
113 | "512",
114 | "576",
115 | "640",
116 | "704",
117 | "768",
118 | "960",
119 | "1024",
120 | "1152",
121 | "1280",
122 | "1536",
123 | "1792",
124 | "2048",
125 | )
126 |
127 |
128 | DALLE2_SIZE_CHOICES = ("256x256", "512x512", "1024x1024")
129 |
130 | DALLE3_SIZE_CHOICES = ("1024x1024", "1024x1792", "1792x1024")
131 |
132 | DALLE3_QUALITY_CHOICES = ("standard", "hd")
133 |
134 |
135 |
136 | SD3_ASPECT_RATIO_CHOICES = ("1:1", "16:9", "21:9", "2:3", "3:2", "4:5", "5:4", "9:16", "9:21" )
137 |
138 | SD3_MODEL_URL = "https://api.stability.ai/v2beta/stable-image/generate/sd3"
139 |
140 |
141 | def allows_replicate_models():
142 | """Returns whether the current environment allows replicate models."""
143 | return (
144 | find_spec("replicate") is not None
145 | and "REPLICATE_API_TOKEN" in os.environ
146 | )
147 |
148 |
149 | def allows_openai_models():
150 | """Returns whether the current environment allows openai models."""
151 | return find_spec("openai") is not None and "OPENAI_API_KEY" in os.environ
152 |
153 | def allows_stabilityai_models():
154 | """Returns whether the current environment allows stabilityai models."""
155 | return "STABILITY_API_KEY" in os.environ
156 |
157 |
158 | def allows_diffusers_models():
159 | """Returns whether the current environment allows diffusers models."""
160 | if find_spec("diffusers") is None:
161 | return False
162 | version = mversion("diffusers")
163 | return pversion.parse(version) >= pversion.parse("0.24.0")
164 |
165 |
166 | def download_image(image_url, filename):
167 | img_data = requests.get(image_url).content
168 | with open(filename, "wb") as handler:
169 | handler.write(img_data)
170 |
171 | def write_image(response, filename):
172 | if response.status_code == 200:
173 | with open(filename, 'wb') as file:
174 | file.write(response.content)
175 | else:
176 | raise Exception(str(response.json()))
177 |
178 | class Text2Image:
179 | """Wrapper for a Text2Image model."""
180 |
181 | def __init__(self):
182 | self.name = None
183 | self.model_name = None
184 |
185 | def generate_image(self, ctx):
186 | pass
187 |
188 |
189 | class StableDiffusion(Text2Image):
190 | """Wrapper for a StableDiffusion model."""
191 |
192 | def __init__(self):
193 | super().__init__()
194 | self.name = "stable-diffusion"
195 | self.model_name = SD_MODEL_URL
196 |
197 | def generate_image(self, ctx):
198 | prompt = ctx.params.get("prompt", "None provided")
199 | width = int(ctx.params.get("width_choices", "None provided"))
200 | height = int(ctx.params.get("height_choices", "None provided"))
201 | inference_steps = ctx.params.get("inference_steps", "None provided")
202 | scheduler = ctx.params.get("scheduler_choices", "None provided")
203 |
204 | response = replicate.run(
205 | self.model_name,
206 | input={
207 | "prompt": prompt,
208 | "width": width,
209 | "height": height,
210 | "inference_steps": inference_steps,
211 | "scheduler": scheduler,
212 | },
213 | )
214 | if type(response) == list:
215 | response = response[0]
216 | return response
217 |
218 |
219 | class SDXL(Text2Image):
220 | """Wrapper for a StableDiffusion XL model."""
221 |
222 | def __init__(self):
223 | super().__init__()
224 | self.name = "sdxl"
225 | self.model_name = SDXL_MODEL_URL
226 |
227 | def generate_image(self, ctx):
228 | prompt = ctx.params.get("prompt", "None provided")
229 | inference_steps = ctx.params.get("inference_steps", 50.0)
230 | scheduler = ctx.params.get("scheduler_choices", "None provided")
231 | guidance_scale = ctx.params.get("guidance_scale", 7.5)
232 | refiner = ctx.params.get("refine_choices", SDXL_REFINE_CHOICES[0])
233 | refiner = SDXL_REFINE_MAP[refiner]
234 | refine_steps = ctx.params.get("refine_steps", None)
235 | negative_prompt = ctx.params.get("negative_prompt", None)
236 | high_noise_frac = ctx.params.get("high_noise_frac", None)
237 |
238 | _inputs = {
239 | "prompt": prompt,
240 | "inference_steps": inference_steps,
241 | "scheduler": scheduler,
242 | "refine": refiner,
243 | "guidance_scale": guidance_scale,
244 | }
245 | if negative_prompt is not None:
246 | _inputs["negative_prompt"] = negative_prompt
247 | if refine_steps is not None:
248 | _inputs["refine_steps"] = refine_steps
249 | if high_noise_frac is not None:
250 | _inputs["high_noise_frac"] = high_noise_frac
251 |
252 | response = replicate.run(
253 | self.model_name,
254 | input=_inputs,
255 | )
256 | if type(response) == list:
257 | response = response[0]
258 | return response
259 |
260 |
261 | class SDXLLightning(Text2Image):
262 | """Wrapper for a StableDiffusion XL Lightning model."""
263 |
264 | def __init__(self):
265 | super().__init__()
266 | self.name = "sdxl-lightning"
267 | self.model_name = SDXL_LIGHTNING_MODEL_URL
268 |
269 | def generate_image(self, ctx):
270 | prompt = ctx.params.get("prompt", "None provided")
271 | negative_prompt = ctx.params.get(
272 | "negative_prompt", SDXL_LIGHTNING_NEGATIVE_PROMPT_DEFAULT
273 | )
274 | inference_steps = ctx.params.get(
275 | "inference_steps", SDXL_LIGHTNING_NUM_INFERENCE_STEPS_DEFAULT
276 | )
277 | scheduler = ctx.params.get(
278 | "scheduler_choices", SDXL_LIGHTNING_SCHEDULER_CHOICES[0]
279 | )
280 | guidance_scale = ctx.params.get(
281 | "guidance_scale", SDXL_LIGHTNING_GUIDANCE_SCALE_DEFAULT
282 | )
283 | width = ctx.params.get("width", SDXL_LIGHTNING_SIZE_CHOICES[0])
284 | height = ctx.params.get("height", SDXL_LIGHTNING_SIZE_CHOICES[0])
285 |
286 | _inputs = {
287 | "prompt": prompt,
288 | "num_inference_steps": inference_steps,
289 | "scheduler": scheduler,
290 | "guidance_scale": guidance_scale,
291 | "negative_prompt": negative_prompt,
292 | "width": int(width),
293 | "height": int(height),
294 | }
295 |
296 | response = replicate.run(
297 | self.model_name,
298 | input=_inputs,
299 | )
300 | if type(response) == list:
301 | response = response[0]
302 | return response
303 |
304 | class StableDiffusion3(Text2Image):
305 | """Wrapper for a StableDiffusion 3 model."""
306 |
307 | def __init__(self):
308 | super().__init__()
309 | self.name = "stable-diffusion-3"
310 | self.model_name = SD3_MODEL_URL
311 |
312 | def generate_image(self, ctx):
313 | prompt = ctx.params.get("prompt", "None provided")
314 | aspect_ratio = ctx.params.get("aspect_ratio", SD3_ASPECT_RATIO_CHOICES[0])
315 | seed = ctx.params.get("seed", 51)
316 | negative_prompt = ctx.params.get("negative_prompt", "")
317 |
318 | stability_key = os.environ["STABILITY_API_KEY"]
319 |
320 | response = requests.post(
321 | SD3_MODEL_URL,
322 | headers={
323 | "authorization": f"Bearer {stability_key}",
324 | "accept": "image/*"
325 | },
326 | files={"none": ''},
327 | data={
328 | "prompt": prompt,
329 | "aspect_ratio": aspect_ratio,
330 | "negative_prompt": negative_prompt,
331 | "seed": seed,
332 | "output_format": "jpeg",
333 | },
334 | )
335 | if type(response) == list:
336 | response = response[0]
337 | return response
338 |
339 | class SSD1B(Text2Image):
340 | """Wrapper for a SSD-1B model."""
341 |
342 | def __init__(self):
343 | super().__init__()
344 | self.name = "ssd-1b"
345 | self.model_name = SSD1B_MODEL_URL
346 |
347 | def generate_image(self, ctx):
348 | prompt = ctx.params.get("prompt", "None provided")
349 | inference_steps = ctx.params.get("inference_steps", 25.0)
350 | scheduler = ctx.params.get("scheduler_choices", "None provided")
351 | guidance_scale = ctx.params.get("guidance_scale", 7.5)
352 | negative_prompt = ctx.params.get("negative_prompt", None)
353 |
354 | _inputs = {
355 | "prompt": prompt,
356 | "inference_steps": inference_steps,
357 | "scheduler": scheduler,
358 | "guidance_scale": guidance_scale,
359 | }
360 | if negative_prompt is not None:
361 | _inputs["negative_prompt"] = negative_prompt
362 |
363 | response = replicate.run(
364 | self.model_name,
365 | input=_inputs,
366 | )
367 | if type(response) == list:
368 | response = response[0]
369 | return response
370 |
371 |
372 | class Kandinsky(Text2Image):
373 | """Wrapper for a Kandinsky model."""
374 |
375 | def __init__(self):
376 | super().__init__()
377 | self.name = "kandinsky"
378 | self.model_name = KANDINSKY_MODEL_URL
379 |
380 | def generate_image(self, ctx):
381 | prompt = ctx.params.get("prompt", "None provided")
382 | negative_prompt = ctx.params.get("negative_prompt", None)
383 | width = int(ctx.params.get("width", 512))
384 | height = int(ctx.params.get("height", 512))
385 | inference_steps = ctx.params.get("inference_steps", 75)
386 | inference_steps_prior = ctx.params.get("inference_steps_prior", 25)
387 |
388 | input = {
389 | "prompt": prompt,
390 | "width": width,
391 | "height": height,
392 | "inference_steps": inference_steps,
393 | "inference_steps_prior": inference_steps_prior,
394 | }
395 |
396 | if negative_prompt is not None:
397 | input["negative_prompt"] = negative_prompt
398 |
399 | response = replicate.run(
400 | self.model_name,
401 | input=input,
402 | )
403 | if type(response) == list:
404 | response = response[0]
405 | return response
406 |
407 |
408 | class PlaygroundV2(Text2Image):
409 | """Wrapper for Aesthetic Playground V2 model."""
410 |
411 | def __init__(self):
412 | super().__init__()
413 | self.name = "playground-v2"
414 | self.model_name = PLAYGROUND_V2_MODEL_URL
415 |
416 | def generate_image(self, ctx):
417 | prompt = ctx.params.get("prompt", "None provided")
418 | negative_prompt = ctx.params.get("negative_prompt", None)
419 | width = int(ctx.params.get("width", 1024))
420 | height = int(ctx.params.get("height", 1024))
421 | inference_steps = ctx.params.get("inference_steps", 50)
422 | guidance_scale = ctx.params.get("guide_scale", 3.0)
423 | scheduler = ctx.params.get("scheduler_choices", "None provided")
424 |
425 | input = {
426 | "prompt": prompt,
427 | "width": width,
428 | "height": height,
429 | "num_inference_steps": inference_steps,
430 | "guidance_scale": guidance_scale,
431 | "scheduler": scheduler,
432 | }
433 |
434 | if negative_prompt is not None:
435 | input["negative_prompt"] = negative_prompt
436 |
437 | response = replicate.run(
438 | self.model_name,
439 | input=input,
440 | )
441 | if type(response) == list:
442 | response = response[0]
443 | return response
444 |
445 |
446 | class LatentConsistencyModel(Text2Image):
447 | """Wrapper for a Latent Consistency model."""
448 |
449 | def __init__(self):
450 | super().__init__()
451 | self.name = "latent-consistency"
452 | self.model_name = LC_MODEL_URL
453 |
454 | def generate_image(self, ctx):
455 | prompt = ctx.params.get("prompt", "None provided")
456 | width = int(ctx.params.get("width", 512))
457 | height = int(ctx.params.get("height", 512))
458 | num_inf_steps = int(ctx.params.get("num_inference_steps", 4))
459 | guide_scale = float(ctx.params.get("guidance_scale", 7.5))
460 | lcm_origin_steps = int(ctx.params.get("lcm_origin_steps", 50))
461 | distro = ctx.params.get("model_distribution", "None provided")
462 |
463 | if distro == "replicate":
464 | response = replicate.run(
465 | self.model_name,
466 | input={
467 | "prompt": prompt,
468 | "width": width,
469 | "height": height,
470 | "num_inference_steps": num_inf_steps,
471 | "guidance_scale": guide_scale,
472 | "lcm_origin_steps": lcm_origin_steps,
473 | },
474 | )
475 |
476 | if type(response) == list:
477 | response = response[0]
478 | return response
479 | else:
480 | with add_sys_path(os.path.dirname(os.path.abspath(__file__))):
481 | # pylint: disable=no-name-in-module,import-error
482 | from local_t2i_models import lcm
483 |
484 | response = lcm(
485 | prompt,
486 | width,
487 | height,
488 | num_inf_steps,
489 | guide_scale,
490 | lcm_origin_steps,
491 | )
492 | return response
493 |
494 |
495 | class DALLE2(Text2Image):
496 | """Wrapper for a DALL-E 2 model."""
497 |
498 | def __init__(self):
499 | super().__init__()
500 | self.name = "dalle2"
501 |
502 | def generate_image(self, ctx):
503 | prompt = ctx.params.get("prompt", "None provided")
504 | size = ctx.params.get("size_choices", "None provided")
505 |
506 | response = openai.OpenAI().images.generate(
507 | model="dall-e-2", prompt=prompt, n=1, size=size
508 | )
509 | return response.data[0].url
510 |
511 |
512 | class DALLE3(Text2Image):
513 | """Wrapper for a DALL-E 3 model."""
514 |
515 | def __init__(self):
516 | super().__init__()
517 | self.name = "dalle3"
518 |
519 | def generate_image(self, ctx):
520 | prompt = ctx.params.get("prompt", "None provided")
521 | size = ctx.params.get("size_choices", "None provided")
522 | quality = ctx.params.get("quality_choices", "None provided")
523 |
524 | response = openai.OpenAI().images.generate(
525 | model="dall-e-3", prompt=prompt, n=1, quality=quality, size=size
526 | )
527 |
528 | revised_prompt = response.data[0].revised_prompt
529 | ctx.params["revised_prompt"] = revised_prompt
530 |
531 | return response.data[0].url
532 |
533 |
534 | class VQGANCLIP(Text2Image):
535 | """Wrapper for a VQGAN-CLIP model."""
536 |
537 | def __init__(self):
538 | super().__init__()
539 | self.name = "vqgan-clip"
540 | self.model_name = VQGAN_MODEL_URL
541 |
542 | def generate_image(self, ctx):
543 | prompt = ctx.params.get("prompt", "None provided")
544 | response = replicate.run(self.model_name, input={"prompt": prompt})
545 | if type(response) == list:
546 | response = response[0]
547 | return response
548 |
549 |
550 | def get_model(model_name):
551 | mapping = {
552 | "sd": StableDiffusion,
553 | "sdxl": SDXL,
554 | "sdxl-lightning": SDXLLightning,
555 | "ssd-1b": SSD1B,
556 | "latent-consistency": LatentConsistencyModel,
557 | "kandinsky-2.2": Kandinsky,
558 | "playground-v2": PlaygroundV2,
559 | "dalle2": DALLE2,
560 | "dalle3": DALLE3,
561 | "vqgan-clip": VQGANCLIP,
562 | "stable-diffusion-3": StableDiffusion3,
563 | }
564 | return mapping[model_name]()
565 |
566 |
567 | def set_stable_diffusion_config(sample, ctx):
568 | sample["stable_diffusion_config"] = fo.DynamicEmbeddedDocument(
569 | inference_steps=ctx.params.get("inference_steps", "None provided"),
570 | scheduler=ctx.params.get("scheduler_choices", "None provided"),
571 | width=ctx.params.get("width_choices", "None provided"),
572 | height=ctx.params.get("height_choices", "None provided"),
573 | )
574 |
575 |
576 | def set_sdxl_config(sample, ctx):
577 | sample["sdxl_config"] = fo.DynamicEmbeddedDocument(
578 | inference_steps=ctx.params.get("inference_steps", "None provided"),
579 | scheduler=ctx.params.get("scheduler_choices", "None provided"),
580 | guidance_scale=ctx.params.get("guidance_scale", 7.5),
581 | refiner=SDXL_REFINE_MAP[
582 | ctx.params.get("refine_choices", "None provided")
583 | ],
584 | refine_steps=ctx.params.get("refine_steps", None),
585 | negative_prompt=ctx.params.get("negative_prompt", None),
586 | high_noise_frac=ctx.params.get("high_noise_frac", None),
587 | )
588 |
589 |
590 | def set_sdxl_lightning_config(sample, ctx):
591 | sample["sdxl_lightning_config"] = fo.DynamicEmbeddedDocument(
592 | inference_steps=ctx.params.get(
593 | "inference_steps", SDXL_LIGHTNING_NUM_INFERENCE_STEPS_DEFAULT
594 | ),
595 | scheduler=ctx.params.get("scheduler_choices", "None provided"),
596 | guidance_scale=ctx.params.get(
597 | "guidance_scale", SDXL_LIGHTNING_GUIDANCE_SCALE_DEFAULT
598 | ),
599 | width=ctx.params.get("width", 1024),
600 | height=ctx.params.get("height", 1024),
601 | negative_prompt=ctx.params.get(
602 | "negative_prompt", SDXL_LIGHTNING_NEGATIVE_PROMPT_DEFAULT
603 | ),
604 | )
605 |
606 | def set_sd3_config(sample, ctx):
607 | sample["sd3_config"] = fo.DynamicEmbeddedDocument(
608 | aspect_ratio = ctx.params.get("aspect_ratio", SD3_ASPECT_RATIO_CHOICES[0]),
609 | negative_prompt=ctx.params.get(
610 | "negative_prompt", ""
611 | ),
612 | seed=ctx.params.get(
613 | "seed", 51
614 | ),
615 | )
616 |
617 |
618 | def set_ssd1b_config(sample, ctx):
619 | sample["ssd1b_config"] = fo.DynamicEmbeddedDocument(
620 | inference_steps=ctx.params.get("inference_steps", "None provided"),
621 | scheduler=ctx.params.get("scheduler_choices", "None provided"),
622 | guidance_scale=ctx.params.get("guidance_scale", 7.5),
623 | negative_prompt=ctx.params.get("negative_prompt", None),
624 | )
625 |
626 |
627 | def set_kandinsky_config(sample, ctx):
628 | sample["kandinsky_config"] = fo.DynamicEmbeddedDocument(
629 | inference_steps=ctx.params.get("inference_steps", 75),
630 | inference_steps_prior=ctx.params.get("inference_steps_prior", 25),
631 | width=ctx.params.get("width", 512),
632 | height=ctx.params.get("height", 512),
633 | negative_prompt=ctx.params.get("negative_prompt", None),
634 | )
635 |
636 |
637 | def set_playground_v2_config(sample, ctx):
638 | sample["playground_v2_config"] = fo.DynamicEmbeddedDocument(
639 | inference_steps=ctx.params.get("inference_steps", 50),
640 | width=ctx.params.get("width", 1024),
641 | height=ctx.params.get("height", 1024),
642 | guidance_scale=ctx.params.get("guidance_scale", 3.0),
643 | scheduler=ctx.params.get("scheduler_choices", "None provided"),
644 | negative_prompt=ctx.params.get("negative_prompt", None),
645 | )
646 |
647 |
648 | def set_latent_consistency_config(sample, ctx):
649 | sample["latent_consistency_config"] = fo.DynamicEmbeddedDocument(
650 | inference_steps=ctx.params.get("num_inference_steps", 4),
651 | guidance_scale=ctx.params.get("guidance_scale", 7.5),
652 | lcm_origin_steps=ctx.params.get("lcm_origin_steps", 50),
653 | width=ctx.params.get("width", 512),
654 | height=ctx.params.get("height", 512),
655 | )
656 |
657 |
658 | def set_vqgan_clip_config(sample, ctx):
659 | return
660 |
661 |
662 | def set_dalle2_config(sample, ctx):
663 | sample["dalle2_config"] = fo.DynamicEmbeddedDocument(
664 | size=ctx.params.get("size_choices", "None provided")
665 | )
666 |
667 |
668 | def set_dalle3_config(sample, ctx):
669 | sample["dalle3_config"] = fo.DynamicEmbeddedDocument(
670 | size=ctx.params.get("size_choices", "None provided"),
671 | quality=ctx.params.get("quality_choices", "None provided"),
672 | revised_prompt=ctx.params.get("revised_prompt", "None provided"),
673 | )
674 |
675 |
676 | def set_config(sample, ctx, model_name):
677 | mapping = {
678 | "sd": set_stable_diffusion_config,
679 | "sdxl": set_sdxl_config,
680 | "sdxl-lightning": set_sdxl_lightning_config,
681 | "ssd-1b": set_ssd1b_config,
682 | "latent-consistency": set_latent_consistency_config,
683 | "kandinsky-2.2": set_kandinsky_config,
684 | "playground-v2": set_playground_v2_config,
685 | "dalle2": set_dalle2_config,
686 | "dalle3": set_dalle3_config,
687 | "vqgan-clip": set_vqgan_clip_config,
688 | "stable-diffusion-3": set_sd3_config,
689 | }
690 |
691 | config_setter = mapping[model_name]
692 | config_setter(sample, ctx)
693 |
694 |
695 | def generate_filepath(ctx):
696 | download_dir = ctx.params.get("download_dir", {})
697 | if type(download_dir) == dict:
698 | download_dir = download_dir.get("absolute_path", "/tmp")
699 |
700 | filename = str(uuid.uuid4())[:13].replace("-", "") + ".png"
701 | return os.path.join(download_dir, filename)
702 |
703 |
704 | #### MODEL CHOICES ####
705 | def _add_replicate_choices(model_choices):
706 | model_choices.add_choice("sd", label="Stable Diffusion")
707 | model_choices.add_choice("sdxl", label="SDXL")
708 | model_choices.add_choice("sdxl-lightning", label="SDXL Lightning")
709 | model_choices.add_choice("ssd-1b", label="SSD-1B")
710 | model_choices.add_choice("kandinsky-2.2", label="Kandinsky 2.2")
711 | model_choices.add_choice("playground-v2", label="Playground V2")
712 | if "latent-consistency" not in model_choices.values():
713 | model_choices.add_choice(
714 | "latent-consistency", label="Latent Consistency"
715 | )
716 | model_choices.add_choice("vqgan-clip", label="VQGAN-CLIP")
717 |
718 |
719 | def _add_openai_choices(model_choices):
720 | model_choices.add_choice("dalle2", label="DALL-E2")
721 | model_choices.add_choice("dalle3", label="DALL-E3")
722 |
723 | def _add_stability_choices(model_choices):
724 | model_choices.add_choice("stable-diffusion-3", label="Stable Diffusion 3")
725 |
726 | def _add_diffusers_choices(model_choices):
727 | if "latent-consistency" not in model_choices.values():
728 | model_choices.add_choice(
729 | "latent-consistency", label="Latent Consistency"
730 | )
731 |
732 |
733 | #### STABLE DIFFUSION INPUTS ####
734 | def _handle_stable_diffusion_input(ctx, inputs):
735 | size_choices = SD_SIZE_CHOICES
736 | width_choices = types.Dropdown(label="Width")
737 | for size in size_choices:
738 | width_choices.add_choice(size, label=size)
739 |
740 | inputs.enum(
741 | "width_choices",
742 | width_choices.values(),
743 | default="512",
744 | view=width_choices,
745 | )
746 |
747 | height_choices = types.Dropdown(label="Height")
748 | for size in size_choices:
749 | height_choices.add_choice(size, label=size)
750 |
751 | inputs.enum(
752 | "height_choices",
753 | height_choices.values(),
754 | default="512",
755 | view=height_choices,
756 | )
757 |
758 | inference_steps_slider = types.SliderView(
759 | label="Num Inference Steps",
760 | componentsProps={"slider": {"min": 1, "max": 500, "step": 1}},
761 | )
762 | inputs.int("inference_steps", default=50, view=inference_steps_slider)
763 |
764 | scheduler_choices_dropdown = types.Dropdown(label="Scheduler")
765 | for scheduler in SD_SCHEDULER_CHOICES:
766 | scheduler_choices_dropdown.add_choice(scheduler, label=scheduler)
767 |
768 | inputs.enum(
769 | "scheduler_choices",
770 | scheduler_choices_dropdown.values(),
771 | default="K_EULER",
772 | view=scheduler_choices_dropdown,
773 | )
774 |
775 |
776 | #### SDXL INPUTS ####
777 | def _handle_sdxl_input(ctx, inputs):
778 |
779 | inputs.str("negative_prompt", label="Negative Prompt", required=False)
780 |
781 | scheduler_choices_dropdown = types.Dropdown(label="Scheduler")
782 | for scheduler in SDXL_SCHEDULER_CHOICES:
783 | scheduler_choices_dropdown.add_choice(scheduler, label=scheduler)
784 |
785 | inputs.enum(
786 | "scheduler_choices",
787 | scheduler_choices_dropdown.values(),
788 | default="K_EULER",
789 | view=scheduler_choices_dropdown,
790 | )
791 |
792 | inference_steps_slider = types.SliderView(
793 | label="Num Inference Steps",
794 | componentsProps={"slider": {"min": 1, "max": 100, "step": 1}},
795 | )
796 | inputs.int("inference_steps", default=50, view=inference_steps_slider)
797 |
798 | guidance_scale_slider = types.SliderView(
799 | label="Guidance Scale",
800 | componentsProps={"slider": {"min": 0.0, "max": 10.0, "step": 0.1}},
801 | )
802 | inputs.float("guidance_scale", default=7.5, view=guidance_scale_slider)
803 |
804 | refiner_choices_dropdown = types.Dropdown(
805 | label="Refiner",
806 | description="Which refine style to use",
807 | )
808 | for refiner in SDXL_REFINE_CHOICES:
809 | refiner_choices_dropdown.add_choice(refiner, label=refiner)
810 |
811 | inputs.enum(
812 | "refine_choices",
813 | refiner_choices_dropdown.values(),
814 | default="None",
815 | view=refiner_choices_dropdown,
816 | )
817 |
818 | rfc = SDXL_REFINE_MAP[ctx.params.get("refine_choices", "None")]
819 | if rfc == "base_image_refiner":
820 | _default = ctx.params.get("inference_steps", 50)
821 | refine_steps_slider = types.SliderView(
822 | label="Num Refine Steps",
823 | componentsProps={"slider": {"min": 1, "max": _default, "step": 1}},
824 | )
825 | inputs.int(
826 | "refine_steps",
827 | label="Refine Steps",
828 | description="The number of steps to refine",
829 | default=_default,
830 | view=refine_steps_slider,
831 | )
832 | elif rfc == "expert_ensemble_refiner":
833 | inputs.float(
834 | "high_noise_frac",
835 | label="High Noise Fraction",
836 | description="The fraction of noise to use",
837 | default=0.8,
838 | )
839 |
840 |
841 | #### SDXL LIGHTNING INPUTS ####
842 | def _handle_sdxl_lightning_input(ctx, inputs):
843 | size_choices = SDXL_LIGHTNING_SIZE_CHOICES
844 | width_choices = types.Dropdown(label="Width")
845 | for size in size_choices:
846 | width_choices.add_choice(size, label=size)
847 |
848 | inputs.enum(
849 | "width",
850 | width_choices.values(),
851 | default=SDXL_LIGHTNING_SIZE_CHOICES[0],
852 | view=width_choices,
853 | )
854 |
855 | height_choices = types.Dropdown(label="Height")
856 | for size in size_choices:
857 | height_choices.add_choice(size, label=size)
858 |
859 | inputs.enum(
860 | "height",
861 | height_choices.values(),
862 | default=SDXL_LIGHTNING_SIZE_CHOICES[0],
863 | view=height_choices,
864 | )
865 |
866 | inputs.str(
867 | "negative_prompt",
868 | label="Negative Prompt",
869 | required=False,
870 | default=SDXL_LIGHTNING_NEGATIVE_PROMPT_DEFAULT,
871 | )
872 |
873 | inference_steps_slider = types.SliderView(
874 | label="Num Inference Steps",
875 | componentsProps={"slider": {"min": 1, "max": 10, "step": 1}},
876 | )
877 | inputs.int(
878 | "inference_steps",
879 | default=SDXL_LIGHTNING_NUM_INFERENCE_STEPS_DEFAULT,
880 | view=inference_steps_slider,
881 | )
882 |
883 | guidance_scale_slider = types.SliderView(
884 | label="Guidance Scale",
885 | componentsProps={"slider": {"min": 0.0, "max": 10.0, "step": 0.1}},
886 | )
887 | inputs.float(
888 | "guidance_scale",
889 | default=SDXL_LIGHTNING_GUIDANCE_SCALE_DEFAULT,
890 | view=guidance_scale_slider,
891 | )
892 |
893 | scheduler_choices_dropdown = types.Dropdown(label="Scheduler")
894 | for scheduler in SDXL_LIGHTNING_SCHEDULER_CHOICES:
895 | scheduler_choices_dropdown.add_choice(scheduler, label=scheduler)
896 |
897 | inputs.enum(
898 | "scheduler_choices",
899 | scheduler_choices_dropdown.values(),
900 | default="K_EULER",
901 | view=scheduler_choices_dropdown,
902 | )
903 |
904 | #### SD3 INPUTS ####
905 | def _handle_sd3_input(ctx, inputs):
906 | aspect_choices = SD3_ASPECT_RATIO_CHOICES
907 | aspect_choices_drop = types.Dropdown(label="Aspect Ratio")
908 | for aspect in aspect_choices:
909 | aspect_choices_drop.add_choice(aspect, label=aspect)
910 |
911 | inputs.enum(
912 | "aspect_ratio",
913 | aspect_choices_drop.values(),
914 | default=SD3_ASPECT_RATIO_CHOICES[0],
915 | view=aspect_choices_drop,
916 | )
917 |
918 |
919 | inputs.str(
920 | "negative_prompt",
921 | label="Negative Prompt",
922 | required=False,
923 | default="",
924 | )
925 |
926 |
927 | inputs.int("seed",
928 | label="seed",
929 | description="Enter the seed for the run",
930 | default=51,
931 | view=types.FieldView(componentsProps={'field': {'min': 1, 'max': 100}})
932 | )
933 |
934 | #### SSD-1B INPUTS ####
935 | def _handle_ssd1b_input(ctx, inputs):
936 | inputs.int("width", label="Width", default=768)
937 | inputs.int("height", label="Height", default=768)
938 | inputs.str("negative_prompt", label="Negative Prompt", required=False)
939 |
940 | scheduler_choices_dropdown = types.Dropdown(label="Scheduler")
941 | for scheduler in SSD1B_SCHEDULER_CHOICES:
942 | scheduler_choices_dropdown.add_choice(scheduler, label=scheduler)
943 |
944 | inputs.enum(
945 | "scheduler_choices",
946 | scheduler_choices_dropdown.values(),
947 | default="K_EULER",
948 | view=scheduler_choices_dropdown,
949 | )
950 |
951 | inference_steps_slider = types.SliderView(
952 | label="Num Inference Steps",
953 | componentsProps={"slider": {"min": 1, "max": 100, "step": 1}},
954 | )
955 | inputs.int("inference_steps", default=25, view=inference_steps_slider)
956 |
957 | guidance_scale_slider = types.SliderView(
958 | label="Guidance Scale",
959 | componentsProps={"slider": {"min": 0.0, "max": 10.0, "step": 0.1}},
960 | )
961 | inputs.float("guidance_scale", default=7.5, view=guidance_scale_slider)
962 |
963 |
964 | #### KANDINSKY INPUTS ####
965 | def _handle_kandinsky_input(ctx, inputs):
966 |
967 | size_choices = KANDINSDKY_SIZE_CHOICES
968 | width_choices = types.Dropdown(label="Width")
969 | for size in size_choices:
970 | width_choices.add_choice(size, label=size)
971 |
972 | inputs.enum(
973 | "width",
974 | width_choices.values(),
975 | default="512",
976 | view=width_choices,
977 | )
978 |
979 | height_choices = types.Dropdown(label="Height")
980 | for size in size_choices:
981 | height_choices.add_choice(size, label=size)
982 |
983 | inputs.enum(
984 | "height",
985 | height_choices.values(),
986 | default="512",
987 | view=height_choices,
988 | )
989 | inputs.str("negative_prompt", label="Negative Prompt", required=False)
990 |
991 | inference_steps_slider = types.SliderView(
992 | label="Num Inference Steps",
993 | componentsProps={"slider": {"min": 1, "max": 100, "step": 1}},
994 | )
995 | inputs.int("inference_steps", default=75, view=inference_steps_slider)
996 |
997 | inference_steps_prior_slider = types.SliderView(
998 | label="Num Inference Steps Prior",
999 | componentsProps={"slider": {"min": 1, "max": 50, "step": 1}},
1000 | )
1001 | inputs.int(
1002 | "inference_steps_prior",
1003 | default=25,
1004 | view=inference_steps_prior_slider,
1005 | )
1006 |
1007 |
1008 | #### PLAYGROUND V2 INPUTS ####
1009 | def _handle_playground_v2_input(ctx, inputs):
1010 | inputs.int("width", label="Width", default=1024)
1011 | inputs.int("height", label="Height", default=1024)
1012 |
1013 | inputs.str("negative_prompt", label="Negative Prompt", required=False)
1014 |
1015 | inference_steps_slider = types.SliderView(
1016 | label="Num Inference Steps",
1017 | componentsProps={"slider": {"min": 1, "max": 100, "step": 1}},
1018 | )
1019 | inputs.int("inference_steps", default=50, view=inference_steps_slider)
1020 |
1021 | guidance_scale_slider = types.SliderView(
1022 | label="Guidance Scale",
1023 | componentsProps={"slider": {"min": 0.0, "max": 10.0, "step": 0.1}},
1024 | )
1025 | inputs.float("guidance_scale", default=3.0, view=guidance_scale_slider)
1026 |
1027 | scheduler_choices_dropdown = types.Dropdown(label="Scheduler")
1028 | for scheduler in PLAYGROUND_V2_SCHEDULER_CHOICES:
1029 | scheduler_choices_dropdown.add_choice(scheduler, label=scheduler)
1030 |
1031 | inputs.enum(
1032 | "scheduler_choices",
1033 | scheduler_choices_dropdown.values(),
1034 | default="K_EULER_ANCESTRAL",
1035 | view=scheduler_choices_dropdown,
1036 | )
1037 |
1038 |
1039 | #### LATENT CONSISTENCY INPUTS ####
1040 | def _handle_latent_consistency_input(ctx, inputs):
1041 |
1042 | replicate_flag = allows_replicate_models()
1043 | diffusers_flag = allows_diffusers_models()
1044 |
1045 | if not replicate_flag:
1046 | ctx.params["model_distribution"] = "diffusers"
1047 | elif not diffusers_flag:
1048 | ctx.params["model_distribution"] = "replicate"
1049 | else:
1050 | model_distribution_choices = types.Dropdown(label="Model Distribution")
1051 | model_distribution_choices.add_choice("diffusers", label="Diffusers")
1052 | model_distribution_choices.add_choice("replicate", label="Replicate")
1053 | inputs.enum(
1054 | "model_distribution",
1055 | model_distribution_choices.values(),
1056 | default="diffusers",
1057 | view=model_distribution_choices,
1058 | )
1059 |
1060 | inputs.int("width", label="Width", default=512)
1061 | inputs.int("height", label="Height", default=512)
1062 |
1063 | inference_steps_slider = types.SliderView(
1064 | label="Num Inference Steps",
1065 | componentsProps={"slider": {"min": 1, "max": 50, "step": 1}},
1066 | )
1067 | inputs.int("num_inference_steps", default=4, view=inference_steps_slider)
1068 |
1069 | lcm_origin_steps_slider = types.SliderView(
1070 | label="LCM Origin Steps",
1071 | componentsProps={"slider": {"min": 1, "max": 100, "step": 1}},
1072 | )
1073 | inputs.int("lcm_origin_steps", default=50, view=lcm_origin_steps_slider)
1074 |
1075 | guidance_scale_slider = types.SliderView(
1076 | label="Guidance Scale",
1077 | componentsProps={"slider": {"min": 0.0, "max": 10.0, "step": 0.1}},
1078 | )
1079 | inputs.float("guidance_scale", default=7.5, view=guidance_scale_slider)
1080 |
1081 |
1082 | #### DALLE2 INPUTS ####
1083 | def _handle_dalle2_input(ctx, inputs):
1084 | size_choices_dropdown = types.Dropdown(label="Size")
1085 | for size in DALLE2_SIZE_CHOICES:
1086 | size_choices_dropdown.add_choice(size, label=size)
1087 |
1088 | inputs.enum(
1089 | "size_choices",
1090 | size_choices_dropdown.values(),
1091 | default="512x512",
1092 | view=size_choices_dropdown,
1093 | )
1094 |
1095 |
1096 | #### DALLE3 INPUTS ####
1097 | def _handle_dalle3_input(ctx, inputs):
1098 | size_choices_dropdown = types.Dropdown(label="Size")
1099 | for size in DALLE3_SIZE_CHOICES:
1100 | size_choices_dropdown.add_choice(size, label=size)
1101 |
1102 | inputs.enum(
1103 | "size_choices",
1104 | size_choices_dropdown.values(),
1105 | default="1024x1024",
1106 | view=size_choices_dropdown,
1107 | )
1108 |
1109 | quality_choices_dropdown = types.Dropdown(label="Quality")
1110 | for quality in DALLE3_QUALITY_CHOICES:
1111 | quality_choices_dropdown.add_choice(quality, label=quality)
1112 |
1113 | inputs.enum(
1114 | "quality_choices",
1115 | quality_choices_dropdown.values(),
1116 | default="standard",
1117 | view=quality_choices_dropdown,
1118 | )
1119 |
1120 |
1121 | #### VQGAN-CLIP INPUTS ####
1122 | def _handle_vqgan_clip_input(ctx, inputs):
1123 | return
1124 |
1125 |
1126 | INPUT_MAPPER = {
1127 | "sd": _handle_stable_diffusion_input,
1128 | "sdxl": _handle_sdxl_input,
1129 | "sdxl-lightning": _handle_sdxl_lightning_input,
1130 | "ssd-1b": _handle_ssd1b_input,
1131 | "kandinsky-2.2": _handle_kandinsky_input,
1132 | "playground-v2": _handle_playground_v2_input,
1133 | "latent-consistency": _handle_latent_consistency_input,
1134 | "dalle2": _handle_dalle2_input,
1135 | "dalle3": _handle_dalle3_input,
1136 | "vqgan-clip": _handle_vqgan_clip_input,
1137 | "stable-diffusion-3": _handle_sd3_input
1138 | }
1139 |
1140 |
1141 | def _handle_input(ctx, inputs):
1142 | model_name = ctx.params.get("model_choices", "sd")
1143 | model_input_handler = INPUT_MAPPER[model_name]
1144 | model_input_handler(ctx, inputs)
1145 |
1146 |
1147 | def _resolve_download_dir(ctx, inputs):
1148 | if len(ctx.dataset) == 0:
1149 | file_explorer = types.FileExplorerView(
1150 | choose_dir=True,
1151 | button_label="Choose a directory...",
1152 | )
1153 | inputs.file(
1154 | "download_dir",
1155 | required=True,
1156 | description="Choose a location to store downloaded images",
1157 | view=file_explorer,
1158 | )
1159 | else:
1160 | base_dir = os.path.dirname(ctx.dataset.first().filepath).split("/")
1161 | ctx.params["download_dir"] = "/".join(base_dir)
1162 |
1163 |
1164 | def _handle_calling(uri, sample_collection, prompt, model_name, **kwargs):
1165 | ctx = dict(view=sample_collection.view())
1166 | params = dict(kwargs)
1167 | params["prompt"] = prompt
1168 | params["model_choices"] = model_name
1169 |
1170 | return foo.execute_operator(uri, ctx, params=params)
1171 |
1172 |
1173 | class Txt2Image(foo.Operator):
1174 | @property
1175 | def config(self):
1176 | _config = foo.OperatorConfig(
1177 | name="txt2img",
1178 | label="Text to Image: Generate Image from Text",
1179 | dynamic=True,
1180 | )
1181 | _config.icon = "/assets/icon.svg"
1182 | return _config
1183 |
1184 | def resolve_input(self, ctx):
1185 | inputs = types.Object()
1186 | _resolve_download_dir(ctx, inputs)
1187 |
1188 | replicate_flag = allows_replicate_models()
1189 | openai_flag = allows_openai_models()
1190 | stabilityai_flag = allows_stabilityai_models()
1191 | diffusers_flag = allows_diffusers_models()
1192 |
1193 |
1194 |
1195 | any_flag = replicate_flag or openai_flag or diffusers_flag or stabilityai_flag
1196 | if not any_flag:
1197 | inputs.message(
1198 | "message",
1199 | label="No models available.",
1200 | descriptions=(
1201 | "You must install one of `replicate`, `openai`, or `diffusers` or define a STABILITY_API_KEY",
1202 | " to use this plugin. ",
1203 | ),
1204 | )
1205 | return types.Property(inputs)
1206 |
1207 | model_choices = types.Dropdown()
1208 | if replicate_flag:
1209 | _add_replicate_choices(model_choices)
1210 | if openai_flag:
1211 | _add_openai_choices(model_choices)
1212 | if diffusers_flag:
1213 | _add_diffusers_choices(model_choices)
1214 | if stabilityai_flag:
1215 | _add_stability_choices(model_choices)
1216 | inputs.enum(
1217 | "model_choices",
1218 | model_choices.values(),
1219 | default=model_choices.choices[0].value,
1220 | label="Model",
1221 | description="Choose a model to generate images",
1222 | view=model_choices,
1223 | )
1224 |
1225 | inputs.str(
1226 | "prompt",
1227 | label="Prompt",
1228 | description="The prompt to generate an image from",
1229 | required=True,
1230 | )
1231 | _handle_input(ctx, inputs)
1232 | return types.Property(inputs)
1233 |
1234 | def execute(self, ctx):
1235 | model_name = ctx.params.get("model_choices", "None provided")
1236 | model = get_model(model_name)
1237 | prompt = ctx.params.get("prompt", "None provided")
1238 |
1239 | response = model.generate_image(ctx)
1240 | filepath = generate_filepath(ctx)
1241 |
1242 | if type(response) == str:
1243 | ## served models return a url
1244 | image_url = response
1245 | download_image(image_url, filepath)
1246 | elif type(response) == requests.models.Response:
1247 | ## served model return whole object
1248 | write_image(response,filepath)
1249 | else:
1250 | ## local models return a PIL image
1251 | response.save(filepath)
1252 |
1253 | sample = fo.Sample(
1254 | filepath=filepath,
1255 | tags=["generated"],
1256 | model=model.name,
1257 | prompt=prompt,
1258 | date_created=datetime.now(),
1259 | )
1260 | set_config(sample, ctx, model_name)
1261 |
1262 | dataset = ctx.dataset
1263 | dataset.add_sample(sample, dynamic=True)
1264 |
1265 | if dataset.get_dynamic_field_schema() is not None:
1266 | dataset.add_dynamic_sample_fields()
1267 | ctx.ops.reload_dataset()
1268 | else:
1269 | ctx.ops.reload_samples()
1270 |
1271 | def list_models(self):
1272 | return list(INPUT_MAPPER.keys())
1273 |
1274 | def __call__(self, sample_collection, prompt, model_name, **kwargs):
1275 | _handle_calling(
1276 | self.uri, sample_collection, prompt, model_name, **kwargs
1277 | )
1278 |
1279 |
1280 | def register(plugin):
1281 | plugin.register(Txt2Image)
1282 |
--------------------------------------------------------------------------------