├── README.md
├── asset
└── teaser.png
├── demo.py
├── environment.yaml
├── example_images
├── 009698.jpg
├── Arknight.jpg
└── new_cat_3.jpeg
├── main.py
├── requirements.txt
└── src
├── config.py
├── eunms.py
├── pipes
├── sd_inversion_pipeline.py
├── sdxl_forward_pipeline.py
└── sdxl_inversion_pipeline.py
├── renoise_inversion.py
├── schedulers
├── ddim_scheduler.py
├── euler_scheduler.py
└── lcm_scheduler.py
└── utils
├── enums_utils.py
└── images_utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Scaling Concept With Text-Guided Diffusion Models
2 | > **Chao Huang, Susan Liang, Yunlong Tang, Yapeng Tian, Anurag Kumar, Chenliang Xu**
3 | >
4 | > Text-guided diffusion models have revolutionized generative tasks by producing high-fidelity content from text descriptions. They have also enabled an editing paradigm where concepts can be replaced through text conditioning (e.g., *a dog -> a tiger*). In this work, we explore a novel approach: instead of replacing a concept, can we enhance or suppress the concept itself? Through an empirical study, we identify a trend where concepts can be decomposed in text-guided diffusion models. Leveraging this insight, we introduce **ScalingConcept**, a simple yet effective method to scale decomposed concepts up or down in real input without introducing new elements.
5 | To systematically evaluate our approach, we present the *WeakConcept-10* dataset, where concepts are imperfect and need to be enhanced. More importantly, ScalingConcept enables a variety of novel zero-shot applications across image and audio domains, including tasks such as canonical pose generation and generative sound highlighting or removal.
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | ## Environment Setup
16 | Our code builds on the requirement of the `diffusers` library. To set up the environment, please run:
17 | ```
18 | conda env create -f environment.yaml
19 | conda activate ScalingConcept
20 | ```
21 | or install requirements:
22 | ```
23 | pip install -r requirements.txt
24 | ```
25 |
26 |
27 | ## Minimal Example
28 |
29 | We provide a minimal example to explore the effects of concept scaling. The `examples_images/` directory contains three sample images demonstrating different applications: `canonical pose generation`, `face attribute editing`, and `anime sketch enhancement`. To get started, try running:
30 |
31 | ```bash
32 | python demo.py
33 | ```
34 |
35 | The default setting is configured for `canonical pose generation`. For optimal results with other applications, adjust the prompt and relevant hyperparameters as noted in the code comments.
36 |
37 | ### Usage
38 |
39 | Our ScalingConcept method supports various applications, each customizable by adjusting scaling parameters within `pipe_inference`. Below are recommended configurations for each application:
40 |
41 | - **Canonical Pose Generation/Object Stitching**:
42 | ```python
43 | prompt = [object_name]
44 | omega = 5
45 | gamma = 3
46 | t_exit = 15
47 | ```
48 |
49 | - **Weather Manipulation**:
50 | ```python
51 | prompt = '(heavy) fog' or '(heavy) rain'
52 | omega = 5
53 | gamma = 3
54 | t_exit = 15
55 | ```
56 |
57 | - **Creative Enhancement**:
58 | ```python
59 | prompt = [concept to enhance]
60 | omega = 3
61 | gamma = 3
62 | t_exit = 15
63 | ```
64 |
65 | - **Face Attribute Scaling**:
66 | ```python
67 | prompt = [face attribute, e.g., 'young face' or 'old face']
68 | omega = 3
69 | gamma = 3
70 | t_exit = 15
71 | ```
72 |
73 | - **Anime Sketch Enhancement**:
74 | ```python
75 | prompt = 'anime'
76 | omega = 5
77 | gamma = 3
78 | t_exit = 25
79 | ```
80 |
81 | In general, a larger `omega` value increases the effect of concept scaling, while higher `gamma` and `t_exit` values maintain fidelity. Note that inversion `prompt` selection is crucial, as the model is sensitive to the wording of prompts.
82 |
83 | ## Acknowledgements
84 |
85 | This code builds upon the [diffusers](https://github.com/huggingface/diffusers) library. Additionally, we borrow code from the following repositories:
86 |
87 | - [Pix2PixZero](https://github.com/pix2pixzero/pix2pix-zero) for noise regularization.
88 | - [sdxl_inversions](https://github.com/cloneofsimo/sdxl_inversions) for the initial implementation of DDIM inversion in SDXL.
89 | - [ReNoise-Inversion](https://github.com/garibida/ReNoise-Inversion) for a precise inversion technique.
90 |
91 | ## Citation
92 | If you use this code for your research, please cite the following work:
93 | ```
94 | @misc{huang2024scaling,
95 | title={Scaling Concept With Text-Guided Diffusion Models},
96 | author={Chao Huang and Susan Liang and Yunlong Tang and Yapeng Tian Anurag Kumar and Chenliang Xu},
97 | year={2024},
98 | eprint={2410.24151},
99 | archivePrefix={arXiv},
100 | primaryClass={cs.CV}
101 | }
102 | ```
103 |
--------------------------------------------------------------------------------
/asset/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WikiChao/ScalingConcept/e201f3ab5f460b08d4067464f7cedb26bc5b012b/asset/teaser.png
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 |
4 | from src.eunms import Model_Type, Scheduler_Type
5 | from src.utils.enums_utils import get_pipes
6 | from src.config import RunConfig
7 |
8 | from main import run as invert
9 |
10 |
11 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
12 |
13 | model_type = Model_Type.SDXL
14 | scheduler_type = Scheduler_Type.DDIM
15 | pipe_inversion, pipe_inference = get_pipes(model_type, scheduler_type, device=device)
16 |
17 | input_image = Image.open("example_images/new_cat_3.jpeg").convert("RGB") # "009698.jpg", "Arknight.jpg"
18 | original_shape = input_image.size
19 | input_image = input_image.resize((1024, 1024))
20 | prompt = "cat" # 'smile' for "009698.jpg", 'anime' for "Arknight.jpg"
21 |
22 | config = RunConfig(model_type = model_type,
23 | num_inference_steps = 50,
24 | num_inversion_steps = 50,
25 | num_renoise_steps = 1,
26 | scheduler_type = scheduler_type,
27 | perform_noise_correction = False,
28 | seed = 7865)
29 |
30 | _, inv_latent, _, all_latents, other_kwargs = invert(input_image,
31 | prompt,
32 | config,
33 | pipe_inversion=pipe_inversion,
34 | pipe_inference=pipe_inference,
35 | do_reconstruction=False)
36 |
37 | rec_image = pipe_inference(image = inv_latent,
38 | prompt = "",
39 | denoising_start=0.0,
40 | num_inference_steps = config.num_inference_steps,
41 | guidance_scale = 1.0,
42 | omega=5, # omega=3 for "009698.jpg", omega=5 for "Arknight.jpg"
43 | gamma=3, # gamma=3 for "009698.jpg", gamma=3 for "Arknight.jpg"
44 | inv_latents=all_latents,
45 | prompt_embeds_ref=other_kwargs[0],
46 | added_cond_kwargs_ref=other_kwargs[1],
47 | t_exit=15, # t_exit=15 for "009698.jpg", t_exit=25 for "Arknight.jpg"
48 | ).images[0]
49 |
50 | rec_image.resize(original_shape).save("new_cat_3.jpg")
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: ScalingConcept
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.11.8
7 | - pip=23.3.1
8 | - pip:
9 | - -r requirements.txt
--------------------------------------------------------------------------------
/example_images/009698.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WikiChao/ScalingConcept/e201f3ab5f460b08d4067464f7cedb26bc5b012b/example_images/009698.jpg
--------------------------------------------------------------------------------
/example_images/Arknight.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WikiChao/ScalingConcept/e201f3ab5f460b08d4067464f7cedb26bc5b012b/example_images/Arknight.jpg
--------------------------------------------------------------------------------
/example_images/new_cat_3.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WikiChao/ScalingConcept/e201f3ab5f460b08d4067464f7cedb26bc5b012b/example_images/new_cat_3.jpeg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import pyrallis
2 | import torch
3 | from PIL import Image
4 | from diffusers.utils.torch_utils import randn_tensor
5 |
6 | from src.config import RunConfig
7 | from src.utils.enums_utils import model_type_to_size, is_stochastic
8 |
9 | def create_noise_list(model_type, length, generator=None):
10 | img_size = model_type_to_size(model_type)
11 | VQAE_SCALE = 8
12 | latents_size = (1, 4, img_size[0] // VQAE_SCALE, img_size[1] // VQAE_SCALE)
13 | return [randn_tensor(latents_size, dtype=torch.float16, device=torch.device("cuda:0"), generator=generator) for i in range(length)]
14 |
15 | @pyrallis.wrap()
16 | def main(cfg: RunConfig):
17 | run(cfg)
18 |
19 | def run(init_image: Image,
20 | prompt: str,
21 | cfg: RunConfig,
22 | pipe_inversion,
23 | pipe_inference,
24 | latents = None,
25 | edit_prompt = None,
26 | edit_cfg = 1.0,
27 | noise = None,
28 | do_reconstruction = True):
29 |
30 | generator = torch.Generator().manual_seed(cfg.seed)
31 |
32 | if is_stochastic(cfg.scheduler_type):
33 | if latents is None:
34 | noise = create_noise_list(cfg.model_type, cfg.num_inversion_steps, generator=generator)
35 | pipe_inversion.scheduler.set_noise_list(noise)
36 | pipe_inference.scheduler.set_noise_list(noise)
37 |
38 | pipe_inversion.cfg = cfg
39 | pipe_inference.cfg = cfg
40 | all_latents = None
41 |
42 | if latents is None:
43 | print("Inverting...")
44 | res = pipe_inversion(prompt = prompt,
45 | num_inversion_steps = cfg.num_inversion_steps,
46 | num_inference_steps = cfg.num_inference_steps,
47 | generator = generator,
48 | image = init_image,
49 | guidance_scale = cfg.guidance_scale,
50 | strength = cfg.inversion_max_step,
51 | denoising_start = 1.0-cfg.inversion_max_step,
52 | num_renoise_steps = cfg.num_renoise_steps)
53 | latents = res[0][0]
54 | all_latents = res[1]
55 | other_kwargs = res[2]
56 |
57 | inv_latent = latents.clone()
58 |
59 | if do_reconstruction:
60 | print("Generating...")
61 | edit_prompt = prompt if edit_prompt is None else edit_prompt
62 | guidance_scale = edit_cfg
63 | img = pipe_inference(prompt = edit_prompt,
64 | num_inference_steps = cfg.num_inference_steps,
65 | negative_prompt = prompt,
66 | image = latents,
67 | strength = cfg.inversion_max_step,
68 | denoising_start = 1.0-cfg.inversion_max_step,
69 | guidance_scale = guidance_scale,
70 | omega=1,
71 | gamma=0,
72 | inv_latents=all_latents,
73 | prompt_embeds_ref=other_kwargs[0],
74 | added_cond_kwargs_ref=other_kwargs[1]).images[0]
75 | else:
76 | img = None
77 |
78 | return img, inv_latent, noise, all_latents, other_kwargs
79 |
80 | if __name__ == "__main__":
81 | main()
82 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.2.0
2 | torchvision==0.17.0
3 | diffusers==0.24.0
4 | transformers==4.32.1
5 | pyrallis==0.3.1
6 | accelerate==0.25.0
7 | bitsandbytes==0.43.0
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | from src.eunms import Model_Type, Scheduler_Type
4 |
5 | @dataclass
6 | class RunConfig:
7 | model_type : Model_Type = Model_Type.SDXL_Turbo
8 |
9 | scheduler_type : Scheduler_Type = Scheduler_Type.EULER
10 |
11 | seed: int = 7865
12 |
13 | num_inference_steps: int = 4
14 |
15 | num_inversion_steps: int = 4
16 |
17 | guidance_scale: float = 0.0
18 |
19 | num_renoise_steps: int = 9
20 |
21 | max_num_renoise_steps_first_step: int = 5
22 |
23 | inversion_max_step: float = 1.0
24 |
25 | # Average Parameters
26 |
27 | average_latent_estimations: bool = True
28 |
29 | average_first_step_range: tuple = (0, 5)
30 |
31 | average_step_range: tuple = (8, 10)
32 |
33 | # Noise Regularization
34 |
35 | noise_regularization_lambda_ac: float = 20.0
36 |
37 | noise_regularization_lambda_kl: float = 0.065
38 |
39 | noise_regularization_num_reg_steps: int = 4
40 |
41 | noise_regularization_num_ac_rolls: int = 5
42 |
43 | # Noise Correction
44 |
45 | perform_noise_correction: bool = True
46 |
47 | def __post_init__(self):
48 | pass
--------------------------------------------------------------------------------
/src/eunms.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | class Scheduler_Type(Enum):
4 | DDIM = 1
5 | EULER = 2
6 | LCM = 3
7 |
8 | class Model_Type(Enum):
9 | SDXL = 1
10 | SDXL_Turbo = 2
11 | LCM_SDXL = 3
12 | SD15 = 4
13 | SD21 = 5
14 | SD21_Turbo = 6
15 | SD14 = 7
--------------------------------------------------------------------------------
/src/pipes/sd_inversion_pipeline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3 | from diffusers import StableDiffusionImg2ImgPipeline
4 | from diffusers.utils.torch_utils import randn_tensor
5 |
6 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
7 | StableDiffusionPipelineOutput,
8 | retrieve_timesteps,
9 | PipelineImageInput
10 | )
11 |
12 | from src.renoise_inversion import inversion_step
13 |
14 |
15 | class SDDDIMPipeline(StableDiffusionImg2ImgPipeline):
16 | # @torch.no_grad()
17 | def __call__(
18 | self,
19 | prompt: Union[str, List[str]] = None,
20 | image: PipelineImageInput = None,
21 | strength: float = 1.0,
22 | num_inversion_steps: Optional[int] = 50,
23 | timesteps: List[int] = None,
24 | guidance_scale: Optional[float] = 7.5,
25 | negative_prompt: Optional[Union[str, List[str]]] = None,
26 | num_images_per_prompt: Optional[int] = 1,
27 | eta: Optional[float] = 0.0,
28 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
29 | prompt_embeds: Optional[torch.FloatTensor] = None,
30 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
31 | ip_adapter_image: Optional[PipelineImageInput] = None,
32 | output_type: Optional[str] = "pil",
33 | return_dict: bool = True,
34 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
35 | clip_skip: int = None,
36 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
37 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
38 | num_renoise_steps: int = 100,
39 | **kwargs,
40 | ):
41 | callback = kwargs.pop("callback", None)
42 | callback_steps = kwargs.pop("callback_steps", None)
43 |
44 | if callback is not None:
45 | deprecate(
46 | "callback",
47 | "1.0.0",
48 | "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
49 | )
50 | if callback_steps is not None:
51 | deprecate(
52 | "callback_steps",
53 | "1.0.0",
54 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
55 | )
56 |
57 | # 1. Check inputs. Raise error if not correct
58 | self.check_inputs(
59 | prompt,
60 | strength,
61 | callback_steps,
62 | negative_prompt,
63 | prompt_embeds,
64 | negative_prompt_embeds,
65 | callback_on_step_end_tensor_inputs,
66 | )
67 |
68 | self._guidance_scale = guidance_scale
69 | self._clip_skip = clip_skip
70 | self._cross_attention_kwargs = cross_attention_kwargs
71 |
72 | # 2. Define call parameters
73 | if prompt is not None and isinstance(prompt, str):
74 | batch_size = 1
75 | elif prompt is not None and isinstance(prompt, list):
76 | batch_size = len(prompt)
77 | else:
78 | batch_size = prompt_embeds.shape[0]
79 |
80 | device = self._execution_device
81 |
82 | # 3. Encode input prompt
83 | text_encoder_lora_scale = (
84 | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
85 | )
86 | prompt_embeds, negative_prompt_embeds = self.encode_prompt(
87 | prompt,
88 | device,
89 | num_images_per_prompt,
90 | self.do_classifier_free_guidance,
91 | negative_prompt,
92 | prompt_embeds=prompt_embeds,
93 | negative_prompt_embeds=negative_prompt_embeds,
94 | lora_scale=text_encoder_lora_scale,
95 | clip_skip=self.clip_skip,
96 | )
97 | # For classifier free guidance, we need to do two forward passes.
98 | # Here we concatenate the unconditional and text embeddings into a single batch
99 | # to avoid doing two forward passes
100 | if self.do_classifier_free_guidance:
101 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
102 |
103 | if ip_adapter_image is not None:
104 | image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
105 | if self.do_classifier_free_guidance:
106 | image_embeds = torch.cat([negative_image_embeds, image_embeds])
107 |
108 | # 4. Preprocess image
109 | image = self.image_processor.preprocess(image)
110 |
111 | # 5. set timesteps
112 | timesteps, num_inversion_steps = retrieve_timesteps(self.scheduler, num_inversion_steps, device, timesteps)
113 | timesteps, num_inversion_steps = self.get_timesteps(num_inversion_steps, strength, device)
114 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
115 |
116 | # 6. Prepare latent variables
117 | with torch.no_grad():
118 | latents = self.prepare_latents(
119 | image,
120 | latent_timestep,
121 | batch_size,
122 | num_images_per_prompt,
123 | prompt_embeds.dtype,
124 | device,
125 | generator,
126 | )
127 |
128 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
129 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
130 |
131 | # 7.1 Add image embeds for IP-Adapter
132 | added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
133 |
134 | # 7.2 Optionally get Guidance Scale Embedding
135 | timestep_cond = None
136 | if self.unet.config.time_cond_proj_dim is not None:
137 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
138 | timestep_cond = self.get_guidance_scale_embedding(
139 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
140 | ).to(device=device, dtype=latents.dtype)
141 |
142 | # 8. Denoising loop
143 | num_warmup_steps = len(timesteps) - num_inversion_steps * self.scheduler.order
144 |
145 | self._num_timesteps = len(timesteps)
146 | self.z_0 = torch.clone(latents)
147 | self.noise = randn_tensor(self.z_0.shape, generator=generator, device=self.z_0.device, dtype=self.z_0.dtype)
148 |
149 | all_latents = [latents.clone()]
150 | with self.progress_bar(total=num_inversion_steps) as progress_bar:
151 | for i, t in enumerate(reversed(timesteps)):
152 |
153 | latents = inversion_step(self,
154 | latents,
155 | t,
156 | prompt_embeds,
157 | added_cond_kwargs,
158 | num_renoise_steps=num_renoise_steps,
159 | generator=generator)
160 |
161 | all_latents.append(latents.clone())
162 |
163 | if callback_on_step_end is not None:
164 | callback_kwargs = {}
165 | for k in callback_on_step_end_tensor_inputs:
166 | callback_kwargs[k] = locals()[k]
167 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
168 |
169 | latents = callback_outputs.pop("latents", latents)
170 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
171 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
172 |
173 | # call the callback, if provided
174 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
175 | progress_bar.update()
176 | if callback is not None and i % callback_steps == 0:
177 | step_idx = i // getattr(self.scheduler, "order", 1)
178 | callback(step_idx, t, latents)
179 |
180 | image = latents
181 |
182 | # Offload all models
183 | self.maybe_free_model_hooks()
184 |
185 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None), all_latents
186 |
--------------------------------------------------------------------------------
/src/pipes/sdxl_forward_pipeline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
5 | StableDiffusionPipelineOutput,
6 | retrieve_timesteps,
7 | PipelineImageInput,
8 | rescale_noise_cfg,
9 | )
10 |
11 | from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
12 | from diffusers.utils import is_torch_xla_available
13 | from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
14 |
15 | if is_torch_xla_available():
16 | import torch_xla.core.xla_model as xm
17 | XLA_AVAILABLE = True
18 | else:
19 | XLA_AVAILABLE = False
20 |
21 | def degrade_proportionally(i, max_value=1, num_inference_steps=49, gamma=0):
22 | return max(0, max_value * (1 - i / num_inference_steps)**gamma)
23 |
24 | class StableDiffusionXLDecompositionPipeline(StableDiffusionXLImg2ImgPipeline):
25 | @torch.no_grad()
26 | def __call__(
27 | self,
28 | prompt: Union[str, List[str]] = None,
29 | prompt_2: Optional[Union[str, List[str]]] = None,
30 | image: PipelineImageInput = None,
31 | strength: float = 0.3,
32 | num_inference_steps: int = 50,
33 | timesteps: List[int] = None,
34 | denoising_start: Optional[float] = None,
35 | denoising_end: Optional[float] = None,
36 | guidance_scale: float = 5.0,
37 | negative_prompt: Optional[Union[str, List[str]]] = None,
38 | negative_prompt_2: Optional[Union[str, List[str]]] = None,
39 | num_images_per_prompt: Optional[int] = 1,
40 | eta: float = 0.0,
41 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
42 | latents: Optional[torch.FloatTensor] = None,
43 | prompt_embeds: Optional[torch.FloatTensor] = None,
44 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
45 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
46 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
47 | ip_adapter_image: Optional[PipelineImageInput] = None,
48 | output_type: Optional[str] = "pil",
49 | return_dict: bool = True,
50 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
51 | guidance_rescale: float = 0.0,
52 | original_size: Tuple[int, int] = None,
53 | crops_coords_top_left: Tuple[int, int] = (0, 0),
54 | target_size: Tuple[int, int] = None,
55 | negative_original_size: Optional[Tuple[int, int]] = None,
56 | negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
57 | negative_target_size: Optional[Tuple[int, int]] = None,
58 | aesthetic_score: float = 6.0,
59 | negative_aesthetic_score: float = 2.5,
60 | clip_skip: Optional[int] = None,
61 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
62 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
63 | omega = 1,
64 | gamma = 0,
65 | inv_latents = None,
66 | prompt_embeds_ref = None,
67 | added_cond_kwargs_ref = None,
68 | t_exit=15,
69 | **kwargs,
70 | ):
71 | r"""
72 | Function invoked when calling the pipeline for generation.
73 |
74 | Args:
75 | prompt (`str` or `List[str]`, *optional*):
76 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
77 | instead.
78 | prompt_2 (`str` or `List[str]`, *optional*):
79 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
80 | used in both text-encoders
81 | image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`):
82 | The image(s) to modify with the pipeline.
83 | strength (`float`, *optional*, defaults to 0.3):
84 | Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
85 | will be used as a starting point, adding more noise to it the larger the `strength`. The number of
86 | denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
87 | be maximum and the denoising process will run for the full number of iterations specified in
88 | `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of
89 | `denoising_start` being declared as an integer, the value of `strength` will be ignored.
90 | num_inference_steps (`int`, *optional*, defaults to 50):
91 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
92 | expense of slower inference.
93 | timesteps (`List[int]`, *optional*):
94 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
95 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
96 | passed will be used. Must be in descending order.
97 | denoising_start (`float`, *optional*):
98 | When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
99 | bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
100 | it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
101 | strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
102 | is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image
103 | Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
104 | denoising_end (`float`, *optional*):
105 | When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
106 | completed before it is intentionally prematurely terminated. As a result, the returned sample will
107 | still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
108 | denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
109 | final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
110 | forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refine Image
111 | Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
112 | guidance_scale (`float`, *optional*, defaults to 7.5):
113 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
114 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
115 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
116 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
117 | usually at the expense of lower image quality.
118 | negative_prompt (`str` or `List[str]`, *optional*):
119 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
120 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
121 | less than `1`).
122 | negative_prompt_2 (`str` or `List[str]`, *optional*):
123 | The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
124 | `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
125 | num_images_per_prompt (`int`, *optional*, defaults to 1):
126 | The number of images to generate per prompt.
127 | eta (`float`, *optional*, defaults to 0.0):
128 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
129 | [`schedulers.DDIMScheduler`], will be ignored for others.
130 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
131 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
132 | to make generation deterministic.
133 | latents (`torch.FloatTensor`, *optional*):
134 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
135 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
136 | tensor will ge generated by sampling using the supplied random `generator`.
137 | prompt_embeds (`torch.FloatTensor`, *optional*):
138 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
139 | provided, text embeddings will be generated from `prompt` input argument.
140 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
141 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
142 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
143 | argument.
144 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
145 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
146 | If not provided, pooled text embeddings will be generated from `prompt` input argument.
147 | negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
148 | Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
149 | weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
150 | input argument.
151 | ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
152 | output_type (`str`, *optional*, defaults to `"pil"`):
153 | The output format of the generate image. Choose between
154 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
155 | return_dict (`bool`, *optional*, defaults to `True`):
156 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
157 | plain tuple.
158 | cross_attention_kwargs (`dict`, *optional*):
159 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
160 | `self.processor` in
161 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
162 | guidance_rescale (`float`, *optional*, defaults to 0.0):
163 | Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
164 | Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
165 | [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
166 | Guidance rescale factor should fix overexposure when using zero terminal SNR.
167 | original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
168 | If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
169 | `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
170 | explained in section 2.2 of
171 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
172 | crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
173 | `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
174 | `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
175 | `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
176 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
177 | target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
178 | For most cases, `target_size` should be set to the desired height and width of the generated image. If
179 | not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
180 | section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
181 | negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
182 | To negatively condition the generation process based on a specific image resolution. Part of SDXL's
183 | micro-conditioning as explained in section 2.2 of
184 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
185 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
186 | negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
187 | To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
188 | micro-conditioning as explained in section 2.2 of
189 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
190 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
191 | negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
192 | To negatively condition the generation process based on a target image resolution. It should be as same
193 | as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
194 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
195 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
196 | aesthetic_score (`float`, *optional*, defaults to 6.0):
197 | Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
198 | Part of SDXL's micro-conditioning as explained in section 2.2 of
199 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
200 | negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
201 | Part of SDXL's micro-conditioning as explained in section 2.2 of
202 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
203 | simulate an aesthetic score of the generated image by influencing the negative text condition.
204 | clip_skip (`int`, *optional*):
205 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
206 | the output of the pre-final layer will be used for computing the prompt embeddings.
207 | callback_on_step_end (`Callable`, *optional*):
208 | A function that calls at the end of each denoising steps during the inference. The function is called
209 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
210 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
211 | `callback_on_step_end_tensor_inputs`.
212 | callback_on_step_end_tensor_inputs (`List`, *optional*):
213 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
214 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
215 | `._callback_tensor_inputs` attribute of your pipeline class.
216 |
217 | Examples:
218 |
219 | Returns:
220 | [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
221 | [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
222 | `tuple. When returning a tuple, the first element is a list with the generated images.
223 | """
224 |
225 | callback = kwargs.pop("callback", None)
226 | callback_steps = kwargs.pop("callback_steps", None)
227 |
228 | if callback is not None:
229 | deprecate(
230 | "callback",
231 | "1.0.0",
232 | "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
233 | )
234 | if callback_steps is not None:
235 | deprecate(
236 | "callback_steps",
237 | "1.0.0",
238 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
239 | )
240 |
241 | # 1. Check inputs. Raise error if not correct
242 | self.check_inputs(
243 | prompt,
244 | prompt_2,
245 | strength,
246 | num_inference_steps,
247 | callback_steps,
248 | negative_prompt,
249 | negative_prompt_2,
250 | prompt_embeds,
251 | negative_prompt_embeds,
252 | callback_on_step_end_tensor_inputs,
253 | )
254 |
255 | self._guidance_scale = guidance_scale
256 | self._guidance_rescale = guidance_rescale
257 | self._clip_skip = clip_skip
258 | self._cross_attention_kwargs = cross_attention_kwargs
259 | self._denoising_end = denoising_end
260 | self._denoising_start = denoising_start
261 | self._interrupt = False
262 |
263 | # 2. Define call parameters
264 | if prompt is not None and isinstance(prompt, str):
265 | batch_size = 1
266 | elif prompt is not None and isinstance(prompt, list):
267 | batch_size = len(prompt)
268 | else:
269 | batch_size = prompt_embeds.shape[0]
270 |
271 | device = self._execution_device
272 |
273 | # 3. Encode input prompt
274 | text_encoder_lora_scale = (
275 | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
276 | )
277 | (
278 | prompt_embeds,
279 | negative_prompt_embeds,
280 | pooled_prompt_embeds,
281 | negative_pooled_prompt_embeds,
282 | ) = self.encode_prompt(
283 | prompt=prompt,
284 | prompt_2=prompt_2,
285 | device=device,
286 | num_images_per_prompt=num_images_per_prompt,
287 | do_classifier_free_guidance=self.do_classifier_free_guidance,
288 | negative_prompt=negative_prompt,
289 | negative_prompt_2=negative_prompt_2,
290 | prompt_embeds=prompt_embeds,
291 | negative_prompt_embeds=negative_prompt_embeds,
292 | pooled_prompt_embeds=pooled_prompt_embeds,
293 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
294 | lora_scale=text_encoder_lora_scale,
295 | clip_skip=self.clip_skip,
296 | )
297 |
298 | # 4. Preprocess image
299 | image = self.image_processor.preprocess(image)
300 |
301 | # 5. Prepare timesteps
302 | def denoising_value_valid(dnv):
303 | return isinstance(self.denoising_end, float) and 0 < dnv < 1
304 |
305 | timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
306 | timesteps, num_inference_steps = self.get_timesteps(
307 | num_inference_steps,
308 | strength,
309 | device,
310 | denoising_start=self.denoising_start if denoising_value_valid else None,
311 | )
312 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
313 |
314 | add_noise = True if self.denoising_start is None else False
315 | # 6. Prepare latent variables
316 | latents = self.prepare_latents(
317 | image,
318 | latent_timestep,
319 | batch_size,
320 | num_images_per_prompt,
321 | prompt_embeds.dtype,
322 | device,
323 | generator,
324 | add_noise,
325 | )
326 | # 7. Prepare extra step kwargs.
327 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
328 |
329 | height, width = latents.shape[-2:]
330 | height = height * self.vae_scale_factor
331 | width = width * self.vae_scale_factor
332 |
333 | original_size = original_size or (height, width)
334 | target_size = target_size or (height, width)
335 |
336 | # 8. Prepare added time ids & embeddings
337 | if negative_original_size is None:
338 | negative_original_size = original_size
339 | if negative_target_size is None:
340 | negative_target_size = target_size
341 |
342 | add_text_embeds = pooled_prompt_embeds
343 | if self.text_encoder_2 is None:
344 | text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
345 | else:
346 | text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
347 |
348 | add_time_ids, add_neg_time_ids = self._get_add_time_ids(
349 | original_size,
350 | crops_coords_top_left,
351 | target_size,
352 | aesthetic_score,
353 | negative_aesthetic_score,
354 | negative_original_size,
355 | negative_crops_coords_top_left,
356 | negative_target_size,
357 | dtype=prompt_embeds.dtype,
358 | text_encoder_projection_dim=text_encoder_projection_dim,
359 | )
360 | add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
361 |
362 | if self.do_classifier_free_guidance:
363 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
364 | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
365 | add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
366 | add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
367 |
368 | prompt_embeds = prompt_embeds.to(device)
369 | add_text_embeds = add_text_embeds.to(device)
370 | add_time_ids = add_time_ids.to(device)
371 |
372 | if ip_adapter_image is not None:
373 | image_embeds = self.prepare_ip_adapter_image_embeds(
374 | ip_adapter_image, device, batch_size * num_images_per_prompt
375 | )
376 |
377 | # 9. Denoising loop
378 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
379 |
380 | # 9.1 Apply denoising_end
381 | if (
382 | self.denoising_end is not None
383 | and self.denoising_start is not None
384 | and denoising_value_valid(self.denoising_end)
385 | and denoising_value_valid(self.denoising_start)
386 | and self.denoising_start >= self.denoising_end
387 | ):
388 | raise ValueError(
389 | f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
390 | + f" {self.denoising_end} when using type float."
391 | )
392 | elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
393 | discrete_timestep_cutoff = int(
394 | round(
395 | self.scheduler.config.num_train_timesteps
396 | - (self.denoising_end * self.scheduler.config.num_train_timesteps)
397 | )
398 | )
399 | num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
400 | timesteps = timesteps[:num_inference_steps]
401 |
402 | # 9.2 Optionally get Guidance Scale Embedding
403 | timestep_cond = None
404 | if self.unet.config.time_cond_proj_dim is not None:
405 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
406 | timestep_cond = self.get_guidance_scale_embedding(
407 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
408 | ).to(device=device, dtype=latents.dtype)
409 |
410 | self._num_timesteps = len(timesteps)
411 | reference_latents = latents
412 | with self.progress_bar(total=num_inference_steps) as progress_bar:
413 | for i, t in enumerate(timesteps):
414 | if self.interrupt:
415 | continue
416 |
417 | reference_latents = inv_latents[num_inference_steps - i]
418 | # expand the latents if we are doing classifier free guidance
419 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
420 |
421 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
422 |
423 | # get the inversion latents from list
424 | reference_latents = (0.5 * latents + 0.5 * reference_latents)
425 |
426 | reference_model_input = torch.cat([reference_latents] * 2) if self.do_classifier_free_guidance else reference_latents
427 |
428 | reference_model_input = self.scheduler.scale_model_input(reference_model_input, t)
429 |
430 | # predict the noise residual
431 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
432 | if ip_adapter_image is not None:
433 | added_cond_kwargs["image_embeds"] = image_embeds
434 | noise_pred = self.unet(
435 | latent_model_input,
436 | t,
437 | encoder_hidden_states=prompt_embeds, # null prompt
438 | timestep_cond=timestep_cond,
439 | cross_attention_kwargs=self.cross_attention_kwargs,
440 | added_cond_kwargs=added_cond_kwargs,
441 | return_dict=False,
442 | )[0]
443 |
444 | noise_pred_fwd = self.unet(
445 | latent_model_input,
446 | t,
447 | encoder_hidden_states=prompt_embeds_ref, # c prompt
448 | timestep_cond=timestep_cond,
449 | cross_attention_kwargs=self.cross_attention_kwargs,
450 | added_cond_kwargs=added_cond_kwargs_ref,
451 | return_dict=False,
452 | )[0]
453 |
454 | if i < t_exit:
455 | noise_pred_recon = self.unet(
456 | reference_model_input,
457 | t,
458 | encoder_hidden_states=prompt_embeds_ref,
459 | timestep_cond=timestep_cond,
460 | cross_attention_kwargs=self.cross_attention_kwargs,
461 | added_cond_kwargs=added_cond_kwargs_ref,
462 | return_dict=False,
463 | )[0]
464 |
465 | scaling_factor = degrade_proportionally(i, omega, num_inference_steps-1, gamma)
466 |
467 | noise_pred = noise_pred + scaling_factor * (noise_pred_fwd - noise_pred)
468 |
469 | if i < t_exit:
470 | noise_pred = (noise_pred + scaling_factor * (noise_pred_recon - noise_pred_fwd))
471 |
472 | # perform guidance
473 | if self.do_classifier_free_guidance:
474 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
475 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
476 | noise_pred_uncond, noise_pred_text_recon = noise_pred_recon.chunk(2)
477 | noise_pred_recon = noise_pred_uncond + self.guidance_scale * (noise_pred_text_recon - noise_pred_uncond)
478 |
479 | if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
480 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
481 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
482 | noise_pred_recon = rescale_noise_cfg(noise_pred_recon, noise_pred_text_recon, guidance_rescale=self.guidance_rescale)
483 |
484 | # compute the previous noisy sample x_t -> x_t-1
485 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
486 |
487 | if callback_on_step_end is not None:
488 | callback_kwargs = {}
489 | for k in callback_on_step_end_tensor_inputs:
490 | callback_kwargs[k] = locals()[k]
491 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
492 |
493 | latents = callback_outputs.pop("latents", latents)
494 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
495 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
496 | add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
497 | negative_pooled_prompt_embeds = callback_outputs.pop(
498 | "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
499 | )
500 | add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
501 | add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
502 |
503 | # call the callback, if provided
504 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
505 | progress_bar.update()
506 | if callback is not None and i % callback_steps == 0:
507 | step_idx = i // getattr(self.scheduler, "order", 1)
508 | callback(step_idx, t, latents)
509 |
510 | if XLA_AVAILABLE:
511 | xm.mark_step()
512 |
513 | if not output_type == "latent":
514 | # make sure the VAE is in float32 mode, as it overflows in float16
515 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
516 |
517 | if needs_upcasting:
518 | self.upcast_vae()
519 | latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
520 |
521 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
522 |
523 | # cast back to fp16 if needed
524 | if needs_upcasting:
525 | self.vae.to(dtype=torch.float16)
526 | else:
527 | image = latents
528 | return StableDiffusionXLPipelineOutput(images=image)
529 |
530 | # apply watermark if available
531 | if self.watermark is not None:
532 | image = self.watermark.apply_watermark(image)
533 |
534 | image = self.image_processor.postprocess(image, output_type=output_type)
535 |
536 | # Offload all models
537 | self.maybe_free_model_hooks()
538 |
539 | if not return_dict:
540 | return (image,)
541 |
542 | return StableDiffusionXLPipelineOutput(images=image)
--------------------------------------------------------------------------------
/src/pipes/sdxl_inversion_pipeline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3 | from diffusers import StableDiffusionXLImg2ImgPipeline
4 | from diffusers.utils.torch_utils import randn_tensor
5 |
6 | from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
7 | StableDiffusionXLPipelineOutput,
8 | retrieve_timesteps,
9 | PipelineImageInput
10 | )
11 |
12 | from src.renoise_inversion import inversion_step
13 |
14 |
15 | class SDXLDDIMPipeline(StableDiffusionXLImg2ImgPipeline):
16 | # @torch.no_grad()
17 | def __call__(
18 | self,
19 | prompt: Union[str, List[str]] = None,
20 | prompt_2: Optional[Union[str, List[str]]] = None,
21 | image: PipelineImageInput = None,
22 | strength: float = 0.3,
23 | num_inversion_steps: int = 50,
24 | timesteps: List[int] = None,
25 | denoising_start: Optional[float] = None,
26 | denoising_end: Optional[float] = None,
27 | guidance_scale: float = 1.0,
28 | negative_prompt: Optional[Union[str, List[str]]] = None,
29 | negative_prompt_2: Optional[Union[str, List[str]]] = None,
30 | num_images_per_prompt: Optional[int] = 1,
31 | eta: float = 0.0,
32 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
33 | latents: Optional[torch.FloatTensor] = None,
34 | prompt_embeds: Optional[torch.FloatTensor] = None,
35 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
36 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
37 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
38 | ip_adapter_image: Optional[PipelineImageInput] = None,
39 | output_type: Optional[str] = "pil",
40 | return_dict: bool = True,
41 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
42 | guidance_rescale: float = 0.0,
43 | original_size: Tuple[int, int] = None,
44 | crops_coords_top_left: Tuple[int, int] = (0, 0),
45 | target_size: Tuple[int, int] = None,
46 | negative_original_size: Optional[Tuple[int, int]] = None,
47 | negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
48 | negative_target_size: Optional[Tuple[int, int]] = None,
49 | aesthetic_score: float = 6.0,
50 | negative_aesthetic_score: float = 2.5,
51 | clip_skip: Optional[int] = None,
52 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
53 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
54 | num_renoise_steps: int = 100,
55 | **kwargs,
56 | ):
57 | callback = kwargs.pop("callback", None)
58 | callback_steps = kwargs.pop("callback_steps", None)
59 |
60 | if callback is not None:
61 | deprecate(
62 | "callback",
63 | "1.0.0",
64 | "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
65 | )
66 | if callback_steps is not None:
67 | deprecate(
68 | "callback_steps",
69 | "1.0.0",
70 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
71 | )
72 |
73 | # 1. Check inputs. Raise error if not correct
74 | self.check_inputs(
75 | prompt,
76 | prompt_2,
77 | strength,
78 | num_inversion_steps,
79 | callback_steps,
80 | negative_prompt,
81 | negative_prompt_2,
82 | prompt_embeds,
83 | negative_prompt_embeds,
84 | callback_on_step_end_tensor_inputs,
85 | )
86 |
87 | self._guidance_scale = guidance_scale
88 | self._guidance_rescale = guidance_rescale
89 | self._clip_skip = clip_skip
90 | self._cross_attention_kwargs = cross_attention_kwargs
91 | self._denoising_end = denoising_end
92 | self._denoising_start = denoising_start
93 |
94 | # 2. Define call parameters
95 | if prompt is not None and isinstance(prompt, str):
96 | batch_size = 1
97 | elif prompt is not None and isinstance(prompt, list):
98 | batch_size = len(prompt)
99 | else:
100 | batch_size = prompt_embeds.shape[0]
101 |
102 | device = self._execution_device
103 |
104 | # 3. Encode input prompt
105 | text_encoder_lora_scale = (
106 | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
107 | )
108 | (
109 | prompt_embeds,
110 | negative_prompt_embeds,
111 | pooled_prompt_embeds,
112 | negative_pooled_prompt_embeds,
113 | ) = self.encode_prompt(
114 | prompt=prompt,
115 | prompt_2=prompt_2,
116 | device=device,
117 | num_images_per_prompt=num_images_per_prompt,
118 | do_classifier_free_guidance=self.do_classifier_free_guidance,
119 | negative_prompt=negative_prompt,
120 | negative_prompt_2=negative_prompt_2,
121 | prompt_embeds=prompt_embeds,
122 | negative_prompt_embeds=negative_prompt_embeds,
123 | pooled_prompt_embeds=pooled_prompt_embeds,
124 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
125 | lora_scale=text_encoder_lora_scale,
126 | clip_skip=self.clip_skip,
127 | )
128 |
129 | # 4. Preprocess image
130 | image = self.image_processor.preprocess(image)
131 |
132 | # 5. Prepare timesteps
133 | def denoising_value_valid(dnv):
134 | return isinstance(self.denoising_end, float) and 0 < dnv < 1
135 |
136 | timesteps, num_inversion_steps = retrieve_timesteps(self.scheduler, num_inversion_steps, device, timesteps)
137 |
138 | timesteps, num_inversion_steps = self.get_timesteps(
139 | num_inversion_steps,
140 | strength,
141 | device,
142 | denoising_start=self.denoising_start if denoising_value_valid else None,
143 | )
144 |
145 | # 6. Prepare latent variables
146 | with torch.no_grad():
147 | latents = self.prepare_latents(
148 | image,
149 | None,
150 | batch_size,
151 | num_images_per_prompt,
152 | prompt_embeds.dtype,
153 | device,
154 | generator,
155 | False,
156 | )
157 | # 7. Prepare extra step kwargs.
158 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
159 |
160 | height, width = latents.shape[-2:]
161 | height = height * self.vae_scale_factor
162 | width = width * self.vae_scale_factor
163 |
164 | original_size = original_size or (height, width)
165 | target_size = target_size or (height, width)
166 |
167 | # 8. Prepare added time ids & embeddings
168 | if negative_original_size is None:
169 | negative_original_size = original_size
170 | if negative_target_size is None:
171 | negative_target_size = target_size
172 |
173 | add_text_embeds = pooled_prompt_embeds
174 | if self.text_encoder_2 is None:
175 | text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
176 | else:
177 | text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
178 |
179 | add_time_ids, add_neg_time_ids = self._get_add_time_ids(
180 | original_size,
181 | crops_coords_top_left,
182 | target_size,
183 | aesthetic_score,
184 | negative_aesthetic_score,
185 | negative_original_size,
186 | negative_crops_coords_top_left,
187 | negative_target_size,
188 | dtype=prompt_embeds.dtype,
189 | text_encoder_projection_dim=text_encoder_projection_dim,
190 | )
191 | add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
192 |
193 | if self.do_classifier_free_guidance:
194 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
195 | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
196 | add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
197 | add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
198 |
199 | prompt_embeds = prompt_embeds.to(device)
200 | add_text_embeds = add_text_embeds.to(device)
201 | add_time_ids = add_time_ids.to(device)
202 |
203 | if ip_adapter_image is not None:
204 | image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
205 | if self.do_classifier_free_guidance:
206 | image_embeds = torch.cat([negative_image_embeds, image_embeds])
207 | image_embeds = image_embeds.to(device)
208 |
209 | # 9. Denoising loop
210 | num_warmup_steps = max(len(timesteps) - num_inversion_steps * self.scheduler.order, 0)
211 |
212 | self._num_timesteps = len(timesteps)
213 | self.z_0 = torch.clone(latents)
214 | self.noise = randn_tensor(self.z_0.shape, generator=generator, device=self.z_0.device, dtype=self.z_0.dtype)
215 |
216 | all_latents = [latents.clone()]
217 | with self.progress_bar(total=num_inversion_steps) as progress_bar:
218 | for i, t in enumerate(reversed(timesteps)):
219 |
220 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
221 | if ip_adapter_image is not None:
222 | added_cond_kwargs["image_embeds"] = image_embeds
223 |
224 | latents = inversion_step(self,
225 | latents,
226 | t,
227 | prompt_embeds,
228 | added_cond_kwargs,
229 | num_renoise_steps=num_renoise_steps,
230 | generator=generator)
231 |
232 | all_latents.append(latents.clone())
233 |
234 | if callback_on_step_end is not None:
235 | callback_kwargs = {}
236 | for k in callback_on_step_end_tensor_inputs:
237 | callback_kwargs[k] = locals()[k]
238 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
239 |
240 | latents = callback_outputs.pop("latents", latents)
241 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
242 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
243 | add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
244 | negative_pooled_prompt_embeds = callback_outputs.pop(
245 | "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
246 | )
247 | add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
248 | add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
249 |
250 | # call the callback, if provided
251 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
252 | progress_bar.update()
253 | if callback is not None and i % callback_steps == 0:
254 | step_idx = i // getattr(self.scheduler, "order", 1)
255 | callback(step_idx, t, latents)
256 |
257 | image = latents
258 |
259 | # Offload all models
260 | self.maybe_free_model_hooks()
261 |
262 | return StableDiffusionXLPipelineOutput(images=image), all_latents, [prompt_embeds, added_cond_kwargs]
263 |
--------------------------------------------------------------------------------
/src/renoise_inversion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | # Based on code from https://github.com/pix2pixzero/pix2pix-zero
5 | def noise_regularization(
6 | e_t, noise_pred_optimal, lambda_kl, lambda_ac, num_reg_steps, num_ac_rolls, generator=None
7 | ):
8 | for _outer in range(num_reg_steps):
9 | if lambda_kl > 0:
10 | _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
11 | l_kld = patchify_latents_kl_divergence(_var, noise_pred_optimal)
12 | l_kld.backward()
13 | _grad = _var.grad.detach()
14 | _grad = torch.clip(_grad, -100, 100)
15 | e_t = e_t - lambda_kl * _grad
16 | if lambda_ac > 0:
17 | for _inner in range(num_ac_rolls):
18 | _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
19 | l_ac = auto_corr_loss(_var, generator=generator)
20 | l_ac.backward()
21 | _grad = _var.grad.detach() / num_ac_rolls
22 | e_t = e_t - lambda_ac * _grad
23 | e_t = e_t.detach()
24 |
25 | return e_t
26 |
27 | # Based on code from https://github.com/pix2pixzero/pix2pix-zero
28 | def auto_corr_loss(
29 | x, random_shift=True, generator=None
30 | ):
31 | B, C, H, W = x.shape
32 | assert B == 1
33 | x = x.squeeze(0)
34 | # x must be shape [C,H,W] now
35 | reg_loss = 0.0
36 | for ch_idx in range(x.shape[0]):
37 | noise = x[ch_idx][None, None, :, :]
38 | while True:
39 | if random_shift:
40 | roll_amount = torch.randint(0, noise.shape[2] // 2, (1,), generator=generator).item()
41 | else:
42 | roll_amount = 1
43 | reg_loss += (
44 | noise * torch.roll(noise, shifts=roll_amount, dims=2)
45 | ).mean() ** 2
46 | reg_loss += (
47 | noise * torch.roll(noise, shifts=roll_amount, dims=3)
48 | ).mean() ** 2
49 | if noise.shape[2] <= 8:
50 | break
51 | noise = F.avg_pool2d(noise, kernel_size=2)
52 | return reg_loss
53 |
54 |
55 | def patchify_latents_kl_divergence(x0, x1, patch_size=4, num_channels=4):
56 |
57 | def patchify_tensor(input_tensor):
58 | patches = (
59 | input_tensor.unfold(1, patch_size, patch_size)
60 | .unfold(2, patch_size, patch_size)
61 | .unfold(3, patch_size, patch_size)
62 | )
63 | patches = patches.contiguous().view(-1, num_channels, patch_size, patch_size)
64 | return patches
65 |
66 | x0 = patchify_tensor(x0)
67 | x1 = patchify_tensor(x1)
68 |
69 | kl = latents_kl_divergence(x0, x1).sum()
70 | return kl
71 |
72 |
73 | def latents_kl_divergence(x0, x1):
74 | EPSILON = 1e-6
75 | x0 = x0.view(x0.shape[0], x0.shape[1], -1)
76 | x1 = x1.view(x1.shape[0], x1.shape[1], -1)
77 | mu0 = x0.mean(dim=-1)
78 | mu1 = x1.mean(dim=-1)
79 | var0 = x0.var(dim=-1)
80 | var1 = x1.var(dim=-1)
81 | kl = (
82 | torch.log((var1 + EPSILON) / (var0 + EPSILON))
83 | + (var0 + (mu0 - mu1) ** 2) / (var1 + EPSILON)
84 | - 1
85 | )
86 | kl = torch.abs(kl).sum(dim=-1)
87 | return kl
88 |
89 | def inversion_step(
90 | pipe,
91 | z_t: torch.tensor,
92 | t: torch.tensor,
93 | prompt_embeds,
94 | added_cond_kwargs,
95 | num_renoise_steps: int = 100,
96 | first_step_max_timestep: int = 250,
97 | generator=None,
98 | ) -> torch.tensor:
99 | extra_step_kwargs = {}
100 | avg_range = pipe.cfg.average_first_step_range if t.item() < first_step_max_timestep else pipe.cfg.average_step_range
101 | num_renoise_steps = min(pipe.cfg.max_num_renoise_steps_first_step, num_renoise_steps) if t.item() < first_step_max_timestep else num_renoise_steps
102 |
103 | nosie_pred_avg = None
104 | noise_pred_optimal = None
105 | z_tp1_forward = pipe.scheduler.add_noise(pipe.z_0, pipe.noise, t.view((1))).detach()
106 |
107 | approximated_z_tp1 = z_t.clone()
108 | for i in range(num_renoise_steps + 1):
109 |
110 | with torch.no_grad():
111 | # if noise regularization is enabled, we need to double the batch size for the first step
112 | if pipe.cfg.noise_regularization_num_reg_steps > 0 and i == 0:
113 | approximated_z_tp1 = torch.cat([z_tp1_forward, approximated_z_tp1])
114 | prompt_embeds_in = torch.cat([prompt_embeds, prompt_embeds])
115 | if added_cond_kwargs is not None:
116 | added_cond_kwargs_in = {}
117 | added_cond_kwargs_in['text_embeds'] = torch.cat([added_cond_kwargs['text_embeds'], added_cond_kwargs['text_embeds']])
118 | added_cond_kwargs_in['time_ids'] = torch.cat([added_cond_kwargs['time_ids'], added_cond_kwargs['time_ids']])
119 | else:
120 | added_cond_kwargs_in = None
121 | else:
122 | prompt_embeds_in = prompt_embeds
123 | added_cond_kwargs_in = added_cond_kwargs
124 |
125 | noise_pred = unet_pass(pipe, approximated_z_tp1, t, prompt_embeds_in, added_cond_kwargs_in)
126 |
127 | # if noise regularization is enabled, we need to split the batch size for the first step
128 | if pipe.cfg.noise_regularization_num_reg_steps > 0 and i == 0:
129 | noise_pred_optimal, noise_pred = noise_pred.chunk(2)
130 | if pipe.do_classifier_free_guidance:
131 | noise_pred_optimal_uncond, noise_pred_optimal_text = noise_pred_optimal.chunk(2)
132 | noise_pred_optimal = noise_pred_optimal_uncond + pipe.guidance_scale * (noise_pred_optimal_text - noise_pred_optimal_uncond)
133 | noise_pred_optimal = noise_pred_optimal.detach()
134 |
135 | # perform guidance
136 | if pipe.do_classifier_free_guidance:
137 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
138 | noise_pred = noise_pred_uncond + pipe.guidance_scale * (noise_pred_text - noise_pred_uncond)
139 |
140 | # Calculate average noise
141 | if i >= avg_range[0] and i < avg_range[1]:
142 | j = i - avg_range[0]
143 | if nosie_pred_avg is None:
144 | nosie_pred_avg = noise_pred.clone()
145 | else:
146 | nosie_pred_avg = j * nosie_pred_avg / (j + 1) + noise_pred / (j + 1)
147 |
148 | if i >= avg_range[0] or (not pipe.cfg.average_latent_estimations and i > 0):
149 | noise_pred = noise_regularization(noise_pred, noise_pred_optimal, lambda_kl=pipe.cfg.noise_regularization_lambda_kl, lambda_ac=pipe.cfg.noise_regularization_lambda_ac, num_reg_steps=pipe.cfg.noise_regularization_num_reg_steps, num_ac_rolls=pipe.cfg.noise_regularization_num_ac_rolls, generator=generator)
150 |
151 | approximated_z_tp1 = pipe.scheduler.inv_step(noise_pred, t, z_t, **extra_step_kwargs, return_dict=False)[0].detach()
152 |
153 | # if average latents is enabled, we need to perform an additional step with the average noise
154 | if pipe.cfg.average_latent_estimations and nosie_pred_avg is not None:
155 | nosie_pred_avg = noise_regularization(nosie_pred_avg, noise_pred_optimal, lambda_kl=pipe.cfg.noise_regularization_lambda_kl, lambda_ac=pipe.cfg.noise_regularization_lambda_ac, num_reg_steps=pipe.cfg.noise_regularization_num_reg_steps, num_ac_rolls=pipe.cfg.noise_regularization_num_ac_rolls, generator=generator)
156 | approximated_z_tp1 = pipe.scheduler.inv_step(nosie_pred_avg, t, z_t, **extra_step_kwargs, return_dict=False)[0].detach()
157 |
158 | # perform noise correction
159 | if pipe.cfg.perform_noise_correction:
160 | noise_pred = unet_pass(pipe, approximated_z_tp1, t, prompt_embeds, added_cond_kwargs)
161 |
162 | # perform guidance
163 | if pipe.do_classifier_free_guidance:
164 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
165 | noise_pred = noise_pred_uncond + pipe.guidance_scale * (noise_pred_text - noise_pred_uncond)
166 |
167 | pipe.scheduler.step_and_update_noise(noise_pred, t, approximated_z_tp1, z_t, return_dict=False, optimize_epsilon_type=pipe.cfg.perform_noise_correction)
168 |
169 | return approximated_z_tp1
170 |
171 | @torch.no_grad()
172 | def unet_pass(pipe, z_t, t, prompt_embeds, added_cond_kwargs):
173 | latent_model_input = torch.cat([z_t] * 2) if pipe.do_classifier_free_guidance else z_t
174 | latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
175 | return pipe.unet(
176 | latent_model_input,
177 | t,
178 | encoder_hidden_states=prompt_embeds,
179 | timestep_cond=None,
180 | cross_attention_kwargs=pipe.cross_attention_kwargs,
181 | added_cond_kwargs=added_cond_kwargs,
182 | return_dict=False,
183 | )[0]
184 |
--------------------------------------------------------------------------------
/src/schedulers/ddim_scheduler.py:
--------------------------------------------------------------------------------
1 | from diffusers import DDIMScheduler
2 | from diffusers.utils import BaseOutput
3 | from diffusers.utils.torch_utils import randn_tensor
4 | import torch
5 | from typing import List, Optional, Tuple, Union
6 | import numpy as np
7 |
8 | class DDIMSchedulerOutput(BaseOutput):
9 | """
10 | Output class for the scheduler's `step` function output.
11 |
12 | Args:
13 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
14 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
15 | denoising loop.
16 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
17 | The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
18 | `pred_original_sample` can be used to preview progress or for guidance.
19 | """
20 |
21 | prev_sample: torch.FloatTensor
22 | pred_original_sample: Optional[torch.FloatTensor] = None
23 |
24 | class MyDDIMScheduler(DDIMScheduler):
25 |
26 | def inv_step(
27 | self,
28 | model_output: torch.FloatTensor,
29 | timestep: int,
30 | sample: torch.FloatTensor,
31 | eta: float = 0.0,
32 | use_clipped_model_output: bool = False,
33 | generator=None,
34 | variance_noise: Optional[torch.FloatTensor] = None,
35 | return_dict: bool = True,
36 | ) -> Union[DDIMSchedulerOutput, Tuple]:
37 | """
38 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
39 | process from the learned model outputs (most often the predicted noise).
40 |
41 | Args:
42 | model_output (`torch.FloatTensor`):
43 | The direct output from learned diffusion model.
44 | timestep (`float`):
45 | The current discrete timestep in the diffusion chain.
46 | sample (`torch.FloatTensor`):
47 | A current instance of a sample created by the diffusion process.
48 | eta (`float`):
49 | The weight of noise for added noise in diffusion step.
50 | use_clipped_model_output (`bool`, defaults to `False`):
51 | If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
52 | because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
53 | clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
54 | `use_clipped_model_output` has no effect.
55 | generator (`torch.Generator`, *optional*):
56 | A random number generator.
57 | variance_noise (`torch.FloatTensor`):
58 | Alternative to generating noise with `generator` by directly providing the noise for the variance
59 | itself. Useful for methods such as [`CycleDiffusion`].
60 | return_dict (`bool`, *optional*, defaults to `True`):
61 | Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
62 |
63 | Returns:
64 | [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
65 | If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
66 | tuple is returned where the first element is the sample tensor.
67 |
68 | """
69 | if self.num_inference_steps is None:
70 | raise ValueError(
71 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
72 | )
73 |
74 | # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
75 | # Ideally, read DDIM paper in-detail understanding
76 |
77 | # Notation ( ->
78 | # - pred_noise_t -> e_theta(x_t, t)
79 | # - pred_original_sample -> f_theta(x_t, t) or x_0
80 | # - std_dev_t -> sigma_t
81 | # - eta -> η
82 | # - pred_sample_direction -> "direction pointing to x_t"
83 | # - pred_prev_sample -> "x_t-1"
84 |
85 | # 1. get previous step value (=t-1)
86 | prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
87 |
88 | # 2. compute alphas, betas
89 | alpha_prod_t = self.alphas_cumprod[timestep]
90 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
91 |
92 | beta_prod_t = 1 - alpha_prod_t
93 |
94 | # 3. compute predicted original sample from predicted noise also called
95 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
96 | assert self.config.prediction_type == "epsilon"
97 | if self.config.prediction_type == "epsilon":
98 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
99 | pred_epsilon = model_output
100 | elif self.config.prediction_type == "sample":
101 | pred_original_sample = model_output
102 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
103 | elif self.config.prediction_type == "v_prediction":
104 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
105 | pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
106 | else:
107 | raise ValueError(
108 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
109 | " `v_prediction`"
110 | )
111 |
112 | # 4. Clip or threshold "predicted x_0"
113 | if self.config.thresholding:
114 | pred_original_sample = self._threshold_sample(pred_original_sample)
115 | elif self.config.clip_sample:
116 | pred_original_sample = pred_original_sample.clamp(
117 | -self.config.clip_sample_range, self.config.clip_sample_range
118 | )
119 |
120 | # 5. compute variance: "sigma_t(η)" -> see formula (16)
121 | # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
122 | variance = self._get_variance(timestep, prev_timestep)
123 | std_dev_t = eta * variance ** (0.5)
124 |
125 | if use_clipped_model_output:
126 | # the pred_epsilon is always re-derived from the clipped x_0 in Glide
127 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
128 |
129 | # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
130 | pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
131 |
132 | # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
133 | # prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
134 |
135 | prev_sample = (alpha_prod_t ** (0.5) * sample) / alpha_prod_t_prev ** (0.5) + (alpha_prod_t_prev ** (0.5) * beta_prod_t ** (0.5) * model_output) / alpha_prod_t_prev ** (0.5) - (alpha_prod_t ** (0.5) * pred_sample_direction) / alpha_prod_t_prev ** (0.5)
136 |
137 | if eta > 0:
138 | if variance_noise is not None and generator is not None:
139 | raise ValueError(
140 | "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
141 | " `variance_noise` stays `None`."
142 | )
143 |
144 | if variance_noise is None:
145 | variance_noise = randn_tensor(
146 | model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
147 | )
148 | variance = std_dev_t * variance_noise
149 |
150 | prev_sample = prev_sample + variance
151 |
152 | if not return_dict:
153 | return (prev_sample,)
154 |
155 | return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
--------------------------------------------------------------------------------
/src/schedulers/euler_scheduler.py:
--------------------------------------------------------------------------------
1 | from diffusers import EulerAncestralDiscreteScheduler
2 | from diffusers.utils import BaseOutput
3 | import torch
4 | from typing import List, Optional, Tuple, Union
5 | import numpy as np
6 |
7 | class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
8 | """
9 | Output class for the scheduler's `step` function output.
10 |
11 | Args:
12 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
13 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
14 | denoising loop.
15 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
16 | The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
17 | `pred_original_sample` can be used to preview progress or for guidance.
18 | """
19 |
20 | prev_sample: torch.FloatTensor
21 | pred_original_sample: Optional[torch.FloatTensor] = None
22 |
23 | class MyEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
24 | def set_noise_list(self, noise_list):
25 | self.noise_list = noise_list
26 |
27 | def get_noise_to_remove(self):
28 | sigma_from = self.sigmas[self.step_index]
29 | sigma_to = self.sigmas[self.step_index + 1]
30 | sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
31 |
32 | return self.noise_list[self.step_index] * sigma_up\
33 |
34 | def scale_model_input(
35 | self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
36 | ) -> torch.FloatTensor:
37 | """
38 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
39 | current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
40 |
41 | Args:
42 | sample (`torch.FloatTensor`):
43 | The input sample.
44 | timestep (`int`, *optional*):
45 | The current timestep in the diffusion chain.
46 |
47 | Returns:
48 | `torch.FloatTensor`:
49 | A scaled input sample.
50 | """
51 |
52 | self._init_step_index(timestep.view((1)))
53 | return EulerAncestralDiscreteScheduler.scale_model_input(self, sample, timestep)
54 |
55 |
56 | def step(
57 | self,
58 | model_output: torch.FloatTensor,
59 | timestep: Union[float, torch.FloatTensor],
60 | sample: torch.FloatTensor,
61 | generator: Optional[torch.Generator] = None,
62 | return_dict: bool = True,
63 | ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
64 | """
65 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
66 | process from the learned model outputs (most often the predicted noise).
67 |
68 | Args:
69 | model_output (`torch.FloatTensor`):
70 | The direct output from learned diffusion model.
71 | timestep (`float`):
72 | The current discrete timestep in the diffusion chain.
73 | sample (`torch.FloatTensor`):
74 | A current instance of a sample created by the diffusion process.
75 | generator (`torch.Generator`, *optional*):
76 | A random number generator.
77 | return_dict (`bool`):
78 | Whether or not to return a
79 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
80 |
81 | Returns:
82 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
83 | If return_dict is `True`,
84 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
85 | otherwise a tuple is returned where the first element is the sample tensor.
86 |
87 | """
88 |
89 | if (
90 | isinstance(timestep, int)
91 | or isinstance(timestep, torch.IntTensor)
92 | or isinstance(timestep, torch.LongTensor)
93 | ):
94 | raise ValueError(
95 | (
96 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
97 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
98 | " one of the `scheduler.timesteps` as a timestep."
99 | ),
100 | )
101 |
102 | if not self.is_scale_input_called:
103 | logger.warning(
104 | "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
105 | "See `StableDiffusionPipeline` for a usage example."
106 | )
107 |
108 | self._init_step_index(timestep.view((1)))
109 |
110 | sigma = self.sigmas[self.step_index]
111 |
112 | # Upcast to avoid precision issues when computing prev_sample
113 | sample = sample.to(torch.float32)
114 |
115 | # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
116 | if self.config.prediction_type == "epsilon":
117 | pred_original_sample = sample - sigma * model_output
118 | elif self.config.prediction_type == "v_prediction":
119 | # * c_out + input * c_skip
120 | pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
121 | elif self.config.prediction_type == "sample":
122 | raise NotImplementedError("prediction_type not implemented yet: sample")
123 | else:
124 | raise ValueError(
125 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
126 | )
127 |
128 | sigma_from = self.sigmas[self.step_index]
129 | sigma_to = self.sigmas[self.step_index + 1]
130 | sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
131 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
132 |
133 | # 2. Convert to an ODE derivative
134 | # derivative = (sample - pred_original_sample) / sigma
135 | derivative = model_output
136 |
137 | dt = sigma_down - sigma
138 |
139 | prev_sample = sample + derivative * dt
140 |
141 | device = model_output.device
142 | # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
143 | # prev_sample = prev_sample + noise * sigma_up
144 |
145 | prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
146 |
147 | # Cast sample back to model compatible dtype
148 | prev_sample = prev_sample.to(model_output.dtype)
149 |
150 | # upon completion increase step index by one
151 | self._step_index += 1
152 |
153 | if not return_dict:
154 | return (prev_sample,)
155 |
156 | return EulerAncestralDiscreteSchedulerOutput(
157 | prev_sample=prev_sample, pred_original_sample=pred_original_sample
158 | )
159 |
160 | def step_and_update_noise(
161 | self,
162 | model_output: torch.FloatTensor,
163 | timestep: Union[float, torch.FloatTensor],
164 | sample: torch.FloatTensor,
165 | expected_prev_sample: torch.FloatTensor,
166 | optimize_epsilon_type: bool = False,
167 | generator: Optional[torch.Generator] = None,
168 | return_dict: bool = True,
169 | ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
170 | """
171 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
172 | process from the learned model outputs (most often the predicted noise).
173 |
174 | Args:
175 | model_output (`torch.FloatTensor`):
176 | The direct output from learned diffusion model.
177 | timestep (`float`):
178 | The current discrete timestep in the diffusion chain.
179 | sample (`torch.FloatTensor`):
180 | A current instance of a sample created by the diffusion process.
181 | generator (`torch.Generator`, *optional*):
182 | A random number generator.
183 | return_dict (`bool`):
184 | Whether or not to return a
185 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
186 |
187 | Returns:
188 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
189 | If return_dict is `True`,
190 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
191 | otherwise a tuple is returned where the first element is the sample tensor.
192 |
193 | """
194 |
195 | if (
196 | isinstance(timestep, int)
197 | or isinstance(timestep, torch.IntTensor)
198 | or isinstance(timestep, torch.LongTensor)
199 | ):
200 | raise ValueError(
201 | (
202 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
203 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
204 | " one of the `scheduler.timesteps` as a timestep."
205 | ),
206 | )
207 |
208 | if not self.is_scale_input_called:
209 | logger.warning(
210 | "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
211 | "See `StableDiffusionPipeline` for a usage example."
212 | )
213 |
214 | self._init_step_index(timestep.view((1)))
215 |
216 | sigma = self.sigmas[self.step_index]
217 |
218 | # Upcast to avoid precision issues when computing prev_sample
219 | sample = sample.to(torch.float32)
220 |
221 | # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
222 | if self.config.prediction_type == "epsilon":
223 | pred_original_sample = sample - sigma * model_output
224 | elif self.config.prediction_type == "v_prediction":
225 | # * c_out + input * c_skip
226 | pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
227 | elif self.config.prediction_type == "sample":
228 | raise NotImplementedError("prediction_type not implemented yet: sample")
229 | else:
230 | raise ValueError(
231 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
232 | )
233 |
234 | sigma_from = self.sigmas[self.step_index]
235 | sigma_to = self.sigmas[self.step_index + 1]
236 | sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
237 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
238 |
239 | # 2. Convert to an ODE derivative
240 | # derivative = (sample - pred_original_sample) / sigma
241 | derivative = model_output
242 |
243 | dt = sigma_down - sigma
244 |
245 | prev_sample = sample + derivative * dt
246 |
247 | device = model_output.device
248 | # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
249 | # prev_sample = prev_sample + noise * sigma_up
250 |
251 | if sigma_up > 0:
252 | req_noise = (expected_prev_sample - prev_sample) / sigma_up
253 | if not optimize_epsilon_type:
254 | self.noise_list[self.step_index] = req_noise
255 | else:
256 | for i in range(10):
257 | n = torch.autograd.Variable(self.noise_list[self.step_index].detach().clone(), requires_grad=True)
258 | loss = torch.norm(n - req_noise.detach())
259 | loss.backward()
260 | self.noise_list[self.step_index] -= n.grad.detach() * 1.8
261 |
262 |
263 | prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
264 |
265 | # Cast sample back to model compatible dtype
266 | prev_sample = prev_sample.to(model_output.dtype)
267 |
268 | # upon completion increase step index by one
269 | self._step_index += 1
270 |
271 | if not return_dict:
272 | return (prev_sample,)
273 |
274 | return EulerAncestralDiscreteSchedulerOutput(
275 | prev_sample=prev_sample, pred_original_sample=pred_original_sample
276 | )
277 |
278 | def inv_step(
279 | self,
280 | model_output: torch.FloatTensor,
281 | timestep: Union[float, torch.FloatTensor],
282 | sample: torch.FloatTensor,
283 | generator: Optional[torch.Generator] = None,
284 | return_dict: bool = True,
285 | ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
286 | """
287 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
288 | process from the learned model outputs (most often the predicted noise).
289 |
290 | Args:
291 | model_output (`torch.FloatTensor`):
292 | The direct output from learned diffusion model.
293 | timestep (`float`):
294 | The current discrete timestep in the diffusion chain.
295 | sample (`torch.FloatTensor`):
296 | A current instance of a sample created by the diffusion process.
297 | generator (`torch.Generator`, *optional*):
298 | A random number generator.
299 | return_dict (`bool`):
300 | Whether or not to return a
301 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
302 |
303 | Returns:
304 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
305 | If return_dict is `True`,
306 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
307 | otherwise a tuple is returned where the first element is the sample tensor.
308 |
309 | """
310 |
311 | if (
312 | isinstance(timestep, int)
313 | or isinstance(timestep, torch.IntTensor)
314 | or isinstance(timestep, torch.LongTensor)
315 | ):
316 | raise ValueError(
317 | (
318 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
319 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
320 | " one of the `scheduler.timesteps` as a timestep."
321 | ),
322 | )
323 |
324 | if not self.is_scale_input_called:
325 | logger.warning(
326 | "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
327 | "See `StableDiffusionPipeline` for a usage example."
328 | )
329 |
330 | self._init_step_index(timestep.view((1)))
331 |
332 | sigma = self.sigmas[self.step_index]
333 |
334 | # Upcast to avoid precision issues when computing prev_sample
335 | sample = sample.to(torch.float32)
336 |
337 | # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
338 | if self.config.prediction_type == "epsilon":
339 | pred_original_sample = sample - sigma * model_output
340 | elif self.config.prediction_type == "v_prediction":
341 | # * c_out + input * c_skip
342 | pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
343 | elif self.config.prediction_type == "sample":
344 | raise NotImplementedError("prediction_type not implemented yet: sample")
345 | else:
346 | raise ValueError(
347 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
348 | )
349 |
350 | sigma_from = self.sigmas[self.step_index]
351 | sigma_to = self.sigmas[self.step_index+1]
352 | # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
353 | sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2).abs() / sigma_from**2) ** 0.5
354 | # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
355 | sigma_down = sigma_to**2 / sigma_from
356 |
357 | # 2. Convert to an ODE derivative
358 | # derivative = (sample - pred_original_sample) / sigma
359 | derivative = model_output
360 |
361 | dt = sigma_down - sigma
362 | # dt = sigma_down - sigma_from
363 |
364 | prev_sample = sample - derivative * dt
365 |
366 | device = model_output.device
367 | # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
368 | # prev_sample = prev_sample + noise * sigma_up
369 |
370 | prev_sample = prev_sample - self.noise_list[self.step_index] * sigma_up
371 |
372 | # Cast sample back to model compatible dtype
373 | prev_sample = prev_sample.to(model_output.dtype)
374 |
375 | # upon completion increase step index by one
376 | self._step_index += 1
377 |
378 | if not return_dict:
379 | return (prev_sample,)
380 |
381 | return EulerAncestralDiscreteSchedulerOutput(
382 | prev_sample=prev_sample, pred_original_sample=pred_original_sample
383 | )
384 |
385 | def get_all_sigmas(self) -> torch.FloatTensor:
386 | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
387 | sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
388 | return torch.from_numpy(sigmas)
389 |
390 | def add_noise_off_schedule(
391 | self,
392 | original_samples: torch.FloatTensor,
393 | noise: torch.FloatTensor,
394 | timesteps: torch.FloatTensor,
395 | ) -> torch.FloatTensor:
396 | # Make sure sigmas and timesteps have the same device and dtype as original_samples
397 | sigmas = self.get_all_sigmas()
398 | sigmas = sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
399 | if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
400 | # mps does not support float64
401 | timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
402 | else:
403 | timesteps = timesteps.to(original_samples.device)
404 |
405 | step_indices = 1000 - int(timesteps.item())
406 |
407 | sigma = sigmas[step_indices].flatten()
408 | while len(sigma.shape) < len(original_samples.shape):
409 | sigma = sigma.unsqueeze(-1)
410 |
411 | noisy_samples = original_samples + noise * sigma
412 | return noisy_samples
413 |
414 | # def update_noise_for_friendly_inversion(
415 | # self,
416 | # model_output: torch.FloatTensor,
417 | # timestep: Union[float, torch.FloatTensor],
418 | # z_t: torch.FloatTensor,
419 | # z_tp1: torch.FloatTensor,
420 | # return_dict: bool = True,
421 | # ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
422 | # if (
423 | # isinstance(timestep, int)
424 | # or isinstance(timestep, torch.IntTensor)
425 | # or isinstance(timestep, torch.LongTensor)
426 | # ):
427 | # raise ValueError(
428 | # (
429 | # "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
430 | # " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
431 | # " one of the `scheduler.timesteps` as a timestep."
432 | # ),
433 | # )
434 |
435 | # if not self.is_scale_input_called:
436 | # logger.warning(
437 | # "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
438 | # "See `StableDiffusionPipeline` for a usage example."
439 | # )
440 |
441 | # self._init_step_index(timestep.view((1)))
442 |
443 | # sigma = self.sigmas[self.step_index]
444 |
445 | # sigma_from = self.sigmas[self.step_index]
446 | # sigma_to = self.sigmas[self.step_index+1]
447 | # # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
448 | # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2).abs() / sigma_from**2) ** 0.5
449 | # # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
450 | # sigma_down = sigma_to**2 / sigma_from
451 |
452 | # # 2. Conv = (sample - pred_original_sample) / sigma
453 | # derivative = model_output
454 |
455 | # dt = sigma_down - sigma
456 | # # dt = sigma_down - sigma_from
457 |
458 | # prev_sample = z_t - derivative * dt
459 |
460 | # if sigma_up > 0:
461 | # self.noise_list[self.step_index] = (prev_sample - z_tp1) / sigma_up
462 |
463 | # prev_sample = prev_sample - self.noise_list[self.step_index] * sigma_up
464 |
465 |
466 | # if not return_dict:
467 | # return (prev_sample,)
468 |
469 | # return EulerAncestralDiscreteSchedulerOutput(
470 | # prev_sample=prev_sample, pred_original_sample=None
471 | # )
472 |
473 |
474 | # def step_friendly_inversion(
475 | # self,
476 | # model_output: torch.FloatTensor,
477 | # timestep: Union[float, torch.FloatTensor],
478 | # sample: torch.FloatTensor,
479 | # generator: Optional[torch.Generator] = None,
480 | # return_dict: bool = True,
481 | # expected_next_sample: torch.FloatTensor = None,
482 | # ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
483 | # """
484 | # Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
485 | # process from the learned model outputs (most often the predicted noise).
486 |
487 | # Args:
488 | # model_output (`torch.FloatTensor`):
489 | # The direct output from learned diffusion model.
490 | # timestep (`float`):
491 | # The current discrete timestep in the diffusion chain.
492 | # sample (`torch.FloatTensor`):
493 | # A current instance of a sample created by the diffusion process.
494 | # generator (`torch.Generator`, *optional*):
495 | # A random number generator.
496 | # return_dict (`bool`):
497 | # Whether or not to return a
498 | # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
499 |
500 | # Returns:
501 | # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
502 | # If return_dict is `True`,
503 | # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
504 | # otherwise a tuple is returned where the first element is the sample tensor.
505 |
506 | # """
507 |
508 | # if (
509 | # isinstance(timestep, int)
510 | # or isinstance(timestep, torch.IntTensor)
511 | # or isinstance(timestep, torch.LongTensor)
512 | # ):
513 | # raise ValueError(
514 | # (
515 | # "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
516 | # " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
517 | # " one of the `scheduler.timesteps` as a timestep."
518 | # ),
519 | # )
520 |
521 | # if not self.is_scale_input_called:
522 | # logger.warning(
523 | # "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
524 | # "See `StableDiffusionPipeline` for a usage example."
525 | # )
526 |
527 | # self._init_step_index(timestep.view((1)))
528 |
529 | # sigma = self.sigmas[self.step_index]
530 |
531 | # # Upcast to avoid precision issues when computing prev_sample
532 | # sample = sample.to(torch.float32)
533 |
534 | # # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
535 | # if self.config.prediction_type == "epsilon":
536 | # pred_original_sample = sample - sigma * model_output
537 | # elif self.config.prediction_type == "v_prediction":
538 | # # * c_out + input * c_skip
539 | # pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
540 | # elif self.config.prediction_type == "sample":
541 | # raise NotImplementedError("prediction_type not implemented yet: sample")
542 | # else:
543 | # raise ValueError(
544 | # f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
545 | # )
546 |
547 | # sigma_from = self.sigmas[self.step_index]
548 | # sigma_to = self.sigmas[self.step_index + 1]
549 | # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
550 | # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
551 |
552 | # # 2. Convert to an ODE derivative
553 | # # derivative = (sample - pred_original_sample) / sigma
554 | # derivative = model_output
555 |
556 | # dt = sigma_down - sigma
557 |
558 | # prev_sample = sample + derivative * dt
559 |
560 | # device = model_output.device
561 | # # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
562 | # # prev_sample = prev_sample + noise * sigma_up
563 |
564 | # if sigma_up > 0:
565 | # self.noise_list[self.step_index] = (expected_next_sample - prev_sample) / sigma_up
566 |
567 | # prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
568 |
569 | # # Cast sample back to model compatible dtype
570 | # prev_sample = prev_sample.to(model_output.dtype)
571 |
572 | # # upon completion increase step index by one
573 | # self._step_index += 1
574 |
575 | # if not return_dict:
576 | # return (prev_sample,)
577 |
578 | # return EulerAncestralDiscreteSchedulerOutput(
579 | # prev_sample=prev_sample, pred_original_sample=pred_original_sample
580 | # )
--------------------------------------------------------------------------------
/src/schedulers/lcm_scheduler.py:
--------------------------------------------------------------------------------
1 | from diffusers import LCMScheduler
2 | from diffusers.utils import BaseOutput
3 | from diffusers.utils.torch_utils import randn_tensor
4 | import torch
5 | from typing import List, Optional, Tuple, Union
6 | import numpy as np
7 |
8 | class LCMSchedulerOutput(BaseOutput):
9 | """
10 | Output class for the scheduler's `step` function output.
11 |
12 | Args:
13 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
14 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
15 | denoising loop.
16 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
17 | The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
18 | `pred_original_sample` can be used to preview progress or for guidance.
19 | """
20 |
21 | prev_sample: torch.FloatTensor
22 | denoised: Optional[torch.FloatTensor] = None
23 |
24 | class MyLCMScheduler(LCMScheduler):
25 |
26 | def set_noise_list(self, noise_list):
27 | self.noise_list = noise_list
28 |
29 | def step(
30 | self,
31 | model_output: torch.FloatTensor,
32 | timestep: int,
33 | sample: torch.FloatTensor,
34 | generator: Optional[torch.Generator] = None,
35 | return_dict: bool = True,
36 | ) -> Union[LCMSchedulerOutput, Tuple]:
37 | """
38 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
39 | process from the learned model outputs (most often the predicted noise).
40 |
41 | Args:
42 | model_output (`torch.FloatTensor`):
43 | The direct output from learned diffusion model.
44 | timestep (`float`):
45 | The current discrete timestep in the diffusion chain.
46 | sample (`torch.FloatTensor`):
47 | A current instance of a sample created by the diffusion process.
48 | generator (`torch.Generator`, *optional*):
49 | A random number generator.
50 | return_dict (`bool`, *optional*, defaults to `True`):
51 | Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
52 | Returns:
53 | [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
54 | If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
55 | tuple is returned where the first element is the sample tensor.
56 | """
57 | if self.num_inference_steps is None:
58 | raise ValueError(
59 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
60 | )
61 |
62 | self._init_step_index(timestep)
63 |
64 | # 1. get previous step value
65 | prev_step_index = self.step_index + 1
66 | if prev_step_index < len(self.timesteps):
67 | prev_timestep = self.timesteps[prev_step_index]
68 | else:
69 | prev_timestep = timestep
70 |
71 | # 2. compute alphas, betas
72 | alpha_prod_t = self.alphas_cumprod[timestep]
73 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
74 |
75 | beta_prod_t = 1 - alpha_prod_t
76 | beta_prod_t_prev = 1 - alpha_prod_t_prev
77 |
78 | # 3. Get scalings for boundary conditions
79 | c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
80 |
81 | # 4. Compute the predicted original sample x_0 based on the model parameterization
82 | if self.config.prediction_type == "epsilon": # noise-prediction
83 | predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
84 | elif self.config.prediction_type == "sample": # x-prediction
85 | predicted_original_sample = model_output
86 | elif self.config.prediction_type == "v_prediction": # v-prediction
87 | predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
88 | else:
89 | raise ValueError(
90 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
91 | " `v_prediction` for `LCMScheduler`."
92 | )
93 |
94 | # 5. Clip or threshold "predicted x_0"
95 | if self.config.thresholding:
96 | predicted_original_sample = self._threshold_sample(predicted_original_sample)
97 | elif self.config.clip_sample:
98 | predicted_original_sample = predicted_original_sample.clamp(
99 | -self.config.clip_sample_range, self.config.clip_sample_range
100 | )
101 |
102 | # 6. Denoise model output using boundary conditions
103 | denoised = c_out * predicted_original_sample + c_skip * sample
104 |
105 | # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
106 | # Noise is not used on the final timestep of the timestep schedule.
107 | # This also means that noise is not used for one-step sampling.
108 | if self.step_index != self.num_inference_steps - 1:
109 | noise = self.noise_list[self.step_index]
110 | prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
111 | else:
112 | prev_sample = denoised
113 |
114 | # upon completion increase step index by one
115 | self._step_index += 1
116 |
117 | if not return_dict:
118 | return (prev_sample, denoised)
119 |
120 | return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
121 |
122 |
123 | def inv_step(
124 | self,
125 | model_output: torch.FloatTensor,
126 | timestep: int,
127 | sample: torch.FloatTensor,
128 | generator: Optional[torch.Generator] = None,
129 | return_dict: bool = True,
130 | ) -> Union[LCMSchedulerOutput, Tuple]:
131 | """
132 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
133 | process from the learned model outputs (most often the predicted noise).
134 |
135 | Args:
136 | model_output (`torch.FloatTensor`):
137 | The direct output from learned diffusion model.
138 | timestep (`float`):
139 | The current discrete timestep in the diffusion chain.
140 | sample (`torch.FloatTensor`):
141 | A current instance of a sample created by the diffusion process.
142 | generator (`torch.Generator`, *optional*):
143 | A random number generator.
144 | return_dict (`bool`, *optional*, defaults to `True`):
145 | Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
146 | Returns:
147 | [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
148 | If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
149 | tuple is returned where the first element is the sample tensor.
150 | """
151 | if self.num_inference_steps is None:
152 | raise ValueError(
153 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
154 | )
155 |
156 | self._init_step_index(timestep)
157 |
158 | # 1. get previous step value
159 | prev_step_index = self.step_index + 1
160 | if prev_step_index < len(self.timesteps):
161 | prev_timestep = self.timesteps[prev_step_index]
162 | else:
163 | prev_timestep = timestep
164 |
165 | # 2. compute alphas, betas
166 | alpha_prod_t = self.alphas_cumprod[timestep]
167 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
168 |
169 | beta_prod_t = 1 - alpha_prod_t
170 | beta_prod_t_prev = 1 - alpha_prod_t_prev
171 |
172 | # 3. Get scalings for boundary conditions
173 | c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
174 |
175 | if self.step_index != self.num_inference_steps - 1:
176 | c_skip_actual = c_skip * alpha_prod_t_prev.sqrt()
177 | c_out_actual = c_out * alpha_prod_t_prev.sqrt()
178 | noise = self.noise_list[self.step_index] * beta_prod_t_prev.sqrt()
179 | else:
180 | c_skip_actual = c_skip
181 | c_out_actual = c_out
182 | noise = 0
183 |
184 |
185 | dem = c_out_actual / (alpha_prod_t.sqrt()) + c_skip
186 | eps_mul = beta_prod_t.sqrt() * c_out_actual / (alpha_prod_t.sqrt())
187 |
188 | prev_sample = (sample + eps_mul * model_output - noise) / dem
189 |
190 | # upon completion increase step index by one
191 | self._step_index += 1
192 |
193 | if not return_dict:
194 | return (prev_sample, prev_sample)
195 |
196 | return LCMSchedulerOutput(prev_sample=prev_sample, denoised=prev_sample)
--------------------------------------------------------------------------------
/src/utils/enums_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from diffusers import StableDiffusionImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline
4 |
5 | from src.eunms import Model_Type, Scheduler_Type
6 | from src.schedulers.euler_scheduler import MyEulerAncestralDiscreteScheduler
7 | from src.schedulers.lcm_scheduler import MyLCMScheduler
8 | from src.schedulers.ddim_scheduler import MyDDIMScheduler
9 | from src.pipes.sdxl_inversion_pipeline import SDXLDDIMPipeline
10 | from src.pipes.sdxl_forward_pipeline import StableDiffusionXLDecompositionPipeline
11 | from src.pipes.sd_inversion_pipeline import SDDDIMPipeline
12 |
13 | def scheduler_type_to_class(scheduler_type):
14 | if scheduler_type == Scheduler_Type.DDIM:
15 | return MyDDIMScheduler
16 | elif scheduler_type == Scheduler_Type.EULER:
17 | return MyEulerAncestralDiscreteScheduler
18 | elif scheduler_type == Scheduler_Type.LCM:
19 | return MyLCMScheduler
20 | else:
21 | raise ValueError("Unknown scheduler type")
22 |
23 | def is_stochastic(scheduler_type):
24 | if scheduler_type == Scheduler_Type.DDIM:
25 | return False
26 | elif scheduler_type == Scheduler_Type.EULER:
27 | return True
28 | elif scheduler_type == Scheduler_Type.LCM:
29 | return True
30 | else:
31 | raise ValueError("Unknown scheduler type")
32 |
33 | def model_type_to_class(model_type):
34 | if model_type == Model_Type.SDXL:
35 | return StableDiffusionXLDecompositionPipeline, SDXLDDIMPipeline
36 | elif model_type == Model_Type.SDXL_Turbo:
37 | return StableDiffusionXLImg2ImgPipeline, SDXLDDIMPipeline
38 | elif model_type == Model_Type.LCM_SDXL:
39 | return StableDiffusionXLImg2ImgPipeline, SDXLDDIMPipeline
40 | elif model_type == Model_Type.SD15:
41 | return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
42 | elif model_type == Model_Type.SD14:
43 | return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
44 | elif model_type == Model_Type.SD21:
45 | return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
46 | elif model_type == Model_Type.SD21_Turbo:
47 | return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
48 | else:
49 | raise ValueError("Unknown model type")
50 |
51 | def model_type_to_model_name(model_type):
52 | if model_type == Model_Type.SDXL:
53 | return "stabilityai/stable-diffusion-xl-base-1.0"
54 | elif model_type == Model_Type.SDXL_Turbo:
55 | return "stabilityai/sdxl-turbo"
56 | elif model_type == Model_Type.LCM_SDXL:
57 | return "stabilityai/stable-diffusion-xl-base-1.0"
58 | elif model_type == Model_Type.SD15:
59 | return "runwayml/stable-diffusion-v1-5"
60 | elif model_type == Model_Type.SD14:
61 | return "CompVis/stable-diffusion-v1-4"
62 | elif model_type == Model_Type.SD21:
63 | return "stabilityai/stable-diffusion-2-1"
64 | elif model_type == Model_Type.SD21_Turbo:
65 | return "stabilityai/sd-turbo"
66 | else:
67 | raise ValueError("Unknown model type")
68 |
69 |
70 | def model_type_to_size(model_type):
71 | if model_type == Model_Type.SDXL:
72 | return (1024, 1024)
73 | elif model_type == Model_Type.SDXL_Turbo:
74 | return (512, 512)
75 | elif model_type == Model_Type.LCM_SDXL:
76 | return (768, 768) #TODO: check
77 | elif model_type == Model_Type.SD15:
78 | return (512, 512)
79 | elif model_type == Model_Type.SD14:
80 | return (512, 512)
81 | elif model_type == Model_Type.SD21:
82 | return (512, 512)
83 | elif model_type == Model_Type.SD21_Turbo:
84 | return (512, 512)
85 | else:
86 | raise ValueError("Unknown model type")
87 |
88 | def is_float16(model_type):
89 | if model_type == Model_Type.SDXL:
90 | return True
91 | elif model_type == Model_Type.SDXL_Turbo:
92 | return True
93 | elif model_type == Model_Type.LCM_SDXL:
94 | return True
95 | elif model_type == Model_Type.SD15:
96 | return False
97 | elif model_type == Model_Type.SD14:
98 | return False
99 | elif model_type == Model_Type.SD21:
100 | return False
101 | elif model_type == Model_Type.SD21_Turbo:
102 | return False
103 | else:
104 | raise ValueError("Unknown model type")
105 |
106 | def is_sd(model_type):
107 | if model_type == Model_Type.SDXL:
108 | return False
109 | elif model_type == Model_Type.SDXL_Turbo:
110 | return False
111 | elif model_type == Model_Type.LCM_SDXL:
112 | return False
113 | elif model_type == Model_Type.SD15:
114 | return True
115 | elif model_type == Model_Type.SD14:
116 | return True
117 | elif model_type == Model_Type.SD21:
118 | return True
119 | elif model_type == Model_Type.SD21_Turbo:
120 | return True
121 | else:
122 | raise ValueError("Unknown model type")
123 |
124 | def _get_pipes(model_type, device):
125 | model_name = model_type_to_model_name(model_type)
126 | pipeline_inf, pipeline_inv = model_type_to_class(model_type)
127 |
128 | if is_float16(model_type):
129 | pipe_inference = pipeline_inf.from_pretrained(
130 | model_name,
131 | torch_dtype=torch.float16,
132 | use_safetensors=True,
133 | variant="fp16",
134 | safety_checker = None
135 | ).to(device)
136 | else:
137 | pipe_inference = pipeline_inf.from_pretrained(
138 | model_name,
139 | use_safetensors=True,
140 | safety_checker = None
141 | ).to(device)
142 |
143 | pipe_inversion = pipeline_inv(**pipe_inference.components)
144 |
145 | return pipe_inversion, pipe_inference
146 |
147 | def get_pipes(model_type, scheduler_type, device="cuda"):
148 | scheduler_class = scheduler_type_to_class(scheduler_type)
149 |
150 | pipe_inversion, pipe_inference = _get_pipes(model_type, device)
151 |
152 | pipe_inference.scheduler = scheduler_class.from_config(pipe_inference.scheduler.config)
153 | pipe_inversion.scheduler = scheduler_class.from_config(pipe_inversion.scheduler.config)
154 |
155 | if is_sd(model_type):
156 | pipe_inference.scheduler.add_noise = lambda init_latents, noise, timestep: init_latents
157 | pipe_inversion.scheduler.add_noise = lambda init_latents, noise, timestep: init_latents
158 |
159 | if model_type == Model_Type.LCM_SDXL:
160 | adapter_id = "latent-consistency/lcm-lora-sdxl"
161 | pipe_inversion.load_lora_weights(adapter_id)
162 | pipe_inference.load_lora_weights(adapter_id)
163 |
164 | return pipe_inversion, pipe_inference
--------------------------------------------------------------------------------
/src/utils/images_utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import os
3 | import torch
4 |
5 | def read_images_in_path(path, size = (512,512)):
6 | image_paths = []
7 | for filename in os.listdir(path):
8 | if filename.endswith(".png") or filename.endswith(".jpg") or filename.endswith(".jpeg"):
9 | image_path = os.path.join(path, filename)
10 | image_paths.append(image_path)
11 | image_paths = sorted(image_paths)
12 | return [Image.open(image_path).convert("RGB").resize(size) for image_path in image_paths]
13 |
14 | def concatenate_images(image_lists, return_list = False):
15 | num_rows = len(image_lists[0])
16 | num_columns = len(image_lists)
17 | image_width = image_lists[0][0].width
18 | image_height = image_lists[0][0].height
19 |
20 | grid_width = num_columns * image_width
21 | grid_height = num_rows * image_height if not return_list else image_height
22 | if not return_list:
23 | grid_image = [Image.new('RGB', (grid_width, grid_height))]
24 | else:
25 | grid_image = [Image.new('RGB', (grid_width, grid_height)) for i in range(num_rows)]
26 |
27 | for i in range(num_rows):
28 | row_index = i if return_list else 0
29 | for j in range(num_columns):
30 | image = image_lists[j][i]
31 | x_offset = j * image_width
32 | y_offset = i * image_height if not return_list else 0
33 | grid_image[row_index].paste(image, (x_offset, y_offset))
34 |
35 | return grid_image if return_list else grid_image[0]
36 |
37 | def concatenate_images_single(image_lists):
38 | num_columns = len(image_lists)
39 | image_width = image_lists[0].width
40 | image_height = image_lists[0].height
41 |
42 | grid_width = num_columns * image_width
43 | grid_height = image_height
44 | grid_image = Image.new('RGB', (grid_width, grid_height))
45 |
46 | for j in range(num_columns):
47 | image = image_lists[j]
48 | x_offset = j * image_width
49 | y_offset = 0
50 | grid_image.paste(image, (x_offset, y_offset))
51 |
52 | return grid_image
53 |
54 | def get_captions_for_images(images, device):
55 | from transformers import Blip2Processor, Blip2ForConditionalGeneration
56 |
57 | processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
58 | model = Blip2ForConditionalGeneration.from_pretrained(
59 | "Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
60 | ) # doctest: +IGNORE_RESULT
61 |
62 | res = []
63 |
64 | for image in images:
65 | inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
66 |
67 | generated_ids = model.generate(**inputs)
68 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
69 | res.append(generated_text)
70 |
71 | del processor
72 | del model
73 |
74 | return res
--------------------------------------------------------------------------------