├── README.md ├── asset └── teaser.png ├── demo.py ├── environment.yaml ├── example_images ├── 009698.jpg ├── Arknight.jpg └── new_cat_3.jpeg ├── main.py ├── requirements.txt └── src ├── config.py ├── eunms.py ├── pipes ├── sd_inversion_pipeline.py ├── sdxl_forward_pipeline.py └── sdxl_inversion_pipeline.py ├── renoise_inversion.py ├── schedulers ├── ddim_scheduler.py ├── euler_scheduler.py └── lcm_scheduler.py └── utils ├── enums_utils.py └── images_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Scaling Concept With Text-Guided Diffusion Models 2 | > **Chao Huang, Susan Liang, Yunlong Tang, Yapeng Tian, Anurag Kumar, Chenliang Xu** 3 | > 4 | > Text-guided diffusion models have revolutionized generative tasks by producing high-fidelity content from text descriptions. They have also enabled an editing paradigm where concepts can be replaced through text conditioning (e.g., *a dog -> a tiger*). In this work, we explore a novel approach: instead of replacing a concept, can we enhance or suppress the concept itself? Through an empirical study, we identify a trend where concepts can be decomposed in text-guided diffusion models. Leveraging this insight, we introduce **ScalingConcept**, a simple yet effective method to scale decomposed concepts up or down in real input without introducing new elements. 5 | To systematically evaluate our approach, we present the *WeakConcept-10* dataset, where concepts are imperfect and need to be enhanced. More importantly, ScalingConcept enables a variety of novel zero-shot applications across image and audio domains, including tasks such as canonical pose generation and generative sound highlighting or removal. 6 | 7 | 8 | 9 | 10 | 11 |

12 | 13 |

14 | 15 | ## Environment Setup 16 | Our code builds on the requirement of the `diffusers` library. To set up the environment, please run: 17 | ``` 18 | conda env create -f environment.yaml 19 | conda activate ScalingConcept 20 | ``` 21 | or install requirements: 22 | ``` 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | 27 | ## Minimal Example 28 | 29 | We provide a minimal example to explore the effects of concept scaling. The `examples_images/` directory contains three sample images demonstrating different applications: `canonical pose generation`, `face attribute editing`, and `anime sketch enhancement`. To get started, try running: 30 | 31 | ```bash 32 | python demo.py 33 | ``` 34 | 35 | The default setting is configured for `canonical pose generation`. For optimal results with other applications, adjust the prompt and relevant hyperparameters as noted in the code comments. 36 | 37 | ### Usage 38 | 39 | Our ScalingConcept method supports various applications, each customizable by adjusting scaling parameters within `pipe_inference`. Below are recommended configurations for each application: 40 | 41 | - **Canonical Pose Generation/Object Stitching**: 42 | ```python 43 | prompt = [object_name] 44 | omega = 5 45 | gamma = 3 46 | t_exit = 15 47 | ``` 48 | 49 | - **Weather Manipulation**: 50 | ```python 51 | prompt = '(heavy) fog' or '(heavy) rain' 52 | omega = 5 53 | gamma = 3 54 | t_exit = 15 55 | ``` 56 | 57 | - **Creative Enhancement**: 58 | ```python 59 | prompt = [concept to enhance] 60 | omega = 3 61 | gamma = 3 62 | t_exit = 15 63 | ``` 64 | 65 | - **Face Attribute Scaling**: 66 | ```python 67 | prompt = [face attribute, e.g., 'young face' or 'old face'] 68 | omega = 3 69 | gamma = 3 70 | t_exit = 15 71 | ``` 72 | 73 | - **Anime Sketch Enhancement**: 74 | ```python 75 | prompt = 'anime' 76 | omega = 5 77 | gamma = 3 78 | t_exit = 25 79 | ``` 80 | 81 | In general, a larger `omega` value increases the effect of concept scaling, while higher `gamma` and `t_exit` values maintain fidelity. Note that inversion `prompt` selection is crucial, as the model is sensitive to the wording of prompts. 82 | 83 | ## Acknowledgements 84 | 85 | This code builds upon the [diffusers](https://github.com/huggingface/diffusers) library. Additionally, we borrow code from the following repositories: 86 | 87 | - [Pix2PixZero](https://github.com/pix2pixzero/pix2pix-zero) for noise regularization. 88 | - [sdxl_inversions](https://github.com/cloneofsimo/sdxl_inversions) for the initial implementation of DDIM inversion in SDXL. 89 | - [ReNoise-Inversion](https://github.com/garibida/ReNoise-Inversion) for a precise inversion technique. 90 | 91 | ## Citation 92 | If you use this code for your research, please cite the following work: 93 | ``` 94 | @misc{huang2024scaling, 95 | title={Scaling Concept With Text-Guided Diffusion Models}, 96 | author={Chao Huang and Susan Liang and Yunlong Tang and Yapeng Tian Anurag Kumar and Chenliang Xu}, 97 | year={2024}, 98 | eprint={2410.24151}, 99 | archivePrefix={arXiv}, 100 | primaryClass={cs.CV} 101 | } 102 | ``` 103 | -------------------------------------------------------------------------------- /asset/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WikiChao/ScalingConcept/e201f3ab5f460b08d4067464f7cedb26bc5b012b/asset/teaser.png -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | 4 | from src.eunms import Model_Type, Scheduler_Type 5 | from src.utils.enums_utils import get_pipes 6 | from src.config import RunConfig 7 | 8 | from main import run as invert 9 | 10 | 11 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 12 | 13 | model_type = Model_Type.SDXL 14 | scheduler_type = Scheduler_Type.DDIM 15 | pipe_inversion, pipe_inference = get_pipes(model_type, scheduler_type, device=device) 16 | 17 | input_image = Image.open("example_images/new_cat_3.jpeg").convert("RGB") # "009698.jpg", "Arknight.jpg" 18 | original_shape = input_image.size 19 | input_image = input_image.resize((1024, 1024)) 20 | prompt = "cat" # 'smile' for "009698.jpg", 'anime' for "Arknight.jpg" 21 | 22 | config = RunConfig(model_type = model_type, 23 | num_inference_steps = 50, 24 | num_inversion_steps = 50, 25 | num_renoise_steps = 1, 26 | scheduler_type = scheduler_type, 27 | perform_noise_correction = False, 28 | seed = 7865) 29 | 30 | _, inv_latent, _, all_latents, other_kwargs = invert(input_image, 31 | prompt, 32 | config, 33 | pipe_inversion=pipe_inversion, 34 | pipe_inference=pipe_inference, 35 | do_reconstruction=False) 36 | 37 | rec_image = pipe_inference(image = inv_latent, 38 | prompt = "", 39 | denoising_start=0.0, 40 | num_inference_steps = config.num_inference_steps, 41 | guidance_scale = 1.0, 42 | omega=5, # omega=3 for "009698.jpg", omega=5 for "Arknight.jpg" 43 | gamma=3, # gamma=3 for "009698.jpg", gamma=3 for "Arknight.jpg" 44 | inv_latents=all_latents, 45 | prompt_embeds_ref=other_kwargs[0], 46 | added_cond_kwargs_ref=other_kwargs[1], 47 | t_exit=15, # t_exit=15 for "009698.jpg", t_exit=25 for "Arknight.jpg" 48 | ).images[0] 49 | 50 | rec_image.resize(original_shape).save("new_cat_3.jpg") -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: ScalingConcept 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.11.8 7 | - pip=23.3.1 8 | - pip: 9 | - -r requirements.txt -------------------------------------------------------------------------------- /example_images/009698.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WikiChao/ScalingConcept/e201f3ab5f460b08d4067464f7cedb26bc5b012b/example_images/009698.jpg -------------------------------------------------------------------------------- /example_images/Arknight.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WikiChao/ScalingConcept/e201f3ab5f460b08d4067464f7cedb26bc5b012b/example_images/Arknight.jpg -------------------------------------------------------------------------------- /example_images/new_cat_3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WikiChao/ScalingConcept/e201f3ab5f460b08d4067464f7cedb26bc5b012b/example_images/new_cat_3.jpeg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import pyrallis 2 | import torch 3 | from PIL import Image 4 | from diffusers.utils.torch_utils import randn_tensor 5 | 6 | from src.config import RunConfig 7 | from src.utils.enums_utils import model_type_to_size, is_stochastic 8 | 9 | def create_noise_list(model_type, length, generator=None): 10 | img_size = model_type_to_size(model_type) 11 | VQAE_SCALE = 8 12 | latents_size = (1, 4, img_size[0] // VQAE_SCALE, img_size[1] // VQAE_SCALE) 13 | return [randn_tensor(latents_size, dtype=torch.float16, device=torch.device("cuda:0"), generator=generator) for i in range(length)] 14 | 15 | @pyrallis.wrap() 16 | def main(cfg: RunConfig): 17 | run(cfg) 18 | 19 | def run(init_image: Image, 20 | prompt: str, 21 | cfg: RunConfig, 22 | pipe_inversion, 23 | pipe_inference, 24 | latents = None, 25 | edit_prompt = None, 26 | edit_cfg = 1.0, 27 | noise = None, 28 | do_reconstruction = True): 29 | 30 | generator = torch.Generator().manual_seed(cfg.seed) 31 | 32 | if is_stochastic(cfg.scheduler_type): 33 | if latents is None: 34 | noise = create_noise_list(cfg.model_type, cfg.num_inversion_steps, generator=generator) 35 | pipe_inversion.scheduler.set_noise_list(noise) 36 | pipe_inference.scheduler.set_noise_list(noise) 37 | 38 | pipe_inversion.cfg = cfg 39 | pipe_inference.cfg = cfg 40 | all_latents = None 41 | 42 | if latents is None: 43 | print("Inverting...") 44 | res = pipe_inversion(prompt = prompt, 45 | num_inversion_steps = cfg.num_inversion_steps, 46 | num_inference_steps = cfg.num_inference_steps, 47 | generator = generator, 48 | image = init_image, 49 | guidance_scale = cfg.guidance_scale, 50 | strength = cfg.inversion_max_step, 51 | denoising_start = 1.0-cfg.inversion_max_step, 52 | num_renoise_steps = cfg.num_renoise_steps) 53 | latents = res[0][0] 54 | all_latents = res[1] 55 | other_kwargs = res[2] 56 | 57 | inv_latent = latents.clone() 58 | 59 | if do_reconstruction: 60 | print("Generating...") 61 | edit_prompt = prompt if edit_prompt is None else edit_prompt 62 | guidance_scale = edit_cfg 63 | img = pipe_inference(prompt = edit_prompt, 64 | num_inference_steps = cfg.num_inference_steps, 65 | negative_prompt = prompt, 66 | image = latents, 67 | strength = cfg.inversion_max_step, 68 | denoising_start = 1.0-cfg.inversion_max_step, 69 | guidance_scale = guidance_scale, 70 | omega=1, 71 | gamma=0, 72 | inv_latents=all_latents, 73 | prompt_embeds_ref=other_kwargs[0], 74 | added_cond_kwargs_ref=other_kwargs[1]).images[0] 75 | else: 76 | img = None 77 | 78 | return img, inv_latent, noise, all_latents, other_kwargs 79 | 80 | if __name__ == "__main__": 81 | main() 82 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.2.0 2 | torchvision==0.17.0 3 | diffusers==0.24.0 4 | transformers==4.32.1 5 | pyrallis==0.3.1 6 | accelerate==0.25.0 7 | bitsandbytes==0.43.0 -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from src.eunms import Model_Type, Scheduler_Type 4 | 5 | @dataclass 6 | class RunConfig: 7 | model_type : Model_Type = Model_Type.SDXL_Turbo 8 | 9 | scheduler_type : Scheduler_Type = Scheduler_Type.EULER 10 | 11 | seed: int = 7865 12 | 13 | num_inference_steps: int = 4 14 | 15 | num_inversion_steps: int = 4 16 | 17 | guidance_scale: float = 0.0 18 | 19 | num_renoise_steps: int = 9 20 | 21 | max_num_renoise_steps_first_step: int = 5 22 | 23 | inversion_max_step: float = 1.0 24 | 25 | # Average Parameters 26 | 27 | average_latent_estimations: bool = True 28 | 29 | average_first_step_range: tuple = (0, 5) 30 | 31 | average_step_range: tuple = (8, 10) 32 | 33 | # Noise Regularization 34 | 35 | noise_regularization_lambda_ac: float = 20.0 36 | 37 | noise_regularization_lambda_kl: float = 0.065 38 | 39 | noise_regularization_num_reg_steps: int = 4 40 | 41 | noise_regularization_num_ac_rolls: int = 5 42 | 43 | # Noise Correction 44 | 45 | perform_noise_correction: bool = True 46 | 47 | def __post_init__(self): 48 | pass -------------------------------------------------------------------------------- /src/eunms.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | class Scheduler_Type(Enum): 4 | DDIM = 1 5 | EULER = 2 6 | LCM = 3 7 | 8 | class Model_Type(Enum): 9 | SDXL = 1 10 | SDXL_Turbo = 2 11 | LCM_SDXL = 3 12 | SD15 = 4 13 | SD21 = 5 14 | SD21_Turbo = 6 15 | SD14 = 7 -------------------------------------------------------------------------------- /src/pipes/sd_inversion_pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 3 | from diffusers import StableDiffusionImg2ImgPipeline 4 | from diffusers.utils.torch_utils import randn_tensor 5 | 6 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( 7 | StableDiffusionPipelineOutput, 8 | retrieve_timesteps, 9 | PipelineImageInput 10 | ) 11 | 12 | from src.renoise_inversion import inversion_step 13 | 14 | 15 | class SDDDIMPipeline(StableDiffusionImg2ImgPipeline): 16 | # @torch.no_grad() 17 | def __call__( 18 | self, 19 | prompt: Union[str, List[str]] = None, 20 | image: PipelineImageInput = None, 21 | strength: float = 1.0, 22 | num_inversion_steps: Optional[int] = 50, 23 | timesteps: List[int] = None, 24 | guidance_scale: Optional[float] = 7.5, 25 | negative_prompt: Optional[Union[str, List[str]]] = None, 26 | num_images_per_prompt: Optional[int] = 1, 27 | eta: Optional[float] = 0.0, 28 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 29 | prompt_embeds: Optional[torch.FloatTensor] = None, 30 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 31 | ip_adapter_image: Optional[PipelineImageInput] = None, 32 | output_type: Optional[str] = "pil", 33 | return_dict: bool = True, 34 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 35 | clip_skip: int = None, 36 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 37 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 38 | num_renoise_steps: int = 100, 39 | **kwargs, 40 | ): 41 | callback = kwargs.pop("callback", None) 42 | callback_steps = kwargs.pop("callback_steps", None) 43 | 44 | if callback is not None: 45 | deprecate( 46 | "callback", 47 | "1.0.0", 48 | "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", 49 | ) 50 | if callback_steps is not None: 51 | deprecate( 52 | "callback_steps", 53 | "1.0.0", 54 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", 55 | ) 56 | 57 | # 1. Check inputs. Raise error if not correct 58 | self.check_inputs( 59 | prompt, 60 | strength, 61 | callback_steps, 62 | negative_prompt, 63 | prompt_embeds, 64 | negative_prompt_embeds, 65 | callback_on_step_end_tensor_inputs, 66 | ) 67 | 68 | self._guidance_scale = guidance_scale 69 | self._clip_skip = clip_skip 70 | self._cross_attention_kwargs = cross_attention_kwargs 71 | 72 | # 2. Define call parameters 73 | if prompt is not None and isinstance(prompt, str): 74 | batch_size = 1 75 | elif prompt is not None and isinstance(prompt, list): 76 | batch_size = len(prompt) 77 | else: 78 | batch_size = prompt_embeds.shape[0] 79 | 80 | device = self._execution_device 81 | 82 | # 3. Encode input prompt 83 | text_encoder_lora_scale = ( 84 | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None 85 | ) 86 | prompt_embeds, negative_prompt_embeds = self.encode_prompt( 87 | prompt, 88 | device, 89 | num_images_per_prompt, 90 | self.do_classifier_free_guidance, 91 | negative_prompt, 92 | prompt_embeds=prompt_embeds, 93 | negative_prompt_embeds=negative_prompt_embeds, 94 | lora_scale=text_encoder_lora_scale, 95 | clip_skip=self.clip_skip, 96 | ) 97 | # For classifier free guidance, we need to do two forward passes. 98 | # Here we concatenate the unconditional and text embeddings into a single batch 99 | # to avoid doing two forward passes 100 | if self.do_classifier_free_guidance: 101 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 102 | 103 | if ip_adapter_image is not None: 104 | image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) 105 | if self.do_classifier_free_guidance: 106 | image_embeds = torch.cat([negative_image_embeds, image_embeds]) 107 | 108 | # 4. Preprocess image 109 | image = self.image_processor.preprocess(image) 110 | 111 | # 5. set timesteps 112 | timesteps, num_inversion_steps = retrieve_timesteps(self.scheduler, num_inversion_steps, device, timesteps) 113 | timesteps, num_inversion_steps = self.get_timesteps(num_inversion_steps, strength, device) 114 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) 115 | 116 | # 6. Prepare latent variables 117 | with torch.no_grad(): 118 | latents = self.prepare_latents( 119 | image, 120 | latent_timestep, 121 | batch_size, 122 | num_images_per_prompt, 123 | prompt_embeds.dtype, 124 | device, 125 | generator, 126 | ) 127 | 128 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 129 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 130 | 131 | # 7.1 Add image embeds for IP-Adapter 132 | added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None 133 | 134 | # 7.2 Optionally get Guidance Scale Embedding 135 | timestep_cond = None 136 | if self.unet.config.time_cond_proj_dim is not None: 137 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) 138 | timestep_cond = self.get_guidance_scale_embedding( 139 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim 140 | ).to(device=device, dtype=latents.dtype) 141 | 142 | # 8. Denoising loop 143 | num_warmup_steps = len(timesteps) - num_inversion_steps * self.scheduler.order 144 | 145 | self._num_timesteps = len(timesteps) 146 | self.z_0 = torch.clone(latents) 147 | self.noise = randn_tensor(self.z_0.shape, generator=generator, device=self.z_0.device, dtype=self.z_0.dtype) 148 | 149 | all_latents = [latents.clone()] 150 | with self.progress_bar(total=num_inversion_steps) as progress_bar: 151 | for i, t in enumerate(reversed(timesteps)): 152 | 153 | latents = inversion_step(self, 154 | latents, 155 | t, 156 | prompt_embeds, 157 | added_cond_kwargs, 158 | num_renoise_steps=num_renoise_steps, 159 | generator=generator) 160 | 161 | all_latents.append(latents.clone()) 162 | 163 | if callback_on_step_end is not None: 164 | callback_kwargs = {} 165 | for k in callback_on_step_end_tensor_inputs: 166 | callback_kwargs[k] = locals()[k] 167 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 168 | 169 | latents = callback_outputs.pop("latents", latents) 170 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 171 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 172 | 173 | # call the callback, if provided 174 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 175 | progress_bar.update() 176 | if callback is not None and i % callback_steps == 0: 177 | step_idx = i // getattr(self.scheduler, "order", 1) 178 | callback(step_idx, t, latents) 179 | 180 | image = latents 181 | 182 | # Offload all models 183 | self.maybe_free_model_hooks() 184 | 185 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None), all_latents 186 | -------------------------------------------------------------------------------- /src/pipes/sdxl_forward_pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 4 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( 5 | StableDiffusionPipelineOutput, 6 | retrieve_timesteps, 7 | PipelineImageInput, 8 | rescale_noise_cfg, 9 | ) 10 | 11 | from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline 12 | from diffusers.utils import is_torch_xla_available 13 | from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput 14 | 15 | if is_torch_xla_available(): 16 | import torch_xla.core.xla_model as xm 17 | XLA_AVAILABLE = True 18 | else: 19 | XLA_AVAILABLE = False 20 | 21 | def degrade_proportionally(i, max_value=1, num_inference_steps=49, gamma=0): 22 | return max(0, max_value * (1 - i / num_inference_steps)**gamma) 23 | 24 | class StableDiffusionXLDecompositionPipeline(StableDiffusionXLImg2ImgPipeline): 25 | @torch.no_grad() 26 | def __call__( 27 | self, 28 | prompt: Union[str, List[str]] = None, 29 | prompt_2: Optional[Union[str, List[str]]] = None, 30 | image: PipelineImageInput = None, 31 | strength: float = 0.3, 32 | num_inference_steps: int = 50, 33 | timesteps: List[int] = None, 34 | denoising_start: Optional[float] = None, 35 | denoising_end: Optional[float] = None, 36 | guidance_scale: float = 5.0, 37 | negative_prompt: Optional[Union[str, List[str]]] = None, 38 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 39 | num_images_per_prompt: Optional[int] = 1, 40 | eta: float = 0.0, 41 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 42 | latents: Optional[torch.FloatTensor] = None, 43 | prompt_embeds: Optional[torch.FloatTensor] = None, 44 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 45 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 46 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 47 | ip_adapter_image: Optional[PipelineImageInput] = None, 48 | output_type: Optional[str] = "pil", 49 | return_dict: bool = True, 50 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 51 | guidance_rescale: float = 0.0, 52 | original_size: Tuple[int, int] = None, 53 | crops_coords_top_left: Tuple[int, int] = (0, 0), 54 | target_size: Tuple[int, int] = None, 55 | negative_original_size: Optional[Tuple[int, int]] = None, 56 | negative_crops_coords_top_left: Tuple[int, int] = (0, 0), 57 | negative_target_size: Optional[Tuple[int, int]] = None, 58 | aesthetic_score: float = 6.0, 59 | negative_aesthetic_score: float = 2.5, 60 | clip_skip: Optional[int] = None, 61 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 62 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 63 | omega = 1, 64 | gamma = 0, 65 | inv_latents = None, 66 | prompt_embeds_ref = None, 67 | added_cond_kwargs_ref = None, 68 | t_exit=15, 69 | **kwargs, 70 | ): 71 | r""" 72 | Function invoked when calling the pipeline for generation. 73 | 74 | Args: 75 | prompt (`str` or `List[str]`, *optional*): 76 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 77 | instead. 78 | prompt_2 (`str` or `List[str]`, *optional*): 79 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 80 | used in both text-encoders 81 | image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): 82 | The image(s) to modify with the pipeline. 83 | strength (`float`, *optional*, defaults to 0.3): 84 | Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` 85 | will be used as a starting point, adding more noise to it the larger the `strength`. The number of 86 | denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will 87 | be maximum and the denoising process will run for the full number of iterations specified in 88 | `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of 89 | `denoising_start` being declared as an integer, the value of `strength` will be ignored. 90 | num_inference_steps (`int`, *optional*, defaults to 50): 91 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 92 | expense of slower inference. 93 | timesteps (`List[int]`, *optional*): 94 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument 95 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is 96 | passed will be used. Must be in descending order. 97 | denoising_start (`float`, *optional*): 98 | When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be 99 | bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and 100 | it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, 101 | strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline 102 | is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image 103 | Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). 104 | denoising_end (`float`, *optional*): 105 | When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be 106 | completed before it is intentionally prematurely terminated. As a result, the returned sample will 107 | still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be 108 | denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the 109 | final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline 110 | forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refine Image 111 | Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). 112 | guidance_scale (`float`, *optional*, defaults to 7.5): 113 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 114 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 115 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 116 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 117 | usually at the expense of lower image quality. 118 | negative_prompt (`str` or `List[str]`, *optional*): 119 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 120 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 121 | less than `1`). 122 | negative_prompt_2 (`str` or `List[str]`, *optional*): 123 | The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and 124 | `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders 125 | num_images_per_prompt (`int`, *optional*, defaults to 1): 126 | The number of images to generate per prompt. 127 | eta (`float`, *optional*, defaults to 0.0): 128 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 129 | [`schedulers.DDIMScheduler`], will be ignored for others. 130 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 131 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 132 | to make generation deterministic. 133 | latents (`torch.FloatTensor`, *optional*): 134 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 135 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 136 | tensor will ge generated by sampling using the supplied random `generator`. 137 | prompt_embeds (`torch.FloatTensor`, *optional*): 138 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 139 | provided, text embeddings will be generated from `prompt` input argument. 140 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 141 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 142 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 143 | argument. 144 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 145 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 146 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 147 | negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 148 | Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 149 | weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` 150 | input argument. 151 | ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. 152 | output_type (`str`, *optional*, defaults to `"pil"`): 153 | The output format of the generate image. Choose between 154 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 155 | return_dict (`bool`, *optional*, defaults to `True`): 156 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a 157 | plain tuple. 158 | cross_attention_kwargs (`dict`, *optional*): 159 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 160 | `self.processor` in 161 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 162 | guidance_rescale (`float`, *optional*, defaults to 0.0): 163 | Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are 164 | Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of 165 | [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). 166 | Guidance rescale factor should fix overexposure when using zero terminal SNR. 167 | original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 168 | If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. 169 | `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as 170 | explained in section 2.2 of 171 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). 172 | crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): 173 | `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position 174 | `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting 175 | `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of 176 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). 177 | target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 178 | For most cases, `target_size` should be set to the desired height and width of the generated image. If 179 | not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in 180 | section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). 181 | negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 182 | To negatively condition the generation process based on a specific image resolution. Part of SDXL's 183 | micro-conditioning as explained in section 2.2 of 184 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more 185 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. 186 | negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): 187 | To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's 188 | micro-conditioning as explained in section 2.2 of 189 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more 190 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. 191 | negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 192 | To negatively condition the generation process based on a target image resolution. It should be as same 193 | as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of 194 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more 195 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. 196 | aesthetic_score (`float`, *optional*, defaults to 6.0): 197 | Used to simulate an aesthetic score of the generated image by influencing the positive text condition. 198 | Part of SDXL's micro-conditioning as explained in section 2.2 of 199 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). 200 | negative_aesthetic_score (`float`, *optional*, defaults to 2.5): 201 | Part of SDXL's micro-conditioning as explained in section 2.2 of 202 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to 203 | simulate an aesthetic score of the generated image by influencing the negative text condition. 204 | clip_skip (`int`, *optional*): 205 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that 206 | the output of the pre-final layer will be used for computing the prompt embeddings. 207 | callback_on_step_end (`Callable`, *optional*): 208 | A function that calls at the end of each denoising steps during the inference. The function is called 209 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 210 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 211 | `callback_on_step_end_tensor_inputs`. 212 | callback_on_step_end_tensor_inputs (`List`, *optional*): 213 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 214 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 215 | `._callback_tensor_inputs` attribute of your pipeline class. 216 | 217 | Examples: 218 | 219 | Returns: 220 | [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: 221 | [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a 222 | `tuple. When returning a tuple, the first element is a list with the generated images. 223 | """ 224 | 225 | callback = kwargs.pop("callback", None) 226 | callback_steps = kwargs.pop("callback_steps", None) 227 | 228 | if callback is not None: 229 | deprecate( 230 | "callback", 231 | "1.0.0", 232 | "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", 233 | ) 234 | if callback_steps is not None: 235 | deprecate( 236 | "callback_steps", 237 | "1.0.0", 238 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", 239 | ) 240 | 241 | # 1. Check inputs. Raise error if not correct 242 | self.check_inputs( 243 | prompt, 244 | prompt_2, 245 | strength, 246 | num_inference_steps, 247 | callback_steps, 248 | negative_prompt, 249 | negative_prompt_2, 250 | prompt_embeds, 251 | negative_prompt_embeds, 252 | callback_on_step_end_tensor_inputs, 253 | ) 254 | 255 | self._guidance_scale = guidance_scale 256 | self._guidance_rescale = guidance_rescale 257 | self._clip_skip = clip_skip 258 | self._cross_attention_kwargs = cross_attention_kwargs 259 | self._denoising_end = denoising_end 260 | self._denoising_start = denoising_start 261 | self._interrupt = False 262 | 263 | # 2. Define call parameters 264 | if prompt is not None and isinstance(prompt, str): 265 | batch_size = 1 266 | elif prompt is not None and isinstance(prompt, list): 267 | batch_size = len(prompt) 268 | else: 269 | batch_size = prompt_embeds.shape[0] 270 | 271 | device = self._execution_device 272 | 273 | # 3. Encode input prompt 274 | text_encoder_lora_scale = ( 275 | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None 276 | ) 277 | ( 278 | prompt_embeds, 279 | negative_prompt_embeds, 280 | pooled_prompt_embeds, 281 | negative_pooled_prompt_embeds, 282 | ) = self.encode_prompt( 283 | prompt=prompt, 284 | prompt_2=prompt_2, 285 | device=device, 286 | num_images_per_prompt=num_images_per_prompt, 287 | do_classifier_free_guidance=self.do_classifier_free_guidance, 288 | negative_prompt=negative_prompt, 289 | negative_prompt_2=negative_prompt_2, 290 | prompt_embeds=prompt_embeds, 291 | negative_prompt_embeds=negative_prompt_embeds, 292 | pooled_prompt_embeds=pooled_prompt_embeds, 293 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 294 | lora_scale=text_encoder_lora_scale, 295 | clip_skip=self.clip_skip, 296 | ) 297 | 298 | # 4. Preprocess image 299 | image = self.image_processor.preprocess(image) 300 | 301 | # 5. Prepare timesteps 302 | def denoising_value_valid(dnv): 303 | return isinstance(self.denoising_end, float) and 0 < dnv < 1 304 | 305 | timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) 306 | timesteps, num_inference_steps = self.get_timesteps( 307 | num_inference_steps, 308 | strength, 309 | device, 310 | denoising_start=self.denoising_start if denoising_value_valid else None, 311 | ) 312 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) 313 | 314 | add_noise = True if self.denoising_start is None else False 315 | # 6. Prepare latent variables 316 | latents = self.prepare_latents( 317 | image, 318 | latent_timestep, 319 | batch_size, 320 | num_images_per_prompt, 321 | prompt_embeds.dtype, 322 | device, 323 | generator, 324 | add_noise, 325 | ) 326 | # 7. Prepare extra step kwargs. 327 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 328 | 329 | height, width = latents.shape[-2:] 330 | height = height * self.vae_scale_factor 331 | width = width * self.vae_scale_factor 332 | 333 | original_size = original_size or (height, width) 334 | target_size = target_size or (height, width) 335 | 336 | # 8. Prepare added time ids & embeddings 337 | if negative_original_size is None: 338 | negative_original_size = original_size 339 | if negative_target_size is None: 340 | negative_target_size = target_size 341 | 342 | add_text_embeds = pooled_prompt_embeds 343 | if self.text_encoder_2 is None: 344 | text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) 345 | else: 346 | text_encoder_projection_dim = self.text_encoder_2.config.projection_dim 347 | 348 | add_time_ids, add_neg_time_ids = self._get_add_time_ids( 349 | original_size, 350 | crops_coords_top_left, 351 | target_size, 352 | aesthetic_score, 353 | negative_aesthetic_score, 354 | negative_original_size, 355 | negative_crops_coords_top_left, 356 | negative_target_size, 357 | dtype=prompt_embeds.dtype, 358 | text_encoder_projection_dim=text_encoder_projection_dim, 359 | ) 360 | add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) 361 | 362 | if self.do_classifier_free_guidance: 363 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) 364 | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) 365 | add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) 366 | add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) 367 | 368 | prompt_embeds = prompt_embeds.to(device) 369 | add_text_embeds = add_text_embeds.to(device) 370 | add_time_ids = add_time_ids.to(device) 371 | 372 | if ip_adapter_image is not None: 373 | image_embeds = self.prepare_ip_adapter_image_embeds( 374 | ip_adapter_image, device, batch_size * num_images_per_prompt 375 | ) 376 | 377 | # 9. Denoising loop 378 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 379 | 380 | # 9.1 Apply denoising_end 381 | if ( 382 | self.denoising_end is not None 383 | and self.denoising_start is not None 384 | and denoising_value_valid(self.denoising_end) 385 | and denoising_value_valid(self.denoising_start) 386 | and self.denoising_start >= self.denoising_end 387 | ): 388 | raise ValueError( 389 | f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " 390 | + f" {self.denoising_end} when using type float." 391 | ) 392 | elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): 393 | discrete_timestep_cutoff = int( 394 | round( 395 | self.scheduler.config.num_train_timesteps 396 | - (self.denoising_end * self.scheduler.config.num_train_timesteps) 397 | ) 398 | ) 399 | num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) 400 | timesteps = timesteps[:num_inference_steps] 401 | 402 | # 9.2 Optionally get Guidance Scale Embedding 403 | timestep_cond = None 404 | if self.unet.config.time_cond_proj_dim is not None: 405 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) 406 | timestep_cond = self.get_guidance_scale_embedding( 407 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim 408 | ).to(device=device, dtype=latents.dtype) 409 | 410 | self._num_timesteps = len(timesteps) 411 | reference_latents = latents 412 | with self.progress_bar(total=num_inference_steps) as progress_bar: 413 | for i, t in enumerate(timesteps): 414 | if self.interrupt: 415 | continue 416 | 417 | reference_latents = inv_latents[num_inference_steps - i] 418 | # expand the latents if we are doing classifier free guidance 419 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents 420 | 421 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 422 | 423 | # get the inversion latents from list 424 | reference_latents = (0.5 * latents + 0.5 * reference_latents) 425 | 426 | reference_model_input = torch.cat([reference_latents] * 2) if self.do_classifier_free_guidance else reference_latents 427 | 428 | reference_model_input = self.scheduler.scale_model_input(reference_model_input, t) 429 | 430 | # predict the noise residual 431 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} 432 | if ip_adapter_image is not None: 433 | added_cond_kwargs["image_embeds"] = image_embeds 434 | noise_pred = self.unet( 435 | latent_model_input, 436 | t, 437 | encoder_hidden_states=prompt_embeds, # null prompt 438 | timestep_cond=timestep_cond, 439 | cross_attention_kwargs=self.cross_attention_kwargs, 440 | added_cond_kwargs=added_cond_kwargs, 441 | return_dict=False, 442 | )[0] 443 | 444 | noise_pred_fwd = self.unet( 445 | latent_model_input, 446 | t, 447 | encoder_hidden_states=prompt_embeds_ref, # c prompt 448 | timestep_cond=timestep_cond, 449 | cross_attention_kwargs=self.cross_attention_kwargs, 450 | added_cond_kwargs=added_cond_kwargs_ref, 451 | return_dict=False, 452 | )[0] 453 | 454 | if i < t_exit: 455 | noise_pred_recon = self.unet( 456 | reference_model_input, 457 | t, 458 | encoder_hidden_states=prompt_embeds_ref, 459 | timestep_cond=timestep_cond, 460 | cross_attention_kwargs=self.cross_attention_kwargs, 461 | added_cond_kwargs=added_cond_kwargs_ref, 462 | return_dict=False, 463 | )[0] 464 | 465 | scaling_factor = degrade_proportionally(i, omega, num_inference_steps-1, gamma) 466 | 467 | noise_pred = noise_pred + scaling_factor * (noise_pred_fwd - noise_pred) 468 | 469 | if i < t_exit: 470 | noise_pred = (noise_pred + scaling_factor * (noise_pred_recon - noise_pred_fwd)) 471 | 472 | # perform guidance 473 | if self.do_classifier_free_guidance: 474 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 475 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) 476 | noise_pred_uncond, noise_pred_text_recon = noise_pred_recon.chunk(2) 477 | noise_pred_recon = noise_pred_uncond + self.guidance_scale * (noise_pred_text_recon - noise_pred_uncond) 478 | 479 | if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: 480 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 481 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) 482 | noise_pred_recon = rescale_noise_cfg(noise_pred_recon, noise_pred_text_recon, guidance_rescale=self.guidance_rescale) 483 | 484 | # compute the previous noisy sample x_t -> x_t-1 485 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 486 | 487 | if callback_on_step_end is not None: 488 | callback_kwargs = {} 489 | for k in callback_on_step_end_tensor_inputs: 490 | callback_kwargs[k] = locals()[k] 491 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 492 | 493 | latents = callback_outputs.pop("latents", latents) 494 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 495 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 496 | add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) 497 | negative_pooled_prompt_embeds = callback_outputs.pop( 498 | "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds 499 | ) 500 | add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) 501 | add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) 502 | 503 | # call the callback, if provided 504 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 505 | progress_bar.update() 506 | if callback is not None and i % callback_steps == 0: 507 | step_idx = i // getattr(self.scheduler, "order", 1) 508 | callback(step_idx, t, latents) 509 | 510 | if XLA_AVAILABLE: 511 | xm.mark_step() 512 | 513 | if not output_type == "latent": 514 | # make sure the VAE is in float32 mode, as it overflows in float16 515 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast 516 | 517 | if needs_upcasting: 518 | self.upcast_vae() 519 | latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) 520 | 521 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] 522 | 523 | # cast back to fp16 if needed 524 | if needs_upcasting: 525 | self.vae.to(dtype=torch.float16) 526 | else: 527 | image = latents 528 | return StableDiffusionXLPipelineOutput(images=image) 529 | 530 | # apply watermark if available 531 | if self.watermark is not None: 532 | image = self.watermark.apply_watermark(image) 533 | 534 | image = self.image_processor.postprocess(image, output_type=output_type) 535 | 536 | # Offload all models 537 | self.maybe_free_model_hooks() 538 | 539 | if not return_dict: 540 | return (image,) 541 | 542 | return StableDiffusionXLPipelineOutput(images=image) -------------------------------------------------------------------------------- /src/pipes/sdxl_inversion_pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 3 | from diffusers import StableDiffusionXLImg2ImgPipeline 4 | from diffusers.utils.torch_utils import randn_tensor 5 | 6 | from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import ( 7 | StableDiffusionXLPipelineOutput, 8 | retrieve_timesteps, 9 | PipelineImageInput 10 | ) 11 | 12 | from src.renoise_inversion import inversion_step 13 | 14 | 15 | class SDXLDDIMPipeline(StableDiffusionXLImg2ImgPipeline): 16 | # @torch.no_grad() 17 | def __call__( 18 | self, 19 | prompt: Union[str, List[str]] = None, 20 | prompt_2: Optional[Union[str, List[str]]] = None, 21 | image: PipelineImageInput = None, 22 | strength: float = 0.3, 23 | num_inversion_steps: int = 50, 24 | timesteps: List[int] = None, 25 | denoising_start: Optional[float] = None, 26 | denoising_end: Optional[float] = None, 27 | guidance_scale: float = 1.0, 28 | negative_prompt: Optional[Union[str, List[str]]] = None, 29 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 30 | num_images_per_prompt: Optional[int] = 1, 31 | eta: float = 0.0, 32 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 33 | latents: Optional[torch.FloatTensor] = None, 34 | prompt_embeds: Optional[torch.FloatTensor] = None, 35 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 36 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 37 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 38 | ip_adapter_image: Optional[PipelineImageInput] = None, 39 | output_type: Optional[str] = "pil", 40 | return_dict: bool = True, 41 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 42 | guidance_rescale: float = 0.0, 43 | original_size: Tuple[int, int] = None, 44 | crops_coords_top_left: Tuple[int, int] = (0, 0), 45 | target_size: Tuple[int, int] = None, 46 | negative_original_size: Optional[Tuple[int, int]] = None, 47 | negative_crops_coords_top_left: Tuple[int, int] = (0, 0), 48 | negative_target_size: Optional[Tuple[int, int]] = None, 49 | aesthetic_score: float = 6.0, 50 | negative_aesthetic_score: float = 2.5, 51 | clip_skip: Optional[int] = None, 52 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 53 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 54 | num_renoise_steps: int = 100, 55 | **kwargs, 56 | ): 57 | callback = kwargs.pop("callback", None) 58 | callback_steps = kwargs.pop("callback_steps", None) 59 | 60 | if callback is not None: 61 | deprecate( 62 | "callback", 63 | "1.0.0", 64 | "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", 65 | ) 66 | if callback_steps is not None: 67 | deprecate( 68 | "callback_steps", 69 | "1.0.0", 70 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", 71 | ) 72 | 73 | # 1. Check inputs. Raise error if not correct 74 | self.check_inputs( 75 | prompt, 76 | prompt_2, 77 | strength, 78 | num_inversion_steps, 79 | callback_steps, 80 | negative_prompt, 81 | negative_prompt_2, 82 | prompt_embeds, 83 | negative_prompt_embeds, 84 | callback_on_step_end_tensor_inputs, 85 | ) 86 | 87 | self._guidance_scale = guidance_scale 88 | self._guidance_rescale = guidance_rescale 89 | self._clip_skip = clip_skip 90 | self._cross_attention_kwargs = cross_attention_kwargs 91 | self._denoising_end = denoising_end 92 | self._denoising_start = denoising_start 93 | 94 | # 2. Define call parameters 95 | if prompt is not None and isinstance(prompt, str): 96 | batch_size = 1 97 | elif prompt is not None and isinstance(prompt, list): 98 | batch_size = len(prompt) 99 | else: 100 | batch_size = prompt_embeds.shape[0] 101 | 102 | device = self._execution_device 103 | 104 | # 3. Encode input prompt 105 | text_encoder_lora_scale = ( 106 | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None 107 | ) 108 | ( 109 | prompt_embeds, 110 | negative_prompt_embeds, 111 | pooled_prompt_embeds, 112 | negative_pooled_prompt_embeds, 113 | ) = self.encode_prompt( 114 | prompt=prompt, 115 | prompt_2=prompt_2, 116 | device=device, 117 | num_images_per_prompt=num_images_per_prompt, 118 | do_classifier_free_guidance=self.do_classifier_free_guidance, 119 | negative_prompt=negative_prompt, 120 | negative_prompt_2=negative_prompt_2, 121 | prompt_embeds=prompt_embeds, 122 | negative_prompt_embeds=negative_prompt_embeds, 123 | pooled_prompt_embeds=pooled_prompt_embeds, 124 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 125 | lora_scale=text_encoder_lora_scale, 126 | clip_skip=self.clip_skip, 127 | ) 128 | 129 | # 4. Preprocess image 130 | image = self.image_processor.preprocess(image) 131 | 132 | # 5. Prepare timesteps 133 | def denoising_value_valid(dnv): 134 | return isinstance(self.denoising_end, float) and 0 < dnv < 1 135 | 136 | timesteps, num_inversion_steps = retrieve_timesteps(self.scheduler, num_inversion_steps, device, timesteps) 137 | 138 | timesteps, num_inversion_steps = self.get_timesteps( 139 | num_inversion_steps, 140 | strength, 141 | device, 142 | denoising_start=self.denoising_start if denoising_value_valid else None, 143 | ) 144 | 145 | # 6. Prepare latent variables 146 | with torch.no_grad(): 147 | latents = self.prepare_latents( 148 | image, 149 | None, 150 | batch_size, 151 | num_images_per_prompt, 152 | prompt_embeds.dtype, 153 | device, 154 | generator, 155 | False, 156 | ) 157 | # 7. Prepare extra step kwargs. 158 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 159 | 160 | height, width = latents.shape[-2:] 161 | height = height * self.vae_scale_factor 162 | width = width * self.vae_scale_factor 163 | 164 | original_size = original_size or (height, width) 165 | target_size = target_size or (height, width) 166 | 167 | # 8. Prepare added time ids & embeddings 168 | if negative_original_size is None: 169 | negative_original_size = original_size 170 | if negative_target_size is None: 171 | negative_target_size = target_size 172 | 173 | add_text_embeds = pooled_prompt_embeds 174 | if self.text_encoder_2 is None: 175 | text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) 176 | else: 177 | text_encoder_projection_dim = self.text_encoder_2.config.projection_dim 178 | 179 | add_time_ids, add_neg_time_ids = self._get_add_time_ids( 180 | original_size, 181 | crops_coords_top_left, 182 | target_size, 183 | aesthetic_score, 184 | negative_aesthetic_score, 185 | negative_original_size, 186 | negative_crops_coords_top_left, 187 | negative_target_size, 188 | dtype=prompt_embeds.dtype, 189 | text_encoder_projection_dim=text_encoder_projection_dim, 190 | ) 191 | add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) 192 | 193 | if self.do_classifier_free_guidance: 194 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) 195 | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) 196 | add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) 197 | add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) 198 | 199 | prompt_embeds = prompt_embeds.to(device) 200 | add_text_embeds = add_text_embeds.to(device) 201 | add_time_ids = add_time_ids.to(device) 202 | 203 | if ip_adapter_image is not None: 204 | image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) 205 | if self.do_classifier_free_guidance: 206 | image_embeds = torch.cat([negative_image_embeds, image_embeds]) 207 | image_embeds = image_embeds.to(device) 208 | 209 | # 9. Denoising loop 210 | num_warmup_steps = max(len(timesteps) - num_inversion_steps * self.scheduler.order, 0) 211 | 212 | self._num_timesteps = len(timesteps) 213 | self.z_0 = torch.clone(latents) 214 | self.noise = randn_tensor(self.z_0.shape, generator=generator, device=self.z_0.device, dtype=self.z_0.dtype) 215 | 216 | all_latents = [latents.clone()] 217 | with self.progress_bar(total=num_inversion_steps) as progress_bar: 218 | for i, t in enumerate(reversed(timesteps)): 219 | 220 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} 221 | if ip_adapter_image is not None: 222 | added_cond_kwargs["image_embeds"] = image_embeds 223 | 224 | latents = inversion_step(self, 225 | latents, 226 | t, 227 | prompt_embeds, 228 | added_cond_kwargs, 229 | num_renoise_steps=num_renoise_steps, 230 | generator=generator) 231 | 232 | all_latents.append(latents.clone()) 233 | 234 | if callback_on_step_end is not None: 235 | callback_kwargs = {} 236 | for k in callback_on_step_end_tensor_inputs: 237 | callback_kwargs[k] = locals()[k] 238 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 239 | 240 | latents = callback_outputs.pop("latents", latents) 241 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 242 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 243 | add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) 244 | negative_pooled_prompt_embeds = callback_outputs.pop( 245 | "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds 246 | ) 247 | add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) 248 | add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) 249 | 250 | # call the callback, if provided 251 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 252 | progress_bar.update() 253 | if callback is not None and i % callback_steps == 0: 254 | step_idx = i // getattr(self.scheduler, "order", 1) 255 | callback(step_idx, t, latents) 256 | 257 | image = latents 258 | 259 | # Offload all models 260 | self.maybe_free_model_hooks() 261 | 262 | return StableDiffusionXLPipelineOutput(images=image), all_latents, [prompt_embeds, added_cond_kwargs] 263 | -------------------------------------------------------------------------------- /src/renoise_inversion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | # Based on code from https://github.com/pix2pixzero/pix2pix-zero 5 | def noise_regularization( 6 | e_t, noise_pred_optimal, lambda_kl, lambda_ac, num_reg_steps, num_ac_rolls, generator=None 7 | ): 8 | for _outer in range(num_reg_steps): 9 | if lambda_kl > 0: 10 | _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True) 11 | l_kld = patchify_latents_kl_divergence(_var, noise_pred_optimal) 12 | l_kld.backward() 13 | _grad = _var.grad.detach() 14 | _grad = torch.clip(_grad, -100, 100) 15 | e_t = e_t - lambda_kl * _grad 16 | if lambda_ac > 0: 17 | for _inner in range(num_ac_rolls): 18 | _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True) 19 | l_ac = auto_corr_loss(_var, generator=generator) 20 | l_ac.backward() 21 | _grad = _var.grad.detach() / num_ac_rolls 22 | e_t = e_t - lambda_ac * _grad 23 | e_t = e_t.detach() 24 | 25 | return e_t 26 | 27 | # Based on code from https://github.com/pix2pixzero/pix2pix-zero 28 | def auto_corr_loss( 29 | x, random_shift=True, generator=None 30 | ): 31 | B, C, H, W = x.shape 32 | assert B == 1 33 | x = x.squeeze(0) 34 | # x must be shape [C,H,W] now 35 | reg_loss = 0.0 36 | for ch_idx in range(x.shape[0]): 37 | noise = x[ch_idx][None, None, :, :] 38 | while True: 39 | if random_shift: 40 | roll_amount = torch.randint(0, noise.shape[2] // 2, (1,), generator=generator).item() 41 | else: 42 | roll_amount = 1 43 | reg_loss += ( 44 | noise * torch.roll(noise, shifts=roll_amount, dims=2) 45 | ).mean() ** 2 46 | reg_loss += ( 47 | noise * torch.roll(noise, shifts=roll_amount, dims=3) 48 | ).mean() ** 2 49 | if noise.shape[2] <= 8: 50 | break 51 | noise = F.avg_pool2d(noise, kernel_size=2) 52 | return reg_loss 53 | 54 | 55 | def patchify_latents_kl_divergence(x0, x1, patch_size=4, num_channels=4): 56 | 57 | def patchify_tensor(input_tensor): 58 | patches = ( 59 | input_tensor.unfold(1, patch_size, patch_size) 60 | .unfold(2, patch_size, patch_size) 61 | .unfold(3, patch_size, patch_size) 62 | ) 63 | patches = patches.contiguous().view(-1, num_channels, patch_size, patch_size) 64 | return patches 65 | 66 | x0 = patchify_tensor(x0) 67 | x1 = patchify_tensor(x1) 68 | 69 | kl = latents_kl_divergence(x0, x1).sum() 70 | return kl 71 | 72 | 73 | def latents_kl_divergence(x0, x1): 74 | EPSILON = 1e-6 75 | x0 = x0.view(x0.shape[0], x0.shape[1], -1) 76 | x1 = x1.view(x1.shape[0], x1.shape[1], -1) 77 | mu0 = x0.mean(dim=-1) 78 | mu1 = x1.mean(dim=-1) 79 | var0 = x0.var(dim=-1) 80 | var1 = x1.var(dim=-1) 81 | kl = ( 82 | torch.log((var1 + EPSILON) / (var0 + EPSILON)) 83 | + (var0 + (mu0 - mu1) ** 2) / (var1 + EPSILON) 84 | - 1 85 | ) 86 | kl = torch.abs(kl).sum(dim=-1) 87 | return kl 88 | 89 | def inversion_step( 90 | pipe, 91 | z_t: torch.tensor, 92 | t: torch.tensor, 93 | prompt_embeds, 94 | added_cond_kwargs, 95 | num_renoise_steps: int = 100, 96 | first_step_max_timestep: int = 250, 97 | generator=None, 98 | ) -> torch.tensor: 99 | extra_step_kwargs = {} 100 | avg_range = pipe.cfg.average_first_step_range if t.item() < first_step_max_timestep else pipe.cfg.average_step_range 101 | num_renoise_steps = min(pipe.cfg.max_num_renoise_steps_first_step, num_renoise_steps) if t.item() < first_step_max_timestep else num_renoise_steps 102 | 103 | nosie_pred_avg = None 104 | noise_pred_optimal = None 105 | z_tp1_forward = pipe.scheduler.add_noise(pipe.z_0, pipe.noise, t.view((1))).detach() 106 | 107 | approximated_z_tp1 = z_t.clone() 108 | for i in range(num_renoise_steps + 1): 109 | 110 | with torch.no_grad(): 111 | # if noise regularization is enabled, we need to double the batch size for the first step 112 | if pipe.cfg.noise_regularization_num_reg_steps > 0 and i == 0: 113 | approximated_z_tp1 = torch.cat([z_tp1_forward, approximated_z_tp1]) 114 | prompt_embeds_in = torch.cat([prompt_embeds, prompt_embeds]) 115 | if added_cond_kwargs is not None: 116 | added_cond_kwargs_in = {} 117 | added_cond_kwargs_in['text_embeds'] = torch.cat([added_cond_kwargs['text_embeds'], added_cond_kwargs['text_embeds']]) 118 | added_cond_kwargs_in['time_ids'] = torch.cat([added_cond_kwargs['time_ids'], added_cond_kwargs['time_ids']]) 119 | else: 120 | added_cond_kwargs_in = None 121 | else: 122 | prompt_embeds_in = prompt_embeds 123 | added_cond_kwargs_in = added_cond_kwargs 124 | 125 | noise_pred = unet_pass(pipe, approximated_z_tp1, t, prompt_embeds_in, added_cond_kwargs_in) 126 | 127 | # if noise regularization is enabled, we need to split the batch size for the first step 128 | if pipe.cfg.noise_regularization_num_reg_steps > 0 and i == 0: 129 | noise_pred_optimal, noise_pred = noise_pred.chunk(2) 130 | if pipe.do_classifier_free_guidance: 131 | noise_pred_optimal_uncond, noise_pred_optimal_text = noise_pred_optimal.chunk(2) 132 | noise_pred_optimal = noise_pred_optimal_uncond + pipe.guidance_scale * (noise_pred_optimal_text - noise_pred_optimal_uncond) 133 | noise_pred_optimal = noise_pred_optimal.detach() 134 | 135 | # perform guidance 136 | if pipe.do_classifier_free_guidance: 137 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 138 | noise_pred = noise_pred_uncond + pipe.guidance_scale * (noise_pred_text - noise_pred_uncond) 139 | 140 | # Calculate average noise 141 | if i >= avg_range[0] and i < avg_range[1]: 142 | j = i - avg_range[0] 143 | if nosie_pred_avg is None: 144 | nosie_pred_avg = noise_pred.clone() 145 | else: 146 | nosie_pred_avg = j * nosie_pred_avg / (j + 1) + noise_pred / (j + 1) 147 | 148 | if i >= avg_range[0] or (not pipe.cfg.average_latent_estimations and i > 0): 149 | noise_pred = noise_regularization(noise_pred, noise_pred_optimal, lambda_kl=pipe.cfg.noise_regularization_lambda_kl, lambda_ac=pipe.cfg.noise_regularization_lambda_ac, num_reg_steps=pipe.cfg.noise_regularization_num_reg_steps, num_ac_rolls=pipe.cfg.noise_regularization_num_ac_rolls, generator=generator) 150 | 151 | approximated_z_tp1 = pipe.scheduler.inv_step(noise_pred, t, z_t, **extra_step_kwargs, return_dict=False)[0].detach() 152 | 153 | # if average latents is enabled, we need to perform an additional step with the average noise 154 | if pipe.cfg.average_latent_estimations and nosie_pred_avg is not None: 155 | nosie_pred_avg = noise_regularization(nosie_pred_avg, noise_pred_optimal, lambda_kl=pipe.cfg.noise_regularization_lambda_kl, lambda_ac=pipe.cfg.noise_regularization_lambda_ac, num_reg_steps=pipe.cfg.noise_regularization_num_reg_steps, num_ac_rolls=pipe.cfg.noise_regularization_num_ac_rolls, generator=generator) 156 | approximated_z_tp1 = pipe.scheduler.inv_step(nosie_pred_avg, t, z_t, **extra_step_kwargs, return_dict=False)[0].detach() 157 | 158 | # perform noise correction 159 | if pipe.cfg.perform_noise_correction: 160 | noise_pred = unet_pass(pipe, approximated_z_tp1, t, prompt_embeds, added_cond_kwargs) 161 | 162 | # perform guidance 163 | if pipe.do_classifier_free_guidance: 164 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 165 | noise_pred = noise_pred_uncond + pipe.guidance_scale * (noise_pred_text - noise_pred_uncond) 166 | 167 | pipe.scheduler.step_and_update_noise(noise_pred, t, approximated_z_tp1, z_t, return_dict=False, optimize_epsilon_type=pipe.cfg.perform_noise_correction) 168 | 169 | return approximated_z_tp1 170 | 171 | @torch.no_grad() 172 | def unet_pass(pipe, z_t, t, prompt_embeds, added_cond_kwargs): 173 | latent_model_input = torch.cat([z_t] * 2) if pipe.do_classifier_free_guidance else z_t 174 | latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t) 175 | return pipe.unet( 176 | latent_model_input, 177 | t, 178 | encoder_hidden_states=prompt_embeds, 179 | timestep_cond=None, 180 | cross_attention_kwargs=pipe.cross_attention_kwargs, 181 | added_cond_kwargs=added_cond_kwargs, 182 | return_dict=False, 183 | )[0] 184 | -------------------------------------------------------------------------------- /src/schedulers/ddim_scheduler.py: -------------------------------------------------------------------------------- 1 | from diffusers import DDIMScheduler 2 | from diffusers.utils import BaseOutput 3 | from diffusers.utils.torch_utils import randn_tensor 4 | import torch 5 | from typing import List, Optional, Tuple, Union 6 | import numpy as np 7 | 8 | class DDIMSchedulerOutput(BaseOutput): 9 | """ 10 | Output class for the scheduler's `step` function output. 11 | 12 | Args: 13 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 14 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 15 | denoising loop. 16 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 17 | The predicted denoised sample `(x_{0})` based on the model output from the current timestep. 18 | `pred_original_sample` can be used to preview progress or for guidance. 19 | """ 20 | 21 | prev_sample: torch.FloatTensor 22 | pred_original_sample: Optional[torch.FloatTensor] = None 23 | 24 | class MyDDIMScheduler(DDIMScheduler): 25 | 26 | def inv_step( 27 | self, 28 | model_output: torch.FloatTensor, 29 | timestep: int, 30 | sample: torch.FloatTensor, 31 | eta: float = 0.0, 32 | use_clipped_model_output: bool = False, 33 | generator=None, 34 | variance_noise: Optional[torch.FloatTensor] = None, 35 | return_dict: bool = True, 36 | ) -> Union[DDIMSchedulerOutput, Tuple]: 37 | """ 38 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 39 | process from the learned model outputs (most often the predicted noise). 40 | 41 | Args: 42 | model_output (`torch.FloatTensor`): 43 | The direct output from learned diffusion model. 44 | timestep (`float`): 45 | The current discrete timestep in the diffusion chain. 46 | sample (`torch.FloatTensor`): 47 | A current instance of a sample created by the diffusion process. 48 | eta (`float`): 49 | The weight of noise for added noise in diffusion step. 50 | use_clipped_model_output (`bool`, defaults to `False`): 51 | If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary 52 | because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no 53 | clipping has happened, "corrected" `model_output` would coincide with the one provided as input and 54 | `use_clipped_model_output` has no effect. 55 | generator (`torch.Generator`, *optional*): 56 | A random number generator. 57 | variance_noise (`torch.FloatTensor`): 58 | Alternative to generating noise with `generator` by directly providing the noise for the variance 59 | itself. Useful for methods such as [`CycleDiffusion`]. 60 | return_dict (`bool`, *optional*, defaults to `True`): 61 | Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. 62 | 63 | Returns: 64 | [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: 65 | If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a 66 | tuple is returned where the first element is the sample tensor. 67 | 68 | """ 69 | if self.num_inference_steps is None: 70 | raise ValueError( 71 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" 72 | ) 73 | 74 | # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf 75 | # Ideally, read DDIM paper in-detail understanding 76 | 77 | # Notation ( -> 78 | # - pred_noise_t -> e_theta(x_t, t) 79 | # - pred_original_sample -> f_theta(x_t, t) or x_0 80 | # - std_dev_t -> sigma_t 81 | # - eta -> η 82 | # - pred_sample_direction -> "direction pointing to x_t" 83 | # - pred_prev_sample -> "x_t-1" 84 | 85 | # 1. get previous step value (=t-1) 86 | prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps 87 | 88 | # 2. compute alphas, betas 89 | alpha_prod_t = self.alphas_cumprod[timestep] 90 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod 91 | 92 | beta_prod_t = 1 - alpha_prod_t 93 | 94 | # 3. compute predicted original sample from predicted noise also called 95 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 96 | assert self.config.prediction_type == "epsilon" 97 | if self.config.prediction_type == "epsilon": 98 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) 99 | pred_epsilon = model_output 100 | elif self.config.prediction_type == "sample": 101 | pred_original_sample = model_output 102 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) 103 | elif self.config.prediction_type == "v_prediction": 104 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output 105 | pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample 106 | else: 107 | raise ValueError( 108 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" 109 | " `v_prediction`" 110 | ) 111 | 112 | # 4. Clip or threshold "predicted x_0" 113 | if self.config.thresholding: 114 | pred_original_sample = self._threshold_sample(pred_original_sample) 115 | elif self.config.clip_sample: 116 | pred_original_sample = pred_original_sample.clamp( 117 | -self.config.clip_sample_range, self.config.clip_sample_range 118 | ) 119 | 120 | # 5. compute variance: "sigma_t(η)" -> see formula (16) 121 | # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) 122 | variance = self._get_variance(timestep, prev_timestep) 123 | std_dev_t = eta * variance ** (0.5) 124 | 125 | if use_clipped_model_output: 126 | # the pred_epsilon is always re-derived from the clipped x_0 in Glide 127 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) 128 | 129 | # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 130 | pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon 131 | 132 | # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 133 | # prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction 134 | 135 | prev_sample = (alpha_prod_t ** (0.5) * sample) / alpha_prod_t_prev ** (0.5) + (alpha_prod_t_prev ** (0.5) * beta_prod_t ** (0.5) * model_output) / alpha_prod_t_prev ** (0.5) - (alpha_prod_t ** (0.5) * pred_sample_direction) / alpha_prod_t_prev ** (0.5) 136 | 137 | if eta > 0: 138 | if variance_noise is not None and generator is not None: 139 | raise ValueError( 140 | "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" 141 | " `variance_noise` stays `None`." 142 | ) 143 | 144 | if variance_noise is None: 145 | variance_noise = randn_tensor( 146 | model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype 147 | ) 148 | variance = std_dev_t * variance_noise 149 | 150 | prev_sample = prev_sample + variance 151 | 152 | if not return_dict: 153 | return (prev_sample,) 154 | 155 | return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) -------------------------------------------------------------------------------- /src/schedulers/euler_scheduler.py: -------------------------------------------------------------------------------- 1 | from diffusers import EulerAncestralDiscreteScheduler 2 | from diffusers.utils import BaseOutput 3 | import torch 4 | from typing import List, Optional, Tuple, Union 5 | import numpy as np 6 | 7 | class EulerAncestralDiscreteSchedulerOutput(BaseOutput): 8 | """ 9 | Output class for the scheduler's `step` function output. 10 | 11 | Args: 12 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 13 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 14 | denoising loop. 15 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 16 | The predicted denoised sample `(x_{0})` based on the model output from the current timestep. 17 | `pred_original_sample` can be used to preview progress or for guidance. 18 | """ 19 | 20 | prev_sample: torch.FloatTensor 21 | pred_original_sample: Optional[torch.FloatTensor] = None 22 | 23 | class MyEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler): 24 | def set_noise_list(self, noise_list): 25 | self.noise_list = noise_list 26 | 27 | def get_noise_to_remove(self): 28 | sigma_from = self.sigmas[self.step_index] 29 | sigma_to = self.sigmas[self.step_index + 1] 30 | sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 31 | 32 | return self.noise_list[self.step_index] * sigma_up\ 33 | 34 | def scale_model_input( 35 | self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] 36 | ) -> torch.FloatTensor: 37 | """ 38 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the 39 | current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. 40 | 41 | Args: 42 | sample (`torch.FloatTensor`): 43 | The input sample. 44 | timestep (`int`, *optional*): 45 | The current timestep in the diffusion chain. 46 | 47 | Returns: 48 | `torch.FloatTensor`: 49 | A scaled input sample. 50 | """ 51 | 52 | self._init_step_index(timestep.view((1))) 53 | return EulerAncestralDiscreteScheduler.scale_model_input(self, sample, timestep) 54 | 55 | 56 | def step( 57 | self, 58 | model_output: torch.FloatTensor, 59 | timestep: Union[float, torch.FloatTensor], 60 | sample: torch.FloatTensor, 61 | generator: Optional[torch.Generator] = None, 62 | return_dict: bool = True, 63 | ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: 64 | """ 65 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 66 | process from the learned model outputs (most often the predicted noise). 67 | 68 | Args: 69 | model_output (`torch.FloatTensor`): 70 | The direct output from learned diffusion model. 71 | timestep (`float`): 72 | The current discrete timestep in the diffusion chain. 73 | sample (`torch.FloatTensor`): 74 | A current instance of a sample created by the diffusion process. 75 | generator (`torch.Generator`, *optional*): 76 | A random number generator. 77 | return_dict (`bool`): 78 | Whether or not to return a 79 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. 80 | 81 | Returns: 82 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: 83 | If return_dict is `True`, 84 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, 85 | otherwise a tuple is returned where the first element is the sample tensor. 86 | 87 | """ 88 | 89 | if ( 90 | isinstance(timestep, int) 91 | or isinstance(timestep, torch.IntTensor) 92 | or isinstance(timestep, torch.LongTensor) 93 | ): 94 | raise ValueError( 95 | ( 96 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 97 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 98 | " one of the `scheduler.timesteps` as a timestep." 99 | ), 100 | ) 101 | 102 | if not self.is_scale_input_called: 103 | logger.warning( 104 | "The `scale_model_input` function should be called before `step` to ensure correct denoising. " 105 | "See `StableDiffusionPipeline` for a usage example." 106 | ) 107 | 108 | self._init_step_index(timestep.view((1))) 109 | 110 | sigma = self.sigmas[self.step_index] 111 | 112 | # Upcast to avoid precision issues when computing prev_sample 113 | sample = sample.to(torch.float32) 114 | 115 | # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 116 | if self.config.prediction_type == "epsilon": 117 | pred_original_sample = sample - sigma * model_output 118 | elif self.config.prediction_type == "v_prediction": 119 | # * c_out + input * c_skip 120 | pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) 121 | elif self.config.prediction_type == "sample": 122 | raise NotImplementedError("prediction_type not implemented yet: sample") 123 | else: 124 | raise ValueError( 125 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" 126 | ) 127 | 128 | sigma_from = self.sigmas[self.step_index] 129 | sigma_to = self.sigmas[self.step_index + 1] 130 | sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 131 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 132 | 133 | # 2. Convert to an ODE derivative 134 | # derivative = (sample - pred_original_sample) / sigma 135 | derivative = model_output 136 | 137 | dt = sigma_down - sigma 138 | 139 | prev_sample = sample + derivative * dt 140 | 141 | device = model_output.device 142 | # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) 143 | # prev_sample = prev_sample + noise * sigma_up 144 | 145 | prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up 146 | 147 | # Cast sample back to model compatible dtype 148 | prev_sample = prev_sample.to(model_output.dtype) 149 | 150 | # upon completion increase step index by one 151 | self._step_index += 1 152 | 153 | if not return_dict: 154 | return (prev_sample,) 155 | 156 | return EulerAncestralDiscreteSchedulerOutput( 157 | prev_sample=prev_sample, pred_original_sample=pred_original_sample 158 | ) 159 | 160 | def step_and_update_noise( 161 | self, 162 | model_output: torch.FloatTensor, 163 | timestep: Union[float, torch.FloatTensor], 164 | sample: torch.FloatTensor, 165 | expected_prev_sample: torch.FloatTensor, 166 | optimize_epsilon_type: bool = False, 167 | generator: Optional[torch.Generator] = None, 168 | return_dict: bool = True, 169 | ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: 170 | """ 171 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 172 | process from the learned model outputs (most often the predicted noise). 173 | 174 | Args: 175 | model_output (`torch.FloatTensor`): 176 | The direct output from learned diffusion model. 177 | timestep (`float`): 178 | The current discrete timestep in the diffusion chain. 179 | sample (`torch.FloatTensor`): 180 | A current instance of a sample created by the diffusion process. 181 | generator (`torch.Generator`, *optional*): 182 | A random number generator. 183 | return_dict (`bool`): 184 | Whether or not to return a 185 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. 186 | 187 | Returns: 188 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: 189 | If return_dict is `True`, 190 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, 191 | otherwise a tuple is returned where the first element is the sample tensor. 192 | 193 | """ 194 | 195 | if ( 196 | isinstance(timestep, int) 197 | or isinstance(timestep, torch.IntTensor) 198 | or isinstance(timestep, torch.LongTensor) 199 | ): 200 | raise ValueError( 201 | ( 202 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 203 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 204 | " one of the `scheduler.timesteps` as a timestep." 205 | ), 206 | ) 207 | 208 | if not self.is_scale_input_called: 209 | logger.warning( 210 | "The `scale_model_input` function should be called before `step` to ensure correct denoising. " 211 | "See `StableDiffusionPipeline` for a usage example." 212 | ) 213 | 214 | self._init_step_index(timestep.view((1))) 215 | 216 | sigma = self.sigmas[self.step_index] 217 | 218 | # Upcast to avoid precision issues when computing prev_sample 219 | sample = sample.to(torch.float32) 220 | 221 | # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 222 | if self.config.prediction_type == "epsilon": 223 | pred_original_sample = sample - sigma * model_output 224 | elif self.config.prediction_type == "v_prediction": 225 | # * c_out + input * c_skip 226 | pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) 227 | elif self.config.prediction_type == "sample": 228 | raise NotImplementedError("prediction_type not implemented yet: sample") 229 | else: 230 | raise ValueError( 231 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" 232 | ) 233 | 234 | sigma_from = self.sigmas[self.step_index] 235 | sigma_to = self.sigmas[self.step_index + 1] 236 | sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 237 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 238 | 239 | # 2. Convert to an ODE derivative 240 | # derivative = (sample - pred_original_sample) / sigma 241 | derivative = model_output 242 | 243 | dt = sigma_down - sigma 244 | 245 | prev_sample = sample + derivative * dt 246 | 247 | device = model_output.device 248 | # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) 249 | # prev_sample = prev_sample + noise * sigma_up 250 | 251 | if sigma_up > 0: 252 | req_noise = (expected_prev_sample - prev_sample) / sigma_up 253 | if not optimize_epsilon_type: 254 | self.noise_list[self.step_index] = req_noise 255 | else: 256 | for i in range(10): 257 | n = torch.autograd.Variable(self.noise_list[self.step_index].detach().clone(), requires_grad=True) 258 | loss = torch.norm(n - req_noise.detach()) 259 | loss.backward() 260 | self.noise_list[self.step_index] -= n.grad.detach() * 1.8 261 | 262 | 263 | prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up 264 | 265 | # Cast sample back to model compatible dtype 266 | prev_sample = prev_sample.to(model_output.dtype) 267 | 268 | # upon completion increase step index by one 269 | self._step_index += 1 270 | 271 | if not return_dict: 272 | return (prev_sample,) 273 | 274 | return EulerAncestralDiscreteSchedulerOutput( 275 | prev_sample=prev_sample, pred_original_sample=pred_original_sample 276 | ) 277 | 278 | def inv_step( 279 | self, 280 | model_output: torch.FloatTensor, 281 | timestep: Union[float, torch.FloatTensor], 282 | sample: torch.FloatTensor, 283 | generator: Optional[torch.Generator] = None, 284 | return_dict: bool = True, 285 | ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: 286 | """ 287 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 288 | process from the learned model outputs (most often the predicted noise). 289 | 290 | Args: 291 | model_output (`torch.FloatTensor`): 292 | The direct output from learned diffusion model. 293 | timestep (`float`): 294 | The current discrete timestep in the diffusion chain. 295 | sample (`torch.FloatTensor`): 296 | A current instance of a sample created by the diffusion process. 297 | generator (`torch.Generator`, *optional*): 298 | A random number generator. 299 | return_dict (`bool`): 300 | Whether or not to return a 301 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. 302 | 303 | Returns: 304 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: 305 | If return_dict is `True`, 306 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, 307 | otherwise a tuple is returned where the first element is the sample tensor. 308 | 309 | """ 310 | 311 | if ( 312 | isinstance(timestep, int) 313 | or isinstance(timestep, torch.IntTensor) 314 | or isinstance(timestep, torch.LongTensor) 315 | ): 316 | raise ValueError( 317 | ( 318 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 319 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 320 | " one of the `scheduler.timesteps` as a timestep." 321 | ), 322 | ) 323 | 324 | if not self.is_scale_input_called: 325 | logger.warning( 326 | "The `scale_model_input` function should be called before `step` to ensure correct denoising. " 327 | "See `StableDiffusionPipeline` for a usage example." 328 | ) 329 | 330 | self._init_step_index(timestep.view((1))) 331 | 332 | sigma = self.sigmas[self.step_index] 333 | 334 | # Upcast to avoid precision issues when computing prev_sample 335 | sample = sample.to(torch.float32) 336 | 337 | # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 338 | if self.config.prediction_type == "epsilon": 339 | pred_original_sample = sample - sigma * model_output 340 | elif self.config.prediction_type == "v_prediction": 341 | # * c_out + input * c_skip 342 | pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) 343 | elif self.config.prediction_type == "sample": 344 | raise NotImplementedError("prediction_type not implemented yet: sample") 345 | else: 346 | raise ValueError( 347 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" 348 | ) 349 | 350 | sigma_from = self.sigmas[self.step_index] 351 | sigma_to = self.sigmas[self.step_index+1] 352 | # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 353 | sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2).abs() / sigma_from**2) ** 0.5 354 | # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 355 | sigma_down = sigma_to**2 / sigma_from 356 | 357 | # 2. Convert to an ODE derivative 358 | # derivative = (sample - pred_original_sample) / sigma 359 | derivative = model_output 360 | 361 | dt = sigma_down - sigma 362 | # dt = sigma_down - sigma_from 363 | 364 | prev_sample = sample - derivative * dt 365 | 366 | device = model_output.device 367 | # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) 368 | # prev_sample = prev_sample + noise * sigma_up 369 | 370 | prev_sample = prev_sample - self.noise_list[self.step_index] * sigma_up 371 | 372 | # Cast sample back to model compatible dtype 373 | prev_sample = prev_sample.to(model_output.dtype) 374 | 375 | # upon completion increase step index by one 376 | self._step_index += 1 377 | 378 | if not return_dict: 379 | return (prev_sample,) 380 | 381 | return EulerAncestralDiscreteSchedulerOutput( 382 | prev_sample=prev_sample, pred_original_sample=pred_original_sample 383 | ) 384 | 385 | def get_all_sigmas(self) -> torch.FloatTensor: 386 | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) 387 | sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) 388 | return torch.from_numpy(sigmas) 389 | 390 | def add_noise_off_schedule( 391 | self, 392 | original_samples: torch.FloatTensor, 393 | noise: torch.FloatTensor, 394 | timesteps: torch.FloatTensor, 395 | ) -> torch.FloatTensor: 396 | # Make sure sigmas and timesteps have the same device and dtype as original_samples 397 | sigmas = self.get_all_sigmas() 398 | sigmas = sigmas.to(device=original_samples.device, dtype=original_samples.dtype) 399 | if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): 400 | # mps does not support float64 401 | timesteps = timesteps.to(original_samples.device, dtype=torch.float32) 402 | else: 403 | timesteps = timesteps.to(original_samples.device) 404 | 405 | step_indices = 1000 - int(timesteps.item()) 406 | 407 | sigma = sigmas[step_indices].flatten() 408 | while len(sigma.shape) < len(original_samples.shape): 409 | sigma = sigma.unsqueeze(-1) 410 | 411 | noisy_samples = original_samples + noise * sigma 412 | return noisy_samples 413 | 414 | # def update_noise_for_friendly_inversion( 415 | # self, 416 | # model_output: torch.FloatTensor, 417 | # timestep: Union[float, torch.FloatTensor], 418 | # z_t: torch.FloatTensor, 419 | # z_tp1: torch.FloatTensor, 420 | # return_dict: bool = True, 421 | # ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: 422 | # if ( 423 | # isinstance(timestep, int) 424 | # or isinstance(timestep, torch.IntTensor) 425 | # or isinstance(timestep, torch.LongTensor) 426 | # ): 427 | # raise ValueError( 428 | # ( 429 | # "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 430 | # " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 431 | # " one of the `scheduler.timesteps` as a timestep." 432 | # ), 433 | # ) 434 | 435 | # if not self.is_scale_input_called: 436 | # logger.warning( 437 | # "The `scale_model_input` function should be called before `step` to ensure correct denoising. " 438 | # "See `StableDiffusionPipeline` for a usage example." 439 | # ) 440 | 441 | # self._init_step_index(timestep.view((1))) 442 | 443 | # sigma = self.sigmas[self.step_index] 444 | 445 | # sigma_from = self.sigmas[self.step_index] 446 | # sigma_to = self.sigmas[self.step_index+1] 447 | # # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 448 | # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2).abs() / sigma_from**2) ** 0.5 449 | # # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 450 | # sigma_down = sigma_to**2 / sigma_from 451 | 452 | # # 2. Conv = (sample - pred_original_sample) / sigma 453 | # derivative = model_output 454 | 455 | # dt = sigma_down - sigma 456 | # # dt = sigma_down - sigma_from 457 | 458 | # prev_sample = z_t - derivative * dt 459 | 460 | # if sigma_up > 0: 461 | # self.noise_list[self.step_index] = (prev_sample - z_tp1) / sigma_up 462 | 463 | # prev_sample = prev_sample - self.noise_list[self.step_index] * sigma_up 464 | 465 | 466 | # if not return_dict: 467 | # return (prev_sample,) 468 | 469 | # return EulerAncestralDiscreteSchedulerOutput( 470 | # prev_sample=prev_sample, pred_original_sample=None 471 | # ) 472 | 473 | 474 | # def step_friendly_inversion( 475 | # self, 476 | # model_output: torch.FloatTensor, 477 | # timestep: Union[float, torch.FloatTensor], 478 | # sample: torch.FloatTensor, 479 | # generator: Optional[torch.Generator] = None, 480 | # return_dict: bool = True, 481 | # expected_next_sample: torch.FloatTensor = None, 482 | # ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: 483 | # """ 484 | # Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 485 | # process from the learned model outputs (most often the predicted noise). 486 | 487 | # Args: 488 | # model_output (`torch.FloatTensor`): 489 | # The direct output from learned diffusion model. 490 | # timestep (`float`): 491 | # The current discrete timestep in the diffusion chain. 492 | # sample (`torch.FloatTensor`): 493 | # A current instance of a sample created by the diffusion process. 494 | # generator (`torch.Generator`, *optional*): 495 | # A random number generator. 496 | # return_dict (`bool`): 497 | # Whether or not to return a 498 | # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. 499 | 500 | # Returns: 501 | # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: 502 | # If return_dict is `True`, 503 | # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, 504 | # otherwise a tuple is returned where the first element is the sample tensor. 505 | 506 | # """ 507 | 508 | # if ( 509 | # isinstance(timestep, int) 510 | # or isinstance(timestep, torch.IntTensor) 511 | # or isinstance(timestep, torch.LongTensor) 512 | # ): 513 | # raise ValueError( 514 | # ( 515 | # "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 516 | # " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 517 | # " one of the `scheduler.timesteps` as a timestep." 518 | # ), 519 | # ) 520 | 521 | # if not self.is_scale_input_called: 522 | # logger.warning( 523 | # "The `scale_model_input` function should be called before `step` to ensure correct denoising. " 524 | # "See `StableDiffusionPipeline` for a usage example." 525 | # ) 526 | 527 | # self._init_step_index(timestep.view((1))) 528 | 529 | # sigma = self.sigmas[self.step_index] 530 | 531 | # # Upcast to avoid precision issues when computing prev_sample 532 | # sample = sample.to(torch.float32) 533 | 534 | # # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 535 | # if self.config.prediction_type == "epsilon": 536 | # pred_original_sample = sample - sigma * model_output 537 | # elif self.config.prediction_type == "v_prediction": 538 | # # * c_out + input * c_skip 539 | # pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) 540 | # elif self.config.prediction_type == "sample": 541 | # raise NotImplementedError("prediction_type not implemented yet: sample") 542 | # else: 543 | # raise ValueError( 544 | # f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" 545 | # ) 546 | 547 | # sigma_from = self.sigmas[self.step_index] 548 | # sigma_to = self.sigmas[self.step_index + 1] 549 | # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 550 | # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 551 | 552 | # # 2. Convert to an ODE derivative 553 | # # derivative = (sample - pred_original_sample) / sigma 554 | # derivative = model_output 555 | 556 | # dt = sigma_down - sigma 557 | 558 | # prev_sample = sample + derivative * dt 559 | 560 | # device = model_output.device 561 | # # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) 562 | # # prev_sample = prev_sample + noise * sigma_up 563 | 564 | # if sigma_up > 0: 565 | # self.noise_list[self.step_index] = (expected_next_sample - prev_sample) / sigma_up 566 | 567 | # prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up 568 | 569 | # # Cast sample back to model compatible dtype 570 | # prev_sample = prev_sample.to(model_output.dtype) 571 | 572 | # # upon completion increase step index by one 573 | # self._step_index += 1 574 | 575 | # if not return_dict: 576 | # return (prev_sample,) 577 | 578 | # return EulerAncestralDiscreteSchedulerOutput( 579 | # prev_sample=prev_sample, pred_original_sample=pred_original_sample 580 | # ) -------------------------------------------------------------------------------- /src/schedulers/lcm_scheduler.py: -------------------------------------------------------------------------------- 1 | from diffusers import LCMScheduler 2 | from diffusers.utils import BaseOutput 3 | from diffusers.utils.torch_utils import randn_tensor 4 | import torch 5 | from typing import List, Optional, Tuple, Union 6 | import numpy as np 7 | 8 | class LCMSchedulerOutput(BaseOutput): 9 | """ 10 | Output class for the scheduler's `step` function output. 11 | 12 | Args: 13 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 14 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 15 | denoising loop. 16 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 17 | The predicted denoised sample `(x_{0})` based on the model output from the current timestep. 18 | `pred_original_sample` can be used to preview progress or for guidance. 19 | """ 20 | 21 | prev_sample: torch.FloatTensor 22 | denoised: Optional[torch.FloatTensor] = None 23 | 24 | class MyLCMScheduler(LCMScheduler): 25 | 26 | def set_noise_list(self, noise_list): 27 | self.noise_list = noise_list 28 | 29 | def step( 30 | self, 31 | model_output: torch.FloatTensor, 32 | timestep: int, 33 | sample: torch.FloatTensor, 34 | generator: Optional[torch.Generator] = None, 35 | return_dict: bool = True, 36 | ) -> Union[LCMSchedulerOutput, Tuple]: 37 | """ 38 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 39 | process from the learned model outputs (most often the predicted noise). 40 | 41 | Args: 42 | model_output (`torch.FloatTensor`): 43 | The direct output from learned diffusion model. 44 | timestep (`float`): 45 | The current discrete timestep in the diffusion chain. 46 | sample (`torch.FloatTensor`): 47 | A current instance of a sample created by the diffusion process. 48 | generator (`torch.Generator`, *optional*): 49 | A random number generator. 50 | return_dict (`bool`, *optional*, defaults to `True`): 51 | Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`. 52 | Returns: 53 | [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`: 54 | If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a 55 | tuple is returned where the first element is the sample tensor. 56 | """ 57 | if self.num_inference_steps is None: 58 | raise ValueError( 59 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" 60 | ) 61 | 62 | self._init_step_index(timestep) 63 | 64 | # 1. get previous step value 65 | prev_step_index = self.step_index + 1 66 | if prev_step_index < len(self.timesteps): 67 | prev_timestep = self.timesteps[prev_step_index] 68 | else: 69 | prev_timestep = timestep 70 | 71 | # 2. compute alphas, betas 72 | alpha_prod_t = self.alphas_cumprod[timestep] 73 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod 74 | 75 | beta_prod_t = 1 - alpha_prod_t 76 | beta_prod_t_prev = 1 - alpha_prod_t_prev 77 | 78 | # 3. Get scalings for boundary conditions 79 | c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep) 80 | 81 | # 4. Compute the predicted original sample x_0 based on the model parameterization 82 | if self.config.prediction_type == "epsilon": # noise-prediction 83 | predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt() 84 | elif self.config.prediction_type == "sample": # x-prediction 85 | predicted_original_sample = model_output 86 | elif self.config.prediction_type == "v_prediction": # v-prediction 87 | predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output 88 | else: 89 | raise ValueError( 90 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" 91 | " `v_prediction` for `LCMScheduler`." 92 | ) 93 | 94 | # 5. Clip or threshold "predicted x_0" 95 | if self.config.thresholding: 96 | predicted_original_sample = self._threshold_sample(predicted_original_sample) 97 | elif self.config.clip_sample: 98 | predicted_original_sample = predicted_original_sample.clamp( 99 | -self.config.clip_sample_range, self.config.clip_sample_range 100 | ) 101 | 102 | # 6. Denoise model output using boundary conditions 103 | denoised = c_out * predicted_original_sample + c_skip * sample 104 | 105 | # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference 106 | # Noise is not used on the final timestep of the timestep schedule. 107 | # This also means that noise is not used for one-step sampling. 108 | if self.step_index != self.num_inference_steps - 1: 109 | noise = self.noise_list[self.step_index] 110 | prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise 111 | else: 112 | prev_sample = denoised 113 | 114 | # upon completion increase step index by one 115 | self._step_index += 1 116 | 117 | if not return_dict: 118 | return (prev_sample, denoised) 119 | 120 | return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised) 121 | 122 | 123 | def inv_step( 124 | self, 125 | model_output: torch.FloatTensor, 126 | timestep: int, 127 | sample: torch.FloatTensor, 128 | generator: Optional[torch.Generator] = None, 129 | return_dict: bool = True, 130 | ) -> Union[LCMSchedulerOutput, Tuple]: 131 | """ 132 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 133 | process from the learned model outputs (most often the predicted noise). 134 | 135 | Args: 136 | model_output (`torch.FloatTensor`): 137 | The direct output from learned diffusion model. 138 | timestep (`float`): 139 | The current discrete timestep in the diffusion chain. 140 | sample (`torch.FloatTensor`): 141 | A current instance of a sample created by the diffusion process. 142 | generator (`torch.Generator`, *optional*): 143 | A random number generator. 144 | return_dict (`bool`, *optional*, defaults to `True`): 145 | Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`. 146 | Returns: 147 | [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`: 148 | If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a 149 | tuple is returned where the first element is the sample tensor. 150 | """ 151 | if self.num_inference_steps is None: 152 | raise ValueError( 153 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" 154 | ) 155 | 156 | self._init_step_index(timestep) 157 | 158 | # 1. get previous step value 159 | prev_step_index = self.step_index + 1 160 | if prev_step_index < len(self.timesteps): 161 | prev_timestep = self.timesteps[prev_step_index] 162 | else: 163 | prev_timestep = timestep 164 | 165 | # 2. compute alphas, betas 166 | alpha_prod_t = self.alphas_cumprod[timestep] 167 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod 168 | 169 | beta_prod_t = 1 - alpha_prod_t 170 | beta_prod_t_prev = 1 - alpha_prod_t_prev 171 | 172 | # 3. Get scalings for boundary conditions 173 | c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep) 174 | 175 | if self.step_index != self.num_inference_steps - 1: 176 | c_skip_actual = c_skip * alpha_prod_t_prev.sqrt() 177 | c_out_actual = c_out * alpha_prod_t_prev.sqrt() 178 | noise = self.noise_list[self.step_index] * beta_prod_t_prev.sqrt() 179 | else: 180 | c_skip_actual = c_skip 181 | c_out_actual = c_out 182 | noise = 0 183 | 184 | 185 | dem = c_out_actual / (alpha_prod_t.sqrt()) + c_skip 186 | eps_mul = beta_prod_t.sqrt() * c_out_actual / (alpha_prod_t.sqrt()) 187 | 188 | prev_sample = (sample + eps_mul * model_output - noise) / dem 189 | 190 | # upon completion increase step index by one 191 | self._step_index += 1 192 | 193 | if not return_dict: 194 | return (prev_sample, prev_sample) 195 | 196 | return LCMSchedulerOutput(prev_sample=prev_sample, denoised=prev_sample) -------------------------------------------------------------------------------- /src/utils/enums_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from diffusers import StableDiffusionImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline 4 | 5 | from src.eunms import Model_Type, Scheduler_Type 6 | from src.schedulers.euler_scheduler import MyEulerAncestralDiscreteScheduler 7 | from src.schedulers.lcm_scheduler import MyLCMScheduler 8 | from src.schedulers.ddim_scheduler import MyDDIMScheduler 9 | from src.pipes.sdxl_inversion_pipeline import SDXLDDIMPipeline 10 | from src.pipes.sdxl_forward_pipeline import StableDiffusionXLDecompositionPipeline 11 | from src.pipes.sd_inversion_pipeline import SDDDIMPipeline 12 | 13 | def scheduler_type_to_class(scheduler_type): 14 | if scheduler_type == Scheduler_Type.DDIM: 15 | return MyDDIMScheduler 16 | elif scheduler_type == Scheduler_Type.EULER: 17 | return MyEulerAncestralDiscreteScheduler 18 | elif scheduler_type == Scheduler_Type.LCM: 19 | return MyLCMScheduler 20 | else: 21 | raise ValueError("Unknown scheduler type") 22 | 23 | def is_stochastic(scheduler_type): 24 | if scheduler_type == Scheduler_Type.DDIM: 25 | return False 26 | elif scheduler_type == Scheduler_Type.EULER: 27 | return True 28 | elif scheduler_type == Scheduler_Type.LCM: 29 | return True 30 | else: 31 | raise ValueError("Unknown scheduler type") 32 | 33 | def model_type_to_class(model_type): 34 | if model_type == Model_Type.SDXL: 35 | return StableDiffusionXLDecompositionPipeline, SDXLDDIMPipeline 36 | elif model_type == Model_Type.SDXL_Turbo: 37 | return StableDiffusionXLImg2ImgPipeline, SDXLDDIMPipeline 38 | elif model_type == Model_Type.LCM_SDXL: 39 | return StableDiffusionXLImg2ImgPipeline, SDXLDDIMPipeline 40 | elif model_type == Model_Type.SD15: 41 | return StableDiffusionImg2ImgPipeline, SDDDIMPipeline 42 | elif model_type == Model_Type.SD14: 43 | return StableDiffusionImg2ImgPipeline, SDDDIMPipeline 44 | elif model_type == Model_Type.SD21: 45 | return StableDiffusionImg2ImgPipeline, SDDDIMPipeline 46 | elif model_type == Model_Type.SD21_Turbo: 47 | return StableDiffusionImg2ImgPipeline, SDDDIMPipeline 48 | else: 49 | raise ValueError("Unknown model type") 50 | 51 | def model_type_to_model_name(model_type): 52 | if model_type == Model_Type.SDXL: 53 | return "stabilityai/stable-diffusion-xl-base-1.0" 54 | elif model_type == Model_Type.SDXL_Turbo: 55 | return "stabilityai/sdxl-turbo" 56 | elif model_type == Model_Type.LCM_SDXL: 57 | return "stabilityai/stable-diffusion-xl-base-1.0" 58 | elif model_type == Model_Type.SD15: 59 | return "runwayml/stable-diffusion-v1-5" 60 | elif model_type == Model_Type.SD14: 61 | return "CompVis/stable-diffusion-v1-4" 62 | elif model_type == Model_Type.SD21: 63 | return "stabilityai/stable-diffusion-2-1" 64 | elif model_type == Model_Type.SD21_Turbo: 65 | return "stabilityai/sd-turbo" 66 | else: 67 | raise ValueError("Unknown model type") 68 | 69 | 70 | def model_type_to_size(model_type): 71 | if model_type == Model_Type.SDXL: 72 | return (1024, 1024) 73 | elif model_type == Model_Type.SDXL_Turbo: 74 | return (512, 512) 75 | elif model_type == Model_Type.LCM_SDXL: 76 | return (768, 768) #TODO: check 77 | elif model_type == Model_Type.SD15: 78 | return (512, 512) 79 | elif model_type == Model_Type.SD14: 80 | return (512, 512) 81 | elif model_type == Model_Type.SD21: 82 | return (512, 512) 83 | elif model_type == Model_Type.SD21_Turbo: 84 | return (512, 512) 85 | else: 86 | raise ValueError("Unknown model type") 87 | 88 | def is_float16(model_type): 89 | if model_type == Model_Type.SDXL: 90 | return True 91 | elif model_type == Model_Type.SDXL_Turbo: 92 | return True 93 | elif model_type == Model_Type.LCM_SDXL: 94 | return True 95 | elif model_type == Model_Type.SD15: 96 | return False 97 | elif model_type == Model_Type.SD14: 98 | return False 99 | elif model_type == Model_Type.SD21: 100 | return False 101 | elif model_type == Model_Type.SD21_Turbo: 102 | return False 103 | else: 104 | raise ValueError("Unknown model type") 105 | 106 | def is_sd(model_type): 107 | if model_type == Model_Type.SDXL: 108 | return False 109 | elif model_type == Model_Type.SDXL_Turbo: 110 | return False 111 | elif model_type == Model_Type.LCM_SDXL: 112 | return False 113 | elif model_type == Model_Type.SD15: 114 | return True 115 | elif model_type == Model_Type.SD14: 116 | return True 117 | elif model_type == Model_Type.SD21: 118 | return True 119 | elif model_type == Model_Type.SD21_Turbo: 120 | return True 121 | else: 122 | raise ValueError("Unknown model type") 123 | 124 | def _get_pipes(model_type, device): 125 | model_name = model_type_to_model_name(model_type) 126 | pipeline_inf, pipeline_inv = model_type_to_class(model_type) 127 | 128 | if is_float16(model_type): 129 | pipe_inference = pipeline_inf.from_pretrained( 130 | model_name, 131 | torch_dtype=torch.float16, 132 | use_safetensors=True, 133 | variant="fp16", 134 | safety_checker = None 135 | ).to(device) 136 | else: 137 | pipe_inference = pipeline_inf.from_pretrained( 138 | model_name, 139 | use_safetensors=True, 140 | safety_checker = None 141 | ).to(device) 142 | 143 | pipe_inversion = pipeline_inv(**pipe_inference.components) 144 | 145 | return pipe_inversion, pipe_inference 146 | 147 | def get_pipes(model_type, scheduler_type, device="cuda"): 148 | scheduler_class = scheduler_type_to_class(scheduler_type) 149 | 150 | pipe_inversion, pipe_inference = _get_pipes(model_type, device) 151 | 152 | pipe_inference.scheduler = scheduler_class.from_config(pipe_inference.scheduler.config) 153 | pipe_inversion.scheduler = scheduler_class.from_config(pipe_inversion.scheduler.config) 154 | 155 | if is_sd(model_type): 156 | pipe_inference.scheduler.add_noise = lambda init_latents, noise, timestep: init_latents 157 | pipe_inversion.scheduler.add_noise = lambda init_latents, noise, timestep: init_latents 158 | 159 | if model_type == Model_Type.LCM_SDXL: 160 | adapter_id = "latent-consistency/lcm-lora-sdxl" 161 | pipe_inversion.load_lora_weights(adapter_id) 162 | pipe_inference.load_lora_weights(adapter_id) 163 | 164 | return pipe_inversion, pipe_inference -------------------------------------------------------------------------------- /src/utils/images_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import torch 4 | 5 | def read_images_in_path(path, size = (512,512)): 6 | image_paths = [] 7 | for filename in os.listdir(path): 8 | if filename.endswith(".png") or filename.endswith(".jpg") or filename.endswith(".jpeg"): 9 | image_path = os.path.join(path, filename) 10 | image_paths.append(image_path) 11 | image_paths = sorted(image_paths) 12 | return [Image.open(image_path).convert("RGB").resize(size) for image_path in image_paths] 13 | 14 | def concatenate_images(image_lists, return_list = False): 15 | num_rows = len(image_lists[0]) 16 | num_columns = len(image_lists) 17 | image_width = image_lists[0][0].width 18 | image_height = image_lists[0][0].height 19 | 20 | grid_width = num_columns * image_width 21 | grid_height = num_rows * image_height if not return_list else image_height 22 | if not return_list: 23 | grid_image = [Image.new('RGB', (grid_width, grid_height))] 24 | else: 25 | grid_image = [Image.new('RGB', (grid_width, grid_height)) for i in range(num_rows)] 26 | 27 | for i in range(num_rows): 28 | row_index = i if return_list else 0 29 | for j in range(num_columns): 30 | image = image_lists[j][i] 31 | x_offset = j * image_width 32 | y_offset = i * image_height if not return_list else 0 33 | grid_image[row_index].paste(image, (x_offset, y_offset)) 34 | 35 | return grid_image if return_list else grid_image[0] 36 | 37 | def concatenate_images_single(image_lists): 38 | num_columns = len(image_lists) 39 | image_width = image_lists[0].width 40 | image_height = image_lists[0].height 41 | 42 | grid_width = num_columns * image_width 43 | grid_height = image_height 44 | grid_image = Image.new('RGB', (grid_width, grid_height)) 45 | 46 | for j in range(num_columns): 47 | image = image_lists[j] 48 | x_offset = j * image_width 49 | y_offset = 0 50 | grid_image.paste(image, (x_offset, y_offset)) 51 | 52 | return grid_image 53 | 54 | def get_captions_for_images(images, device): 55 | from transformers import Blip2Processor, Blip2ForConditionalGeneration 56 | 57 | processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") 58 | model = Blip2ForConditionalGeneration.from_pretrained( 59 | "Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16 60 | ) # doctest: +IGNORE_RESULT 61 | 62 | res = [] 63 | 64 | for image in images: 65 | inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) 66 | 67 | generated_ids = model.generate(**inputs) 68 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() 69 | res.append(generated_text) 70 | 71 | del processor 72 | del model 73 | 74 | return res --------------------------------------------------------------------------------